359 lines
14 KiB
Python
359 lines
14 KiB
Python
"""
|
||
Phase 5 Sprint 5.0-5.1 Callback Dispatcher 單元測試
|
||
======================================================
|
||
建立: 2026-04-14 台北深夜 Claude Sonnet 4.6
|
||
|
||
覆蓋:
|
||
- callback_action_spec.yaml 載入正確性
|
||
- 24 個 action 都能解析
|
||
- 模板變數替換(_resolve_template)
|
||
- Context lookup(labels.instance / signals[0].alert_name)
|
||
- dispatch_action 骨架(Sprint 5.0 階段返回 stub)
|
||
|
||
🔴 遵循「禁止 Mock 測試鐵律」: 用真實 spec registry,不 mock。
|
||
"""
|
||
|
||
import pytest
|
||
|
||
from src.plugins.mcp.interfaces import MCPTool, MCPToolProvider, MCPToolResult
|
||
from src.services import callback_dispatcher as callback_dispatcher_module
|
||
from src.services.callback_dispatcher import (
|
||
dispatch_action,
|
||
get_action_spec,
|
||
list_actions_for_category,
|
||
load_action_registry,
|
||
_lookup_context,
|
||
_resolve_provider_name,
|
||
_resolve_template,
|
||
)
|
||
|
||
|
||
# =============================================================================
|
||
# Registry loading
|
||
# =============================================================================
|
||
|
||
|
||
class TestRegistryLoading:
|
||
def test_registry_loads_all_24_actions(self):
|
||
registry = load_action_registry()
|
||
# 10 查類 + 10 寫類 + 4 secops = 24
|
||
assert len(registry) >= 20, f"expected >= 20 actions, got {len(registry)}"
|
||
|
||
def test_all_actions_have_required_fields(self):
|
||
registry = load_action_registry()
|
||
for name, spec in registry.items():
|
||
assert spec.name == name
|
||
assert spec.label
|
||
assert spec.risk in ("low", "medium", "high", "critical")
|
||
assert spec.callback_format in ("info", "nonce")
|
||
assert spec.mcp_provider, f"{name} missing mcp_provider"
|
||
assert spec.mcp_tool, f"{name} missing mcp_tool"
|
||
|
||
def test_secops_requires_multi_sig(self):
|
||
for sa in ("secops_isolate", "secops_block_ip", "secops_evict"):
|
||
spec = get_action_spec(sa)
|
||
assert spec and spec.requires_multi_sig is True, \
|
||
f"{sa} should require multi_sig"
|
||
|
||
def test_info_actions_dont_need_multi_sig(self):
|
||
spec = get_action_spec("check_process")
|
||
assert spec and spec.requires_multi_sig is False
|
||
|
||
def test_write_actions_use_nonce_format(self):
|
||
for wa in ("k8s_restart", "k8s_scale_up", "k8s_rollback", "host_restart_service"):
|
||
spec = get_action_spec(wa)
|
||
assert spec and spec.callback_format == "nonce", \
|
||
f"{wa} should use nonce format"
|
||
|
||
def test_query_actions_use_info_format(self):
|
||
for qa in ("check_process", "check_port", "open_signoz"):
|
||
spec = get_action_spec(qa)
|
||
assert spec and spec.callback_format == "info", \
|
||
f"{qa} should use info format"
|
||
|
||
def test_legacy_provider_aliases_resolve_to_registered_names(self):
|
||
assert _resolve_provider_name("k8s") == "kubernetes"
|
||
assert _resolve_provider_name("ssh") == "ssh_host"
|
||
assert _resolve_provider_name("prometheus") == "prometheus"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_dynamic_button_tools_exist_in_real_providers(self):
|
||
from src.plugins.mcp.providers.k8s_provider import K8sProvider
|
||
from src.plugins.mcp.providers.ssh_provider import SSHProvider
|
||
|
||
provider_tools = {
|
||
"k8s": {tool.name for tool in await K8sProvider().list_tools()},
|
||
"ssh": {tool.name for tool in await SSHProvider().list_tools()},
|
||
}
|
||
registry = load_action_registry()
|
||
for spec in registry.values():
|
||
if spec.mcp_provider not in provider_tools:
|
||
continue
|
||
assert spec.mcp_tool in provider_tools[spec.mcp_provider], (
|
||
f"{spec.name} references missing {spec.mcp_provider}.{spec.mcp_tool}"
|
||
)
|
||
|
||
def test_k8s_pod_log_params_match_provider_schema(self):
|
||
spec = get_action_spec("check_pod_logs")
|
||
assert spec
|
||
assert "pod_name" in spec.mcp_params
|
||
assert "tail" in spec.mcp_params
|
||
assert "pod" not in spec.mcp_params
|
||
assert "tail_lines" not in spec.mcp_params
|
||
|
||
|
||
# =============================================================================
|
||
# Category filtering
|
||
# =============================================================================
|
||
|
||
|
||
class TestCategoryFiltering:
|
||
def test_kubernetes_has_4_write_actions(self):
|
||
actions = list_actions_for_category("kubernetes")
|
||
write_actions = [a for a in actions if a.callback_format == "nonce"]
|
||
assert len(write_actions) >= 4, \
|
||
f"kubernetes should have at least 4 write actions, got {len(write_actions)}"
|
||
|
||
def test_secops_has_4_actions(self):
|
||
actions = list_actions_for_category("secops")
|
||
assert len(actions) == 4, f"secops should have 4 actions, got {len(actions)}"
|
||
|
||
def test_host_resource_has_mix(self):
|
||
actions = list_actions_for_category("host_resource")
|
||
assert len(actions) >= 2
|
||
assert any(a.callback_format == "info" for a in actions), "需至少 1 個查類"
|
||
assert any(a.callback_format == "nonce" for a in actions), "需至少 1 個寫類"
|
||
|
||
def test_backup_failure_has_read_only_diagnostics(self):
|
||
actions = list_actions_for_category("backup_failure")
|
||
names = {a.name for a in actions}
|
||
assert {
|
||
"backup_check_host_disk",
|
||
"backup_check_jobs",
|
||
"backup_check_velero",
|
||
}.issubset(names)
|
||
assert all(a.callback_format == "info" for a in actions)
|
||
|
||
|
||
# =============================================================================
|
||
# Template variable resolution
|
||
# =============================================================================
|
||
|
||
|
||
class TestTemplateResolution:
|
||
def test_lookup_simple_key(self):
|
||
ctx = {"incident_id": "INC-123"}
|
||
assert _lookup_context("incident_id", ctx) == "INC-123"
|
||
|
||
def test_lookup_nested_labels(self):
|
||
ctx = {"labels": {"instance": "192.168.0.110"}}
|
||
assert _lookup_context("labels.instance", ctx) == "192.168.0.110"
|
||
|
||
def test_lookup_deep_nested(self):
|
||
ctx = {"labels": {"k8s": {"pod": "api-1"}}}
|
||
assert _lookup_context("labels.k8s.pod", ctx) == "api-1"
|
||
|
||
def test_lookup_list_index(self):
|
||
ctx = {"signals": [{"alert_name": "KubePodCrashLooping"}]}
|
||
assert _lookup_context("signals[0].alert_name", ctx) == "KubePodCrashLooping"
|
||
|
||
def test_lookup_missing_returns_none(self):
|
||
ctx = {"labels": {}}
|
||
assert _lookup_context("labels.instance", ctx) is None
|
||
|
||
def test_resolve_template_dict(self):
|
||
tpl = {"host": "{labels.instance}", "lines": 50}
|
||
ctx = {"labels": {"instance": "10.0.0.1"}}
|
||
out = _resolve_template(tpl, ctx)
|
||
assert out == {"host": "10.0.0.1", "lines": 50}
|
||
|
||
def test_resolve_keeps_unresolved(self):
|
||
"""若 context 缺 key,保留原 {...}(便於 debug)"""
|
||
tpl = {"host": "{labels.missing}"}
|
||
ctx = {"labels": {}}
|
||
out = _resolve_template(tpl, ctx)
|
||
assert out == {"host": "{labels.missing}"}
|
||
|
||
def test_resolve_string_with_multiple_placeholders(self):
|
||
tpl = "host={labels.instance} port={labels.port}"
|
||
ctx = {"labels": {"instance": "10.0.0.1", "port": "9100"}}
|
||
out = _resolve_template(tpl, ctx)
|
||
assert out == "host=10.0.0.1 port=9100"
|
||
|
||
|
||
# =============================================================================
|
||
# dispatch_action stub (Sprint 5.0 骨架)
|
||
# =============================================================================
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
class TestDispatchActionStub:
|
||
async def test_unknown_action_returns_failure(self):
|
||
result = await dispatch_action(
|
||
action_name="unknown_action",
|
||
incident_id="INC-TEST-001",
|
||
)
|
||
assert result.success is False
|
||
assert "Unknown action" in (result.error or "")
|
||
|
||
async def test_check_process_graceful_without_mcp(self):
|
||
"""Sprint 5.2: 無 MCP provider 註冊時,dispatcher 返回 graceful failure(不 crash)"""
|
||
result = await dispatch_action(
|
||
action_name="check_process",
|
||
incident_id="INC-TEST-002",
|
||
user_id=12345,
|
||
labels={"instance": "192.168.0.110"},
|
||
)
|
||
# 測試環境無 MCP → success=False + provider_not_found 錯誤訊息
|
||
assert isinstance(result.success, bool)
|
||
assert result.action == "check_process"
|
||
assert result.result_text # 有回覆文字
|
||
|
||
async def test_k8s_restart_graceful_without_mcp(self):
|
||
"""Sprint 5.2: k8s provider 未註冊時 graceful fail,不 crash"""
|
||
result = await dispatch_action(
|
||
action_name="k8s_restart",
|
||
incident_id="INC-TEST-003",
|
||
labels={"namespace": "awoooi-prod", "deployment": "awoooi-api"},
|
||
)
|
||
# 不 crash 就算過;具體 success 視 MCP registry 狀態
|
||
assert isinstance(result.success, bool)
|
||
assert result.action == "k8s_restart"
|
||
|
||
async def test_dispatch_includes_duration(self):
|
||
result = await dispatch_action(
|
||
action_name="check_process",
|
||
incident_id="INC-TEST-004",
|
||
labels={"instance": "10.0.0.1"},
|
||
)
|
||
assert result.duration_ms >= 0
|
||
assert result.duration_ms < 5000 # stub 應極快
|
||
|
||
async def test_secops_action_flag_preserved(self):
|
||
"""secops 動作在 dispatcher 結果中能識別 (供上層 Multi-Sig 處理)"""
|
||
spec = get_action_spec("secops_isolate")
|
||
assert spec and spec.requires_multi_sig is True
|
||
# dispatcher 本身不做 multi-sig 攔截(留給 callback_handler),只記錄 spec
|
||
|
||
|
||
# =============================================================================
|
||
# Sprint 5.2 — Internal actions (不走 MCP)
|
||
# =============================================================================
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
class TestInternalActions:
|
||
async def test_open_signoz_returns_url(self):
|
||
result = await dispatch_action(
|
||
action_name="open_signoz",
|
||
incident_id="INC-TEST-SZ",
|
||
labels={"service": "awoooi-api"},
|
||
)
|
||
assert result.success is True
|
||
assert "signoz.wooo.work" in result.result_text
|
||
assert "awoooi-api" in result.result_text
|
||
|
||
async def test_open_flywheel_returns_url(self):
|
||
result = await dispatch_action(
|
||
action_name="open_flywheel",
|
||
incident_id="INC-TEST-FW",
|
||
)
|
||
assert result.success is True
|
||
assert "flywheel" in result.result_text.lower()
|
||
|
||
async def test_secops_authorize_internal(self):
|
||
result = await dispatch_action(
|
||
action_name="secops_authorize",
|
||
incident_id="INC-TEST-SEC",
|
||
user_id=12345,
|
||
labels={"instance": "192.168.0.110"},
|
||
)
|
||
assert result.success is True
|
||
assert "12345" in result.result_text
|
||
|
||
async def test_record_authorization_persists_audit_intent(self, monkeypatch):
|
||
captured = {}
|
||
|
||
async def fake_record_authorization_audit(*, spec, params, incident_id, user_id):
|
||
captured["spec"] = spec
|
||
captured["params"] = params
|
||
captured["incident_id"] = incident_id
|
||
captured["user_id"] = user_id
|
||
return True
|
||
|
||
monkeypatch.setattr(
|
||
callback_dispatcher_module,
|
||
"_record_authorization_audit",
|
||
fake_record_authorization_audit,
|
||
)
|
||
|
||
result = await dispatch_action(
|
||
action_name="secops_isolate",
|
||
incident_id="INC-SEC-AUTH",
|
||
user_id=67890,
|
||
labels={"instance": "192.168.0.110"},
|
||
)
|
||
|
||
assert result.success is True
|
||
assert "已寫入審計與時間線" in result.result_text
|
||
assert captured["spec"].name == "secops_isolate"
|
||
assert captured["params"]["action"] == "request_network_isolation"
|
||
assert captured["incident_id"] == "INC-SEC-AUTH"
|
||
assert captured["user_id"] == 67890
|
||
|
||
|
||
# =============================================================================
|
||
# Sprint 5.2 — MCP 呼叫失敗路徑(Provider 未註冊)
|
||
# =============================================================================
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
class TestMcpFailurePath:
|
||
async def test_unregistered_provider_returns_graceful_error(self):
|
||
"""當 MCP registry 沒有對應 provider,dispatcher 返回 failure 而非 crash"""
|
||
# check_process 走 ssh provider — 在測試環境若 registry 空會返回失敗
|
||
result = await dispatch_action(
|
||
action_name="check_process",
|
||
incident_id="INC-TEST-MCP-FAIL",
|
||
labels={"instance": "192.168.0.110"},
|
||
)
|
||
# 測試環境可能沒註冊 provider,返回 failure 是可接受的
|
||
# 但絕不能 crash
|
||
assert isinstance(result.success, bool)
|
||
assert result.result_text # 有合理的錯誤訊息
|
||
|
||
|
||
class _CaptureMcpProvider(MCPToolProvider):
|
||
name = "ssh_host"
|
||
seen_parameters: dict | None = None
|
||
|
||
async def list_tools(self) -> list[MCPTool]:
|
||
return []
|
||
|
||
async def execute(self, tool_name: str, parameters: dict) -> MCPToolResult:
|
||
self.seen_parameters = dict(parameters)
|
||
return MCPToolResult(success=True, output={"stdout": "ok"})
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_dispatch_action_injects_mcp_audit_context(monkeypatch):
|
||
provider = _CaptureMcpProvider()
|
||
|
||
monkeypatch.setattr("src.plugins.mcp.registry.get_provider", lambda name: provider)
|
||
|
||
result = await dispatch_action(
|
||
action_name="check_process",
|
||
incident_id="INC-CB-AUDIT",
|
||
user_id=12345,
|
||
labels={"instance": "192.168.0.110"},
|
||
)
|
||
|
||
assert result.success is True
|
||
assert provider.seen_parameters is not None
|
||
audit_context = provider.seen_parameters["_mcp_audit"]
|
||
assert audit_context["incident_id"] == "INC-CB-AUDIT"
|
||
assert audit_context["session_id"] == "callback:INC-CB-AUDIT:check_process"
|
||
assert audit_context["flywheel_node"] == "operate"
|
||
assert audit_context["agent_role"] == "telegram_callback_dispatcher"
|
||
assert audit_context["operator_user_id"] == 12345
|