Files
awoooi/apps/api/src/services/model_registry.py
Your Name dccdcdbaf5
All checks were successful
CD Pipeline / build-and-deploy (push) Successful in 9m45s
fix(flywheel): unblock action safety and Claude fallback
2026-04-29 21:51:18 +08:00

277 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Model Registry - Phase 12 P1 修復
=================================
集中管理 AI 模型配置,消除 hardcode 模型名稱
功能:
- 從 models.json 讀取配置
- 提供 get_model(provider, purpose) 方法
- Singleton 模式
- 支援依賴注入測試
版本: v1.0
建立: 2026-03-26 23:00 (台北時區)
建立者: Claude Code
最後修改: 2026-04-09 10:00 (台北時區) — ogt: fallback config 更新為 deepseek-r1:14b + gemma3:4b
修改者: Claude Code
"""
import json
from pathlib import Path
from typing import Protocol
import structlog
logger = structlog.get_logger(__name__)
# =============================================================================
# Interface (支援 DI 測試)
# =============================================================================
class IModelRegistry(Protocol):
"""Model Registry Interface for DI"""
def get_model(self, provider: str, purpose: str = "default") -> str:
"""取得模型名稱"""
...
def get_fallback_order(self) -> list[str]:
"""取得備援順序"""
...
def get_model_by_complexity(self, complexity: int) -> str:
"""依複雜度取得推薦模型"""
...
def get_provider_config(self, provider: str) -> dict:
"""取得 provider 完整配置"""
...
# =============================================================================
# Implementation
# =============================================================================
class ModelRegistry:
"""
Model Registry 實作
從 models.json 讀取配置,提供統一的模型查詢介面
Usage:
registry = get_model_registry()
model = registry.get_model("ollama", "rca") # -> "qwen2.5:7b-instruct"
"""
def __init__(self, config_path: Path | str | None = None):
"""
初始化 ModelRegistry
Args:
config_path: models.json 路徑None 使用預設路徑
"""
if config_path is None:
# 預設路徑: apps/api/models.json
config_path = Path(__file__).parent.parent.parent / "models.json"
elif isinstance(config_path, str):
config_path = Path(config_path)
self._config_path = config_path
self._config: dict = {}
self._load_config()
# 複雜度對應模型 (從 config 或使用預設)
self._complexity_map = self._build_complexity_map()
def _load_config(self) -> None:
"""載入 models.json"""
try:
with open(self._config_path) as f:
self._config = json.load(f)
logger.info(
"model_registry_loaded",
path=str(self._config_path),
providers=list(self._config.get("providers", {}).keys()),
)
except FileNotFoundError:
logger.warning(
"models_json_not_found",
path=str(self._config_path),
using="fallback_defaults",
)
self._config = self._get_default_config()
except json.JSONDecodeError as e:
logger.error(
"models_json_parse_error",
path=str(self._config_path),
error=str(e),
)
self._config = self._get_default_config()
def _get_default_config(self) -> dict:
"""預設配置 (fallback)"""
return {
"default_provider": "ollama",
"fallback_order": ["ollama", "gemini", "claude"],
# 2026-03-29 ogt: P2-3 加入 Tool Calling Fallback (ADR-036)
"tool_calling_fallback_order": ["nvidia", "gemini", "claude"],
"providers": {
# 2026-04-08 ogt: 切換到 M1 Pro (192.168.0.111)deepseek-r1:14b + gemma3:4b
"ollama": {
"models": {
"default": "deepseek-r1:14b",
"rca": "deepseek-r1:14b",
"summary": "gemma3:4b",
}
},
# 2026-03-28 ogt: 更新到 gemini-2.0-flash (1.5 已停用)
"gemini": {
"models": {
"default": "gemini-2.0-flash",
"rca": "gemini-2.0-flash",
"summary": "gemini-2.0-flash",
}
},
"claude": {
"models": {
"default": "claude-haiku-4-5-20251001",
"rca": "claude-haiku-4-5-20251001",
"summary": "claude-haiku-4-5-20251001",
}
},
# 2026-03-29 ogt: P2-3 加入 NVIDIA (ADR-036)
"nvidia": {
"models": {
"default": "nvidia/nemotron-mini-4b-instruct",
"tool_calling": "nvidia/nemotron-mini-4b-instruct",
}
},
},
}
def _build_complexity_map(self) -> dict[int, str]:
"""建立複雜度對應模型映射"""
# 從 config 或使用預設
ollama_models = self._config.get("providers", {}).get("ollama", {}).get("models", {})
default_model = ollama_models.get("default", "qwen2.5:7b-instruct")
summary_model = ollama_models.get("summary", "llama3.2:3b")
return {
1: summary_model, # 簡單任務,快速回應
2: default_model, # 中等任務
3: default_model, # 複雜任務
4: "gemini", # 需要雲端能力
5: "claude", # 極複雜,需要最強模型
}
def get_model(self, provider: str, purpose: str = "default") -> str:
"""
取得模型名稱
Args:
provider: 提供者 (ollama, gemini, claude)
purpose: 用途 (default, rca, summary)
Returns:
模型名稱
"""
providers = self._config.get("providers", {})
provider_config = providers.get(provider, {})
models = provider_config.get("models", {})
# 優先取用途fallback 到 default
model = models.get(purpose) or models.get("default")
if not model:
# 最終 fallback
# 2026-03-28 ogt: 更新 gemini-2.0-flash
fallback_map = {
"ollama": "qwen2.5:7b-instruct",
"gemini": "gemini-2.0-flash",
"claude": "claude-haiku-4-5-20251001",
}
model = fallback_map.get(provider, provider)
logger.warning(
"model_not_found_using_fallback",
provider=provider,
purpose=purpose,
fallback=model,
)
return model
def get_fallback_order(self) -> list[str]:
"""取得備援順序"""
return self._config.get("fallback_order", ["ollama", "gemini", "claude"])
def get_model_by_complexity(self, complexity: int) -> str:
"""
依複雜度取得推薦模型
Args:
complexity: 複雜度分數 (1-5)
Returns:
推薦模型名稱
"""
# 確保在範圍內
complexity = max(1, min(5, complexity))
return self._complexity_map.get(complexity, self.get_model("ollama", "default"))
def get_provider_config(self, provider: str) -> dict:
"""取得 provider 完整配置"""
return self._config.get("providers", {}).get(provider, {})
def get_default_provider(self) -> str:
"""取得預設 provider"""
return self._config.get("default_provider", "ollama")
def get_provider_options(self, provider: str) -> dict:
"""取得 provider 的 options"""
provider_config = self.get_provider_config(provider)
return provider_config.get("options", {})
def get_provider_timeout(self, provider: str) -> int:
"""取得 provider 的 timeout (秒)"""
provider_config = self.get_provider_config(provider)
return provider_config.get("timeout_seconds", 30)
# =============================================================================
# Singleton
# =============================================================================
_registry: ModelRegistry | None = None
def get_model_registry() -> ModelRegistry:
"""取得 ModelRegistry 單例"""
global _registry
if _registry is None:
_registry = ModelRegistry()
return _registry
def reset_model_registry() -> None:
"""重置單例 (用於測試)"""
global _registry
_registry = None
# =============================================================================
# Convenience Functions
# =============================================================================
def get_model(provider: str, purpose: str = "default") -> str:
"""便捷函數: 取得模型名稱"""
return get_model_registry().get_model(provider, purpose)
def get_model_by_complexity(complexity: int) -> str:
"""便捷函數: 依複雜度取得模型"""
return get_model_registry().get_model_by_complexity(complexity)