277 lines
9.0 KiB
Python
277 lines
9.0 KiB
Python
"""
|
||
Model Registry - Phase 12 P1 修復
|
||
=================================
|
||
集中管理 AI 模型配置,消除 hardcode 模型名稱
|
||
|
||
功能:
|
||
- 從 models.json 讀取配置
|
||
- 提供 get_model(provider, purpose) 方法
|
||
- Singleton 模式
|
||
- 支援依賴注入測試
|
||
|
||
版本: v1.0
|
||
建立: 2026-03-26 23:00 (台北時區)
|
||
建立者: Claude Code
|
||
最後修改: 2026-04-09 10:00 (台北時區) — ogt: fallback config 更新為 deepseek-r1:14b + gemma3:4b
|
||
修改者: Claude Code
|
||
"""
|
||
|
||
import json
|
||
from pathlib import Path
|
||
from typing import Protocol
|
||
|
||
import structlog
|
||
|
||
logger = structlog.get_logger(__name__)
|
||
|
||
|
||
# =============================================================================
|
||
# Interface (支援 DI 測試)
|
||
# =============================================================================
|
||
|
||
|
||
class IModelRegistry(Protocol):
|
||
"""Model Registry Interface for DI"""
|
||
|
||
def get_model(self, provider: str, purpose: str = "default") -> str:
|
||
"""取得模型名稱"""
|
||
...
|
||
|
||
def get_fallback_order(self) -> list[str]:
|
||
"""取得備援順序"""
|
||
...
|
||
|
||
def get_model_by_complexity(self, complexity: int) -> str:
|
||
"""依複雜度取得推薦模型"""
|
||
...
|
||
|
||
def get_provider_config(self, provider: str) -> dict:
|
||
"""取得 provider 完整配置"""
|
||
...
|
||
|
||
|
||
# =============================================================================
|
||
# Implementation
|
||
# =============================================================================
|
||
|
||
|
||
class ModelRegistry:
|
||
"""
|
||
Model Registry 實作
|
||
|
||
從 models.json 讀取配置,提供統一的模型查詢介面
|
||
|
||
Usage:
|
||
registry = get_model_registry()
|
||
model = registry.get_model("ollama", "rca") # -> "qwen2.5:7b-instruct"
|
||
"""
|
||
|
||
def __init__(self, config_path: Path | str | None = None):
|
||
"""
|
||
初始化 ModelRegistry
|
||
|
||
Args:
|
||
config_path: models.json 路徑,None 使用預設路徑
|
||
"""
|
||
if config_path is None:
|
||
# 預設路徑: apps/api/models.json
|
||
config_path = Path(__file__).parent.parent.parent / "models.json"
|
||
elif isinstance(config_path, str):
|
||
config_path = Path(config_path)
|
||
|
||
self._config_path = config_path
|
||
self._config: dict = {}
|
||
self._load_config()
|
||
|
||
# 複雜度對應模型 (從 config 或使用預設)
|
||
self._complexity_map = self._build_complexity_map()
|
||
|
||
def _load_config(self) -> None:
|
||
"""載入 models.json"""
|
||
try:
|
||
with open(self._config_path) as f:
|
||
self._config = json.load(f)
|
||
logger.info(
|
||
"model_registry_loaded",
|
||
path=str(self._config_path),
|
||
providers=list(self._config.get("providers", {}).keys()),
|
||
)
|
||
except FileNotFoundError:
|
||
logger.warning(
|
||
"models_json_not_found",
|
||
path=str(self._config_path),
|
||
using="fallback_defaults",
|
||
)
|
||
self._config = self._get_default_config()
|
||
except json.JSONDecodeError as e:
|
||
logger.error(
|
||
"models_json_parse_error",
|
||
path=str(self._config_path),
|
||
error=str(e),
|
||
)
|
||
self._config = self._get_default_config()
|
||
|
||
def _get_default_config(self) -> dict:
|
||
"""預設配置 (fallback)"""
|
||
return {
|
||
"default_provider": "ollama",
|
||
"fallback_order": ["ollama", "gemini", "claude"],
|
||
# 2026-03-29 ogt: P2-3 加入 Tool Calling Fallback (ADR-036)
|
||
"tool_calling_fallback_order": ["nvidia", "gemini", "claude"],
|
||
"providers": {
|
||
# 2026-04-08 ogt: 切換到 M1 Pro (192.168.0.111),deepseek-r1:14b + gemma3:4b
|
||
"ollama": {
|
||
"models": {
|
||
"default": "deepseek-r1:14b",
|
||
"rca": "deepseek-r1:14b",
|
||
"summary": "gemma3:4b",
|
||
}
|
||
},
|
||
# 2026-03-28 ogt: 更新到 gemini-2.0-flash (1.5 已停用)
|
||
"gemini": {
|
||
"models": {
|
||
"default": "gemini-2.0-flash",
|
||
"rca": "gemini-2.0-flash",
|
||
"summary": "gemini-2.0-flash",
|
||
}
|
||
},
|
||
"claude": {
|
||
"models": {
|
||
"default": "claude-haiku-4-5-20251001",
|
||
"rca": "claude-haiku-4-5-20251001",
|
||
"summary": "claude-haiku-4-5-20251001",
|
||
}
|
||
},
|
||
# 2026-03-29 ogt: P2-3 加入 NVIDIA (ADR-036)
|
||
"nvidia": {
|
||
"models": {
|
||
"default": "nvidia/nemotron-mini-4b-instruct",
|
||
"tool_calling": "nvidia/nemotron-mini-4b-instruct",
|
||
}
|
||
},
|
||
},
|
||
}
|
||
|
||
def _build_complexity_map(self) -> dict[int, str]:
|
||
"""建立複雜度對應模型映射"""
|
||
# 從 config 或使用預設
|
||
ollama_models = self._config.get("providers", {}).get("ollama", {}).get("models", {})
|
||
default_model = ollama_models.get("default", "qwen2.5:7b-instruct")
|
||
summary_model = ollama_models.get("summary", "llama3.2:3b")
|
||
|
||
return {
|
||
1: summary_model, # 簡單任務,快速回應
|
||
2: default_model, # 中等任務
|
||
3: default_model, # 複雜任務
|
||
4: "gemini", # 需要雲端能力
|
||
5: "claude", # 極複雜,需要最強模型
|
||
}
|
||
|
||
def get_model(self, provider: str, purpose: str = "default") -> str:
|
||
"""
|
||
取得模型名稱
|
||
|
||
Args:
|
||
provider: 提供者 (ollama, gemini, claude)
|
||
purpose: 用途 (default, rca, summary)
|
||
|
||
Returns:
|
||
模型名稱
|
||
"""
|
||
providers = self._config.get("providers", {})
|
||
provider_config = providers.get(provider, {})
|
||
models = provider_config.get("models", {})
|
||
|
||
# 優先取用途,fallback 到 default
|
||
model = models.get(purpose) or models.get("default")
|
||
|
||
if not model:
|
||
# 最終 fallback
|
||
# 2026-03-28 ogt: 更新 gemini-2.0-flash
|
||
fallback_map = {
|
||
"ollama": "qwen2.5:7b-instruct",
|
||
"gemini": "gemini-2.0-flash",
|
||
"claude": "claude-haiku-4-5-20251001",
|
||
}
|
||
model = fallback_map.get(provider, provider)
|
||
logger.warning(
|
||
"model_not_found_using_fallback",
|
||
provider=provider,
|
||
purpose=purpose,
|
||
fallback=model,
|
||
)
|
||
|
||
return model
|
||
|
||
def get_fallback_order(self) -> list[str]:
|
||
"""取得備援順序"""
|
||
return self._config.get("fallback_order", ["ollama", "gemini", "claude"])
|
||
|
||
def get_model_by_complexity(self, complexity: int) -> str:
|
||
"""
|
||
依複雜度取得推薦模型
|
||
|
||
Args:
|
||
complexity: 複雜度分數 (1-5)
|
||
|
||
Returns:
|
||
推薦模型名稱
|
||
"""
|
||
# 確保在範圍內
|
||
complexity = max(1, min(5, complexity))
|
||
return self._complexity_map.get(complexity, self.get_model("ollama", "default"))
|
||
|
||
def get_provider_config(self, provider: str) -> dict:
|
||
"""取得 provider 完整配置"""
|
||
return self._config.get("providers", {}).get(provider, {})
|
||
|
||
def get_default_provider(self) -> str:
|
||
"""取得預設 provider"""
|
||
return self._config.get("default_provider", "ollama")
|
||
|
||
def get_provider_options(self, provider: str) -> dict:
|
||
"""取得 provider 的 options"""
|
||
provider_config = self.get_provider_config(provider)
|
||
return provider_config.get("options", {})
|
||
|
||
def get_provider_timeout(self, provider: str) -> int:
|
||
"""取得 provider 的 timeout (秒)"""
|
||
provider_config = self.get_provider_config(provider)
|
||
return provider_config.get("timeout_seconds", 30)
|
||
|
||
|
||
# =============================================================================
|
||
# Singleton
|
||
# =============================================================================
|
||
|
||
_registry: ModelRegistry | None = None
|
||
|
||
|
||
def get_model_registry() -> ModelRegistry:
|
||
"""取得 ModelRegistry 單例"""
|
||
global _registry
|
||
if _registry is None:
|
||
_registry = ModelRegistry()
|
||
return _registry
|
||
|
||
|
||
def reset_model_registry() -> None:
|
||
"""重置單例 (用於測試)"""
|
||
global _registry
|
||
_registry = None
|
||
|
||
|
||
# =============================================================================
|
||
# Convenience Functions
|
||
# =============================================================================
|
||
|
||
|
||
def get_model(provider: str, purpose: str = "default") -> str:
|
||
"""便捷函數: 取得模型名稱"""
|
||
return get_model_registry().get_model(provider, purpose)
|
||
|
||
|
||
def get_model_by_complexity(complexity: int) -> str:
|
||
"""便捷函數: 依複雜度取得模型"""
|
||
return get_model_registry().get_model_by_complexity(complexity)
|