1127 lines
40 KiB
Python
1127 lines
40 KiB
Python
"""
|
||
NVIDIA Nemotron Provider - ADR-036
|
||
==================================
|
||
2026-03-29 ogt: Nemotron Tool Calling 整合 (83.3% 精準度)
|
||
|
||
專門處理 Tool Calling 任務,提供高精準度的 K8s 操作決策。
|
||
|
||
設計原則:
|
||
1. OpenAI 相容格式 - 與 Nemotron API 對接
|
||
2. Pydantic 強制驗證 - 所有回應必須通過 Schema 驗證
|
||
3. Fallback 機制 - 失敗時降級到 Gemini/Claude
|
||
4. HITL 高風險保護 - DELETE 等操作需人工審核
|
||
|
||
版本: v1.0
|
||
建立: 2026-03-29 (台北時區)
|
||
建立者: Claude Code
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import random
|
||
import time
|
||
from enum import Enum
|
||
from typing import Any, Protocol, runtime_checkable # 2026-03-29 ogt: P2-1 Protocol
|
||
|
||
import httpx
|
||
import structlog
|
||
from prometheus_client import Counter, Histogram # 2026-03-29 ogt: P3-3 Prometheus
|
||
|
||
from src.core.config import get_settings
|
||
from src.core.telemetry import get_tracer # 2026-03-29 ogt: P1-2 OTEL 追蹤
|
||
from src.models.nvidia import (
|
||
NvidiaProviderResult,
|
||
NvidiaResponse,
|
||
ToolCallValidationResult,
|
||
ToolDefinition,
|
||
)
|
||
from src.services.langfuse_client import ( # 2026-03-29 ogt: P1-1 Langfuse 整合
|
||
LangfuseTraceContext,
|
||
)
|
||
from src.services.ollama_endpoint_resolver import (
|
||
resolve_ollama_endpoint,
|
||
resolve_ollama_order,
|
||
)
|
||
|
||
logger = structlog.get_logger(__name__)
|
||
settings = get_settings()
|
||
|
||
# OTEL Tracer (P1-2 修復)
|
||
_tracer = get_tracer("nvidia_provider")
|
||
|
||
|
||
# =============================================================================
|
||
# Protocol 定義 (P2-1 修復)
|
||
# =============================================================================
|
||
|
||
|
||
@runtime_checkable
|
||
class INvidiaProvider(Protocol):
|
||
"""
|
||
NVIDIA Provider Interface - P2-1 修復
|
||
|
||
2026-03-29 ogt: 定義 NvidiaProvider 介面,支援 DI 和測試替換
|
||
|
||
使用方式:
|
||
```python
|
||
def process_tool_call(provider: INvidiaProvider):
|
||
result = await provider.tool_call(messages, tools)
|
||
```
|
||
"""
|
||
|
||
async def tool_call(
|
||
self,
|
||
messages: list[dict[str, Any]],
|
||
tools: list[ToolDefinition | dict[str, Any]],
|
||
model: str = ...,
|
||
temperature: float = ...,
|
||
max_tokens: int = ...,
|
||
) -> NvidiaProviderResult:
|
||
"""執行 Tool Calling 請求"""
|
||
...
|
||
|
||
def is_high_risk_tool(self, tool_name: str) -> bool:
|
||
"""檢查是否為高風險 Tool"""
|
||
...
|
||
|
||
def get_high_risk_tools(
|
||
self, tool_calls: list[ToolCallValidationResult]
|
||
) -> list[ToolCallValidationResult]:
|
||
"""篩選高風險 Tool Calls"""
|
||
...
|
||
|
||
async def close(self) -> None:
|
||
"""關閉資源"""
|
||
...
|
||
|
||
async def chat(
|
||
self,
|
||
prompt: str,
|
||
model: str = ...,
|
||
temperature: float = ...,
|
||
max_tokens: int = ...,
|
||
) -> tuple[str, bool, int, float]:
|
||
"""
|
||
一般對話 (非 Tool Calling) - 2026-03-29 ogt 新增
|
||
|
||
Returns:
|
||
tuple: (response_text, success, total_tokens, cost_usd)
|
||
"""
|
||
...
|
||
|
||
# =============================================================================
|
||
# 常量定義
|
||
# =============================================================================
|
||
|
||
# NVIDIA NIM API Endpoint
|
||
NVIDIA_API_URL = "https://integrate.api.nvidia.com/v1/chat/completions"
|
||
|
||
# 預設模型 (2026-03-31 ogt: 恢復為 nemotron-mini-4b-instruct)
|
||
NVIDIA_DEFAULT_MODEL = "nvidia/nemotron-mini-4b-instruct"
|
||
|
||
# 請求超時 (秒)
|
||
# 2026-04-01 ogt: 設為 30s (平衡點)
|
||
# 2026-04-03 ogt: 改從 config 讀取,與 NEMOTRON_TIMEOUT_SECONDS=55 對齊
|
||
# Memory 記載 NIM 免費 tier 延遲 11-45s,30s 硬編碼導致慢請求全超時
|
||
def _get_nvidia_timeout() -> float:
|
||
try:
|
||
from src.core.config import get_settings
|
||
return float(get_settings().NEMOTRON_TIMEOUT_SECONDS)
|
||
except Exception:
|
||
return 45.0
|
||
|
||
NVIDIA_TIMEOUT = _get_nvidia_timeout()
|
||
|
||
# 重試次數
|
||
MAX_RETRIES = 2
|
||
|
||
# =============================================================================
|
||
# P3-1: Circuit Breaker 配置 (2026-03-29 ogt)
|
||
# =============================================================================
|
||
|
||
# Circuit Breaker 閾值
|
||
CIRCUIT_BREAKER_FAILURE_THRESHOLD = 3 # 連續失敗次數觸發斷路
|
||
CIRCUIT_BREAKER_RECOVERY_TIMEOUT = 60 # 斷路後等待恢復時間 (秒)
|
||
CIRCUIT_BREAKER_HALF_OPEN_REQUESTS = 1 # 半開狀態允許的測試請求數
|
||
|
||
# P3-2: 指數退避配置
|
||
RETRY_BASE_DELAY = 1.0 # 基礎延遲 (秒)
|
||
RETRY_MAX_DELAY = 30.0 # 最大延遲 (秒)
|
||
RETRY_EXPONENTIAL_BASE = 2 # 指數基數
|
||
|
||
# =============================================================================
|
||
# P3-3: Prometheus Metrics (2026-03-29 ogt)
|
||
# =============================================================================
|
||
|
||
NVIDIA_REQUESTS_TOTAL = Counter(
|
||
"nvidia_tool_call_requests_total",
|
||
"Total NVIDIA Tool Calling requests",
|
||
["status", "tool_name"],
|
||
)
|
||
|
||
NVIDIA_LATENCY_HISTOGRAM = Histogram(
|
||
"nvidia_tool_call_latency_seconds",
|
||
"NVIDIA Tool Calling latency in seconds",
|
||
buckets=[1, 5, 10, 15, 20, 30, 45, 60],
|
||
)
|
||
|
||
NVIDIA_CIRCUIT_BREAKER_STATE = Counter(
|
||
"nvidia_circuit_breaker_state_changes_total",
|
||
"Circuit breaker state changes",
|
||
["from_state", "to_state"],
|
||
)
|
||
|
||
|
||
# =============================================================================
|
||
# P3-1: Circuit Breaker 狀態機 (2026-03-29 ogt)
|
||
# =============================================================================
|
||
|
||
|
||
class CircuitState(Enum):
|
||
"""Circuit Breaker 狀態"""
|
||
|
||
CLOSED = "closed" # 正常運作
|
||
OPEN = "open" # 斷路,拒絕請求
|
||
HALF_OPEN = "half_open" # 測試恢復
|
||
|
||
|
||
class CircuitBreaker:
|
||
"""
|
||
Circuit Breaker 實作 - P3-1 優化
|
||
|
||
防止連鎖故障,當 NVIDIA API 連續失敗時自動斷路。
|
||
|
||
狀態轉換:
|
||
CLOSED → (連續失敗 >= 3) → OPEN
|
||
OPEN → (等待 60s) → HALF_OPEN
|
||
HALF_OPEN → (成功) → CLOSED
|
||
HALF_OPEN → (失敗) → OPEN
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
failure_threshold: int = CIRCUIT_BREAKER_FAILURE_THRESHOLD,
|
||
recovery_timeout: float = CIRCUIT_BREAKER_RECOVERY_TIMEOUT,
|
||
):
|
||
self._state = CircuitState.CLOSED
|
||
self._failure_count = 0
|
||
self._last_failure_time: float = 0
|
||
self._failure_threshold = failure_threshold
|
||
self._recovery_timeout = recovery_timeout
|
||
|
||
@property
|
||
def state(self) -> CircuitState:
|
||
"""取得當前狀態 (含自動轉換檢查)"""
|
||
if self._state == CircuitState.OPEN:
|
||
# 檢查是否應該轉為 HALF_OPEN
|
||
if time.time() - self._last_failure_time >= self._recovery_timeout:
|
||
self._transition_to(CircuitState.HALF_OPEN)
|
||
return self._state
|
||
|
||
def _transition_to(self, new_state: CircuitState) -> None:
|
||
"""狀態轉換 (含 Prometheus 記錄)"""
|
||
old_state = self._state
|
||
if old_state != new_state:
|
||
NVIDIA_CIRCUIT_BREAKER_STATE.labels(
|
||
from_state=old_state.value, to_state=new_state.value
|
||
).inc()
|
||
logger.info(
|
||
"circuit_breaker_state_change",
|
||
from_state=old_state.value,
|
||
to_state=new_state.value,
|
||
)
|
||
self._state = new_state
|
||
|
||
def can_execute(self) -> bool:
|
||
"""是否允許執行請求"""
|
||
state = self.state # 觸發自動狀態檢查
|
||
if state == CircuitState.CLOSED:
|
||
return True
|
||
if state == CircuitState.HALF_OPEN:
|
||
return True # 允許測試請求
|
||
return False # OPEN 狀態拒絕
|
||
|
||
def record_success(self) -> None:
|
||
"""記錄成功"""
|
||
if self._state == CircuitState.HALF_OPEN:
|
||
self._transition_to(CircuitState.CLOSED)
|
||
self._failure_count = 0
|
||
|
||
def record_failure(self) -> None:
|
||
"""記錄失敗"""
|
||
self._failure_count += 1
|
||
self._last_failure_time = time.time()
|
||
|
||
if self._state == CircuitState.HALF_OPEN:
|
||
# HALF_OPEN 失敗,重新斷路
|
||
self._transition_to(CircuitState.OPEN)
|
||
elif self._failure_count >= self._failure_threshold:
|
||
# 連續失敗達閾值,斷路
|
||
self._transition_to(CircuitState.OPEN)
|
||
|
||
|
||
# 高風險 Tool 清單 (需要 HITL 審核)
|
||
HIGH_RISK_TOOLS: set[str] = {
|
||
"delete_pod",
|
||
"delete_deployment",
|
||
"delete_namespace",
|
||
"delete_service",
|
||
"delete_configmap",
|
||
"delete_secret",
|
||
"scale_to_zero",
|
||
"drain_node",
|
||
"cordon_node",
|
||
"delete_pvc",
|
||
"delete_pv",
|
||
}
|
||
|
||
|
||
# =============================================================================
|
||
# NvidiaProvider 類別
|
||
# =============================================================================
|
||
|
||
|
||
class NvidiaProvider:
|
||
"""
|
||
NVIDIA Nemotron Provider
|
||
|
||
專門處理 Tool Calling 任務,提供 83.3% 精準度的 K8s 操作決策。
|
||
|
||
使用方式:
|
||
```python
|
||
provider = NvidiaProvider()
|
||
result = await provider.tool_call(
|
||
messages=[{"role": "user", "content": "重啟 awoooi-api pod"}],
|
||
tools=[restart_tool, scale_tool],
|
||
)
|
||
if result.success:
|
||
for tc in result.tool_calls:
|
||
if tc.valid:
|
||
execute_tool(tc.tool_name, tc.arguments)
|
||
```
|
||
"""
|
||
|
||
def __init__(self, api_key: str | None = None):
|
||
"""
|
||
初始化 NvidiaProvider
|
||
|
||
Args:
|
||
api_key: NVIDIA API Key (預設從 settings 取得)
|
||
|
||
2026-03-29 ogt: P3-1 加入 Circuit Breaker
|
||
"""
|
||
self._api_key = api_key or settings.NVIDIA_API_KEY
|
||
self._client: httpx.AsyncClient | None = None
|
||
self._circuit_breaker = CircuitBreaker() # P3-1: Circuit Breaker
|
||
|
||
async def _get_client(self) -> httpx.AsyncClient:
|
||
"""取得或建立 HTTP Client"""
|
||
if self._client is None or self._client.is_closed:
|
||
self._client = httpx.AsyncClient(
|
||
timeout=httpx.Timeout(NVIDIA_TIMEOUT, connect=10.0),
|
||
limits=httpx.Limits(max_connections=10, max_keepalive_connections=5),
|
||
)
|
||
return self._client
|
||
|
||
async def close(self) -> None:
|
||
"""關閉 HTTP Client"""
|
||
if self._client and not self._client.is_closed:
|
||
await self._client.aclose()
|
||
self._client = None
|
||
|
||
async def tool_call(
|
||
self,
|
||
messages: list[dict[str, Any]],
|
||
tools: list[ToolDefinition | dict[str, Any]],
|
||
model: str = NVIDIA_DEFAULT_MODEL,
|
||
temperature: float = 0.0,
|
||
max_tokens: int = 1024,
|
||
) -> NvidiaProviderResult:
|
||
"""
|
||
執行 Tool Calling 請求
|
||
|
||
Args:
|
||
messages: 對話訊息列表
|
||
tools: 可用 Tool 定義列表
|
||
model: 模型名稱
|
||
temperature: 溫度 (0.0 最確定性)
|
||
max_tokens: 最大輸出 Token
|
||
|
||
Returns:
|
||
NvidiaProviderResult: 包含驗證後的 Tool Calls
|
||
|
||
2026-03-29 ogt: P1-1/P1-2 修復 - 加入 OTEL + Langfuse 追蹤
|
||
2026-03-29 ogt: P3-1/P3-2/P3-3 - Circuit Breaker + 指數退避 + Prometheus
|
||
"""
|
||
start_time = time.perf_counter()
|
||
|
||
# P1-2: OTEL Span 包裝整個 Tool Calling 流程
|
||
with _tracer.start_as_current_span("nvidia_tool_call") as span:
|
||
span.set_attribute("ai.provider", "nvidia")
|
||
span.set_attribute("ai.model", model)
|
||
span.set_attribute("ai.tool_count", len(tools))
|
||
|
||
# P3-1: Circuit Breaker 檢查
|
||
if not self._circuit_breaker.can_execute():
|
||
span.set_attribute("ai.error", "circuit_breaker_open")
|
||
NVIDIA_REQUESTS_TOTAL.labels(status="circuit_open", tool_name="").inc()
|
||
logger.warning(
|
||
"nvidia_circuit_breaker_open",
|
||
state=self._circuit_breaker.state.value,
|
||
)
|
||
return NvidiaProviderResult(
|
||
success=False,
|
||
error="Circuit Breaker OPEN - NVIDIA API 暫時不可用",
|
||
fallback_triggered=True,
|
||
)
|
||
|
||
# 檢查 API Key
|
||
if not self._api_key:
|
||
span.set_attribute("ai.error", "api_key_not_set")
|
||
NVIDIA_REQUESTS_TOTAL.labels(status="error", tool_name="").inc()
|
||
return NvidiaProviderResult(
|
||
success=False,
|
||
error="NVIDIA_API_KEY 未設定",
|
||
fallback_triggered=True,
|
||
)
|
||
|
||
# 轉換 tools 為 dict 格式
|
||
tools_data = []
|
||
tool_names = []
|
||
for tool in tools:
|
||
if isinstance(tool, ToolDefinition):
|
||
tools_data.append(tool.model_dump())
|
||
tool_names.append(tool.function.get("name", "unknown"))
|
||
else:
|
||
tools_data.append(tool)
|
||
tool_names.append(tool.get("function", {}).get("name", "unknown"))
|
||
|
||
span.set_attribute("ai.tool_names", ",".join(tool_names))
|
||
|
||
# 建立請求
|
||
request_body = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"tools": tools_data,
|
||
"tool_choice": "auto",
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
}
|
||
|
||
# P1-1: Langfuse 追蹤
|
||
with LangfuseTraceContext(
|
||
name="nvidia_tool_call",
|
||
metadata={"model": model, "tool_count": len(tools)},
|
||
) as langfuse_ctx:
|
||
|
||
# 執行請求 (含 P3-2 指數退避重試)
|
||
response_data: dict | None = None
|
||
last_error: str | None = None
|
||
|
||
for attempt in range(MAX_RETRIES + 1):
|
||
try:
|
||
response_data = await self._send_request(request_body)
|
||
self._circuit_breaker.record_success() # P3-1
|
||
break
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
span.set_attribute(f"ai.retry.{attempt}", last_error)
|
||
logger.warning(
|
||
"nvidia_request_retry",
|
||
attempt=attempt + 1,
|
||
max_retries=MAX_RETRIES,
|
||
error=last_error,
|
||
)
|
||
if attempt == MAX_RETRIES:
|
||
self._circuit_breaker.record_failure() # P3-1
|
||
break
|
||
# P3-2: 指數退避 (含 jitter)
|
||
delay = min(
|
||
RETRY_BASE_DELAY * (RETRY_EXPONENTIAL_BASE ** attempt),
|
||
RETRY_MAX_DELAY,
|
||
)
|
||
jitter = random.uniform(0, delay * 0.1) # 10% jitter
|
||
await asyncio.sleep(delay + jitter)
|
||
|
||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||
latency_seconds = latency_ms / 1000
|
||
span.set_attribute("ai.latency_ms", round(latency_ms, 2))
|
||
NVIDIA_LATENCY_HISTOGRAM.observe(latency_seconds) # P3-3
|
||
|
||
# 請求失敗
|
||
if response_data is None:
|
||
span.set_attribute("ai.success", False)
|
||
span.set_attribute("ai.error", last_error or "unknown")
|
||
NVIDIA_REQUESTS_TOTAL.labels(status="error", tool_name="").inc()
|
||
logger.error(
|
||
"nvidia_request_failed",
|
||
error=last_error,
|
||
latency_ms=round(latency_ms, 2),
|
||
)
|
||
return NvidiaProviderResult(
|
||
success=False,
|
||
error=last_error,
|
||
latency_ms=latency_ms,
|
||
fallback_triggered=True,
|
||
)
|
||
|
||
# 解析回應
|
||
try:
|
||
nvidia_response = NvidiaResponse.model_validate(response_data)
|
||
except Exception as e:
|
||
span.set_attribute("ai.success", False)
|
||
span.set_attribute("ai.error", f"parse_failed: {e}")
|
||
logger.error(
|
||
"nvidia_response_parse_failed",
|
||
error=str(e),
|
||
raw_response=str(response_data)[:500],
|
||
)
|
||
return NvidiaProviderResult(
|
||
success=False,
|
||
error=f"回應解析失敗: {e}",
|
||
latency_ms=latency_ms,
|
||
fallback_triggered=True,
|
||
)
|
||
|
||
# 驗證 Tool Calls
|
||
tool_calls = self._validate_tool_calls(nvidia_response)
|
||
|
||
# 統計
|
||
usage = nvidia_response.usage
|
||
prompt_tokens = usage.prompt_tokens if usage else 0
|
||
completion_tokens = usage.completion_tokens if usage else 0
|
||
total_tokens = usage.total_tokens if usage else 0
|
||
|
||
# P1-2: OTEL 屬性
|
||
span.set_attribute("ai.success", True)
|
||
span.set_attribute("ai.tool_call_count", len(tool_calls))
|
||
span.set_attribute(
|
||
"ai.valid_count", sum(1 for tc in tool_calls if tc.valid)
|
||
)
|
||
span.set_attribute("ai.prompt_tokens", prompt_tokens)
|
||
span.set_attribute("ai.completion_tokens", completion_tokens)
|
||
span.set_attribute("ai.total_tokens", total_tokens)
|
||
|
||
# P1-1: Langfuse Generation 記錄
|
||
if langfuse_ctx and hasattr(langfuse_ctx, "generation"):
|
||
try:
|
||
langfuse_ctx.generation(
|
||
name="nvidia_nemotron",
|
||
model=model,
|
||
input={"messages": messages, "tools": tool_names},
|
||
output={
|
||
"tool_calls": [
|
||
{"name": tc.tool_name, "args": tc.arguments}
|
||
for tc in tool_calls
|
||
if tc.valid
|
||
]
|
||
},
|
||
usage={"input": prompt_tokens, "output": completion_tokens},
|
||
metadata={
|
||
"latency_ms": round(latency_ms, 2),
|
||
"valid_count": sum(1 for tc in tool_calls if tc.valid),
|
||
},
|
||
)
|
||
except Exception as le:
|
||
logger.warning("langfuse_generation_failed_safe", error=str(le))
|
||
|
||
# P3-3: Prometheus 成功指標
|
||
for tc in tool_calls:
|
||
if tc.valid and tc.tool_name:
|
||
NVIDIA_REQUESTS_TOTAL.labels(
|
||
status="success", tool_name=tc.tool_name
|
||
).inc()
|
||
|
||
logger.info(
|
||
"nvidia_tool_call_completed",
|
||
success=True,
|
||
tool_call_count=len(tool_calls),
|
||
valid_count=sum(1 for tc in tool_calls if tc.valid),
|
||
latency_ms=round(latency_ms, 2),
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
)
|
||
|
||
return NvidiaProviderResult(
|
||
success=True,
|
||
tool_calls=tool_calls,
|
||
usage=usage,
|
||
latency_ms=latency_ms,
|
||
fallback_triggered=False,
|
||
)
|
||
|
||
async def _send_request(self, request_body: dict) -> dict:
|
||
"""
|
||
發送 HTTP 請求到 NVIDIA API
|
||
|
||
Args:
|
||
request_body: 請求內容
|
||
|
||
Returns:
|
||
API 回應 (dict)
|
||
|
||
Raises:
|
||
Exception: 請求失敗
|
||
"""
|
||
client = await self._get_client()
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self._api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
response = await client.post(
|
||
NVIDIA_API_URL,
|
||
headers=headers,
|
||
json=request_body,
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
error_text = response.text[:500]
|
||
raise Exception(
|
||
f"NVIDIA API 錯誤: {response.status_code} - {error_text}"
|
||
)
|
||
|
||
return response.json()
|
||
|
||
def _validate_tool_calls(
|
||
self, response: NvidiaResponse
|
||
) -> list[ToolCallValidationResult]:
|
||
"""
|
||
驗證 Tool Calls
|
||
|
||
Args:
|
||
response: NVIDIA API 回應
|
||
|
||
Returns:
|
||
驗證後的 Tool Call 結果列表
|
||
"""
|
||
results: list[ToolCallValidationResult] = []
|
||
|
||
if not response.choices:
|
||
return results
|
||
|
||
message = response.choices[0].message
|
||
if not message.tool_calls:
|
||
return results
|
||
|
||
for tc in message.tool_calls:
|
||
try:
|
||
# 解析 arguments JSON
|
||
arguments = json.loads(tc.function.arguments)
|
||
|
||
results.append(
|
||
ToolCallValidationResult(
|
||
valid=True,
|
||
tool_name=tc.function.name,
|
||
arguments=arguments,
|
||
)
|
||
)
|
||
except json.JSONDecodeError as e:
|
||
results.append(
|
||
ToolCallValidationResult(
|
||
valid=False,
|
||
tool_name=tc.function.name,
|
||
error=f"Arguments JSON 解析失敗: {e}",
|
||
raw_response=tc.function.arguments,
|
||
)
|
||
)
|
||
except Exception as e:
|
||
results.append(
|
||
ToolCallValidationResult(
|
||
valid=False,
|
||
error=f"驗證失敗: {e}",
|
||
)
|
||
)
|
||
|
||
return results
|
||
|
||
def is_high_risk_tool(self, tool_name: str) -> bool:
|
||
"""
|
||
檢查是否為高風險 Tool
|
||
|
||
Args:
|
||
tool_name: Tool 名稱
|
||
|
||
Returns:
|
||
是否需要 HITL 審核
|
||
"""
|
||
return tool_name.lower() in HIGH_RISK_TOOLS
|
||
|
||
def get_high_risk_tools(
|
||
self, tool_calls: list[ToolCallValidationResult]
|
||
) -> list[ToolCallValidationResult]:
|
||
"""
|
||
篩選高風險 Tool Calls
|
||
|
||
Args:
|
||
tool_calls: Tool Call 結果列表
|
||
|
||
Returns:
|
||
高風險 Tool Calls
|
||
"""
|
||
return [
|
||
tc
|
||
for tc in tool_calls
|
||
if tc.valid and tc.tool_name and self.is_high_risk_tool(tc.tool_name)
|
||
]
|
||
|
||
async def chat(
|
||
self,
|
||
prompt: str,
|
||
model: str | None = None,
|
||
temperature: float = 0.1,
|
||
max_tokens: int = 2048,
|
||
use_json_mode: bool = False,
|
||
) -> tuple[str, bool, int, float]:
|
||
"""
|
||
一般對話 (非 Tool Calling) - 用於 RCA 分析或自由對話
|
||
|
||
2026-03-29 ogt: 新增,符合模組化規範
|
||
2026-03-31 ogt: 新增 use_json_mode 參數 (Phase 22 修復)
|
||
- RCA/分析場景: use_json_mode=True (結構化輸出)
|
||
- 對話場景: use_json_mode=False (預設,自然語言回應)
|
||
|
||
Args:
|
||
prompt: 對話內容
|
||
model: 模型名稱 (預設從 ModelRegistry 取得)
|
||
temperature: 溫度
|
||
max_tokens: 最大輸出 Token
|
||
|
||
Returns:
|
||
tuple: (response_text, success, total_tokens, cost_usd)
|
||
"""
|
||
start_time = time.perf_counter()
|
||
|
||
# OTEL Span
|
||
with _tracer.start_as_current_span("nvidia_chat") as span:
|
||
span.set_attribute("ai.provider", "nvidia")
|
||
|
||
# Circuit Breaker 檢查
|
||
if not self._circuit_breaker.can_execute():
|
||
span.set_attribute("ai.error", "circuit_breaker_open")
|
||
NVIDIA_REQUESTS_TOTAL.labels(status="circuit_open", tool_name="chat").inc()
|
||
logger.warning("nvidia_chat_circuit_breaker_open")
|
||
return "Circuit Breaker OPEN - NVIDIA API 暫時不可用", False, 0, 0.0
|
||
|
||
# 檢查 API Key
|
||
if not self._api_key:
|
||
span.set_attribute("ai.error", "api_key_not_set")
|
||
return "NVIDIA_API_KEY not configured", False, 0, 0.0
|
||
|
||
# 從 ModelRegistry 取得模型
|
||
from src.services.model_registry import get_model_registry
|
||
registry = get_model_registry()
|
||
model_name = model or registry.get_model("nvidia", "rca")
|
||
|
||
span.set_attribute("ai.model", model_name)
|
||
|
||
logger.info(
|
||
"nvidia_chat_request_start",
|
||
model=model_name,
|
||
prompt_length=len(prompt),
|
||
)
|
||
|
||
# Langfuse 追蹤
|
||
with LangfuseTraceContext(
|
||
name="nvidia_chat",
|
||
metadata={"model": model_name, "task": "rca"},
|
||
) as langfuse_ctx:
|
||
try:
|
||
client = await self._get_client()
|
||
|
||
response = await client.post(
|
||
NVIDIA_API_URL,
|
||
headers={
|
||
"Authorization": f"Bearer {self._api_key}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
json={
|
||
"model": model_name,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
**({"response_format": {"type": "json_object"}} if use_json_mode else {}),
|
||
},
|
||
)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
|
||
self._circuit_breaker.record_success()
|
||
|
||
text = data["choices"][0]["message"]["content"]
|
||
|
||
# Token 用量
|
||
usage = data.get("usage", {})
|
||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||
completion_tokens = usage.get("completion_tokens", 0)
|
||
total_tokens = usage.get("total_tokens", prompt_tokens + completion_tokens)
|
||
|
||
# NVIDIA NIM 免費 tier = $0
|
||
cost_usd = 0.0
|
||
|
||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||
span.set_attribute("ai.latency_ms", latency_ms)
|
||
span.set_attribute("ai.total_tokens", total_tokens)
|
||
|
||
# Prometheus
|
||
NVIDIA_REQUESTS_TOTAL.labels(status="success", tool_name="chat").inc()
|
||
NVIDIA_LATENCY_HISTOGRAM.observe(latency_ms / 1000)
|
||
|
||
# Langfuse
|
||
if langfuse_ctx and hasattr(langfuse_ctx, "generation"):
|
||
try:
|
||
langfuse_ctx.generation(
|
||
name="nvidia_chat",
|
||
model=model_name,
|
||
input=prompt[:500],
|
||
output=text[:500],
|
||
metadata={
|
||
"total_tokens": total_tokens,
|
||
"cost_usd": cost_usd,
|
||
"latency_ms": round(latency_ms, 2),
|
||
},
|
||
)
|
||
except Exception as le:
|
||
logger.warning("langfuse_chat_generation_failed_safe", error=str(le))
|
||
|
||
logger.info(
|
||
"nvidia_chat_response_received",
|
||
model=model_name,
|
||
response_length=len(text),
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
total_tokens=total_tokens,
|
||
latency_ms=round(latency_ms, 2),
|
||
)
|
||
|
||
return text, True, total_tokens, cost_usd
|
||
|
||
except httpx.TimeoutException as e:
|
||
# 2026-04-01 ogt: timeout 不計入 circuit breaker
|
||
# Nemo free tier 偶爾慢是正常的,下次請求仍應優先嘗試
|
||
# 只有硬性錯誤 (auth/rate limit) 才應斷路
|
||
NVIDIA_REQUESTS_TOTAL.labels(status="timeout", tool_name="chat").inc()
|
||
logger.warning("nvidia_chat_timeout", error=str(e))
|
||
return f"Timeout: {e}", False, 0, 0.0
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
# 2026-03-31 ogt: 記錄完整響應體以診斷 400 錯誤
|
||
self._circuit_breaker.record_failure() # 硬性錯誤才斷路
|
||
NVIDIA_REQUESTS_TOTAL.labels(status="error", tool_name="chat").inc()
|
||
response_text = e.response.text if e.response else "No response body"
|
||
logger.warning(
|
||
"nvidia_chat_failed",
|
||
error=str(e),
|
||
error_type="HTTPStatusError",
|
||
status_code=e.response.status_code if e.response else None,
|
||
response_body=response_text[:500], # 截斷避免日誌過大
|
||
)
|
||
return str(e), False, 0, 0.0
|
||
|
||
except Exception as e:
|
||
self._circuit_breaker.record_failure()
|
||
NVIDIA_REQUESTS_TOTAL.labels(status="error", tool_name="chat").inc()
|
||
import traceback
|
||
logger.warning(
|
||
"nvidia_chat_failed",
|
||
error=str(e),
|
||
error_type=type(e).__name__,
|
||
stacktrace=traceback.format_exc()
|
||
)
|
||
return str(e), False, 0, 0.0
|
||
|
||
|
||
# =============================================================================
|
||
# OllamaToolProvider — 本機 Tool Calling,取代 NVIDIA 雲端
|
||
# 2026-04-09 Claude Sonnet 4.6 Asia/Taipei
|
||
# Ollama /v1/chat/completions 實作同一 INvidiaProvider protocol
|
||
# =============================================================================
|
||
|
||
|
||
class OllamaToolProvider:
|
||
"""
|
||
Ollama 本機 Tool Calling Provider
|
||
|
||
使用 Ollama OpenAI 相容 API (/v1/chat/completions) 做 tool calling,
|
||
取代 NVIDIA 雲端 NIM。延遲從 44s 降至 ~5s。
|
||
|
||
模型: llama3.1:8b (tool calling 最穩定的 8B 模型)
|
||
Endpoint: local tool lane /v1/chat/completions (OpenAI 相容格式)
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
self._client: httpx.AsyncClient | None = None
|
||
|
||
async def _get_client(self) -> httpx.AsyncClient:
|
||
if self._client is None or self._client.is_closed:
|
||
self._client = httpx.AsyncClient(
|
||
timeout=httpx.Timeout(60.0, connect=5.0),
|
||
limits=httpx.Limits(max_connections=5, max_keepalive_connections=3),
|
||
)
|
||
return self._client
|
||
|
||
async def close(self) -> None:
|
||
if self._client and not self._client.is_closed:
|
||
await self._client.aclose()
|
||
self._client = None
|
||
|
||
def is_high_risk_tool(self, tool_name: str) -> bool:
|
||
return tool_name in HIGH_RISK_TOOLS
|
||
|
||
def get_high_risk_tools(
|
||
self, tool_calls: list[ToolCallValidationResult]
|
||
) -> list[ToolCallValidationResult]:
|
||
return [tc for tc in tool_calls if self.is_high_risk_tool(tc.tool_name)]
|
||
|
||
def _base_url(self) -> str:
|
||
"""Return the first Hermes/tool endpoint for backward compatibility."""
|
||
return resolve_ollama_endpoint("hermes").rstrip("/")
|
||
|
||
def _base_urls(self) -> list[str]:
|
||
"""Tool-calling/Hermes follows GCP-A -> GCP-B -> 111."""
|
||
urls = [self._base_url()]
|
||
urls.extend(endpoint.url.rstrip("/") for endpoint in resolve_ollama_order("hermes") if endpoint.url)
|
||
deduped: list[str] = []
|
||
for url in urls:
|
||
if url and url not in deduped:
|
||
deduped.append(url)
|
||
return deduped
|
||
|
||
async def health_check(self) -> bool:
|
||
try:
|
||
client = await self._get_client()
|
||
for base_url in self._base_urls():
|
||
try:
|
||
resp = await client.get(f"{base_url}/api/tags", timeout=5.0)
|
||
if resp.status_code == 200:
|
||
return True
|
||
except Exception:
|
||
continue
|
||
return False
|
||
except Exception:
|
||
return False
|
||
|
||
async def tool_call(
|
||
self,
|
||
messages: list[dict[str, Any]],
|
||
tools: list[ToolDefinition | dict[str, Any]],
|
||
model: str = "",
|
||
temperature: float = 0.0,
|
||
max_tokens: int = 512,
|
||
) -> NvidiaProviderResult:
|
||
"""Ollama /v1/chat/completions tool calling"""
|
||
start_time = time.perf_counter()
|
||
model = model or settings.OLLAMA_TOOL_MODEL
|
||
|
||
# 轉換 tools 為 dict 格式(同 NvidiaProvider)
|
||
tools_data = []
|
||
for tool in tools:
|
||
if isinstance(tool, ToolDefinition):
|
||
tools_data.append(tool.model_dump())
|
||
else:
|
||
tools_data.append(tool)
|
||
|
||
request_body = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"tools": tools_data,
|
||
"tool_choice": "auto",
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
}
|
||
|
||
try:
|
||
client = await self._get_client()
|
||
last_error = ""
|
||
for base_url in self._base_urls():
|
||
response = await client.post(
|
||
f"{base_url}/v1/chat/completions",
|
||
json=request_body,
|
||
)
|
||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||
|
||
if response.status_code != 200:
|
||
last_error = f"Ollama HTTP {response.status_code}"
|
||
logger.warning(
|
||
"ollama_tool_call_http_error",
|
||
status=response.status_code,
|
||
body=response.text[:200],
|
||
endpoint=base_url,
|
||
)
|
||
continue
|
||
|
||
data = response.json()
|
||
# 解析 tool_calls(OpenAI 格式)
|
||
choices = data.get("choices", [])
|
||
if not choices:
|
||
last_error = "Ollama 無回應"
|
||
continue
|
||
|
||
message = choices[0].get("message", {})
|
||
raw_tool_calls = message.get("tool_calls", [])
|
||
|
||
tool_call_results: list[ToolCallValidationResult] = []
|
||
for tc in raw_tool_calls:
|
||
fn = tc.get("function", {})
|
||
name = fn.get("name", "")
|
||
args_raw = fn.get("arguments", "{}")
|
||
try:
|
||
args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
|
||
except json.JSONDecodeError:
|
||
args = {}
|
||
tool_call_results.append(ToolCallValidationResult(
|
||
tool_name=name,
|
||
arguments=args,
|
||
valid=bool(name),
|
||
))
|
||
|
||
usage_data = data.get("usage", {})
|
||
from src.models.nvidia import NvidiaUsage
|
||
usage = NvidiaUsage(
|
||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||
total_tokens=usage_data.get("total_tokens", 0),
|
||
)
|
||
logger.info(
|
||
"ollama_tool_call_success",
|
||
model=model,
|
||
tool_count=len(tool_call_results),
|
||
latency_ms=round(latency_ms, 1),
|
||
tokens=usage.total_tokens,
|
||
endpoint=base_url,
|
||
)
|
||
|
||
return NvidiaProviderResult(
|
||
success=True,
|
||
tool_calls=tool_call_results,
|
||
latency_ms=latency_ms,
|
||
usage=usage,
|
||
)
|
||
|
||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||
return NvidiaProviderResult(
|
||
success=False,
|
||
error=last_error or "Ollama 無回應",
|
||
latency_ms=latency_ms,
|
||
fallback_triggered=True,
|
||
)
|
||
|
||
except Exception as e:
|
||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||
logger.warning("ollama_tool_call_error", error=str(e), latency_ms=round(latency_ms, 1))
|
||
return NvidiaProviderResult(
|
||
success=False, error=str(e), latency_ms=latency_ms, fallback_triggered=True
|
||
)
|
||
|
||
async def chat(self, prompt: str, model: str = "", temperature: float = 0.7, max_tokens: int = 512) -> str:
|
||
"""簡單 chat(非 tool calling 路徑,保持 INvidiaProvider 相容)"""
|
||
model = model or settings.OLLAMA_TOOL_MODEL
|
||
try:
|
||
client = await self._get_client()
|
||
last_error = ""
|
||
for base_url in self._base_urls():
|
||
try:
|
||
resp = await client.post(
|
||
f"{base_url}/v1/chat/completions",
|
||
json={"model": model, "messages": [{"role": "user", "content": prompt}],
|
||
"temperature": temperature, "max_tokens": max_tokens},
|
||
)
|
||
if resp.status_code != 200:
|
||
last_error = f"http_{resp.status_code}"
|
||
continue
|
||
data = resp.json()
|
||
return data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
continue
|
||
return f"Ollama chat error: {last_error or 'no endpoint'}"
|
||
except Exception as e:
|
||
return f"Ollama chat error: {e}"
|
||
|
||
|
||
# =============================================================================
|
||
# 單例與工廠函數
|
||
# =============================================================================
|
||
|
||
_provider: NvidiaProvider | None = None
|
||
_ollama_tool_provider: OllamaToolProvider | None = None
|
||
|
||
|
||
def get_nvidia_provider() -> NvidiaProvider | OllamaToolProvider:
|
||
"""
|
||
取得 Tool Calling Provider 單例。
|
||
USE_OLLAMA_TOOL_CALLING=True (預設) → OllamaToolProvider (本機,~5s)
|
||
USE_OLLAMA_TOOL_CALLING=False → NvidiaProvider (雲端,~44s)
|
||
2026-04-09 Claude Sonnet 4.6
|
||
"""
|
||
global _provider, _ollama_tool_provider
|
||
if settings.USE_OLLAMA_TOOL_CALLING:
|
||
if _ollama_tool_provider is None:
|
||
_ollama_tool_provider = OllamaToolProvider()
|
||
logger.info("tool_calling_provider", provider="OllamaToolProvider", model=settings.OLLAMA_TOOL_MODEL)
|
||
return _ollama_tool_provider
|
||
if _provider is None:
|
||
_provider = NvidiaProvider()
|
||
logger.info("tool_calling_provider", provider="NvidiaProvider", model=NVIDIA_DEFAULT_MODEL)
|
||
return _provider
|
||
|
||
|
||
def reset_nvidia_provider() -> None:
|
||
"""重置單例 (用於測試)"""
|
||
global _provider
|
||
_provider = None
|
||
|
||
|
||
# =============================================================================
|
||
# 便捷函數
|
||
# =============================================================================
|
||
|
||
|
||
async def nvidia_tool_call(
|
||
messages: list[dict[str, Any]],
|
||
tools: list[ToolDefinition | dict[str, Any]],
|
||
**kwargs,
|
||
) -> NvidiaProviderResult:
|
||
"""
|
||
便捷函數: 執行 NVIDIA Tool Calling
|
||
|
||
Args:
|
||
messages: 對話訊息列表
|
||
tools: 可用 Tool 定義列表
|
||
**kwargs: 其他參數 (model, temperature, max_tokens)
|
||
|
||
Returns:
|
||
NvidiaProviderResult
|
||
"""
|
||
provider = get_nvidia_provider()
|
||
return await provider.tool_call(messages, tools, **kwargs)
|
||
|
||
|
||
def create_tool_definition(
|
||
name: str,
|
||
description: str,
|
||
parameters: dict[str, Any],
|
||
) -> ToolDefinition:
|
||
"""
|
||
建立 Tool 定義
|
||
|
||
Args:
|
||
name: Tool 名稱
|
||
description: Tool 描述
|
||
parameters: JSON Schema 參數定義
|
||
|
||
Returns:
|
||
ToolDefinition
|
||
"""
|
||
return ToolDefinition(
|
||
type="function",
|
||
function={
|
||
"name": name,
|
||
"description": description,
|
||
"parameters": parameters,
|
||
},
|
||
)
|