Files
awoooi/apps/api/src/services/mcp_audit_service.py
Your Name 94d006eac8
All checks were successful
Code Review / ai-code-review (push) Successful in 10s
CD Pipeline / tests (push) Successful in 1m4s
CD Pipeline / build-and-deploy (push) Successful in 3m22s
CD Pipeline / post-deploy-checks (push) Successful in 1m14s
feat(awooop): bridge legacy mcp audit into gateway timeline
2026-05-12 23:44:19 +08:00

306 lines
11 KiB
Python

"""MCP audit and daily usage statistics.
Every MCP provider call should leave a durable trace that can be joined back to
incidents, flywheel nodes, and provider health. This service intentionally uses
raw SQL because the audit schema is additive and may be created by migration
before every runtime model has been refreshed.
"""
from __future__ import annotations
import json
import hashlib
import time
import uuid
from typing import Any
import structlog
from sqlalchemy import text
from src.db.base import get_db_context
from src.services.mcp_audit_context import normalize_mcp_audit_session_id
logger = structlog.get_logger(__name__)
_REDACT_KEYS = {"token", "password", "secret", "api_key", "authorization", "key"}
_LEGACY_PROJECT_ID = "awoooi"
_REAL_GATEWAY_PATH = "awooop_mcp_gateway"
def infer_flywheel_node(mcp_server: str, tool_name: str) -> str:
"""Infer the flywheel node for a provider/tool pair."""
name = f"{mcp_server}:{tool_name}".lower()
if any(k in name for k in ("prometheus", "alert", "metric", "query_range")):
return "detect"
if any(k in name for k in ("logs", "describe", "events", "node", "hpa", "ssh_get")):
return "sense"
if any(k in name for k in ("database", "rag", "knowledge", "history")):
return "reason"
if any(k in name for k in ("blast", "risk", "approval")):
return "decide"
if any(k in name for k in ("restart", "delete", "scale", "rollout", "ssh_restart")):
return "execute"
if any(k in name for k in ("watch", "status", "health", "grafana")):
return "verify"
if any(k in name for k in ("playbook", "km", "embedding")):
return "learn"
return "govern"
def _redact(value: Any) -> Any:
if isinstance(value, dict):
redacted = {}
for key, item in value.items():
if any(marker in str(key).lower() for marker in _REDACT_KEYS):
redacted[key] = "<redacted>"
else:
redacted[key] = _redact(item)
return redacted
if isinstance(value, list):
return [_redact(item) for item in value]
return value
def _json_dumps(value: Any) -> str:
return json.dumps(_redact(value), ensure_ascii=False, default=str)
def _extract_incident_id(parameters: dict[str, Any]) -> str | None:
audit_context = parameters.get("_mcp_audit")
if isinstance(audit_context, dict) and audit_context.get("incident_id"):
return str(audit_context["incident_id"])
for key in ("incident_id", "incidentId"):
value = parameters.get(key)
if value:
return str(value)
return None
def _extract_audit_context(parameters: dict[str, Any]) -> dict[str, Any]:
audit_context = parameters.get("_mcp_audit")
return audit_context if isinstance(audit_context, dict) else {}
def _extract_project_id(
parameters: dict[str, Any],
audit_context: dict[str, Any],
) -> str:
"""Legacy MCP calls predate AwoooP tenancy; keep the bridge explicit."""
for value in (audit_context.get("project_id"), parameters.get("project_id")):
if value:
return str(value)
return _LEGACY_PROJECT_ID
def _extract_run_id(audit_context: dict[str, Any]) -> uuid.UUID | None:
value = audit_context.get("run_id")
if not value:
return None
try:
return uuid.UUID(str(value))
except (TypeError, ValueError):
return None
def _compact(value: Any, limit: int) -> str | None:
if value is None:
return None
return str(value)[:limit]
def _stable_hash(value: Any) -> str:
canonical = json.dumps(value, sort_keys=True, ensure_ascii=False, default=str)
return hashlib.sha256(canonical.encode()).hexdigest()
def _bridge_tool_name(mcp_server: str, tool_name: str) -> str:
combined = f"legacy:{mcp_server}:{tool_name}"
if len(combined) <= 128:
return combined
return str(tool_name)[:128]
def _should_bridge_to_awooop(audit_context: dict[str, Any]) -> bool:
"""Skip real Gateway calls; they already write first-class audit rows."""
return audit_context.get("gateway_path") != _REAL_GATEWAY_PATH
async def _record_awooop_gateway_bridge(
db: Any,
*,
project_id: str,
run_id: uuid.UUID | None,
trace_id: str | None,
agent_id: str | None,
mcp_server: str,
tool_name: str,
input_params: dict[str, Any],
output_result: Any | None,
duration_ms: int,
success: bool,
error_message: str | None,
flywheel_node: str | None,
audit_context: dict[str, Any],
) -> None:
"""Mirror legacy direct provider calls into AwoooP audit as bridge rows."""
result_status = "success" if success else "failed"
gate_result = {
"schema_version": "legacy_mcp_bridge_v1",
"gateway_path": audit_context.get("gateway_path") or "legacy_registry_provider",
"policy_enforced": False,
"not_used_reason": "legacy direct provider path; bridge audit only",
"legacy_mcp_server": mcp_server,
"legacy_tool_name": tool_name,
"flywheel_node": flywheel_node,
}
await db.execute(
text(
"""
INSERT INTO awooop_mcp_gateway_audit (
project_id, run_id, trace_id, agent_id,
tool_name, input_hash, output_hash, gate_result,
result_status, block_gate, block_reason, latency_ms
)
VALUES (
:project_id, :run_id, :trace_id, :agent_id,
:tool_name, :input_hash, :output_hash, CAST(:gate_result AS jsonb),
:result_status, NULL, :block_reason, :latency_ms
)
"""
),
{
"project_id": project_id,
"run_id": run_id,
"trace_id": _compact(trace_id, 128),
"agent_id": _compact(agent_id or "legacy-mcp-provider", 128),
"tool_name": _bridge_tool_name(mcp_server, tool_name),
"input_hash": _stable_hash(input_params),
"output_hash": _stable_hash(output_result) if output_result is not None else None,
"gate_result": json.dumps(gate_result, ensure_ascii=False, default=str),
"result_status": result_status,
"block_reason": _compact(error_message, 256) if not success else None,
"latency_ms": duration_ms,
},
)
async def record_mcp_call(
*,
mcp_server: str,
tool_name: str,
input_params: dict[str, Any],
output_result: Any | None,
duration_ms: int,
success: bool,
error_message: str | None,
session_id: str | None = None,
flywheel_node: str | None = None,
incident_id: str | None = None,
agent_role: str | None = None,
) -> None:
"""Persist one MCP tool call and update daily aggregate stats."""
audit_context = _extract_audit_context(input_params)
session_id = normalize_mcp_audit_session_id(
session_id or audit_context.get("session_id") or str(uuid.uuid4())
)
flywheel_node = flywheel_node or infer_flywheel_node(mcp_server, tool_name)
incident_id = incident_id or _extract_incident_id(input_params)
agent_role = agent_role or audit_context.get("agent_role")
project_id = _extract_project_id(input_params, audit_context)
run_id = _extract_run_id(audit_context)
trace_id = audit_context.get("trace_id") or incident_id or session_id
agent_id = audit_context.get("agent_id") or agent_role
try:
async with get_db_context(project_id) as db:
await db.execute(
text(
"""
INSERT INTO mcp_audit_log (
session_id, flywheel_node, mcp_server, tool_name,
input_params, output_result, duration_ms, success,
error_message, incident_id, agent_role
)
VALUES (
:session_id, :flywheel_node, :mcp_server, :tool_name,
CAST(:input_params AS jsonb), CAST(:output_result AS jsonb),
:duration_ms, :success, :error_message, :incident_id,
:agent_role
)
"""
),
{
"session_id": session_id,
"flywheel_node": flywheel_node,
"mcp_server": mcp_server,
"tool_name": tool_name,
"input_params": _json_dumps(input_params),
"output_result": _json_dumps(output_result),
"duration_ms": duration_ms,
"success": success,
"error_message": error_message,
"incident_id": incident_id,
"agent_role": agent_role,
},
)
await db.execute(
text(
"""
INSERT INTO mcp_daily_stats (
date, mcp_server, tool_name, call_count, success_count,
avg_duration_ms
)
VALUES (
CURRENT_DATE, :mcp_server, :tool_name, 1,
CASE WHEN :success THEN 1 ELSE 0 END,
:duration_ms
)
ON CONFLICT (date, mcp_server, tool_name)
DO UPDATE SET
call_count = mcp_daily_stats.call_count + 1,
success_count = mcp_daily_stats.success_count
+ CASE WHEN EXCLUDED.success_count > 0 THEN 1 ELSE 0 END,
avg_duration_ms = (
(mcp_daily_stats.avg_duration_ms * mcp_daily_stats.call_count)
+ EXCLUDED.avg_duration_ms
) / (mcp_daily_stats.call_count + 1)
"""
),
{
"mcp_server": mcp_server,
"tool_name": tool_name,
"success": success,
"duration_ms": duration_ms,
},
)
if _should_bridge_to_awooop(audit_context):
await _record_awooop_gateway_bridge(
db,
project_id=project_id,
run_id=run_id,
trace_id=trace_id,
agent_id=agent_id,
mcp_server=mcp_server,
tool_name=tool_name,
input_params=input_params,
output_result=output_result,
duration_ms=duration_ms,
success=success,
error_message=error_message,
flywheel_node=flywheel_node,
audit_context=audit_context,
)
except Exception as exc:
logger.warning(
"mcp_audit_write_failed",
mcp_server=mcp_server,
tool_name=tool_name,
error=str(exc),
)
def monotonic_ms() -> int:
return int(time.monotonic() * 1000)