diff --git a/apps/api/src/core/config.py b/apps/api/src/core/config.py index 0123444a..99e38a16 100644 --- a/apps/api/src/core/config.py +++ b/apps/api/src/core/config.py @@ -261,6 +261,15 @@ class Settings(BaseSettings): default="", description="NVIDIA NIM API key for Nemotron Tool Calling (ADR-036)", ) + # 2026-04-09 Claude Sonnet 4.6: Ollama Tool Calling — 替代 NVIDIA 雲端,本機推理 + USE_OLLAMA_TOOL_CALLING: bool = Field( + default=True, + description="使用 Ollama 本機做 Tool Calling,取代 NVIDIA NIM 雲端 (44s→5s)", + ) + OLLAMA_TOOL_MODEL: str = Field( + default="llama3.1:8b", + description="Ollama Tool Calling 模型 (支援 function calling 格式)", + ) @field_validator("AI_FALLBACK_ORDER", mode="before") @classmethod diff --git a/apps/api/src/services/nvidia_provider.py b/apps/api/src/services/nvidia_provider.py index d76d7e4a..6b5be0a6 100644 --- a/apps/api/src/services/nvidia_provider.py +++ b/apps/api/src/services/nvidia_provider.py @@ -830,18 +830,202 @@ class NvidiaProvider: return str(e), False, 0, 0.0 +# ============================================================================= +# OllamaToolProvider — 本機 Tool Calling,取代 NVIDIA 雲端 +# 2026-04-09 Claude Sonnet 4.6 Asia/Taipei +# Ollama /v1/chat/completions 實作同一 INvidiaProvider protocol +# ============================================================================= + + +class OllamaToolProvider: + """ + Ollama 本機 Tool Calling Provider + + 使用 Ollama OpenAI 相容 API (/v1/chat/completions) 做 tool calling, + 取代 NVIDIA 雲端 NIM。延遲從 44s 降至 ~5s。 + + 模型: llama3.1:8b (tool calling 最穩定的 8B 模型) + Endpoint: OLLAMA_URL/v1/chat/completions (OpenAI 相容格式) + """ + + def __init__(self) -> None: + self._client: httpx.AsyncClient | None = None + + async def _get_client(self) -> httpx.AsyncClient: + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient( + timeout=httpx.Timeout(60.0, connect=5.0), + limits=httpx.Limits(max_connections=5, max_keepalive_connections=3), + ) + return self._client + + async def close(self) -> None: + if self._client and not self._client.is_closed: + await self._client.aclose() + self._client = None + + def is_high_risk_tool(self, tool_name: str) -> bool: + return tool_name in HIGH_RISK_TOOLS + + def get_high_risk_tools( + self, tool_calls: list[ToolCallValidationResult] + ) -> list[ToolCallValidationResult]: + return [tc for tc in tool_calls if self.is_high_risk_tool(tc.tool_name)] + + async def health_check(self) -> bool: + try: + client = await self._get_client() + base_url = settings.OLLAMA_URL.rstrip("/") + resp = await client.get(f"{base_url}/api/tags", timeout=5.0) + return resp.status_code == 200 + except Exception: + return False + + async def tool_call( + self, + messages: list[dict[str, Any]], + tools: list[ToolDefinition | dict[str, Any]], + model: str = "", + temperature: float = 0.0, + max_tokens: int = 512, + ) -> NvidiaProviderResult: + """Ollama /v1/chat/completions tool calling""" + start_time = time.perf_counter() + model = model or settings.OLLAMA_TOOL_MODEL + base_url = settings.OLLAMA_URL.rstrip("/") + url = f"{base_url}/v1/chat/completions" + + # 轉換 tools 為 dict 格式(同 NvidiaProvider) + tools_data = [] + for tool in tools: + if isinstance(tool, ToolDefinition): + tools_data.append(tool.model_dump()) + else: + tools_data.append(tool) + + request_body = { + "model": model, + "messages": messages, + "tools": tools_data, + "tool_choice": "auto", + "temperature": temperature, + "max_tokens": max_tokens, + } + + try: + client = await self._get_client() + response = await client.post(url, json=request_body) + latency_ms = (time.perf_counter() - start_time) * 1000 + + if response.status_code != 200: + logger.warning( + "ollama_tool_call_http_error", + status=response.status_code, + body=response.text[:200], + ) + return NvidiaProviderResult( + success=False, + error=f"Ollama HTTP {response.status_code}", + latency_ms=latency_ms, + fallback_triggered=True, + ) + + data = response.json() + # 解析 tool_calls(OpenAI 格式) + choices = data.get("choices", []) + if not choices: + return NvidiaProviderResult( + success=False, error="Ollama 無回應", latency_ms=latency_ms, fallback_triggered=True + ) + + message = choices[0].get("message", {}) + raw_tool_calls = message.get("tool_calls", []) + + tool_call_results: list[ToolCallValidationResult] = [] + for tc in raw_tool_calls: + fn = tc.get("function", {}) + name = fn.get("name", "") + args_raw = fn.get("arguments", "{}") + try: + args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw + except json.JSONDecodeError: + args = {} + tool_call_results.append(ToolCallValidationResult( + tool_name=name, + arguments=args, + valid=bool(name), + )) + + usage_data = data.get("usage", {}) + from src.models.nvidia import NvidiaUsage + usage = NvidiaUsage( + prompt_tokens=usage_data.get("prompt_tokens", 0), + completion_tokens=usage_data.get("completion_tokens", 0), + total_tokens=usage_data.get("total_tokens", 0), + ) + logger.info( + "ollama_tool_call_success", + model=model, + tool_count=len(tool_call_results), + latency_ms=round(latency_ms, 1), + tokens=usage.total_tokens, + ) + + return NvidiaProviderResult( + success=True, + tool_calls=tool_call_results, + latency_ms=latency_ms, + usage=usage, + ) + + except Exception as e: + latency_ms = (time.perf_counter() - start_time) * 1000 + logger.warning("ollama_tool_call_error", error=str(e), latency_ms=round(latency_ms, 1)) + return NvidiaProviderResult( + success=False, error=str(e), latency_ms=latency_ms, fallback_triggered=True + ) + + async def chat(self, prompt: str, model: str = "", temperature: float = 0.7, max_tokens: int = 512) -> str: + """簡單 chat(非 tool calling 路徑,保持 INvidiaProvider 相容)""" + model = model or settings.OLLAMA_TOOL_MODEL + base_url = settings.OLLAMA_URL.rstrip("/") + try: + client = await self._get_client() + resp = await client.post( + f"{base_url}/v1/chat/completions", + json={"model": model, "messages": [{"role": "user", "content": prompt}], + "temperature": temperature, "max_tokens": max_tokens}, + ) + data = resp.json() + return data.get("choices", [{}])[0].get("message", {}).get("content", "") + except Exception as e: + return f"Ollama chat error: {e}" + + # ============================================================================= # 單例與工廠函數 # ============================================================================= _provider: NvidiaProvider | None = None +_ollama_tool_provider: OllamaToolProvider | None = None -def get_nvidia_provider() -> NvidiaProvider: - """取得 NvidiaProvider 單例""" - global _provider +def get_nvidia_provider() -> "NvidiaProvider | OllamaToolProvider": + """ + 取得 Tool Calling Provider 單例。 + USE_OLLAMA_TOOL_CALLING=True (預設) → OllamaToolProvider (本機,~5s) + USE_OLLAMA_TOOL_CALLING=False → NvidiaProvider (雲端,~44s) + 2026-04-09 Claude Sonnet 4.6 + """ + global _provider, _ollama_tool_provider + if settings.USE_OLLAMA_TOOL_CALLING: + if _ollama_tool_provider is None: + _ollama_tool_provider = OllamaToolProvider() + logger.info("tool_calling_provider", provider="OllamaToolProvider", model=settings.OLLAMA_TOOL_MODEL) + return _ollama_tool_provider if _provider is None: _provider = NvidiaProvider() + logger.info("tool_calling_provider", provider="NvidiaProvider", model=NVIDIA_DEFAULT_MODEL) return _provider