Files
awoooi/apps/api/src/services/mcp_audit_context.py
Your Name 0e2e856f12
All checks were successful
Code Review / ai-code-review (push) Successful in 11s
CD Pipeline / tests (push) Successful in 58s
CD Pipeline / build-and-deploy (push) Successful in 4m39s
CD Pipeline / post-deploy-checks (push) Successful in 1m17s
fix(mcp): normalize audit session ids
2026-05-06 17:40:42 +08:00

103 lines
3.2 KiB
Python

"""MCP audit context helpers.
Legacy production callers still use the provider registry directly while
AwoooP MCP Gateway is rolled in by lane. These helpers make those calls
observable without changing execution semantics.
"""
from __future__ import annotations
import hashlib
from typing import Any
MAX_MCP_AUDIT_SESSION_ID_LENGTH = 36
_STAGE_ALIASES = {
"pre_decision": "pre",
"post_execution": "post",
}
def _digest(value: str) -> str:
return hashlib.sha1(value.encode("utf-8", errors="ignore")).hexdigest()[:8]
def _compact_with_hash(prefix: str, stable_part: str, raw: str) -> str:
safe_prefix = "".join(
char for char in prefix if char.isalnum() or char in "-_"
)[:8] or "mcp"
digest = _digest(raw)
head_limit = (
MAX_MCP_AUDIT_SESSION_ID_LENGTH
- len(safe_prefix)
- len(digest)
- 2
)
head = str(stable_part)[:max(1, head_limit)]
compacted = f"{safe_prefix}:{head}:{digest}"
return compacted[:MAX_MCP_AUDIT_SESSION_ID_LENGTH]
def normalize_mcp_audit_session_id(session_id: Any | None) -> str | None:
"""Normalize MCP audit session IDs to the legacy DB column length."""
if session_id is None:
return None
raw = str(session_id)
if len(raw) <= MAX_MCP_AUDIT_SESSION_ID_LENGTH:
return raw
parts = raw.split(":")
if len(parts) >= 3 and parts[0] == "incident":
stage = _STAGE_ALIASES.get(parts[-1], parts[-1][:6])
candidate = f"inc:{parts[1]}:{stage}"
if len(candidate) <= MAX_MCP_AUDIT_SESSION_ID_LENGTH:
return candidate
if len(parts) >= 3 and parts[0] == "callback":
return _compact_with_hash("cb", parts[1], raw)
if len(parts) >= 2 and parts[0] == "approval":
return _compact_with_hash("apr", parts[1], raw)
if len(parts) >= 2 and parts[0] == "mcp_bridge":
return _compact_with_hash("bridge", parts[1], raw)
return _compact_with_hash(parts[0] if parts else "mcp", raw, raw)
def build_mcp_audit_context(
*,
session_id: str | None = None,
incident_id: str | None = None,
flywheel_node: str | None = None,
agent_role: str | None = None,
gateway_path: str = "legacy_registry_provider",
operator_user_id: int | str | None = None,
) -> dict[str, Any]:
"""Build the `_mcp_audit` metadata object carried beside tool params."""
context: dict[str, Any] = {
"gateway_path": gateway_path,
}
optional_values = {
"session_id": normalize_mcp_audit_session_id(session_id),
"incident_id": incident_id,
"flywheel_node": flywheel_node,
"agent_role": agent_role,
"operator_user_id": operator_user_id,
}
context.update({key: value for key, value in optional_values.items() if value is not None})
return context
def with_mcp_audit_context(
parameters: dict[str, Any],
**audit_kwargs: Any,
) -> dict[str, Any]:
"""Return a shallow copy of provider parameters with merged audit context."""
audited = dict(parameters)
existing = audited.get("_mcp_audit")
merged = dict(existing) if isinstance(existing, dict) else {}
merged.update(build_mcp_audit_context(**audit_kwargs))
audited["_mcp_audit"] = merged
return audited