""" Shared SSH command helpers for AutoHeal and AiderHeal. The service layer owns allowlists and action semantics; this module only builds and runs the SSH command consistently. """ import os import subprocess from dataclasses import dataclass from typing import Any, List, Optional, Sequence, Union RemoteCommand = Union[str, Sequence[Any]] @dataclass(frozen=True) class SshExecResult: returncode: int stdout: str stderr: str argv: List[str] @property def success(self) -> bool: return self.returncode == 0 def ensure_ssh_key_permissions(key_path: Optional[str], logger: Optional[Any] = None) -> None: if not key_path: return safe_key = os.path.expanduser(key_path) if not os.path.exists(safe_key): if logger: logger.warning("SSH key not found: %s", safe_key) return try: os.chmod(safe_key, 0o600) except Exception as exc: if logger: logger.warning("Failed to secure SSH key: %s", exc) def build_ssh_command( *, host: str, user: str, command: RemoteCommand, port: int = 22, key_path: Optional[str] = None, connect_timeout: int = 10, jump_host: Optional[str] = None, jump_user: Optional[str] = None, strict_host_key_checking: str = "no", batch_mode: bool = False, server_alive_interval: Optional[int] = None, server_alive_count_max: Optional[int] = None, ) -> List[str]: argv = [ "ssh", "-p", str(port), ] if key_path: argv.extend(["-i", os.path.expanduser(key_path)]) argv.extend(["-o", f"StrictHostKeyChecking={strict_host_key_checking}"]) if batch_mode: argv.extend(["-o", "BatchMode=yes"]) argv.extend(["-o", f"ConnectTimeout={connect_timeout}"]) if server_alive_interval is not None: argv.extend(["-o", f"ServerAliveInterval={server_alive_interval}"]) if server_alive_count_max is not None: argv.extend(["-o", f"ServerAliveCountMax={server_alive_count_max}"]) if jump_host and jump_user: argv.extend(["-J", f"{jump_user}@{jump_host}"]) argv.append(f"{user}@{host}") if isinstance(command, str): argv.append(command) else: argv.append("--") argv.extend(str(part) for part in command) return argv def run_ssh_command( *, host: str, user: str, command: RemoteCommand, port: int = 22, key_path: Optional[str] = None, connect_timeout: int = 10, command_timeout: int = 60, jump_host: Optional[str] = None, jump_user: Optional[str] = None, strict_host_key_checking: str = "no", batch_mode: bool = False, server_alive_interval: Optional[int] = None, server_alive_count_max: Optional[int] = None, cwd: Optional[str] = None, logger: Optional[Any] = None, ) -> SshExecResult: ensure_ssh_key_permissions(key_path, logger=logger) argv = build_ssh_command( host=host, user=user, command=command, port=port, key_path=key_path, connect_timeout=connect_timeout, jump_host=jump_host, jump_user=jump_user, strict_host_key_checking=strict_host_key_checking, batch_mode=batch_mode, server_alive_interval=server_alive_interval, server_alive_count_max=server_alive_count_max, ) try: result = subprocess.run( argv, shell=False, capture_output=True, text=True, cwd=cwd, timeout=command_timeout, ) return SshExecResult( returncode=result.returncode, stdout=result.stdout.strip(), stderr=result.stderr.strip(), argv=argv, ) except subprocess.TimeoutExpired: return SshExecResult( returncode=-1, stdout="", stderr=f"SSH timeout after {command_timeout}s", argv=argv, ) except Exception as exc: if logger: logger.warning("SSH exec error: %s", exc) return SshExecResult(returncode=-1, stdout="", stderr=str(exc), argv=argv)