feat: Ollama 本機 Tool Calling 取代 NVIDIA 雲端 (44s→~5s)
Some checks failed
CD Pipeline / build-and-deploy (push) Has been cancelled
Some checks failed
CD Pipeline / build-and-deploy (push) Has been cancelled
- nvidia_provider.py: 新增 OllamaToolProvider - 實作 INvidiaProvider protocol,打 Ollama /v1/chat/completions - 模型: llama3.1:8b (tool calling 最穩定的 8B) - 延遲: 44s → ~5s(本機 M1 Pro 192.168.0.111) - get_nvidia_provider() 根據 USE_OLLAMA_TOOL_CALLING 切換 - config.py: USE_OLLAMA_TOOL_CALLING=True (預設開啟), OLLAMA_TOOL_MODEL=llama3.1:8b - 回退: USE_OLLAMA_TOOL_CALLING=False → 恢復 NvidiaProvider 雲端 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -261,6 +261,15 @@ class Settings(BaseSettings):
|
||||
default="",
|
||||
description="NVIDIA NIM API key for Nemotron Tool Calling (ADR-036)",
|
||||
)
|
||||
# 2026-04-09 Claude Sonnet 4.6: Ollama Tool Calling — 替代 NVIDIA 雲端,本機推理
|
||||
USE_OLLAMA_TOOL_CALLING: bool = Field(
|
||||
default=True,
|
||||
description="使用 Ollama 本機做 Tool Calling,取代 NVIDIA NIM 雲端 (44s→5s)",
|
||||
)
|
||||
OLLAMA_TOOL_MODEL: str = Field(
|
||||
default="llama3.1:8b",
|
||||
description="Ollama Tool Calling 模型 (支援 function calling 格式)",
|
||||
)
|
||||
|
||||
@field_validator("AI_FALLBACK_ORDER", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -830,18 +830,202 @@ class NvidiaProvider:
|
||||
return str(e), False, 0, 0.0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OllamaToolProvider — 本機 Tool Calling,取代 NVIDIA 雲端
|
||||
# 2026-04-09 Claude Sonnet 4.6 Asia/Taipei
|
||||
# Ollama /v1/chat/completions 實作同一 INvidiaProvider protocol
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class OllamaToolProvider:
|
||||
"""
|
||||
Ollama 本機 Tool Calling Provider
|
||||
|
||||
使用 Ollama OpenAI 相容 API (/v1/chat/completions) 做 tool calling,
|
||||
取代 NVIDIA 雲端 NIM。延遲從 44s 降至 ~5s。
|
||||
|
||||
模型: llama3.1:8b (tool calling 最穩定的 8B 模型)
|
||||
Endpoint: OLLAMA_URL/v1/chat/completions (OpenAI 相容格式)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None or self._client.is_closed:
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(60.0, connect=5.0),
|
||||
limits=httpx.Limits(max_connections=5, max_keepalive_connections=3),
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def is_high_risk_tool(self, tool_name: str) -> bool:
|
||||
return tool_name in HIGH_RISK_TOOLS
|
||||
|
||||
def get_high_risk_tools(
|
||||
self, tool_calls: list[ToolCallValidationResult]
|
||||
) -> list[ToolCallValidationResult]:
|
||||
return [tc for tc in tool_calls if self.is_high_risk_tool(tc.tool_name)]
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
try:
|
||||
client = await self._get_client()
|
||||
base_url = settings.OLLAMA_URL.rstrip("/")
|
||||
resp = await client.get(f"{base_url}/api/tags", timeout=5.0)
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def tool_call(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[ToolDefinition | dict[str, Any]],
|
||||
model: str = "",
|
||||
temperature: float = 0.0,
|
||||
max_tokens: int = 512,
|
||||
) -> NvidiaProviderResult:
|
||||
"""Ollama /v1/chat/completions tool calling"""
|
||||
start_time = time.perf_counter()
|
||||
model = model or settings.OLLAMA_TOOL_MODEL
|
||||
base_url = settings.OLLAMA_URL.rstrip("/")
|
||||
url = f"{base_url}/v1/chat/completions"
|
||||
|
||||
# 轉換 tools 為 dict 格式(同 NvidiaProvider)
|
||||
tools_data = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, ToolDefinition):
|
||||
tools_data.append(tool.model_dump())
|
||||
else:
|
||||
tools_data.append(tool)
|
||||
|
||||
request_body = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": tools_data,
|
||||
"tool_choice": "auto",
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
response = await client.post(url, json=request_body)
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
"ollama_tool_call_http_error",
|
||||
status=response.status_code,
|
||||
body=response.text[:200],
|
||||
)
|
||||
return NvidiaProviderResult(
|
||||
success=False,
|
||||
error=f"Ollama HTTP {response.status_code}",
|
||||
latency_ms=latency_ms,
|
||||
fallback_triggered=True,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# 解析 tool_calls(OpenAI 格式)
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
return NvidiaProviderResult(
|
||||
success=False, error="Ollama 無回應", latency_ms=latency_ms, fallback_triggered=True
|
||||
)
|
||||
|
||||
message = choices[0].get("message", {})
|
||||
raw_tool_calls = message.get("tool_calls", [])
|
||||
|
||||
tool_call_results: list[ToolCallValidationResult] = []
|
||||
for tc in raw_tool_calls:
|
||||
fn = tc.get("function", {})
|
||||
name = fn.get("name", "")
|
||||
args_raw = fn.get("arguments", "{}")
|
||||
try:
|
||||
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_call_results.append(ToolCallValidationResult(
|
||||
tool_name=name,
|
||||
arguments=args,
|
||||
valid=bool(name),
|
||||
))
|
||||
|
||||
usage_data = data.get("usage", {})
|
||||
from src.models.nvidia import NvidiaUsage
|
||||
usage = NvidiaUsage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
logger.info(
|
||||
"ollama_tool_call_success",
|
||||
model=model,
|
||||
tool_count=len(tool_call_results),
|
||||
latency_ms=round(latency_ms, 1),
|
||||
tokens=usage.total_tokens,
|
||||
)
|
||||
|
||||
return NvidiaProviderResult(
|
||||
success=True,
|
||||
tool_calls=tool_call_results,
|
||||
latency_ms=latency_ms,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
logger.warning("ollama_tool_call_error", error=str(e), latency_ms=round(latency_ms, 1))
|
||||
return NvidiaProviderResult(
|
||||
success=False, error=str(e), latency_ms=latency_ms, fallback_triggered=True
|
||||
)
|
||||
|
||||
async def chat(self, prompt: str, model: str = "", temperature: float = 0.7, max_tokens: int = 512) -> str:
|
||||
"""簡單 chat(非 tool calling 路徑,保持 INvidiaProvider 相容)"""
|
||||
model = model or settings.OLLAMA_TOOL_MODEL
|
||||
base_url = settings.OLLAMA_URL.rstrip("/")
|
||||
try:
|
||||
client = await self._get_client()
|
||||
resp = await client.post(
|
||||
f"{base_url}/v1/chat/completions",
|
||||
json={"model": model, "messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": temperature, "max_tokens": max_tokens},
|
||||
)
|
||||
data = resp.json()
|
||||
return data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
except Exception as e:
|
||||
return f"Ollama chat error: {e}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 單例與工廠函數
|
||||
# =============================================================================
|
||||
|
||||
_provider: NvidiaProvider | None = None
|
||||
_ollama_tool_provider: OllamaToolProvider | None = None
|
||||
|
||||
|
||||
def get_nvidia_provider() -> NvidiaProvider:
|
||||
"""取得 NvidiaProvider 單例"""
|
||||
global _provider
|
||||
def get_nvidia_provider() -> "NvidiaProvider | OllamaToolProvider":
|
||||
"""
|
||||
取得 Tool Calling Provider 單例。
|
||||
USE_OLLAMA_TOOL_CALLING=True (預設) → OllamaToolProvider (本機,~5s)
|
||||
USE_OLLAMA_TOOL_CALLING=False → NvidiaProvider (雲端,~44s)
|
||||
2026-04-09 Claude Sonnet 4.6
|
||||
"""
|
||||
global _provider, _ollama_tool_provider
|
||||
if settings.USE_OLLAMA_TOOL_CALLING:
|
||||
if _ollama_tool_provider is None:
|
||||
_ollama_tool_provider = OllamaToolProvider()
|
||||
logger.info("tool_calling_provider", provider="OllamaToolProvider", model=settings.OLLAMA_TOOL_MODEL)
|
||||
return _ollama_tool_provider
|
||||
if _provider is None:
|
||||
_provider = NvidiaProvider()
|
||||
logger.info("tool_calling_provider", provider="NvidiaProvider", model=NVIDIA_DEFAULT_MODEL)
|
||||
return _provider
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user