Files
awoooi/apps/api/src/services/host_repair_agent.py
OG T 2fe8062fb8 refactor(api): Re-Review S1/S2/S3 改善 — 消除重複+防禦性驗證+測試隔離
S1: 抽取 _execute_and_observe() 公用方法
  - 消除 repair_by_uri 中 3 處重複的 execute+audit+langfuse 邏輯
  - 統一 AuditLog + Langfuse trace 寫入路徑

S2: SSH username 防禦性驗證
  - 新增 validate_ssh_user() + _SSH_USER_RE 正則
  - 在 _ssh_execute() 入口驗證 user 參數
  - 防止 user@host 拼接產生非預期行為
  - 新增 8 個 username 驗證測試

S3: Singleton 測試重置
  - 新增 _reset_for_test() classmethod
  - 避免跨測試狀態污染
  - 新增 2 個 singleton reset 測試

測試: 55/55 全數通過 (原 45 + 新 10)
首席架構師 Re-Review: 91/100  通過,3 個 Suggestion 全數實裝

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-04-07 11:17:40 +08:00

562 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
src/services/host_repair_agent.py
Host Repair Agent — 透過 SSH 執行主機層修復
2026-04-05 Claude Code: Sprint 3 Host Auto-Repair
2026-04-05 Claude Code: C1 修正 — key_path 直接傳入 _ssh_execute不反查
2026-04-06 Claude Code: Sprint 3 P0-1/P0-2/P0-3/P0-4 Critical Security Fixes
"""
import asyncio
import os
import re
import logging
import shlex
from dataclasses import dataclass
logger = logging.getLogger(__name__)
# SSH 連線設定 — layer → host config
LAYER_SSH_CONFIG: dict[str, dict] = {
"docker-110": {
"host": "192.168.0.110",
"user": "wooo",
"key_path": "/etc/repair-ssh/id_ed25519",
},
"docker-188": {
"host": "192.168.0.188",
"user": "ollama",
"key_path": "/etc/repair-ssh/id_ed25519",
},
"systemd-188": {
"host": "192.168.0.188",
"user": "ollama",
"key_path": "/etc/repair-ssh/id_ed25519",
},
}
# Component 名稱規則: 小寫英數 + 連字符1-31 字元
_COMPONENT_RE = re.compile(r"^[a-z0-9][a-z0-9-]{0,30}$")
SSH_TIMEOUT = 60 # seconds
# =============================================================================
# URI Scheme 解析
# 2026-04-06 Claude Code: Sprint 3 T1
# =============================================================================
@dataclass
class SshCommandURI:
"""解析後的 SSH_COMMAND URI"""
scheme: str # "openclaw" | "ansible" | "ssh"
host_or_layer: str # "docker-110" | "192.168.0.188" | "wooo@192.168.0.110"
payload: str # component name | playbook filename | raw command
_SUPPORTED_SCHEMES = {"openclaw", "ansible", "ssh"}
# 2026-04-06 Claude Code: Sprint 3 P0-1 Security Fix — Complete shell metacharacter detection
# Prevents: pipe (|), redirect (>, <), command substitution ($(), ${}, ``), logic ops (;, &)
_SHELL_METACHAR_RE = re.compile(r'[;&|<>`\n]|(\$\{|\$\()')
_MAX_COMMAND_LEN = 512
# 2026-04-07 Claude Code: S2 — SSH username 防禦性驗證 (小寫英數+底線+連字號)
_SSH_USER_RE = re.compile(r'^[a-z][a-z0-9_-]{0,31}$')
# Ansible 控制節點設定 — 從 env/ConfigMap 讀取
# 2026-04-06 Claude Code: Sprint 3 T3
ANSIBLE_CONTROL_HOST = os.environ.get("ANSIBLE_CONTROL_NODE_HOST", "192.168.0.188")
ANSIBLE_CONTROL_USER = os.environ.get("ANSIBLE_CONTROL_NODE_USER", "ollama")
ANSIBLE_PLAYBOOKS_PATH = os.environ.get("ANSIBLE_PLAYBOOKS_PATH", "~/openclaw-v5/ansible/playbooks")
KNOWN_HOSTS_PATH = "/etc/repair-known-hosts/known_hosts"
# 2026-04-06 Claude Code: Sprint 3 P0-3 — SSH target whitelist (prevent unauthorized targets)
SSH_TARGET_WHITELIST = {"192.168.0.110", "192.168.0.188"} # Only docker/ansible control nodes
def validate_ansible_playbook(playbook_name: str) -> None:
"""
驗證 playbook 名稱在白名單內,防止路徑遍歷攻擊。
白名單從環境變數 ANSIBLE_PLAYBOOK_WHITELIST 讀取ConfigMap 注入)。
Raises:
ValueError: playbook 不在白名單
"""
whitelist_raw = os.environ.get("ANSIBLE_PLAYBOOK_WHITELIST", "")
allowed = {p.strip() for p in whitelist_raw.split(",") if p.strip()}
if "/" in playbook_name or ".." in playbook_name or playbook_name not in allowed:
raise ValueError(
f"Security Block: '{playbook_name}' not in allowed whitelist. "
f"Allowed: {sorted(allowed)}"
)
def validate_ssh_target_host(host: str) -> None:
"""
驗證 SSH 目標主機在白名單內,防止向未授權的主機執行命令。
2026-04-06 Claude Code: Sprint 3 P0-3
Raises:
ValueError: host 不在白名單
"""
if host not in SSH_TARGET_WHITELIST:
raise ValueError(
f"Security Block: SSH target '{host}' not in allowed whitelist. "
f"Allowed: {sorted(SSH_TARGET_WHITELIST)}"
)
def validate_ssh_user(user: str) -> None:
"""
驗證 SSH username 僅含安全字元,防止 user@host 拼接產生非預期行為。
2026-04-07 Claude Code: Re-Review S2 — 防禦性工程
Raises:
ValueError: username 格式不合法
"""
if not _SSH_USER_RE.match(user):
raise ValueError(
f"Security Block: SSH user '{user}' contains invalid characters. "
f"Expected: lowercase alphanumeric, underscore, hyphen (1-32 chars)"
)
def parse_uri_command(command: str) -> SshCommandURI:
"""
解析 SSH_COMMAND URI scheme。
支援格式:
openclaw://docker-110/sentry
ansible://192.168.0.188/vacuum_postgres.yml
ssh://wooo@192.168.0.110/docker ps
Raises:
ValueError: scheme 不支援或 payload 為空
"""
if "://" not in command:
raise ValueError(f"Unsupported scheme: '{command}' (expected scheme://host/payload)")
scheme, rest = command.split("://", 1)
if scheme not in _SUPPORTED_SCHEMES:
raise ValueError(f"Unsupported scheme: '{scheme}' (supported: {_SUPPORTED_SCHEMES})")
if "/" not in rest:
raise ValueError(f"Invalid URI '{command}': missing payload after host")
host_or_layer, payload = rest.split("/", 1)
if not payload:
raise ValueError(f"Invalid URI '{command}': payload is empty")
return SshCommandURI(scheme=scheme, host_or_layer=host_or_layer, payload=payload)
def validate_shell_safety(command: str) -> None:
"""
驗證 ssh:// payload 不含 shell metacharacter 或超長命令。
Raises:
ValueError: 含危險字元或超過長度限制
"""
if len(command) > _MAX_COMMAND_LEN:
raise ValueError(f"Command too long: {len(command)} > {_MAX_COMMAND_LEN}")
if _SHELL_METACHAR_RE.search(command):
raise ValueError(f"Shell metacharacter detected in command: '{command}'")
@dataclass
class HostRepairResult:
success: bool
layer: str
component: str
output: str = ""
error: str = ""
def get_ssh_config_for_layer(layer: str) -> dict:
"""取得指定 layer 的 SSH 連線設定。k8s layer 不走 SSH。"""
if layer == "k8s" or layer.startswith("k8s"):
raise ValueError(f"Layer '{layer}' uses kubectl, not SSH")
config = LAYER_SSH_CONFIG.get(layer)
if config is None:
raise ValueError(f"Unknown layer: '{layer}'")
return config
def build_repair_command(component: str) -> str:
"""組裝 repair 命令,防止 command injection。"""
if not _COMPONENT_RE.match(component):
raise ValueError(f"Invalid component name: '{component}'")
return f"repair:{component}"
class HostRepairAgent:
"""透過 SSH 執行主機層修復命令。
2026-04-06 Claude Code: Sprint 3 P0-4 — Singleton pattern ensures in-process locks persist
"""
_instance: "HostRepairAgent | None" = None
def __new__(cls) -> "HostRepairAgent":
"""Singleton: return shared instance across all calls."""
if cls._instance is None:
cls._instance = super().__new__(cls)
# Initialize only once
cls._instance._in_process_locks = {}
return cls._instance
@classmethod
def _reset_for_test(cls) -> None:
"""測試專用:重置 singleton 狀態,避免跨測試污染。
2026-04-07 Claude Code: Re-Review S3
"""
cls._instance = None
def __init__(self) -> None:
# in-process 鎖表 — key: lock_key → asyncio.Lock
# 2026-04-06 Claude Code: Sprint 3 T4
# Only initialize if not already set (singleton pattern)
if not hasattr(self, '_in_process_locks'):
self._in_process_locks: dict[str, asyncio.Lock] = {}
def _get_in_process_lock(self, lock_key: str) -> asyncio.Lock:
"""取得或建立指定 key 的 in-process 鎖。"""
if lock_key not in self._in_process_locks:
self._in_process_locks[lock_key] = asyncio.Lock()
return self._in_process_locks[lock_key]
async def repair(self, layer: str, component: str) -> HostRepairResult:
"""執行修復並回傳結果。"""
try:
config = get_ssh_config_for_layer(layer)
command = build_repair_command(component)
except ValueError as e:
return HostRepairResult(
success=False,
layer=layer,
component=component,
error=str(e),
)
try:
output = await self._ssh_execute(
host=config["host"],
user=config["user"],
key_path=config["key_path"],
command=command,
)
except asyncio.TimeoutError:
return HostRepairResult(
success=False,
layer=layer,
component=component,
error=f"SSH timeout after {SSH_TIMEOUT}s",
)
except Exception as e:
return HostRepairResult(
success=False,
layer=layer,
component=component,
error=str(e),
)
success = output.startswith("REPAIR_OK:")
return HostRepairResult(
success=success,
layer=layer,
component=component,
output=output,
error="" if success else output,
)
async def _write_audit_log(
self,
uri: str,
success: bool,
output: str,
error: str | None,
duration_ms: int,
) -> None:
"""寫入 SSH_COMMAND 稽核日誌到 PostgreSQL。
2026-04-06 Claude Code: Sprint 3 T5
"""
try:
from src.db.base import get_db_context
from src.db.models import AuditLog
async with get_db_context() as db:
audit = AuditLog(
approval_id="auto_repair", # nullable=False — SSH_COMMAND 不走 approval flow
operation_type="SSH_COMMAND",
target_resource=uri[:200],
namespace="host-layer",
success=success,
error_message=error,
k8s_response={"output": output[:1000]} if output else None,
executed_by="auto_repair",
execution_duration_ms=duration_ms,
dry_run_passed=True,
dry_run_message=None,
)
db.add(audit)
await db.commit()
logger.info("ssh_command_audit_written", uri=uri, success=success)
except Exception as e:
logger.error("ssh_command_audit_failed", uri=uri, error=str(e))
# Do not re-raise — audit failure must not affect repair result
async def _execute_and_observe(
self,
command: str,
uri: SshCommandURI,
execute_fn,
) -> HostRepairResult:
"""
執行修復 + AuditLog + Langfuse trace 的公用邏輯。
2026-04-07 Claude Code: Re-Review S1 — 消除 repair_by_uri 中 3 處重複
"""
import time as _time
_start = _time.monotonic()
result = await execute_fn()
duration_ms = int((_time.monotonic() - _start) * 1000)
# AuditLog (fire and forget)
try:
await self._write_audit_log(
uri=command,
success=result.success,
output=result.output,
error=result.error or None,
duration_ms=duration_ms,
)
except Exception:
pass
# Langfuse trace (fire and forget)
try:
from src.services.langfuse_client import get_langfuse
lf = get_langfuse()
if lf:
trace = lf.trace(name="ssh_command_repair")
trace.span(
name=f"{uri.scheme}_execute",
input={"uri": command},
output={"success": result.success, "output": result.output[:500] if result.output else ""},
metadata={"duration_ms": duration_ms, "scheme": uri.scheme},
)
lf.flush()
except Exception as lf_err:
logger.debug("langfuse_trace_skipped", error=str(lf_err))
return result
async def repair_by_uri(self, command: str, approved: bool = False) -> HostRepairResult:
"""
根據 URI scheme 路由至對應的執行路徑。
2026-04-06 Claude Code: Sprint 3 T3
2026-04-06 Claude Code: Sprint 3 T4 — Redis 冪等鎖防止重複修復
2026-04-06 Claude Code: Sprint 3 T5 — AuditLog + Langfuse Trace
2026-04-07 Claude Code: Re-Review S1 — 抽取 _execute_and_observe 消除重複
"""
try:
uri = parse_uri_command(command)
except ValueError as e:
return HostRepairResult(success=False, layer="", component="", error=str(e))
# 冪等鎖 — 防止同一 component 並發修復
# 雙層鎖: in-process asyncio.Lock (必定生效) + Redis 分散式鎖 (best-effort)
# 2026-04-06 Claude Code: Sprint 3 T4
lock_key = f"repair_lock:ssh_command:{uri.scheme}:{uri.host_or_layer}:{uri.payload}"
in_process_lock = self._get_in_process_lock(lock_key)
async def _execute() -> HostRepairResult:
if uri.scheme == "openclaw":
return await self._execute_openclaw(uri.host_or_layer, uri.payload)
if uri.scheme == "ansible":
try:
validate_ansible_playbook(uri.payload)
except ValueError as e:
return HostRepairResult(success=False, layer="ansible", component=uri.payload, error=str(e))
return await self._execute_ansible(uri.host_or_layer, uri.payload)
if uri.scheme == "ssh":
if not approved:
return HostRepairResult(
success=False,
layer="ssh",
component=uri.payload,
error="ssh:// scheme requires_approval=True — must be explicitly approved",
)
try:
validate_shell_safety(uri.payload)
except ValueError as e:
return HostRepairResult(success=False, layer="ssh", component=uri.payload, error=str(e))
return await self._execute_ssh_direct(uri.host_or_layer, uri.payload)
return HostRepairResult(success=False, layer="", component="", error=f"Unhandled scheme: {uri.scheme}")
# in-process 鎖: locked() 代表正在進行,立即拒絕重複
if in_process_lock.locked():
return HostRepairResult(
success=False,
layer=uri.scheme,
component=uri.payload,
error=f"Repair already running for {uri.scheme}://{uri.host_or_layer}/{uri.payload}",
)
async with in_process_lock:
# Redis 分散式鎖 (best-effort跨 Pod)
# blocking_timeout=0: 立即失敗,不等待
try:
from src.core.redis_client import RedisLock
redis_lock: RedisLock | None = RedisLock(lock_key, timeout=120, blocking_timeout=0)
except Exception:
redis_lock = None # Redis 未連線fail open
# Redis 無法取得 → fail open直接執行
if redis_lock is None:
return await self._execute_and_observe(command, uri, _execute)
try:
acquired = await redis_lock.acquire()
except Exception:
# Redis 不可用fail open
return await self._execute_and_observe(command, uri, _execute)
if not acquired:
# Redis 鎖已被其他 Pod 持有
return HostRepairResult(
success=False,
layer=uri.scheme,
component=uri.payload,
error=f"Repair already running for {uri.scheme}://{uri.host_or_layer}/{uri.payload}",
)
try:
result = await self._execute_and_observe(command, uri, _execute)
finally:
try:
await redis_lock.release()
except Exception:
pass
return result
async def _execute_openclaw(self, layer: str, component: str) -> HostRepairResult:
"""openclaw:// — 呼叫現有的 repair(layer, component) 邏輯"""
return await self.repair(layer=layer, component=component)
async def _execute_ansible(self, _control_host: str, playbook_name: str) -> HostRepairResult:
"""
ansible:// — SSH 至控制節點,執行 ansible-playbook。
2026-04-06 Claude Code: Sprint 3 T3
2026-04-06 Claude Code: Sprint 3 P0-2 — shlex.quote() prevents path injection
注意: 強制使用 ConfigMap 的控制節點,忽略 URI 中的 host (安全設計)
"""
host = ANSIBLE_CONTROL_HOST
user = ANSIBLE_CONTROL_USER
# Important fix: 驗證 ConfigMap 的控制節點也在白名單內,防止環境變數被篡改繞過白名單
try:
validate_ssh_target_host(host)
except ValueError as e:
return HostRepairResult(
success=False, layer="ansible", component=playbook_name,
error=f"Ansible control host validation failed: {e}",
)
playbook_path = f"{ANSIBLE_PLAYBOOKS_PATH}/{playbook_name}"
# P0-2: Quote playbook_path to prevent shell injection if path contains special chars
ssh_command = f"ansible-playbook {shlex.quote(playbook_path)}"
try:
output = await self._ssh_execute(
host=host,
user=user,
key_path="/etc/repair-ssh/id_ed25519",
command=ssh_command,
)
except asyncio.TimeoutError:
return HostRepairResult(
success=False, layer="ansible", component=playbook_name,
error=f"Ansible SSH timeout after {SSH_TIMEOUT}s",
)
except Exception as e:
return HostRepairResult(
success=False, layer="ansible", component=playbook_name,
error=str(e),
)
success = "REPAIR_OK" in output or "ok=" in output
return HostRepairResult(
success=success,
layer="ansible",
component=playbook_name,
output=output,
error="" if success else output,
)
async def _execute_ssh_direct(self, host_user: str, command: str) -> HostRepairResult:
"""
ssh:// — 直接執行 SSH 命令(需明確 approved=True
host_user 格式: "wooo@192.168.0.110"
2026-04-06 Claude Code: Sprint 3 T3
2026-04-06 Claude Code: Sprint 3 P0-3 — Validate target host in whitelist
"""
if "@" in host_user:
user, host = host_user.split("@", 1)
else:
return HostRepairResult(
success=False, layer="ssh", component=command,
error=f"Invalid host_user format '{host_user}' (expected user@host)",
)
# P0-3: Validate SSH target is in whitelist
try:
validate_ssh_target_host(host)
except ValueError as e:
return HostRepairResult(
success=False, layer="ssh", component=command,
error=str(e),
)
try:
output = await self._ssh_execute(
host=host,
user=user,
key_path="/etc/repair-ssh/id_ed25519",
command=command,
)
except asyncio.TimeoutError:
return HostRepairResult(
success=False, layer="ssh", component=command,
error=f"SSH timeout after {SSH_TIMEOUT}s",
)
except Exception as e:
return HostRepairResult(success=False, layer="ssh", component=command, error=str(e))
success = not output.startswith("ERROR")
return HostRepairResult(
success=success,
layer="ssh",
component=command,
output=output,
error="" if success else output,
)
async def _ssh_execute(self, host: str, user: str, key_path: str, command: str) -> str:
"""執行 SSH 命令,回傳 stdout。key_path 由呼叫方傳入,不反查。
2026-04-07 Claude Code: Re-Review S2 — 加 user 防禦性驗證
"""
validate_ssh_user(user)
import time
deadline = time.monotonic() + SSH_TIMEOUT
proc = await asyncio.wait_for(
asyncio.create_subprocess_exec(
"ssh",
"-i", key_path,
"-o", "StrictHostKeyChecking=yes",
"-o", f"UserKnownHostsFile={KNOWN_HOSTS_PATH}",
"-o", "BatchMode=yes",
"-o", f"ConnectTimeout={SSH_TIMEOUT}",
f"{user}@{host}",
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
),
timeout=SSH_TIMEOUT,
)
remaining = max(1.0, deadline - time.monotonic())
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=remaining)
output = stdout.decode().strip()
logger.info("SSH repair %s@%s %s%s", user, host, command, output)
return output