145 lines
4.1 KiB
Python
145 lines
4.1 KiB
Python
"""
|
|
Shared SSH command helpers for AutoHeal and AiderHeal.
|
|
|
|
The service layer owns allowlists and action semantics; this module only
|
|
builds and runs the SSH command consistently.
|
|
"""
|
|
|
|
import os
|
|
import subprocess
|
|
from dataclasses import dataclass
|
|
from typing import Any, List, Optional, Sequence, Union
|
|
|
|
|
|
RemoteCommand = Union[str, Sequence[Any]]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SshExecResult:
|
|
returncode: int
|
|
stdout: str
|
|
stderr: str
|
|
argv: List[str]
|
|
|
|
@property
|
|
def success(self) -> bool:
|
|
return self.returncode == 0
|
|
|
|
|
|
def ensure_ssh_key_permissions(key_path: Optional[str], logger: Optional[Any] = None) -> None:
|
|
if not key_path:
|
|
return
|
|
safe_key = os.path.expanduser(key_path)
|
|
if not os.path.exists(safe_key):
|
|
if logger:
|
|
logger.warning("SSH key not found: %s", safe_key)
|
|
return
|
|
try:
|
|
os.chmod(safe_key, 0o600)
|
|
except Exception as exc:
|
|
if logger:
|
|
logger.warning("Failed to secure SSH key: %s", exc)
|
|
|
|
|
|
def build_ssh_command(
|
|
*,
|
|
host: str,
|
|
user: str,
|
|
command: RemoteCommand,
|
|
port: int = 22,
|
|
key_path: Optional[str] = None,
|
|
connect_timeout: int = 10,
|
|
jump_host: Optional[str] = None,
|
|
jump_user: Optional[str] = None,
|
|
strict_host_key_checking: str = "no",
|
|
batch_mode: bool = False,
|
|
server_alive_interval: Optional[int] = None,
|
|
server_alive_count_max: Optional[int] = None,
|
|
) -> List[str]:
|
|
argv = [
|
|
"ssh",
|
|
"-p",
|
|
str(port),
|
|
]
|
|
if key_path:
|
|
argv.extend(["-i", os.path.expanduser(key_path)])
|
|
argv.extend(["-o", f"StrictHostKeyChecking={strict_host_key_checking}"])
|
|
if batch_mode:
|
|
argv.extend(["-o", "BatchMode=yes"])
|
|
argv.extend(["-o", f"ConnectTimeout={connect_timeout}"])
|
|
if server_alive_interval is not None:
|
|
argv.extend(["-o", f"ServerAliveInterval={server_alive_interval}"])
|
|
if server_alive_count_max is not None:
|
|
argv.extend(["-o", f"ServerAliveCountMax={server_alive_count_max}"])
|
|
if jump_host and jump_user:
|
|
argv.extend(["-J", f"{jump_user}@{jump_host}"])
|
|
argv.append(f"{user}@{host}")
|
|
|
|
if isinstance(command, str):
|
|
argv.append(command)
|
|
else:
|
|
argv.append("--")
|
|
argv.extend(str(part) for part in command)
|
|
return argv
|
|
|
|
|
|
def run_ssh_command(
|
|
*,
|
|
host: str,
|
|
user: str,
|
|
command: RemoteCommand,
|
|
port: int = 22,
|
|
key_path: Optional[str] = None,
|
|
connect_timeout: int = 10,
|
|
command_timeout: int = 60,
|
|
jump_host: Optional[str] = None,
|
|
jump_user: Optional[str] = None,
|
|
strict_host_key_checking: str = "no",
|
|
batch_mode: bool = False,
|
|
server_alive_interval: Optional[int] = None,
|
|
server_alive_count_max: Optional[int] = None,
|
|
cwd: Optional[str] = None,
|
|
logger: Optional[Any] = None,
|
|
) -> SshExecResult:
|
|
ensure_ssh_key_permissions(key_path, logger=logger)
|
|
argv = build_ssh_command(
|
|
host=host,
|
|
user=user,
|
|
command=command,
|
|
port=port,
|
|
key_path=key_path,
|
|
connect_timeout=connect_timeout,
|
|
jump_host=jump_host,
|
|
jump_user=jump_user,
|
|
strict_host_key_checking=strict_host_key_checking,
|
|
batch_mode=batch_mode,
|
|
server_alive_interval=server_alive_interval,
|
|
server_alive_count_max=server_alive_count_max,
|
|
)
|
|
try:
|
|
result = subprocess.run(
|
|
argv,
|
|
shell=False,
|
|
capture_output=True,
|
|
text=True,
|
|
cwd=cwd,
|
|
timeout=command_timeout,
|
|
)
|
|
return SshExecResult(
|
|
returncode=result.returncode,
|
|
stdout=result.stdout.strip(),
|
|
stderr=result.stderr.strip(),
|
|
argv=argv,
|
|
)
|
|
except subprocess.TimeoutExpired:
|
|
return SshExecResult(
|
|
returncode=-1,
|
|
stdout="",
|
|
stderr=f"SSH timeout after {command_timeout}s",
|
|
argv=argv,
|
|
)
|
|
except Exception as exc:
|
|
if logger:
|
|
logger.warning("SSH exec error: %s", exc)
|
|
return SshExecResult(returncode=-1, stdout="", stderr=str(exc), argv=argv)
|