Files
awoooi/apps/api/tests/test_callback_dispatcher.py
Your Name fa0e956c0e
All checks were successful
Code Review / ai-code-review (push) Successful in 10s
CD Pipeline / tests (push) Successful in 59s
CD Pipeline / build-and-deploy (push) Successful in 3m22s
CD Pipeline / post-deploy-checks (push) Successful in 1m19s
fix(mcp): tag legacy provider calls with audit context
2026-05-06 17:18:52 +08:00

359 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Phase 5 Sprint 5.0-5.1 Callback Dispatcher 單元測試
======================================================
建立: 2026-04-14 台北深夜 Claude Sonnet 4.6
覆蓋:
- callback_action_spec.yaml 載入正確性
- 24 個 action 都能解析
- 模板變數替換_resolve_template
- Context lookuplabels.instance / signals[0].alert_name
- dispatch_action 骨架Sprint 5.0 階段返回 stub
🔴 遵循「禁止 Mock 測試鐵律」: 用真實 spec registry不 mock。
"""
import pytest
from src.plugins.mcp.interfaces import MCPTool, MCPToolProvider, MCPToolResult
from src.services import callback_dispatcher as callback_dispatcher_module
from src.services.callback_dispatcher import (
dispatch_action,
get_action_spec,
list_actions_for_category,
load_action_registry,
_lookup_context,
_resolve_provider_name,
_resolve_template,
)
# =============================================================================
# Registry loading
# =============================================================================
class TestRegistryLoading:
def test_registry_loads_all_24_actions(self):
registry = load_action_registry()
# 10 查類 + 10 寫類 + 4 secops = 24
assert len(registry) >= 20, f"expected >= 20 actions, got {len(registry)}"
def test_all_actions_have_required_fields(self):
registry = load_action_registry()
for name, spec in registry.items():
assert spec.name == name
assert spec.label
assert spec.risk in ("low", "medium", "high", "critical")
assert spec.callback_format in ("info", "nonce")
assert spec.mcp_provider, f"{name} missing mcp_provider"
assert spec.mcp_tool, f"{name} missing mcp_tool"
def test_secops_requires_multi_sig(self):
for sa in ("secops_isolate", "secops_block_ip", "secops_evict"):
spec = get_action_spec(sa)
assert spec and spec.requires_multi_sig is True, \
f"{sa} should require multi_sig"
def test_info_actions_dont_need_multi_sig(self):
spec = get_action_spec("check_process")
assert spec and spec.requires_multi_sig is False
def test_write_actions_use_nonce_format(self):
for wa in ("k8s_restart", "k8s_scale_up", "k8s_rollback", "host_restart_service"):
spec = get_action_spec(wa)
assert spec and spec.callback_format == "nonce", \
f"{wa} should use nonce format"
def test_query_actions_use_info_format(self):
for qa in ("check_process", "check_port", "open_signoz"):
spec = get_action_spec(qa)
assert spec and spec.callback_format == "info", \
f"{qa} should use info format"
def test_legacy_provider_aliases_resolve_to_registered_names(self):
assert _resolve_provider_name("k8s") == "kubernetes"
assert _resolve_provider_name("ssh") == "ssh_host"
assert _resolve_provider_name("prometheus") == "prometheus"
@pytest.mark.asyncio
async def test_dynamic_button_tools_exist_in_real_providers(self):
from src.plugins.mcp.providers.k8s_provider import K8sProvider
from src.plugins.mcp.providers.ssh_provider import SSHProvider
provider_tools = {
"k8s": {tool.name for tool in await K8sProvider().list_tools()},
"ssh": {tool.name for tool in await SSHProvider().list_tools()},
}
registry = load_action_registry()
for spec in registry.values():
if spec.mcp_provider not in provider_tools:
continue
assert spec.mcp_tool in provider_tools[spec.mcp_provider], (
f"{spec.name} references missing {spec.mcp_provider}.{spec.mcp_tool}"
)
def test_k8s_pod_log_params_match_provider_schema(self):
spec = get_action_spec("check_pod_logs")
assert spec
assert "pod_name" in spec.mcp_params
assert "tail" in spec.mcp_params
assert "pod" not in spec.mcp_params
assert "tail_lines" not in spec.mcp_params
# =============================================================================
# Category filtering
# =============================================================================
class TestCategoryFiltering:
def test_kubernetes_has_4_write_actions(self):
actions = list_actions_for_category("kubernetes")
write_actions = [a for a in actions if a.callback_format == "nonce"]
assert len(write_actions) >= 4, \
f"kubernetes should have at least 4 write actions, got {len(write_actions)}"
def test_secops_has_4_actions(self):
actions = list_actions_for_category("secops")
assert len(actions) == 4, f"secops should have 4 actions, got {len(actions)}"
def test_host_resource_has_mix(self):
actions = list_actions_for_category("host_resource")
assert len(actions) >= 2
assert any(a.callback_format == "info" for a in actions), "需至少 1 個查類"
assert any(a.callback_format == "nonce" for a in actions), "需至少 1 個寫類"
def test_backup_failure_has_read_only_diagnostics(self):
actions = list_actions_for_category("backup_failure")
names = {a.name for a in actions}
assert {
"backup_check_host_disk",
"backup_check_jobs",
"backup_check_velero",
}.issubset(names)
assert all(a.callback_format == "info" for a in actions)
# =============================================================================
# Template variable resolution
# =============================================================================
class TestTemplateResolution:
def test_lookup_simple_key(self):
ctx = {"incident_id": "INC-123"}
assert _lookup_context("incident_id", ctx) == "INC-123"
def test_lookup_nested_labels(self):
ctx = {"labels": {"instance": "192.168.0.110"}}
assert _lookup_context("labels.instance", ctx) == "192.168.0.110"
def test_lookup_deep_nested(self):
ctx = {"labels": {"k8s": {"pod": "api-1"}}}
assert _lookup_context("labels.k8s.pod", ctx) == "api-1"
def test_lookup_list_index(self):
ctx = {"signals": [{"alert_name": "KubePodCrashLooping"}]}
assert _lookup_context("signals[0].alert_name", ctx) == "KubePodCrashLooping"
def test_lookup_missing_returns_none(self):
ctx = {"labels": {}}
assert _lookup_context("labels.instance", ctx) is None
def test_resolve_template_dict(self):
tpl = {"host": "{labels.instance}", "lines": 50}
ctx = {"labels": {"instance": "10.0.0.1"}}
out = _resolve_template(tpl, ctx)
assert out == {"host": "10.0.0.1", "lines": 50}
def test_resolve_keeps_unresolved(self):
"""若 context 缺 key保留原 {...}(便於 debug"""
tpl = {"host": "{labels.missing}"}
ctx = {"labels": {}}
out = _resolve_template(tpl, ctx)
assert out == {"host": "{labels.missing}"}
def test_resolve_string_with_multiple_placeholders(self):
tpl = "host={labels.instance} port={labels.port}"
ctx = {"labels": {"instance": "10.0.0.1", "port": "9100"}}
out = _resolve_template(tpl, ctx)
assert out == "host=10.0.0.1 port=9100"
# =============================================================================
# dispatch_action stub (Sprint 5.0 骨架)
# =============================================================================
@pytest.mark.asyncio
class TestDispatchActionStub:
async def test_unknown_action_returns_failure(self):
result = await dispatch_action(
action_name="unknown_action",
incident_id="INC-TEST-001",
)
assert result.success is False
assert "Unknown action" in (result.error or "")
async def test_check_process_graceful_without_mcp(self):
"""Sprint 5.2: 無 MCP provider 註冊時dispatcher 返回 graceful failure不 crash"""
result = await dispatch_action(
action_name="check_process",
incident_id="INC-TEST-002",
user_id=12345,
labels={"instance": "192.168.0.110"},
)
# 測試環境無 MCP → success=False + provider_not_found 錯誤訊息
assert isinstance(result.success, bool)
assert result.action == "check_process"
assert result.result_text # 有回覆文字
async def test_k8s_restart_graceful_without_mcp(self):
"""Sprint 5.2: k8s provider 未註冊時 graceful fail不 crash"""
result = await dispatch_action(
action_name="k8s_restart",
incident_id="INC-TEST-003",
labels={"namespace": "awoooi-prod", "deployment": "awoooi-api"},
)
# 不 crash 就算過;具體 success 視 MCP registry 狀態
assert isinstance(result.success, bool)
assert result.action == "k8s_restart"
async def test_dispatch_includes_duration(self):
result = await dispatch_action(
action_name="check_process",
incident_id="INC-TEST-004",
labels={"instance": "10.0.0.1"},
)
assert result.duration_ms >= 0
assert result.duration_ms < 5000 # stub 應極快
async def test_secops_action_flag_preserved(self):
"""secops 動作在 dispatcher 結果中能識別 (供上層 Multi-Sig 處理)"""
spec = get_action_spec("secops_isolate")
assert spec and spec.requires_multi_sig is True
# dispatcher 本身不做 multi-sig 攔截(留給 callback_handler只記錄 spec
# =============================================================================
# Sprint 5.2 — Internal actions (不走 MCP)
# =============================================================================
@pytest.mark.asyncio
class TestInternalActions:
async def test_open_signoz_returns_url(self):
result = await dispatch_action(
action_name="open_signoz",
incident_id="INC-TEST-SZ",
labels={"service": "awoooi-api"},
)
assert result.success is True
assert "signoz.wooo.work" in result.result_text
assert "awoooi-api" in result.result_text
async def test_open_flywheel_returns_url(self):
result = await dispatch_action(
action_name="open_flywheel",
incident_id="INC-TEST-FW",
)
assert result.success is True
assert "flywheel" in result.result_text.lower()
async def test_secops_authorize_internal(self):
result = await dispatch_action(
action_name="secops_authorize",
incident_id="INC-TEST-SEC",
user_id=12345,
labels={"instance": "192.168.0.110"},
)
assert result.success is True
assert "12345" in result.result_text
async def test_record_authorization_persists_audit_intent(self, monkeypatch):
captured = {}
async def fake_record_authorization_audit(*, spec, params, incident_id, user_id):
captured["spec"] = spec
captured["params"] = params
captured["incident_id"] = incident_id
captured["user_id"] = user_id
return True
monkeypatch.setattr(
callback_dispatcher_module,
"_record_authorization_audit",
fake_record_authorization_audit,
)
result = await dispatch_action(
action_name="secops_isolate",
incident_id="INC-SEC-AUTH",
user_id=67890,
labels={"instance": "192.168.0.110"},
)
assert result.success is True
assert "已寫入審計與時間線" in result.result_text
assert captured["spec"].name == "secops_isolate"
assert captured["params"]["action"] == "request_network_isolation"
assert captured["incident_id"] == "INC-SEC-AUTH"
assert captured["user_id"] == 67890
# =============================================================================
# Sprint 5.2 — MCP 呼叫失敗路徑Provider 未註冊)
# =============================================================================
@pytest.mark.asyncio
class TestMcpFailurePath:
async def test_unregistered_provider_returns_graceful_error(self):
"""當 MCP registry 沒有對應 providerdispatcher 返回 failure 而非 crash"""
# check_process 走 ssh provider — 在測試環境若 registry 空會返回失敗
result = await dispatch_action(
action_name="check_process",
incident_id="INC-TEST-MCP-FAIL",
labels={"instance": "192.168.0.110"},
)
# 測試環境可能沒註冊 provider返回 failure 是可接受的
# 但絕不能 crash
assert isinstance(result.success, bool)
assert result.result_text # 有合理的錯誤訊息
class _CaptureMcpProvider(MCPToolProvider):
name = "ssh_host"
seen_parameters: dict | None = None
async def list_tools(self) -> list[MCPTool]:
return []
async def execute(self, tool_name: str, parameters: dict) -> MCPToolResult:
self.seen_parameters = dict(parameters)
return MCPToolResult(success=True, output={"stdout": "ok"})
@pytest.mark.asyncio
async def test_dispatch_action_injects_mcp_audit_context(monkeypatch):
provider = _CaptureMcpProvider()
monkeypatch.setattr("src.plugins.mcp.registry.get_provider", lambda name: provider)
result = await dispatch_action(
action_name="check_process",
incident_id="INC-CB-AUDIT",
user_id=12345,
labels={"instance": "192.168.0.110"},
)
assert result.success is True
assert provider.seen_parameters is not None
audit_context = provider.seen_parameters["_mcp_audit"]
assert audit_context["incident_id"] == "INC-CB-AUDIT"
assert audit_context["session_id"] == "callback:INC-CB-AUDIT:check_process"
assert audit_context["flywheel_node"] == "operate"
assert audit_context["agent_role"] == "telegram_callback_dispatcher"
assert audit_context["operator_user_id"] == 12345