103 lines
3.2 KiB
Python
103 lines
3.2 KiB
Python
"""MCP audit context helpers.
|
|
|
|
Legacy production callers still use the provider registry directly while
|
|
AwoooP MCP Gateway is rolled in by lane. These helpers make those calls
|
|
observable without changing execution semantics.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
from typing import Any
|
|
|
|
MAX_MCP_AUDIT_SESSION_ID_LENGTH = 36
|
|
|
|
_STAGE_ALIASES = {
|
|
"pre_decision": "pre",
|
|
"post_execution": "post",
|
|
}
|
|
|
|
|
|
def _digest(value: str) -> str:
|
|
return hashlib.sha1(value.encode("utf-8", errors="ignore")).hexdigest()[:8]
|
|
|
|
|
|
def _compact_with_hash(prefix: str, stable_part: str, raw: str) -> str:
|
|
safe_prefix = "".join(
|
|
char for char in prefix if char.isalnum() or char in "-_"
|
|
)[:8] or "mcp"
|
|
digest = _digest(raw)
|
|
head_limit = (
|
|
MAX_MCP_AUDIT_SESSION_ID_LENGTH
|
|
- len(safe_prefix)
|
|
- len(digest)
|
|
- 2
|
|
)
|
|
head = str(stable_part)[:max(1, head_limit)]
|
|
compacted = f"{safe_prefix}:{head}:{digest}"
|
|
return compacted[:MAX_MCP_AUDIT_SESSION_ID_LENGTH]
|
|
|
|
|
|
def normalize_mcp_audit_session_id(session_id: Any | None) -> str | None:
|
|
"""Normalize MCP audit session IDs to the legacy DB column length."""
|
|
|
|
if session_id is None:
|
|
return None
|
|
raw = str(session_id)
|
|
if len(raw) <= MAX_MCP_AUDIT_SESSION_ID_LENGTH:
|
|
return raw
|
|
|
|
parts = raw.split(":")
|
|
if len(parts) >= 3 and parts[0] == "incident":
|
|
stage = _STAGE_ALIASES.get(parts[-1], parts[-1][:6])
|
|
candidate = f"inc:{parts[1]}:{stage}"
|
|
if len(candidate) <= MAX_MCP_AUDIT_SESSION_ID_LENGTH:
|
|
return candidate
|
|
if len(parts) >= 3 and parts[0] == "callback":
|
|
return _compact_with_hash("cb", parts[1], raw)
|
|
if len(parts) >= 2 and parts[0] == "approval":
|
|
return _compact_with_hash("apr", parts[1], raw)
|
|
if len(parts) >= 2 and parts[0] == "mcp_bridge":
|
|
return _compact_with_hash("bridge", parts[1], raw)
|
|
|
|
return _compact_with_hash(parts[0] if parts else "mcp", raw, raw)
|
|
|
|
|
|
def build_mcp_audit_context(
|
|
*,
|
|
session_id: str | None = None,
|
|
incident_id: str | None = None,
|
|
flywheel_node: str | None = None,
|
|
agent_role: str | None = None,
|
|
gateway_path: str = "legacy_registry_provider",
|
|
operator_user_id: int | str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Build the `_mcp_audit` metadata object carried beside tool params."""
|
|
|
|
context: dict[str, Any] = {
|
|
"gateway_path": gateway_path,
|
|
}
|
|
optional_values = {
|
|
"session_id": normalize_mcp_audit_session_id(session_id),
|
|
"incident_id": incident_id,
|
|
"flywheel_node": flywheel_node,
|
|
"agent_role": agent_role,
|
|
"operator_user_id": operator_user_id,
|
|
}
|
|
context.update({key: value for key, value in optional_values.items() if value is not None})
|
|
return context
|
|
|
|
|
|
def with_mcp_audit_context(
|
|
parameters: dict[str, Any],
|
|
**audit_kwargs: Any,
|
|
) -> dict[str, Any]:
|
|
"""Return a shallow copy of provider parameters with merged audit context."""
|
|
|
|
audited = dict(parameters)
|
|
existing = audited.get("_mcp_audit")
|
|
merged = dict(existing) if isinstance(existing, dict) else {}
|
|
merged.update(build_mcp_audit_context(**audit_kwargs))
|
|
audited["_mcp_audit"] = merged
|
|
return audited
|