Add support for stateful Tor sessions

This commit is contained in:
codeking 2026-01-18 05:14:54 +01:00
parent 616a864c9d
commit 531f5cdf3c
10 changed files with 213 additions and 27 deletions

View file

@ -11,6 +11,7 @@ class Constants:
CONNECTION_RETRY_INTERVAL: Final[int] = int(os.environ.get('CONNECTION_RETRY_INTERVAL', '5')) CONNECTION_RETRY_INTERVAL: Final[int] = int(os.environ.get('CONNECTION_RETRY_INTERVAL', '5'))
MAX_CONNECTION_ATTEMPTS: Final[int] = int(os.environ.get('MAX_CONNECTION_ATTEMPTS', '2')) MAX_CONNECTION_ATTEMPTS: Final[int] = int(os.environ.get('MAX_CONNECTION_ATTEMPTS', '2'))
TOR_BOOTSTRAP_TIMEOUT: Final[int] = int(os.environ.get('TOR_BOOTSTRAP_TIMEOUT', '90'))
HV_CLIENT_PATH: Final[str] = os.environ.get('HV_CLIENT_PATH') HV_CLIENT_PATH: Final[str] = os.environ.get('HV_CLIENT_PATH')
HV_CLIENT_VERSION_NUMBER: Final[str] = os.environ.get('HV_CLIENT_VERSION_NUMBER') HV_CLIENT_VERSION_NUMBER: Final[str] = os.environ.get('HV_CLIENT_VERSION_NUMBER')
@ -46,3 +47,7 @@ class Constants:
HV_PRIVILEGE_POLICY_PATH: Final[str] = f'{SYSTEM_CONFIG_PATH}/sudoers.d/hydra-veil' HV_PRIVILEGE_POLICY_PATH: Final[str] = f'{SYSTEM_CONFIG_PATH}/sudoers.d/hydra-veil'
HV_SESSION_STATE_HOME: Final[str] = f'{HV_STATE_HOME}/sessions' HV_SESSION_STATE_HOME: Final[str] = f'{HV_STATE_HOME}/sessions'
HV_TOR_SESSION_STATE_HOME: Final[str] = f'{HV_SESSION_STATE_HOME}/tor'
HV_TOR_CONTROL_SOCKET_PATH: Final[str] = f'{HV_TOR_SESSION_STATE_HOME}/tor.sock'
HV_TOR_PROCESS_IDENTIFIER_PATH: Final[str] = f'{HV_TOR_SESSION_STATE_HOME}/tor.pid'

View file

@ -26,6 +26,10 @@ class ConnectionTerminationError(Exception):
pass pass
class TorServiceInitializationError(Exception):
pass
class PolicyAssignmentError(Exception): class PolicyAssignmentError(Exception):
pass pass

View file

@ -183,7 +183,7 @@ class ApplicationController:
process = subprocess.Popen(initialization_file_path, env=environment) process = subprocess.Popen(initialization_file_path, env=environment)
session_state = SessionState(session_state.id, session_state.network_port_numbers, [virtual_display_process.pid, process.pid]) session_state.process_ids.extend([virtual_display_process.pid, process.pid])
SessionStateController.update_or_create(session_state) SessionStateController.update_or_create(session_state)
process.wait() process.wait()

View file

@ -1,6 +1,7 @@
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
from core.Constants import Constants from core.Constants import Constants
from core.Errors import InvalidSubscriptionError, MissingSubscriptionError, ConnectionUnprotectedError, ConnectionTerminationError, CommandNotFoundError from core.Errors import InvalidSubscriptionError, MissingSubscriptionError, ConnectionUnprotectedError, ConnectionTerminationError, CommandNotFoundError, TorServiceInitializationError
from core.controllers.ConfigurationController import ConfigurationController from core.controllers.ConfigurationController import ConfigurationController
from core.controllers.ProfileController import ProfileController from core.controllers.ProfileController import ProfileController
from core.controllers.SessionStateController import SessionStateController from core.controllers.SessionStateController import SessionStateController
@ -14,10 +15,14 @@ from pathlib import Path
from subprocess import CalledProcessError from subprocess import CalledProcessError
from typing import Union, Optional, Any from typing import Union, Optional, Any
import os import os
import psutil
import random import random
import re import re
import shutil import shutil
import socket import socket
import stem
import stem.control
import stem.process
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
@ -137,8 +142,8 @@ class ConnectionController:
if profile.connection.code == 'tor': if profile.connection.code == 'tor':
port_number = ConnectionController.get_random_available_port_number() port_number = ConnectionController.get_random_available_port_number()
ConnectionController.establish_tor_session_connection(session_directory, port_number) ConnectionController.establish_tor_session_connection(port_number, connection_observer=connection_observer)
session_state.network_port_numbers.append(port_number) session_state.network_port_numbers.tor.append(port_number)
elif profile.connection.code == 'wireguard': elif profile.connection.code == 'wireguard':
@ -147,7 +152,7 @@ class ConnectionController:
port_number = ConnectionController.get_random_available_port_number() port_number = ConnectionController.get_random_available_port_number()
ConnectionController.establish_wireguard_session_connection(profile, session_directory, port_number) ConnectionController.establish_wireguard_session_connection(profile, session_directory, port_number)
session_state.network_port_numbers.append(port_number) session_state.network_port_numbers.wireguard.append(port_number)
if profile.connection.masked: if profile.connection.masked:
@ -155,7 +160,7 @@ class ConnectionController:
proxy_port_number = ConnectionController.get_random_available_port_number() proxy_port_number = ConnectionController.get_random_available_port_number()
ConnectionController.establish_proxy_session_connection(profile, session_directory, port_number, proxy_port_number) ConnectionController.establish_proxy_session_connection(profile, session_directory, port_number, proxy_port_number)
session_state.network_port_numbers.append(proxy_port_number) session_state.network_port_numbers.proxy.append(proxy_port_number)
if not profile.connection.is_unprotected(): if not profile.connection.is_unprotected():
ConnectionController.await_connection(proxy_port_number or port_number, connection_observer=connection_observer) ConnectionController.await_connection(proxy_port_number or port_number, connection_observer=connection_observer)
@ -204,16 +209,118 @@ class ConnectionController:
time.sleep(1.0) time.sleep(1.0)
@staticmethod @staticmethod
def establish_tor_session_connection(session_directory: str, port_number: int): def establish_tor_session_connection(port_number: int, connection_observer: Optional[ConnectionObserver] = None):
if shutil.which('tor') is None: try:
raise CommandNotFoundError('tor')
tor_session_directory = f'{session_directory}/tor' controller = stem.control.Controller.from_socket_file(Constants.HV_TOR_CONTROL_SOCKET_PATH)
Path(tor_session_directory).mkdir(exist_ok=True, mode=0o700) controller.authenticate()
process = subprocess.Popen(('echo', f'DataDirectory {tor_session_directory}/tor\nSocksPort {port_number}'), stdout=subprocess.PIPE) except (FileNotFoundError, stem.SocketError, TypeError, IndexError):
return subprocess.Popen(('tor', '-f', '-'), stdin=process.stdout, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
ConnectionController.establish_tor_connection(connection_observer=connection_observer)
controller = stem.control.Controller.from_socket_file(Constants.HV_TOR_CONTROL_SOCKET_PATH)
controller.authenticate()
socks_port_numbers = [str(port_number) for port_number in controller.get_ports('socks')]
socks_port_numbers.append(str(port_number))
controller.set_conf('SocksPort', socks_port_numbers)
@staticmethod
def terminate_tor_session_connection(port_number: int):
try:
controller = stem.control.Controller.from_socket_file(Constants.HV_TOR_CONTROL_SOCKET_PATH)
controller.authenticate()
socks_port_numbers = [str(port_number) for port_number in controller.get_ports('socks')]
if len(socks_port_numbers) > 1:
socks_port_numbers = [socks_port_number for socks_port_number in socks_port_numbers if socks_port_number != port_number]
controller.set_conf('SocksPort', socks_port_numbers)
else:
controller.set_conf('SocksPort', '0')
except (FileNotFoundError, stem.SocketError, TypeError, IndexError):
pass
@staticmethod
def establish_tor_connection(connection_observer: Optional[ConnectionObserver] = None):
Path(Constants.HV_TOR_SESSION_STATE_HOME).mkdir(exist_ok=True, mode=0o700)
ConnectionController.terminate_tor_connection()
if connection_observer is not None:
connection_observer.notify('tor_bootstrapping')
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(
stem.process.launch_tor_with_config,
config={
'DataDirectory': Constants.HV_TOR_SESSION_STATE_HOME,
'ControlSocket': Constants.HV_TOR_CONTROL_SOCKET_PATH,
'PIDFile': Constants.HV_TOR_PROCESS_IDENTIFIER_PATH,
'SocksPort': '0'
},
init_msg_handler=lambda contents: ConnectionController.__on_tor_initialization_message(contents, connection_observer)
)
try:
future.result(timeout=Constants.TOR_BOOTSTRAP_TIMEOUT)
except FutureTimeoutError:
ConnectionController.terminate_tor_connection()
raise TorServiceInitializationError('The dedicated Tor service could not be initialized.')
if connection_observer is not None:
connection_observer.notify('tor_bootstrapped')
try:
controller = stem.control.Controller.from_socket_file(Constants.HV_TOR_CONTROL_SOCKET_PATH)
controller.authenticate()
except (FileNotFoundError, stem.SocketError, TypeError, IndexError):
raise TorServiceInitializationError('The dedicated Tor service could not be initialized.')
for session_state in SessionStateController.all():
for port_number in session_state.network_port_numbers.tor:
ConnectionController.establish_tor_session_connection(port_number)
@staticmethod
def terminate_tor_connection():
process_identifier_file = Path(Constants.HV_TOR_PROCESS_IDENTIFIER_PATH)
control_socket_file = Path(Constants.HV_TOR_CONTROL_SOCKET_PATH)
try:
process_identifier = int(process_identifier_file.read_text().strip())
except (OSError, ValueError):
process_identifier = None
if process_identifier is not None:
try:
process = psutil.Process(process_identifier)
if process.is_running():
process.terminate()
except psutil.NoSuchProcess:
pass
process_identifier_file.unlink(missing_ok=True)
control_socket_file.unlink(missing_ok=True)
@staticmethod @staticmethod
def establish_wireguard_session_connection(profile: SessionProfile, session_directory: str, port_number: int): def establish_wireguard_session_connection(profile: SessionProfile, session_directory: str, port_number: int):
@ -403,13 +510,13 @@ class ConnectionController:
@staticmethod @staticmethod
def __with_tor_connection(*args, task: Callable[..., Any], connection_observer: Optional[ConnectionObserver] = None, **kwargs): def __with_tor_connection(*args, task: Callable[..., Any], connection_observer: Optional[ConnectionObserver] = None, **kwargs):
session_directory = tempfile.mkdtemp(prefix='hv-')
port_number = ConnectionController.get_random_available_port_number() port_number = ConnectionController.get_random_available_port_number()
process = ConnectionController.establish_tor_session_connection(session_directory, port_number) ConnectionController.establish_tor_session_connection(port_number, connection_observer=connection_observer)
ConnectionController.await_connection(port_number, 5, connection_observer=connection_observer) ConnectionController.await_connection(port_number, connection_observer=connection_observer)
task_output = task(*args, proxies=ConnectionController.get_proxies(port_number), **kwargs) task_output = task(*args, proxies=ConnectionController.get_proxies(port_number), **kwargs)
process.terminate()
ConnectionController.terminate_tor_session_connection(port_number)
return task_output return task_output
@ -465,3 +572,16 @@ class ConnectionController:
return True return True
return False return False
@staticmethod
def __on_tor_initialization_message(contents, connection_observer: Optional[ConnectionObserver] = None):
if connection_observer is not None:
if 'Bootstrapped ' in contents:
progress = (m := re.search(r' (\d{1,3})% ', contents)) and int(m.group(1))
connection_observer.notify('tor_bootstrap_progressing', None, dict(
progress=progress
))

View file

@ -97,6 +97,10 @@ class ProfileController:
session_state = SessionStateController.get(profile.id) session_state = SessionStateController.get(profile.id)
if session_state is not None: if session_state is not None:
for port_number in session_state.network_port_numbers.tor:
ConnectionController.terminate_tor_session_connection(port_number)
session_state.dissolve(session_state.id) session_state.dissolve(session_state.id)
if profile.is_system_profile(): if profile.is_system_profile():
@ -173,7 +177,7 @@ class ProfileController:
if profile.is_session_profile(): if profile.is_session_profile():
session_state = SessionStateController.get_or_new(profile.id) session_state = SessionStateController.get_or_new(profile.id)
return len(session_state.network_port_numbers) > 0 or len(session_state.process_ids) > 0 return len(session_state.network_port_numbers.all) > 0 or len(session_state.process_ids) > 0
if profile.is_system_profile(): if profile.is_system_profile():

View file

@ -21,6 +21,10 @@ class SessionStateController:
def exists(id: int): def exists(id: int):
return SessionState.exists(id) return SessionState.exists(id)
@staticmethod
def all():
return SessionState.all()
@staticmethod @staticmethod
def update_or_create(session_state): def update_or_create(session_state):
session_state.save() session_state.save()

View file

@ -0,0 +1,16 @@
from dataclasses import dataclass, field
@dataclass
class NetworkPortNumbers:
proxy: list[int] = field(default_factory=list)
wireguard: list[int] = field(default_factory=list)
tor: list[int] = field(default_factory=list)
@property
def all(self):
return self.proxy + self.wireguard + self.tor
@property
def isolated(self):
return self.proxy + self.wireguard

View file

@ -1,4 +1,5 @@
from core.Constants import Constants from core.Constants import Constants
from core.models.session.NetworkPortNumbers import NetworkPortNumbers
from dataclasses import dataclass, field from dataclasses import dataclass, field
from dataclasses_json import config, Exclude, dataclass_json from dataclasses_json import config, Exclude, dataclass_json
from json import JSONDecodeError from json import JSONDecodeError
@ -17,7 +18,7 @@ class SessionState:
id: int = field( id: int = field(
metadata=config(exclude=Exclude.ALWAYS) metadata=config(exclude=Exclude.ALWAYS)
) )
network_port_numbers: list[int] = field(default_factory=list) network_port_numbers: NetworkPortNumbers = field(default_factory=NetworkPortNumbers)
process_ids: list[int] = field(default_factory=list) process_ids: list[int] = field(default_factory=list)
def get_state_path(self): def get_state_path(self):
@ -39,25 +40,53 @@ class SessionState:
@staticmethod @staticmethod
def find_by_id(id: int): def find_by_id(id: int):
state_path = SessionState.__get_state_path(id)
try: try:
session_state_file_contents = open(f'{SessionState.__get_state_path(id)}/state.json', 'r').read() session_state_file_contents = Path(f'{state_path}/state.json').read_text()
except FileNotFoundError: except FileNotFoundError:
return None return None
try: try:
session_state = json.loads(session_state_file_contents) session_state = json.loads(session_state_file_contents)
except JSONDecodeError: session_state['id'] = id
# noinspection PyUnresolvedReferences
return SessionState.from_dict(session_state)
except (JSONDecodeError, AttributeError):
shutil.rmtree(Path(state_path), ignore_errors=True)
return None return None
session_state['id'] = id
# noinspection PyUnresolvedReferences
return SessionState.from_dict(session_state)
@staticmethod @staticmethod
def exists(id: int): def exists(id: int):
return os.path.isdir(SessionState.__get_state_path(id)) and re.match(r'^\d+$', str(id)) return os.path.isdir(SessionState.__get_state_path(id)) and re.match(r'^\d+$', str(id))
@staticmethod
def all():
session_states = []
for directory_entry in os.listdir(Constants.HV_SESSION_STATE_HOME):
try:
id = int(directory_entry)
except ValueError:
continue
if SessionState.exists(id):
session_state = SessionState.find_by_id(id)
if session_state is not None:
session_states.append(session_state)
session_states.sort(key=lambda key: key.id)
return session_states
@staticmethod @staticmethod
def dissolve(id: int): def dissolve(id: int):
@ -76,7 +105,7 @@ class SessionState:
associated_process_ids = list(session_state.process_ids) associated_process_ids = list(session_state.process_ids)
network_connections = psutil.net_connections() network_connections = psutil.net_connections()
for network_port_number in session_state.network_port_numbers: for network_port_number in session_state.network_port_numbers.isolated:
for network_connection in network_connections: for network_connection in network_connections:

View file

@ -5,3 +5,6 @@ class ConnectionObserver(BaseObserver):
def __init__(self): def __init__(self):
self.on_connecting = [] self.on_connecting = []
self.on_tor_bootstrapping = []
self.on_tor_bootstrap_progressing = []
self.on_tor_bootstrapped = []

View file

@ -19,6 +19,7 @@ dependencies = [
"pysocks ~= 1.7.1", "pysocks ~= 1.7.1",
"python-dateutil ~= 2.9.0.post0", "python-dateutil ~= 2.9.0.post0",
"requests ~= 2.32.5", "requests ~= 2.32.5",
"stem ~= 1.8.2",
] ]
[project.urls] [project.urls]