306 lines
11 KiB
Python
306 lines
11 KiB
Python
"""MCP audit and daily usage statistics.
|
|
|
|
Every MCP provider call should leave a durable trace that can be joined back to
|
|
incidents, flywheel nodes, and provider health. This service intentionally uses
|
|
raw SQL because the audit schema is additive and may be created by migration
|
|
before every runtime model has been refreshed.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import hashlib
|
|
import time
|
|
import uuid
|
|
from typing import Any
|
|
|
|
import structlog
|
|
from sqlalchemy import text
|
|
|
|
from src.db.base import get_db_context
|
|
from src.services.mcp_audit_context import normalize_mcp_audit_session_id
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
_REDACT_KEYS = {"token", "password", "secret", "api_key", "authorization", "key"}
|
|
_LEGACY_PROJECT_ID = "awoooi"
|
|
_REAL_GATEWAY_PATH = "awooop_mcp_gateway"
|
|
|
|
|
|
def infer_flywheel_node(mcp_server: str, tool_name: str) -> str:
|
|
"""Infer the flywheel node for a provider/tool pair."""
|
|
|
|
name = f"{mcp_server}:{tool_name}".lower()
|
|
if any(k in name for k in ("prometheus", "alert", "metric", "query_range")):
|
|
return "detect"
|
|
if any(k in name for k in ("logs", "describe", "events", "node", "hpa", "ssh_get")):
|
|
return "sense"
|
|
if any(k in name for k in ("database", "rag", "knowledge", "history")):
|
|
return "reason"
|
|
if any(k in name for k in ("blast", "risk", "approval")):
|
|
return "decide"
|
|
if any(k in name for k in ("restart", "delete", "scale", "rollout", "ssh_restart")):
|
|
return "execute"
|
|
if any(k in name for k in ("watch", "status", "health", "grafana")):
|
|
return "verify"
|
|
if any(k in name for k in ("playbook", "km", "embedding")):
|
|
return "learn"
|
|
return "govern"
|
|
|
|
|
|
def _redact(value: Any) -> Any:
|
|
if isinstance(value, dict):
|
|
redacted = {}
|
|
for key, item in value.items():
|
|
if any(marker in str(key).lower() for marker in _REDACT_KEYS):
|
|
redacted[key] = "<redacted>"
|
|
else:
|
|
redacted[key] = _redact(item)
|
|
return redacted
|
|
if isinstance(value, list):
|
|
return [_redact(item) for item in value]
|
|
return value
|
|
|
|
|
|
def _json_dumps(value: Any) -> str:
|
|
return json.dumps(_redact(value), ensure_ascii=False, default=str)
|
|
|
|
|
|
def _extract_incident_id(parameters: dict[str, Any]) -> str | None:
|
|
audit_context = parameters.get("_mcp_audit")
|
|
if isinstance(audit_context, dict) and audit_context.get("incident_id"):
|
|
return str(audit_context["incident_id"])
|
|
for key in ("incident_id", "incidentId"):
|
|
value = parameters.get(key)
|
|
if value:
|
|
return str(value)
|
|
return None
|
|
|
|
|
|
def _extract_audit_context(parameters: dict[str, Any]) -> dict[str, Any]:
|
|
audit_context = parameters.get("_mcp_audit")
|
|
return audit_context if isinstance(audit_context, dict) else {}
|
|
|
|
|
|
def _extract_project_id(
|
|
parameters: dict[str, Any],
|
|
audit_context: dict[str, Any],
|
|
) -> str:
|
|
"""Legacy MCP calls predate AwoooP tenancy; keep the bridge explicit."""
|
|
for value in (audit_context.get("project_id"), parameters.get("project_id")):
|
|
if value:
|
|
return str(value)
|
|
return _LEGACY_PROJECT_ID
|
|
|
|
|
|
def _extract_run_id(audit_context: dict[str, Any]) -> uuid.UUID | None:
|
|
value = audit_context.get("run_id")
|
|
if not value:
|
|
return None
|
|
try:
|
|
return uuid.UUID(str(value))
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
|
|
def _compact(value: Any, limit: int) -> str | None:
|
|
if value is None:
|
|
return None
|
|
return str(value)[:limit]
|
|
|
|
|
|
def _stable_hash(value: Any) -> str:
|
|
canonical = json.dumps(value, sort_keys=True, ensure_ascii=False, default=str)
|
|
return hashlib.sha256(canonical.encode()).hexdigest()
|
|
|
|
|
|
def _bridge_tool_name(mcp_server: str, tool_name: str) -> str:
|
|
combined = f"legacy:{mcp_server}:{tool_name}"
|
|
if len(combined) <= 128:
|
|
return combined
|
|
return str(tool_name)[:128]
|
|
|
|
|
|
def _should_bridge_to_awooop(audit_context: dict[str, Any]) -> bool:
|
|
"""Skip real Gateway calls; they already write first-class audit rows."""
|
|
return audit_context.get("gateway_path") != _REAL_GATEWAY_PATH
|
|
|
|
|
|
async def _record_awooop_gateway_bridge(
|
|
db: Any,
|
|
*,
|
|
project_id: str,
|
|
run_id: uuid.UUID | None,
|
|
trace_id: str | None,
|
|
agent_id: str | None,
|
|
mcp_server: str,
|
|
tool_name: str,
|
|
input_params: dict[str, Any],
|
|
output_result: Any | None,
|
|
duration_ms: int,
|
|
success: bool,
|
|
error_message: str | None,
|
|
flywheel_node: str | None,
|
|
audit_context: dict[str, Any],
|
|
) -> None:
|
|
"""Mirror legacy direct provider calls into AwoooP audit as bridge rows."""
|
|
result_status = "success" if success else "failed"
|
|
gate_result = {
|
|
"schema_version": "legacy_mcp_bridge_v1",
|
|
"gateway_path": audit_context.get("gateway_path") or "legacy_registry_provider",
|
|
"policy_enforced": False,
|
|
"not_used_reason": "legacy direct provider path; bridge audit only",
|
|
"legacy_mcp_server": mcp_server,
|
|
"legacy_tool_name": tool_name,
|
|
"flywheel_node": flywheel_node,
|
|
}
|
|
await db.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO awooop_mcp_gateway_audit (
|
|
project_id, run_id, trace_id, agent_id,
|
|
tool_name, input_hash, output_hash, gate_result,
|
|
result_status, block_gate, block_reason, latency_ms
|
|
)
|
|
VALUES (
|
|
:project_id, :run_id, :trace_id, :agent_id,
|
|
:tool_name, :input_hash, :output_hash, CAST(:gate_result AS jsonb),
|
|
:result_status, NULL, :block_reason, :latency_ms
|
|
)
|
|
"""
|
|
),
|
|
{
|
|
"project_id": project_id,
|
|
"run_id": run_id,
|
|
"trace_id": _compact(trace_id, 128),
|
|
"agent_id": _compact(agent_id or "legacy-mcp-provider", 128),
|
|
"tool_name": _bridge_tool_name(mcp_server, tool_name),
|
|
"input_hash": _stable_hash(input_params),
|
|
"output_hash": _stable_hash(output_result) if output_result is not None else None,
|
|
"gate_result": json.dumps(gate_result, ensure_ascii=False, default=str),
|
|
"result_status": result_status,
|
|
"block_reason": _compact(error_message, 256) if not success else None,
|
|
"latency_ms": duration_ms,
|
|
},
|
|
)
|
|
|
|
|
|
async def record_mcp_call(
|
|
*,
|
|
mcp_server: str,
|
|
tool_name: str,
|
|
input_params: dict[str, Any],
|
|
output_result: Any | None,
|
|
duration_ms: int,
|
|
success: bool,
|
|
error_message: str | None,
|
|
session_id: str | None = None,
|
|
flywheel_node: str | None = None,
|
|
incident_id: str | None = None,
|
|
agent_role: str | None = None,
|
|
) -> None:
|
|
"""Persist one MCP tool call and update daily aggregate stats."""
|
|
|
|
audit_context = _extract_audit_context(input_params)
|
|
session_id = normalize_mcp_audit_session_id(
|
|
session_id or audit_context.get("session_id") or str(uuid.uuid4())
|
|
)
|
|
flywheel_node = flywheel_node or infer_flywheel_node(mcp_server, tool_name)
|
|
incident_id = incident_id or _extract_incident_id(input_params)
|
|
agent_role = agent_role or audit_context.get("agent_role")
|
|
project_id = _extract_project_id(input_params, audit_context)
|
|
run_id = _extract_run_id(audit_context)
|
|
trace_id = audit_context.get("trace_id") or incident_id or session_id
|
|
agent_id = audit_context.get("agent_id") or agent_role
|
|
|
|
try:
|
|
async with get_db_context(project_id) as db:
|
|
await db.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO mcp_audit_log (
|
|
session_id, flywheel_node, mcp_server, tool_name,
|
|
input_params, output_result, duration_ms, success,
|
|
error_message, incident_id, agent_role
|
|
)
|
|
VALUES (
|
|
:session_id, :flywheel_node, :mcp_server, :tool_name,
|
|
CAST(:input_params AS jsonb), CAST(:output_result AS jsonb),
|
|
:duration_ms, :success, :error_message, :incident_id,
|
|
:agent_role
|
|
)
|
|
"""
|
|
),
|
|
{
|
|
"session_id": session_id,
|
|
"flywheel_node": flywheel_node,
|
|
"mcp_server": mcp_server,
|
|
"tool_name": tool_name,
|
|
"input_params": _json_dumps(input_params),
|
|
"output_result": _json_dumps(output_result),
|
|
"duration_ms": duration_ms,
|
|
"success": success,
|
|
"error_message": error_message,
|
|
"incident_id": incident_id,
|
|
"agent_role": agent_role,
|
|
},
|
|
)
|
|
await db.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO mcp_daily_stats (
|
|
date, mcp_server, tool_name, call_count, success_count,
|
|
avg_duration_ms
|
|
)
|
|
VALUES (
|
|
CURRENT_DATE, :mcp_server, :tool_name, 1,
|
|
CASE WHEN :success THEN 1 ELSE 0 END,
|
|
:duration_ms
|
|
)
|
|
ON CONFLICT (date, mcp_server, tool_name)
|
|
DO UPDATE SET
|
|
call_count = mcp_daily_stats.call_count + 1,
|
|
success_count = mcp_daily_stats.success_count
|
|
+ CASE WHEN EXCLUDED.success_count > 0 THEN 1 ELSE 0 END,
|
|
avg_duration_ms = (
|
|
(mcp_daily_stats.avg_duration_ms * mcp_daily_stats.call_count)
|
|
+ EXCLUDED.avg_duration_ms
|
|
) / (mcp_daily_stats.call_count + 1)
|
|
"""
|
|
),
|
|
{
|
|
"mcp_server": mcp_server,
|
|
"tool_name": tool_name,
|
|
"success": success,
|
|
"duration_ms": duration_ms,
|
|
},
|
|
)
|
|
if _should_bridge_to_awooop(audit_context):
|
|
await _record_awooop_gateway_bridge(
|
|
db,
|
|
project_id=project_id,
|
|
run_id=run_id,
|
|
trace_id=trace_id,
|
|
agent_id=agent_id,
|
|
mcp_server=mcp_server,
|
|
tool_name=tool_name,
|
|
input_params=input_params,
|
|
output_result=output_result,
|
|
duration_ms=duration_ms,
|
|
success=success,
|
|
error_message=error_message,
|
|
flywheel_node=flywheel_node,
|
|
audit_context=audit_context,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"mcp_audit_write_failed",
|
|
mcp_server=mcp_server,
|
|
tool_name=tool_name,
|
|
error=str(exc),
|
|
)
|
|
|
|
|
|
def monotonic_ms() -> int:
|
|
return int(time.monotonic() * 1000)
|