Files
awoooi/apps/api/src/services/callback_dispatcher.py
Your Name fa0e956c0e
All checks were successful
Code Review / ai-code-review (push) Successful in 10s
CD Pipeline / tests (push) Successful in 59s
CD Pipeline / build-and-deploy (push) Successful in 3m22s
CD Pipeline / post-deploy-checks (push) Successful in 1m19s
fix(mcp): tag legacy provider calls with audit context
2026-05-06 17:18:52 +08:00

718 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Telegram Callback Dispatcher — 分類按鈕統一調度
================================================
Phase 5 Sprint 5.0-5.1 — 2026-04-14 Claude Sonnet 4.6
相關: docs/superpowers/plans/2026-04-14-PHASE-5-category-buttons-completion.md
ADR-079 分類按鈕完整化
職責:
1. 從 callback_action_spec.yaml 載入 action registry
2. 接收 Telegram callback_data (action:incident_id or action:id:ts:rand)
3. 驗證 nonce寫類按鈕或 allow info查類按鈕
4. 依 spec 呼叫對應 MCP tool
5. Reply 執行結果到原告警卡片reply_to_message_id
設計原則:
- Registry pattern — 新增按鈕只需 yaml 一行,無需改 dispatcher code
- 模板變數: {incident_id} / {labels.xxx} / {signals[0].xxx} / {callback.user_id}
- 所有 action 都有 audit log寫類額外 nonce 驗證 log
- reply_to 原告警 message_id從 Redis tg_msg:{incident_id}
遵守「禁止 Mock 測試鐵律」: 純邏輯 + MCP dispatch測試用真實 registry。
"""
from __future__ import annotations
import json
import time
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any
import structlog
import yaml
logger = structlog.get_logger(__name__)
_PROVIDER_ALIASES = {
"k8s": "kubernetes",
"ssh": "ssh_host",
}
def _resolve_provider_name(provider_name: str) -> str:
"""Normalize legacy callback spec provider names to registered MCP providers."""
return _PROVIDER_ALIASES.get(provider_name, provider_name)
# =============================================================================
# Data Types
# =============================================================================
@dataclass
class ActionSpec:
"""從 callback_action_spec.yaml 載入的單一 action 規格"""
name: str
label: str
emoji: str
risk: str # low | medium | high | critical
callback_format: str # info | nonce
category: str
mcp_provider: str # k8s | ssh | prometheus | signoz | database | internal
mcp_tool: str
mcp_params: dict[str, Any]
reply_format: str # text | code | url | truncated
timeout_sec: int
description: str
requires_multi_sig: bool = False
@dataclass
class DispatchResult:
"""Dispatcher 執行結果"""
success: bool
action: str
incident_id: str
user_id: int | None
result_text: str
error: str | None = None
duration_ms: float = 0.0
# =============================================================================
# Spec Registry
# =============================================================================
@lru_cache(maxsize=1)
def load_action_registry() -> dict[str, ActionSpec]:
"""
載入 callback_action_spec.yaml 並快取(進程內不重載,重啟 Pod 才更新)
Returns:
{action_name: ActionSpec}
"""
spec_path = Path(__file__).parent / "callback_action_spec.yaml"
if not spec_path.exists():
logger.warning("callback_action_spec_not_found", path=str(spec_path))
return {}
with spec_path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f)
registry: dict[str, ActionSpec] = {}
for name, spec_dict in (data.get("actions") or {}).items():
mcp = spec_dict.get("mcp", {}) or {}
registry[name] = ActionSpec(
name=name,
label=spec_dict.get("label", name),
emoji=spec_dict.get("emoji", ""),
risk=spec_dict.get("risk", "medium"),
callback_format=spec_dict.get("callback_format", "info"),
category=spec_dict.get("category", ""),
mcp_provider=mcp.get("provider", ""),
mcp_tool=mcp.get("tool", ""),
mcp_params=mcp.get("params") or {},
reply_format=spec_dict.get("reply_format", "text"),
timeout_sec=int(spec_dict.get("timeout_sec", 10)),
description=spec_dict.get("description", ""),
requires_multi_sig=bool(spec_dict.get("requires_multi_sig", False)),
)
logger.info("callback_action_registry_loaded", count=len(registry))
return registry
def get_action_spec(action_name: str) -> ActionSpec | None:
"""查找單一 action 規格"""
return load_action_registry().get(action_name)
def list_actions_for_category(alert_category: str) -> list[ActionSpec]:
"""列出特定分類的所有可用 action供 _build_inline_keyboard 使用)"""
return [
spec for spec in load_action_registry().values()
if spec.category == alert_category
]
# =============================================================================
# Template Variable Substitution
# =============================================================================
def _resolve_template(template: Any, context: dict) -> Any:
"""
遞迴替換模板變數。
支援:
- {incident_id}
- {labels.xxx} / {labels.xxx.yyy}
- {signals[0].xxx}
- {callback.user_id}
Example:
template = {"host": "{labels.instance}", "lines": 50}
context = {"labels": {"instance": "192.168.0.110"}, "incident_id": "INC-123"}
{"host": "192.168.0.110", "lines": 50}
"""
if isinstance(template, dict):
return {k: _resolve_template(v, context) for k, v in template.items()}
if isinstance(template, list):
return [_resolve_template(v, context) for v in template]
if isinstance(template, str) and "{" in template:
# 找出所有 {xxx} placeholder 並替換
import re
def _repl(m: re.Match) -> str:
key = m.group(1)
val = _lookup_context(key, context)
return str(val) if val is not None else m.group(0)
return re.sub(r"\{([a-zA-Z0-9_.\[\]]+)\}", _repl, template)
return template
def _lookup_context(key: str, context: dict) -> Any:
"""
從 context 查表(支援巢狀 key: labels.instance / signals[0].alert_name
"""
parts = key.replace("[", ".").replace("]", "").split(".")
cur: Any = context
for part in parts:
if part == "":
continue
if isinstance(cur, dict):
cur = cur.get(part)
elif isinstance(cur, list):
try:
cur = cur[int(part)]
except (ValueError, IndexError):
return None
else:
return None
if cur is None:
return None
return cur
# =============================================================================
# Dispatcher (Sprint 5.1)
# =============================================================================
async def dispatch_action(
action_name: str,
incident_id: str,
user_id: int | None = None,
labels: dict | None = None,
extra_context: dict | None = None,
) -> DispatchResult:
"""
執行 callback action — 依 spec 呼叫 MCP tool
Args:
action_name: action 名稱(對應 spec registry
incident_id: 關聯 incident
user_id: Telegram user idcallback 來源)
labels: alert labels供模板替換
extra_context: 額外上下文signals 等)
Returns:
DispatchResult包含 result_text 供 reply 使用)
"""
start = time.perf_counter()
spec = get_action_spec(action_name)
if not spec:
logger.warning("dispatch_action_unknown", action=action_name)
return DispatchResult(
success=False,
action=action_name,
incident_id=incident_id,
user_id=user_id,
result_text="",
error=f"Unknown action: {action_name}",
duration_ms=(time.perf_counter() - start) * 1000,
)
# 建立模板 context
context = {
"incident_id": incident_id,
"labels": labels or {},
"callback": {"user_id": user_id or 0},
**(extra_context or {}),
}
resolved_params = _resolve_template(spec.mcp_params, context)
# Audit log (all actions)
logger.info(
"dispatch_action_start",
action=action_name,
incident_id=incident_id,
user_id=user_id,
risk=spec.risk,
provider=spec.mcp_provider,
tool=spec.mcp_tool,
params=resolved_params,
)
# MCP 呼叫 (Sprint 5.2 2026-04-14 Claude Sonnet 4.6: 接入真實 MCP registry)
import asyncio
try:
# internal provider: 特殊 URL builder無 MCP call
if spec.mcp_provider == "internal":
result_text = await _handle_internal_action(
spec,
resolved_params,
incident_id=incident_id,
user_id=user_id,
)
duration = (time.perf_counter() - start) * 1000
logger.info("dispatch_action_internal", action=action_name, duration_ms=round(duration, 1))
return DispatchResult(
success=True, action=action_name, incident_id=incident_id,
user_id=user_id, result_text=result_text, duration_ms=duration,
)
# MCP registry dispatch
from src.plugins.mcp.registry import get_provider
from src.services.mcp_audit_context import with_mcp_audit_context
provider_name = _resolve_provider_name(spec.mcp_provider)
provider = get_provider(provider_name)
if not provider:
duration = (time.perf_counter() - start) * 1000
return DispatchResult(
success=False, action=action_name, incident_id=incident_id,
user_id=user_id,
result_text=f"{spec.emoji} {spec.label} 失敗MCP provider '{provider_name}' 未註冊",
error=f"provider_not_found: {provider_name}",
duration_ms=duration,
)
# 執行 MCP tool with timeout
audited_params = with_mcp_audit_context(
resolved_params,
session_id=f"callback:{incident_id}:{action_name}",
incident_id=incident_id,
flywheel_node="operate",
agent_role="telegram_callback_dispatcher",
operator_user_id=user_id,
)
mcp_result = await asyncio.wait_for(
provider.execute(spec.mcp_tool, audited_params),
timeout=float(spec.timeout_sec),
)
duration = (time.perf_counter() - start) * 1000
if mcp_result.success:
result_text = _format_reply(
mcp_result.output, spec.reply_format, spec.label, spec.emoji
)
logger.info(
"dispatch_action_success",
action=action_name,
incident_id=incident_id,
provider=spec.mcp_provider,
tool=spec.mcp_tool,
duration_ms=round(duration, 1),
)
return DispatchResult(
success=True, action=action_name, incident_id=incident_id,
user_id=user_id, result_text=result_text, duration_ms=duration,
)
# MCP returned success=False
result_text = (
f"{spec.emoji} <b>{spec.label}</b> 執行失敗\n"
f"<i>{(mcp_result.error or '未知錯誤')[:200]}</i>"
)
logger.warning(
"dispatch_action_mcp_failed",
action=action_name,
incident_id=incident_id,
error=mcp_result.error,
)
return DispatchResult(
success=False, action=action_name, incident_id=incident_id,
user_id=user_id, result_text=result_text,
error=mcp_result.error, duration_ms=duration,
)
except asyncio.TimeoutError:
duration = (time.perf_counter() - start) * 1000
logger.warning(
"dispatch_action_timeout",
action=action_name, incident_id=incident_id,
timeout_sec=spec.timeout_sec, duration_ms=round(duration, 1),
)
return DispatchResult(
success=False, action=action_name, incident_id=incident_id,
user_id=user_id,
result_text=f"{spec.emoji} {spec.label} 超時 ({spec.timeout_sec}s)",
error="timeout", duration_ms=duration,
)
except Exception as e:
duration = (time.perf_counter() - start) * 1000
logger.error(
"dispatch_action_failed",
action=action_name,
incident_id=incident_id,
error=str(e),
duration_ms=round(duration, 1),
)
return DispatchResult(
success=False,
action=action_name,
incident_id=incident_id,
user_id=user_id,
result_text=f"{spec.emoji} {spec.label} 執行失敗",
error=str(e),
duration_ms=duration,
)
async def _handle_internal_action(
spec: ActionSpec,
params: dict,
*,
incident_id: str,
user_id: int | None,
) -> str:
"""
Internal actions — 不走 MCP直接產生 URL/文字回覆
Sprint 5.2 (2026-04-14 Claude Sonnet 4.6): 處理 open_signoz / open_flywheel /
build_*_url / secops_authorize 等內部 action
"""
tool = spec.mcp_tool
if tool == "build_signoz_url":
service = params.get("service", "unknown")
url = f"https://signoz.wooo.work/services/{service}"
return f"{spec.emoji} <b>{spec.label}</b>\n{url}"
if tool == "build_flywheel_url":
return f"{spec.emoji} <b>{spec.label}</b>\nhttps://awoooi.wooo.work/flywheel"
if tool == "record_authorization":
recorded = await _record_authorization_audit(
spec=spec,
params=params,
incident_id=incident_id,
user_id=user_id,
)
_user_id = params.get("user_id", user_id or 0)
source = params.get("source", "unknown")
action = params.get("action", "authorize")
suffix = "已寫入審計與時間線" if recorded else "已受理;審計寫入將由後續補償"
return (
f"{spec.emoji} <b>{spec.label}</b>\n"
f"已記錄 user={_user_id} 授權 source={source} action={action}24h 內同源告警將靜默)\n"
f"{suffix}"
)
# 未知的 internal tool
return (
f"{spec.emoji} <b>{spec.label}</b>\n"
f"⚠️ Unknown internal tool: {tool}"
)
async def _record_authorization_audit(
*,
spec: ActionSpec,
params: dict,
incident_id: str,
user_id: int | None,
) -> bool:
"""Best-effort persistence for internal authorization actions."""
source = str(params.get("source") or "unknown")
requested_action = str(params.get("action") or spec.name)
source_ip = str(params.get("source_ip") or "")
actor = f"telegram:{user_id or params.get('user_id') or 0}"
context = {
"action": spec.name,
"label": spec.label,
"risk": spec.risk,
"category": spec.category,
"requested_action": requested_action,
"source": source,
"source_ip": source_ip,
"user_id": user_id or params.get("user_id") or 0,
"requires_multi_sig": spec.requires_multi_sig,
}
wrote_any = False
try:
from src.core.redis_client import get_redis
redis = get_redis()
redis_key = f"secops:authorization:{source}"
await redis.set(redis_key, json.dumps(context, ensure_ascii=False), ex=86400)
wrote_any = True
except Exception as exc:
logger.warning(
"record_authorization_redis_failed",
incident_id=incident_id,
source=source,
error=str(exc),
)
try:
from src.repositories.alert_operation_log_repository import (
get_alert_operation_log_repository,
)
event_type = "APPROVAL_ESCALATED" if spec.requires_multi_sig or spec.risk == "critical" else "USER_ACTION"
record = await get_alert_operation_log_repository().append(
event_type,
incident_id=incident_id,
actor=actor,
action_detail=f"telegram_authorization:{requested_action}"[:200],
success=True,
context=context,
)
wrote_any = wrote_any or bool(record)
except Exception as exc:
logger.warning(
"record_authorization_aol_failed",
incident_id=incident_id,
source=source,
error=str(exc),
)
try:
from src.services.approval_db import get_timeline_service
await get_timeline_service().add_event(
event_type="security",
status="warning" if spec.requires_multi_sig or spec.risk == "critical" else "info",
title="Telegram authorization recorded",
description=(
f"action={requested_action} source={source} source_ip={source_ip or 'unknown'}"
)[:500],
actor=actor,
actor_role="secops_authorization",
risk_level=spec.risk,
incident_id=incident_id,
)
wrote_any = True
except Exception as exc:
logger.warning(
"record_authorization_timeline_failed",
incident_id=incident_id,
source=source,
error=str(exc),
)
logger.info(
"record_authorization_audit_complete",
incident_id=incident_id,
source=source,
action=requested_action,
wrote_any=wrote_any,
)
return wrote_any
def _format_reply(
mcp_result: Any, reply_format: str, label: str, emoji: str
) -> str:
"""
依 spec 格式化 reply 文字。
reply_format:
- text: 單行文字
- code: <code>...</code>
- truncated: 截斷到 500 字
- url: 直接返回 URL
"""
header = f"{emoji} <b>{label}</b>"
if reply_format == "url":
return f"{header}\n{mcp_result}"
if reply_format == "code":
return f"{header}\n<code>{str(mcp_result)[:800]}</code>"
if reply_format == "truncated":
text = str(mcp_result)[:500]
if len(str(mcp_result)) > 500:
text += "...\n<i>(已截斷)</i>"
return f"{header}\n<pre>{text}</pre>"
return f"{header}\n{mcp_result}"
# =============================================================================
# B2: LLM Dynamic Action Dispatcher
# 2026-04-27 Claude Sonnet 4.6: B2 — dispatch_llm_action()
# 支援 RecommendedAction 結構化動作的風險閘控 + allowlist 驗證 + 模板渲染
# ADR-082 §B2LLM 動態 MCP 規格派發安全閘
# =============================================================================
import re as _re
def _render_llm_params(params: dict[str, str], context: dict) -> dict[str, str]:
"""
渲染 RecommendedAction.params 模板。
支援兩個命名空間:
- {labels.xxx} → context["labels"]["xxx"]
- {context.xxx} → context["xxx"](如 context.incident_id
- {incident_id} → context["incident_id"](舊式相容)
渲染失敗的 key → 保留原始字串,不 crash。
"""
def _repl(m: _re.Match) -> str:
key = m.group(1)
parts = key.split(".", 1)
try:
if parts[0] == "labels" and len(parts) == 2:
val = (context.get("labels") or {}).get(parts[1])
return str(val) if val is not None else m.group(0)
if parts[0] == "context" and len(parts) == 2:
val = context.get(parts[1])
return str(val) if val is not None else m.group(0)
# 舊式:直接 top-level key如 {incident_id}
val = context.get(key)
return str(val) if val is not None else m.group(0)
except Exception:
return m.group(0)
rendered: dict[str, str] = {}
for k, v in params.items():
if isinstance(v, str) and "{" in v:
try:
rendered[k] = _re.sub(r"\{([a-zA-Z0-9_.]+)\}", _repl, v)
except Exception:
rendered[k] = v
else:
rendered[k] = v
return rendered
def _load_llm_tool_registry() -> dict[str, dict]:
"""
Lazy import _load_mcp_tool_registry from solver_agent避免 circular import。
失敗時返回 {} 並 log warning不 crash
"""
try:
from src.agents.solver_agent import _load_mcp_tool_registry # noqa: PLC0415
return _load_mcp_tool_registry()
except Exception as exc:
logger.warning("llm_dispatch_registry_load_failed", error=str(exc))
return {}
def dispatch_llm_action(
action: Any,
context: dict,
) -> dict:
"""
B2: LLM 動態 MCP 規格派發閘控器
安全層次(依序執行):
1. Risk Gating — critical 直接拒絕high 需要 confirmed=True
2. Allowlist — mcp_tool 必須在 registry 中
3. Params 渲染 — 支援 {labels.xxx} / {context.xxx} / {incident_id}
4. Nonce 生成 — medium/high 允許時寫 Redis SET NX TTL=300s 防重放
Args:
action: RecommendedAction dataclass來自 solver_agent B1 輸出)
context: 執行上下文 dict含 labels / incident_id / confirmed 等)
Returns:
dict — ok=True 為允許執行ok=False 附 reason 拒絕原因
"""
# 2026-04-27 Claude Sonnet 4.6: H2 Fix — nonce 改用 secrets.token_hex(16)
import secrets as _secrets # noqa: PLC0415
risk: str = getattr(action, "risk", "medium")
mcp_tool: str = getattr(action, "mcp_tool", "")
mcp_provider: str = getattr(action, "mcp_provider", "")
name: str = getattr(action, "name", "")
params: dict = dict(getattr(action, "params", {}) or {})
# ── M1: params 型別驗證(所有 value 必須是 str────────────────────────────
# 2026-04-27 Claude Sonnet 4.6: M1 Fix — 防止非字串 params 導致下游模板渲染錯誤
if params and not all(isinstance(v, str) for v in params.values()):
logger.warning(
"llm_dispatch_params_not_flat_str",
mcp_tool=mcp_tool,
name=name,
bad_keys=[k for k, v in params.items() if not isinstance(v, str)],
)
return {"ok": False, "reason": "params_not_flat_str"}
# ── 1. Risk Gating ────────────────────────────────────────────────────────
if risk == "critical":
logger.warning(
"llm_dispatch_critical_rejected",
mcp_tool=mcp_tool,
name=name,
incident_id=context.get("incident_id"),
)
return {"ok": False, "reason": "critical_action_rejected"}
if risk == "high":
if not context.get("confirmed"):
# 2026-04-27 Claude Sonnet 4.6: H2 Fix — 純字串 nonce不寫 Redis此路徑只回拒絕
pending_nonce = _secrets.token_hex(16)
logger.info(
"llm_dispatch_high_risk_pending",
mcp_tool=mcp_tool,
name=name,
incident_id=context.get("incident_id"),
)
return {
"ok": False,
"reason": "high_risk_requires_confirmation",
"nonce": pending_nonce,
}
# ── 2. Allowlist 驗證 ─────────────────────────────────────────────────────
registry = _load_llm_tool_registry()
if mcp_tool not in registry:
logger.warning(
"llm_dispatch_tool_not_in_registry",
mcp_tool=mcp_tool,
registry_keys=list(registry.keys()),
)
return {"ok": False, "reason": "tool_not_in_registry"}
# ── 3. Params 模板渲染 ────────────────────────────────────────────────────
rendered_params = _render_llm_params(params, context)
# ── 4. Nonce 生成medium/high 允許時) ───────────────────────────────────
# 2026-04-27 Claude Sonnet 4.6: H2 Fix — secrets.token_hex(16) 取代時間戳拼接
nonce: str | None = None
if risk in ("medium", "high"):
nonce = _secrets.token_hex(16)
logger.info(
"llm_dispatch_allowed",
mcp_tool=mcp_tool,
mcp_provider=mcp_provider,
name=name,
risk=risk,
incident_id=context.get("incident_id"),
has_nonce=nonce is not None,
)
return {
"ok": True,
"mcp_provider": mcp_provider,
"mcp_tool": mcp_tool,
"params": rendered_params,
"risk": risk,
"nonce": nonce,
"button_source": "llm",
}