"""MCP audit context helpers. Legacy production callers still use the provider registry directly while AwoooP MCP Gateway is rolled in by lane. These helpers make those calls observable without changing execution semantics. """ from __future__ import annotations import hashlib from typing import Any MAX_MCP_AUDIT_SESSION_ID_LENGTH = 36 _STAGE_ALIASES = { "pre_decision": "pre", "post_execution": "post", } def _digest(value: str) -> str: return hashlib.sha1(value.encode("utf-8", errors="ignore")).hexdigest()[:8] def _compact_with_hash(prefix: str, stable_part: str, raw: str) -> str: safe_prefix = "".join( char for char in prefix if char.isalnum() or char in "-_" )[:8] or "mcp" digest = _digest(raw) head_limit = ( MAX_MCP_AUDIT_SESSION_ID_LENGTH - len(safe_prefix) - len(digest) - 2 ) head = str(stable_part)[:max(1, head_limit)] compacted = f"{safe_prefix}:{head}:{digest}" return compacted[:MAX_MCP_AUDIT_SESSION_ID_LENGTH] def normalize_mcp_audit_session_id(session_id: Any | None) -> str | None: """Normalize MCP audit session IDs to the legacy DB column length.""" if session_id is None: return None raw = str(session_id) if len(raw) <= MAX_MCP_AUDIT_SESSION_ID_LENGTH: return raw parts = raw.split(":") if len(parts) >= 3 and parts[0] == "incident": stage = _STAGE_ALIASES.get(parts[-1], parts[-1][:6]) candidate = f"inc:{parts[1]}:{stage}" if len(candidate) <= MAX_MCP_AUDIT_SESSION_ID_LENGTH: return candidate if len(parts) >= 3 and parts[0] == "callback": return _compact_with_hash("cb", parts[1], raw) if len(parts) >= 2 and parts[0] == "approval": return _compact_with_hash("apr", parts[1], raw) if len(parts) >= 2 and parts[0] == "mcp_bridge": return _compact_with_hash("bridge", parts[1], raw) return _compact_with_hash(parts[0] if parts else "mcp", raw, raw) def build_mcp_audit_context( *, session_id: str | None = None, incident_id: str | None = None, flywheel_node: str | None = None, agent_role: str | None = None, gateway_path: str = "legacy_registry_provider", operator_user_id: int | str | None = None, ) -> dict[str, Any]: """Build the `_mcp_audit` metadata object carried beside tool params.""" context: dict[str, Any] = { "gateway_path": gateway_path, } optional_values = { "session_id": normalize_mcp_audit_session_id(session_id), "incident_id": incident_id, "flywheel_node": flywheel_node, "agent_role": agent_role, "operator_user_id": operator_user_id, } context.update({key: value for key, value in optional_values.items() if value is not None}) return context def with_mcp_audit_context( parameters: dict[str, Any], **audit_kwargs: Any, ) -> dict[str, Any]: """Return a shallow copy of provider parameters with merged audit context.""" audited = dict(parameters) existing = audited.get("_mcp_audit") merged = dict(existing) if isinstance(existing, dict) else {} merged.update(build_mcp_audit_context(**audit_kwargs)) audited["_mcp_audit"] = merged return audited