359 lines
15 KiB
Python
359 lines
15 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
services/mcp_router.py
|
||
Operation Ollama-First v5.0 / Phase 10.5 — MCP 統一路由
|
||
|
||
設計原則(ADR-031):
|
||
1. 統一 HTTP client 對 self-hosted MCP stack(postgres / omnisearch / firecrawl / filesystem)
|
||
2. 所有 MCP call 雙寫 mcp_calls 表(含 cost_usd / cache_hit / status)
|
||
3. fail-safe:MCP server 不可達 → 回 None,caller 自決 fallback
|
||
4. feature flag MCP_ROUTER_ENABLED 預設 OFF
|
||
啟用條件:docker-compose.mcp.yml 已 deploy + 4 個 health endpoint 200
|
||
|
||
部署位置(與 docker-compose.mcp.yml 對齊):
|
||
postgres-mcp: 127.0.0.1:3001
|
||
firecrawl-self: 127.0.0.1:3002
|
||
mcp-omnisearch: 127.0.0.1:3003
|
||
filesystem-mcp: 127.0.0.1:3004
|
||
|
||
caller 介面(範例):
|
||
from services.mcp_router import mcp_router
|
||
result = mcp_router.call(
|
||
server='omnisearch',
|
||
tool='tavily_search',
|
||
args={'query': 'momo 母親節促銷'},
|
||
caller='mcp_collector',
|
||
)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
import os
|
||
import time
|
||
import json
|
||
import logging
|
||
import hashlib
|
||
import threading
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import requests
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Feature flag + 配置
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
MCP_ROUTER_ENABLED = os.getenv('MCP_ROUTER_ENABLED', 'false').strip().lower() in ('true', '1', 'yes', 'on')
|
||
|
||
MCP_BASE_HOSTS = {
|
||
'postgres': os.getenv('MCP_POSTGRES_URL', 'http://127.0.0.1:3001'),
|
||
'firecrawl': os.getenv('MCP_FIRECRAWL_URL', 'http://127.0.0.1:3002'),
|
||
'omnisearch': os.getenv('MCP_OMNISEARCH_URL', 'http://127.0.0.1:3003'),
|
||
'filesystem': os.getenv('MCP_FILESYSTEM_URL', 'http://127.0.0.1:3004'),
|
||
}
|
||
|
||
MCP_DEFAULT_TIMEOUT = int(os.getenv('MCP_TIMEOUT_SEC', '30'))
|
||
MCP_CACHE_TTL_SEC = int(os.getenv('MCP_CACHE_TTL_SEC', '3600')) # 1h
|
||
MCP_MAX_RESULT_BYTES = int(os.getenv('MCP_MAX_RESULT_BYTES', '65536')) # 64KB
|
||
|
||
|
||
def is_mcp_router_enabled() -> bool:
|
||
"""Runtime check(避免 import-time freeze)"""
|
||
return os.getenv('MCP_ROUTER_ENABLED', 'false').strip().lower() in ('true', '1', 'yes', 'on')
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Tool 白名單(caller × server × tool)— 限制 LLM 不能亂打 MCP
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
READONLY_FILESYSTEM_TOOLS = [
|
||
'list_allowed_directories',
|
||
'list_directory',
|
||
'directory_tree',
|
||
'read_file',
|
||
'read_multiple_files',
|
||
'search_files',
|
||
'get_file_info',
|
||
]
|
||
|
||
TOOL_REGISTRY: Dict[str, Dict[str, List[str]]] = {
|
||
# mcp_collector 取代 Gemini Grounding
|
||
'mcp_collector': {
|
||
'omnisearch': ['tavily_search', 'exa_search'],
|
||
'firecrawl': ['scrape_url'],
|
||
},
|
||
# Hermes 競品分析可查 DB + 抓網頁
|
||
'hermes_analyst': {
|
||
'postgres': ['query'],
|
||
'omnisearch': ['tavily_search'],
|
||
'firecrawl': ['scrape_url'],
|
||
},
|
||
# OpenClaw 戰略分析
|
||
'openclaw_strategist': {
|
||
'postgres': ['query'],
|
||
'omnisearch': ['tavily_search', 'exa_search'],
|
||
},
|
||
# filesystem-mcp 僅掛載 /data、/logs read-only;保留給診斷工具讀檔,不開寫入類工具。
|
||
'ops_diagnostics': {
|
||
'filesystem': READONLY_FILESYSTEM_TOOLS,
|
||
},
|
||
}
|
||
|
||
|
||
def _is_tool_allowed(caller: str, server: str, tool: str) -> bool:
|
||
"""白名單檢查;caller 不在 registry → 拒絕"""
|
||
return tool in TOOL_REGISTRY.get(caller, {}).get(server, [])
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Cache(記憶體 + DB hash 指紋)
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
_memory_cache: Dict[str, Dict[str, Any]] = {}
|
||
_cache_lock = threading.Lock()
|
||
|
||
|
||
def _cache_key(server: str, tool: str, args: Dict[str, Any]) -> str:
|
||
"""穩定排序後 SHA256[:16]"""
|
||
payload = json.dumps({'s': server, 't': tool, 'a': args}, sort_keys=True, ensure_ascii=False)
|
||
return hashlib.sha256(payload.encode('utf-8')).hexdigest()[:16]
|
||
|
||
|
||
def _cache_get(key: str) -> Optional[Dict[str, Any]]:
|
||
with _cache_lock:
|
||
entry = _memory_cache.get(key)
|
||
if not entry:
|
||
return None
|
||
if time.time() - entry['ts'] > MCP_CACHE_TTL_SEC:
|
||
_memory_cache.pop(key, None)
|
||
return None
|
||
return entry['data']
|
||
|
||
|
||
def _cache_set(key: str, data: Dict[str, Any]) -> None:
|
||
with _cache_lock:
|
||
_memory_cache[key] = {'data': data, 'ts': time.time()}
|
||
# 簡單 LRU:超 200 筆清最舊
|
||
if len(_memory_cache) > 200:
|
||
oldest = min(_memory_cache.items(), key=lambda kv: kv[1]['ts'])
|
||
_memory_cache.pop(oldest[0], None)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# 結果容器
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
@dataclass
|
||
class MCPResult:
|
||
success: bool
|
||
server: str
|
||
tool: str
|
||
data: Dict[str, Any] = field(default_factory=dict)
|
||
cache_hit: bool = False
|
||
duration_ms: int = 0
|
||
cost_usd: float = 0.0
|
||
error: Optional[str] = None
|
||
output_size: int = 0
|
||
status: Optional[str] = None
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# DB 寫入(fire-and-forget,與 ai_call_logger 同模式)
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
def _async_write_mcp_call(
|
||
caller: str,
|
||
server: str,
|
||
tool: str,
|
||
args: Dict[str, Any],
|
||
result: MCPResult,
|
||
request_id: Optional[str] = None,
|
||
) -> None:
|
||
"""寫 mcp_calls 表(async thread 不阻塞主流程)"""
|
||
def _writer():
|
||
try:
|
||
from sqlalchemy import text as sa_text
|
||
from database.manager import get_session
|
||
session = get_session()
|
||
try:
|
||
# PII 保護:input_args 只存 hash + size,不存原文
|
||
args_redacted = {
|
||
'hash': hashlib.sha1(
|
||
json.dumps(args, sort_keys=True, ensure_ascii=False).encode('utf-8')
|
||
).hexdigest()[:12],
|
||
'keys': list(args.keys())[:10],
|
||
}
|
||
session.execute(
|
||
sa_text("""
|
||
INSERT INTO mcp_calls (
|
||
caller, server, tool, input_args, output_size,
|
||
duration_ms, status, error, cost_usd, cache_hit, request_id
|
||
) VALUES (
|
||
:caller, :server, :tool, CAST(:args AS JSONB), :osz,
|
||
:dur, :status, :err, :cost, :cache, :req
|
||
)
|
||
"""),
|
||
{
|
||
'caller': caller,
|
||
'server': server,
|
||
'tool': tool,
|
||
'args': json.dumps(args_redacted),
|
||
'osz': result.output_size,
|
||
'dur': result.duration_ms,
|
||
'status': result.status or ('ok' if result.success else 'error'),
|
||
'err': (result.error or '')[:4000] if result.error else None,
|
||
'cost': result.cost_usd,
|
||
'cache': result.cache_hit,
|
||
'req': request_id,
|
||
},
|
||
)
|
||
session.commit()
|
||
except Exception as exc:
|
||
session.rollback()
|
||
logger.debug(f"[MCPRouter] DB write failed: {exc}")
|
||
finally:
|
||
session.close()
|
||
except Exception:
|
||
logger.debug("[MCPRouter] async DB writer bootstrap failed", exc_info=True)
|
||
|
||
threading.Thread(target=_writer, daemon=True).start()
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# MCPRouter 主類
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
class MCPRouter:
|
||
"""統一 MCP 路由 — HTTP client + cache + DB log."""
|
||
|
||
def call(
|
||
self,
|
||
server: str,
|
||
tool: str,
|
||
args: Dict[str, Any],
|
||
caller: str = 'unknown',
|
||
timeout: Optional[int] = None,
|
||
request_id: Optional[str] = None,
|
||
) -> MCPResult:
|
||
"""主入口。flag OFF 時直接回 success=False(caller 自走 fallback)。"""
|
||
if not is_mcp_router_enabled():
|
||
return MCPResult(
|
||
success=False, server=server, tool=tool,
|
||
error='MCP_ROUTER_ENABLED=false (戰役 v5.0 預設 OFF;待 docker-compose.mcp.yml deploy 後翻 ON)',
|
||
status='error',
|
||
)
|
||
|
||
# 白名單檢查
|
||
if not _is_tool_allowed(caller, server, tool):
|
||
return MCPResult(
|
||
success=False, server=server, tool=tool,
|
||
error=f'tool not in registry: caller={caller} server={server} tool={tool}',
|
||
status='error',
|
||
)
|
||
|
||
# Server 配置檢查
|
||
base_url = MCP_BASE_HOSTS.get(server)
|
||
if not base_url:
|
||
return MCPResult(
|
||
success=False, server=server, tool=tool,
|
||
error=f'unknown server: {server}',
|
||
status='error',
|
||
)
|
||
|
||
# Cache 命中
|
||
ckey = _cache_key(server, tool, args)
|
||
cached = _cache_get(ckey)
|
||
if cached is not None:
|
||
result = MCPResult(
|
||
success=True, server=server, tool=tool,
|
||
data=cached, cache_hit=True, duration_ms=0,
|
||
output_size=len(json.dumps(cached, ensure_ascii=False)),
|
||
status='cache_only',
|
||
)
|
||
_async_write_mcp_call(caller, server, tool, args, result, request_id)
|
||
return result
|
||
|
||
# HTTP call
|
||
url = f"{base_url.rstrip('/')}/tools/{tool}"
|
||
request_timeout = timeout or MCP_DEFAULT_TIMEOUT
|
||
t0 = time.monotonic()
|
||
|
||
try:
|
||
resp = requests.post(url, json=args, timeout=request_timeout)
|
||
duration_ms = int((time.monotonic() - t0) * 1000)
|
||
|
||
if resp.status_code != 200:
|
||
status = 'rate_limited' if resp.status_code == 429 else 'error'
|
||
result = MCPResult(
|
||
success=False, server=server, tool=tool,
|
||
duration_ms=duration_ms,
|
||
error=f'HTTP {resp.status_code}: {resp.text[:200]}',
|
||
status=status,
|
||
)
|
||
_async_write_mcp_call(caller, server, tool, args, result, request_id)
|
||
return result
|
||
|
||
data = resp.json()
|
||
output_text = json.dumps(data, ensure_ascii=False)
|
||
output_size = len(output_text.encode('utf-8'))
|
||
|
||
# 大小護欄
|
||
if output_size > MCP_MAX_RESULT_BYTES:
|
||
logger.warning(f"[MCPRouter] {server}/{tool} output {output_size} > {MCP_MAX_RESULT_BYTES} bytes; 截斷")
|
||
# 截 64KB(保留主結構,刪細節)
|
||
data = {'_truncated': True, '_original_bytes': output_size, 'preview': output_text[:MCP_MAX_RESULT_BYTES]}
|
||
|
||
# Cache(成功才存)
|
||
_cache_set(ckey, data)
|
||
|
||
result = MCPResult(
|
||
success=True, server=server, tool=tool,
|
||
data=data, cache_hit=False,
|
||
duration_ms=duration_ms, output_size=output_size,
|
||
status='ok',
|
||
)
|
||
_async_write_mcp_call(caller, server, tool, args, result, request_id)
|
||
return result
|
||
|
||
except requests.Timeout:
|
||
duration_ms = int((time.monotonic() - t0) * 1000)
|
||
result = MCPResult(
|
||
success=False, server=server, tool=tool,
|
||
duration_ms=duration_ms, error=f'timeout ({request_timeout}s)',
|
||
status='timeout',
|
||
)
|
||
_async_write_mcp_call(caller, server, tool, args, result, request_id)
|
||
return result
|
||
|
||
except Exception as exc:
|
||
duration_ms = int((time.monotonic() - t0) * 1000)
|
||
result = MCPResult(
|
||
success=False, server=server, tool=tool,
|
||
duration_ms=duration_ms,
|
||
error=f'{type(exc).__name__}: {str(exc)[:300]}',
|
||
status='error',
|
||
)
|
||
_async_write_mcp_call(caller, server, tool, args, result, request_id)
|
||
return result
|
||
|
||
def health_check(self) -> Dict[str, bool]:
|
||
"""檢查 4 個 server 健康度(給排程或 admin endpoint 用)"""
|
||
results = {}
|
||
for server, base_url in MCP_BASE_HOSTS.items():
|
||
try:
|
||
resp = requests.get(f"{base_url.rstrip('/')}/health", timeout=3)
|
||
results[server] = resp.status_code == 200
|
||
except Exception:
|
||
results[server] = False
|
||
return results
|
||
|
||
|
||
# 全域單例
|
||
mcp_router = MCPRouter()
|
||
|
||
|
||
__all__ = [
|
||
'MCPRouter',
|
||
'MCPResult',
|
||
'mcp_router',
|
||
'is_mcp_router_enabled',
|
||
'TOOL_REGISTRY',
|
||
'MCP_BASE_HOSTS',
|
||
]
|