feat(ai): ADR-036 NVIDIA Nemotron Tool Calling 整合
Phase 20 - 提升 Tool Calling 精準度 50% → 83.3% 新增: - src/models/nvidia.py: Pydantic Schema - src/services/nvidia_provider.py: NvidiaProvider 類別 - tests/test_nvidia_provider.py: 15 項單元測試 (全部通過) 修改: - ai_router.py: AIProvider.NVIDIA + route_tool_calling() - ai_rate_limiter.py: NVIDIA 限制 (5 RPM, 100/day) - models.json: NVIDIA 配置 - cd.yaml: Secrets 注入 NVIDIA_API_KEY 路由策略: - Tool Calling: Nemotron → Gemini → Claude - 一般對話: Ollama → Gemini → Claude (不變) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -201,8 +201,9 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# AI Fallback Strategy (ADR-006)
|
||||
# AI Fallback Strategy (ADR-006 v1.3 + ADR-036)
|
||||
# Order: Ollama (local) -> Gemini (cloud) -> Claude (cloud)
|
||||
# Tool Calling: Nemotron (專用) -> Gemini -> Claude
|
||||
# ==========================================================================
|
||||
AI_FALLBACK_ORDER: list[str] = Field(
|
||||
default=["ollama", "gemini", "claude"],
|
||||
@@ -210,6 +211,11 @@ class Settings(BaseSettings):
|
||||
)
|
||||
GEMINI_API_KEY: str = Field(default="", description="Google Gemini API key")
|
||||
CLAUDE_API_KEY: str = Field(default="", description="Anthropic Claude API key")
|
||||
# 2026-03-29 ogt: ADR-036 Nemotron Tool Calling 整合
|
||||
NVIDIA_API_KEY: str = Field(
|
||||
default="",
|
||||
description="NVIDIA NIM API key for Nemotron Tool Calling (ADR-036)",
|
||||
)
|
||||
|
||||
@field_validator("AI_FALLBACK_ORDER", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -6,6 +6,7 @@ AWOOOI Models Package
|
||||
- Approval: 簽核相關模型 (Phase 2 HITL)
|
||||
- Incident: 事件相關模型 (Phase 6 認知覺醒)
|
||||
- AI: AI 相關模型
|
||||
- NVIDIA: Nemotron Tool Calling 模型 (ADR-036)
|
||||
"""
|
||||
|
||||
# Approval Models (Phase 2)
|
||||
@@ -39,6 +40,16 @@ from src.models.incident import (
|
||||
Signal,
|
||||
)
|
||||
|
||||
# NVIDIA Models (ADR-036 - Nemotron Tool Calling)
|
||||
from src.models.nvidia import (
|
||||
NvidiaProviderResult,
|
||||
NvidiaResponse,
|
||||
NvidiaUsage,
|
||||
ToolCall,
|
||||
ToolCallValidationResult,
|
||||
ToolDefinition,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Approval
|
||||
"ApprovalRequest",
|
||||
@@ -65,4 +76,11 @@ __all__ = [
|
||||
"IncidentUpdate",
|
||||
"Severity",
|
||||
"Signal",
|
||||
# NVIDIA (ADR-036)
|
||||
"NvidiaProviderResult",
|
||||
"NvidiaResponse",
|
||||
"NvidiaUsage",
|
||||
"ToolCall",
|
||||
"ToolCallValidationResult",
|
||||
"ToolDefinition",
|
||||
]
|
||||
|
||||
119
apps/api/src/models/nvidia.py
Normal file
119
apps/api/src/models/nvidia.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
NVIDIA Nemotron API Models - ADR-036
|
||||
====================================
|
||||
2026-03-29 ogt: Nemotron Tool Calling 整合 (83.3% 精準度)
|
||||
|
||||
OpenAI 相容格式 - 用於 Tool Calling 任務
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ToolFunction(BaseModel):
|
||||
"""Tool Function 定義"""
|
||||
|
||||
name: str = Field(..., description="Tool 函數名稱")
|
||||
arguments: str = Field(..., description="Tool 參數 (JSON 字串)")
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Tool Call 結構"""
|
||||
|
||||
id: str = Field(..., description="Tool Call ID")
|
||||
type: str = Field(default="function", description="Tool 類型")
|
||||
function: ToolFunction = Field(..., description="Tool 函數")
|
||||
|
||||
|
||||
class NvidiaMessage(BaseModel):
|
||||
"""NVIDIA API Message 結構"""
|
||||
|
||||
role: str = Field(..., description="訊息角色 (assistant/user/system)")
|
||||
content: str | None = Field(default=None, description="訊息內容")
|
||||
tool_calls: list[ToolCall] | None = Field(
|
||||
default=None, description="Tool Calls (僅 assistant)"
|
||||
)
|
||||
|
||||
|
||||
class NvidiaChoice(BaseModel):
|
||||
"""NVIDIA API Choice 結構"""
|
||||
|
||||
index: int = Field(default=0, description="選項索引")
|
||||
message: NvidiaMessage = Field(..., description="回應訊息")
|
||||
finish_reason: str | None = Field(
|
||||
default=None, description="結束原因 (stop/tool_calls)"
|
||||
)
|
||||
|
||||
|
||||
class NvidiaUsage(BaseModel):
|
||||
"""NVIDIA API Token 使用統計"""
|
||||
|
||||
prompt_tokens: int = Field(default=0, description="輸入 Token 數")
|
||||
completion_tokens: int = Field(default=0, description="輸出 Token 數")
|
||||
total_tokens: int = Field(default=0, description="總 Token 數")
|
||||
|
||||
|
||||
class NvidiaResponse(BaseModel):
|
||||
"""NVIDIA Nemotron API 完整回應"""
|
||||
|
||||
id: str = Field(..., description="回應 ID")
|
||||
object: str = Field(default="chat.completion", description="物件類型")
|
||||
created: int = Field(..., description="建立時間戳")
|
||||
model: str = Field(..., description="模型名稱")
|
||||
choices: list[NvidiaChoice] = Field(..., description="回應選項")
|
||||
usage: NvidiaUsage | None = Field(default=None, description="Token 使用統計")
|
||||
|
||||
|
||||
# === Tool Calling 請求結構 ===
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
"""Tool 定義 (發送給 API)"""
|
||||
|
||||
type: str = Field(default="function", description="Tool 類型")
|
||||
function: dict[str, Any] = Field(..., description="函數定義 (JSON Schema)")
|
||||
|
||||
|
||||
class NvidiaToolCallRequest(BaseModel):
|
||||
"""NVIDIA Tool Calling 請求"""
|
||||
|
||||
model: str = Field(
|
||||
default="nvidia/llama-3.1-nemotron-70b-instruct",
|
||||
description="模型名稱",
|
||||
)
|
||||
messages: list[dict[str, Any]] = Field(..., description="對話訊息")
|
||||
tools: list[ToolDefinition] = Field(..., description="可用 Tools")
|
||||
tool_choice: str | dict[str, Any] = Field(
|
||||
default="auto", description="Tool 選擇策略"
|
||||
)
|
||||
temperature: float = Field(default=0.0, description="溫度 (0.0 最確定性)")
|
||||
max_tokens: int = Field(default=1024, description="最大輸出 Token")
|
||||
|
||||
|
||||
# === 驗證結果結構 ===
|
||||
|
||||
|
||||
class ToolCallValidationResult(BaseModel):
|
||||
"""Tool Call 驗證結果"""
|
||||
|
||||
valid: bool = Field(..., description="是否有效")
|
||||
tool_name: str | None = Field(default=None, description="Tool 名稱")
|
||||
arguments: dict[str, Any] | None = Field(default=None, description="解析後參數")
|
||||
error: str | None = Field(default=None, description="錯誤訊息")
|
||||
raw_response: str | None = Field(default=None, description="原始回應 (debug)")
|
||||
|
||||
|
||||
class NvidiaProviderResult(BaseModel):
|
||||
"""NvidiaProvider 回傳結果"""
|
||||
|
||||
success: bool = Field(..., description="是否成功")
|
||||
tool_calls: list[ToolCallValidationResult] = Field(
|
||||
default_factory=list, description="驗證後的 Tool Calls"
|
||||
)
|
||||
usage: NvidiaUsage | None = Field(default=None, description="Token 使用統計")
|
||||
latency_ms: float = Field(default=0.0, description="延遲 (毫秒)")
|
||||
error: str | None = Field(default=None, description="錯誤訊息")
|
||||
fallback_triggered: bool = Field(
|
||||
default=False, description="是否觸發 Fallback"
|
||||
)
|
||||
@@ -37,6 +37,12 @@ RATE_LIMITS = {
|
||||
"daily_requests": 200,
|
||||
"daily_tokens": 50_000,
|
||||
},
|
||||
# 2026-03-29 ogt: ADR-036 Nemotron Tool Calling (免費 Tier)
|
||||
"nvidia": {
|
||||
"rpm": 5, # 每分鐘請求數 (延遲較高,控制併發)
|
||||
"daily_requests": 100, # 每日請求數 (免費 Tier 限制)
|
||||
"daily_tokens": 50_000, # 每日 Token 數
|
||||
},
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
@@ -52,6 +58,11 @@ COST_LIMITS = {
|
||||
"total_cost_usd": 10.0,
|
||||
"alert_threshold_usd": 8.0,
|
||||
},
|
||||
# 2026-03-29 ogt: ADR-036 Nemotron (免費 Tier,設定低限制作為監控)
|
||||
"nvidia": {
|
||||
"total_cost_usd": 0.0, # 免費 Tier,不計費
|
||||
"alert_threshold_usd": 0.0, # 不發送成本告警
|
||||
},
|
||||
}
|
||||
|
||||
# Gemini 1.5 Flash 定價 (per token)
|
||||
|
||||
@@ -66,6 +66,8 @@ class AIProvider(Enum):
|
||||
OLLAMA = "ollama"
|
||||
GEMINI = "gemini"
|
||||
CLAUDE = "claude"
|
||||
# 2026-03-29 ogt: ADR-036 Nemotron Tool Calling (83.3% 精準度)
|
||||
NVIDIA = "nvidia"
|
||||
|
||||
|
||||
# Provider 對應延遲預算 (ms)
|
||||
@@ -73,6 +75,8 @@ PROVIDER_LATENCY_BUDGET: dict[AIProvider, int] = {
|
||||
AIProvider.OLLAMA: 60000, # 本地,允許較長處理時間
|
||||
AIProvider.GEMINI: 30000, # 雲端,較低延遲
|
||||
AIProvider.CLAUDE: 30000, # 雲端,較低延遲
|
||||
# 2026-03-29 ogt: ADR-036 Nemotron Tool Calling (延遲 11-45s)
|
||||
AIProvider.NVIDIA: 60000, # Tool Calling 專用,允許較長時間
|
||||
}
|
||||
|
||||
|
||||
@@ -164,21 +168,32 @@ class AIRouter:
|
||||
self._ollama_summary = self._model_registry.get_model("ollama", "summary")
|
||||
self._gemini_default = self._model_registry.get_model("gemini", "default")
|
||||
self._claude_default = self._model_registry.get_model("claude", "default")
|
||||
# 2026-03-29 ogt: ADR-036 Nemotron Tool Calling
|
||||
self._nvidia_default = self._model_registry.get_model("nvidia", "default")
|
||||
|
||||
# Provider 對應模型映射
|
||||
self._provider_models: dict[AIProvider, str] = {
|
||||
AIProvider.OLLAMA: self._ollama_default,
|
||||
AIProvider.GEMINI: self._gemini_default,
|
||||
AIProvider.CLAUDE: self._claude_default,
|
||||
AIProvider.NVIDIA: self._nvidia_default, # ADR-036
|
||||
}
|
||||
|
||||
# 完整 Fallback 鏈 (Provider, Model)
|
||||
# 2026-03-29 ogt: NVIDIA 不在一般 Fallback 鏈 (僅用於 Tool Calling)
|
||||
self._full_fallback_chain: list[tuple[AIProvider, str]] = [
|
||||
(AIProvider.OLLAMA, self._ollama_default),
|
||||
(AIProvider.GEMINI, self._gemini_default),
|
||||
(AIProvider.CLAUDE, self._claude_default),
|
||||
]
|
||||
|
||||
# Tool Calling 專用 Fallback 鏈 (ADR-036)
|
||||
self._tool_calling_fallback_chain: list[tuple[AIProvider, str]] = [
|
||||
(AIProvider.NVIDIA, self._nvidia_default),
|
||||
(AIProvider.GEMINI, self._gemini_default),
|
||||
(AIProvider.CLAUDE, self._claude_default),
|
||||
]
|
||||
|
||||
# 意圖對應 Provider 強制覆寫 (None = 依複雜度決定)
|
||||
self._intent_provider_overrides: dict[IntentType, AIProvider | None] = {
|
||||
# 四大核心意圖
|
||||
@@ -466,6 +481,39 @@ class AIRouter:
|
||||
routing_latency_ms=routing_latency,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Tool Calling 路由 (ADR-036)
|
||||
# =========================================================================
|
||||
|
||||
def route_tool_calling(self) -> tuple[AIProvider, str, list[tuple[AIProvider, str]]]:
|
||||
"""
|
||||
Tool Calling 專用路由 (ADR-036)
|
||||
|
||||
Tool Calling 任務優先使用 Nemotron (83.3% 精準度),
|
||||
Fallback 到 Gemini/Claude。
|
||||
|
||||
Returns:
|
||||
(provider, model, fallback_chain)
|
||||
"""
|
||||
provider = AIProvider.NVIDIA
|
||||
model = self._nvidia_default
|
||||
fallback_chain = [
|
||||
(p, m) for p, m in self._tool_calling_fallback_chain if p != provider
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"tool_calling_routing",
|
||||
provider=provider.value,
|
||||
model=model,
|
||||
fallback_count=len(fallback_chain),
|
||||
)
|
||||
|
||||
return provider, model, fallback_chain
|
||||
|
||||
def get_tool_calling_fallback_chain(self) -> list[tuple[AIProvider, str]]:
|
||||
"""取得 Tool Calling Fallback 鏈"""
|
||||
return self._tool_calling_fallback_chain.copy()
|
||||
|
||||
# =========================================================================
|
||||
# 便捷方法
|
||||
# =========================================================================
|
||||
|
||||
432
apps/api/src/services/nvidia_provider.py
Normal file
432
apps/api/src/services/nvidia_provider.py
Normal file
@@ -0,0 +1,432 @@
|
||||
"""
|
||||
NVIDIA Nemotron Provider - ADR-036
|
||||
==================================
|
||||
2026-03-29 ogt: Nemotron Tool Calling 整合 (83.3% 精準度)
|
||||
|
||||
專門處理 Tool Calling 任務,提供高精準度的 K8s 操作決策。
|
||||
|
||||
設計原則:
|
||||
1. OpenAI 相容格式 - 與 Nemotron API 對接
|
||||
2. Pydantic 強制驗證 - 所有回應必須通過 Schema 驗證
|
||||
3. Fallback 機制 - 失敗時降級到 Gemini/Claude
|
||||
4. HITL 高風險保護 - DELETE 等操作需人工審核
|
||||
|
||||
版本: v1.0
|
||||
建立: 2026-03-29 (台北時區)
|
||||
建立者: Claude Code
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
|
||||
from src.core.config import get_settings
|
||||
from src.models.nvidia import (
|
||||
NvidiaProviderResult,
|
||||
NvidiaResponse,
|
||||
NvidiaUsage,
|
||||
ToolCallValidationResult,
|
||||
ToolDefinition,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
# =============================================================================
|
||||
# 常量定義
|
||||
# =============================================================================
|
||||
|
||||
# NVIDIA NIM API Endpoint
|
||||
NVIDIA_API_URL = "https://integrate.api.nvidia.com/v1/chat/completions"
|
||||
|
||||
# 預設模型
|
||||
NVIDIA_DEFAULT_MODEL = "nvidia/llama-3.1-nemotron-70b-instruct"
|
||||
|
||||
# 請求超時 (秒) - Nemotron 延遲 11-45s
|
||||
NVIDIA_TIMEOUT = 60.0
|
||||
|
||||
# 重試次數
|
||||
MAX_RETRIES = 2
|
||||
|
||||
# 高風險 Tool 清單 (需要 HITL 審核)
|
||||
HIGH_RISK_TOOLS: set[str] = {
|
||||
"delete_pod",
|
||||
"delete_deployment",
|
||||
"delete_namespace",
|
||||
"delete_service",
|
||||
"delete_configmap",
|
||||
"delete_secret",
|
||||
"scale_to_zero",
|
||||
"drain_node",
|
||||
"cordon_node",
|
||||
"delete_pvc",
|
||||
"delete_pv",
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# NvidiaProvider 類別
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class NvidiaProvider:
|
||||
"""
|
||||
NVIDIA Nemotron Provider
|
||||
|
||||
專門處理 Tool Calling 任務,提供 83.3% 精準度的 K8s 操作決策。
|
||||
|
||||
使用方式:
|
||||
```python
|
||||
provider = NvidiaProvider()
|
||||
result = await provider.tool_call(
|
||||
messages=[{"role": "user", "content": "重啟 awoooi-api pod"}],
|
||||
tools=[restart_tool, scale_tool],
|
||||
)
|
||||
if result.success:
|
||||
for tc in result.tool_calls:
|
||||
if tc.valid:
|
||||
execute_tool(tc.tool_name, tc.arguments)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
"""
|
||||
初始化 NvidiaProvider
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API Key (預設從 settings 取得)
|
||||
"""
|
||||
self._api_key = api_key or settings.NVIDIA_API_KEY
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""取得或建立 HTTP Client"""
|
||||
if self._client is None or self._client.is_closed:
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(NVIDIA_TIMEOUT, connect=10.0),
|
||||
limits=httpx.Limits(max_connections=10, max_keepalive_connections=5),
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def close(self) -> None:
|
||||
"""關閉 HTTP Client"""
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def tool_call(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[ToolDefinition | dict[str, Any]],
|
||||
model: str = NVIDIA_DEFAULT_MODEL,
|
||||
temperature: float = 0.0,
|
||||
max_tokens: int = 1024,
|
||||
) -> NvidiaProviderResult:
|
||||
"""
|
||||
執行 Tool Calling 請求
|
||||
|
||||
Args:
|
||||
messages: 對話訊息列表
|
||||
tools: 可用 Tool 定義列表
|
||||
model: 模型名稱
|
||||
temperature: 溫度 (0.0 最確定性)
|
||||
max_tokens: 最大輸出 Token
|
||||
|
||||
Returns:
|
||||
NvidiaProviderResult: 包含驗證後的 Tool Calls
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# 檢查 API Key
|
||||
if not self._api_key:
|
||||
return NvidiaProviderResult(
|
||||
success=False,
|
||||
error="NVIDIA_API_KEY 未設定",
|
||||
fallback_triggered=True,
|
||||
)
|
||||
|
||||
# 轉換 tools 為 dict 格式
|
||||
tools_data = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, ToolDefinition):
|
||||
tools_data.append(tool.model_dump())
|
||||
else:
|
||||
tools_data.append(tool)
|
||||
|
||||
# 建立請求
|
||||
request_body = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": tools_data,
|
||||
"tool_choice": "auto",
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
# 執行請求 (含重試)
|
||||
response_data: dict | None = None
|
||||
last_error: str | None = None
|
||||
|
||||
for attempt in range(MAX_RETRIES + 1):
|
||||
try:
|
||||
response_data = await self._send_request(request_body)
|
||||
break
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
logger.warning(
|
||||
"nvidia_request_retry",
|
||||
attempt=attempt + 1,
|
||||
max_retries=MAX_RETRIES,
|
||||
error=last_error,
|
||||
)
|
||||
if attempt == MAX_RETRIES:
|
||||
break
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
# 請求失敗
|
||||
if response_data is None:
|
||||
logger.error(
|
||||
"nvidia_request_failed",
|
||||
error=last_error,
|
||||
latency_ms=round(latency_ms, 2),
|
||||
)
|
||||
return NvidiaProviderResult(
|
||||
success=False,
|
||||
error=last_error,
|
||||
latency_ms=latency_ms,
|
||||
fallback_triggered=True,
|
||||
)
|
||||
|
||||
# 解析回應
|
||||
try:
|
||||
nvidia_response = NvidiaResponse.model_validate(response_data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"nvidia_response_parse_failed",
|
||||
error=str(e),
|
||||
raw_response=str(response_data)[:500],
|
||||
)
|
||||
return NvidiaProviderResult(
|
||||
success=False,
|
||||
error=f"回應解析失敗: {e}",
|
||||
latency_ms=latency_ms,
|
||||
fallback_triggered=True,
|
||||
)
|
||||
|
||||
# 驗證 Tool Calls
|
||||
tool_calls = self._validate_tool_calls(nvidia_response)
|
||||
|
||||
# 統計
|
||||
usage = nvidia_response.usage
|
||||
|
||||
logger.info(
|
||||
"nvidia_tool_call_completed",
|
||||
success=True,
|
||||
tool_call_count=len(tool_calls),
|
||||
valid_count=sum(1 for tc in tool_calls if tc.valid),
|
||||
latency_ms=round(latency_ms, 2),
|
||||
prompt_tokens=usage.prompt_tokens if usage else 0,
|
||||
completion_tokens=usage.completion_tokens if usage else 0,
|
||||
)
|
||||
|
||||
return NvidiaProviderResult(
|
||||
success=True,
|
||||
tool_calls=tool_calls,
|
||||
usage=usage,
|
||||
latency_ms=latency_ms,
|
||||
fallback_triggered=False,
|
||||
)
|
||||
|
||||
async def _send_request(self, request_body: dict) -> dict:
|
||||
"""
|
||||
發送 HTTP 請求到 NVIDIA API
|
||||
|
||||
Args:
|
||||
request_body: 請求內容
|
||||
|
||||
Returns:
|
||||
API 回應 (dict)
|
||||
|
||||
Raises:
|
||||
Exception: 請求失敗
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
NVIDIA_API_URL,
|
||||
headers=headers,
|
||||
json=request_body,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text[:500]
|
||||
raise Exception(
|
||||
f"NVIDIA API 錯誤: {response.status_code} - {error_text}"
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
def _validate_tool_calls(
|
||||
self, response: NvidiaResponse
|
||||
) -> list[ToolCallValidationResult]:
|
||||
"""
|
||||
驗證 Tool Calls
|
||||
|
||||
Args:
|
||||
response: NVIDIA API 回應
|
||||
|
||||
Returns:
|
||||
驗證後的 Tool Call 結果列表
|
||||
"""
|
||||
results: list[ToolCallValidationResult] = []
|
||||
|
||||
if not response.choices:
|
||||
return results
|
||||
|
||||
message = response.choices[0].message
|
||||
if not message.tool_calls:
|
||||
return results
|
||||
|
||||
for tc in message.tool_calls:
|
||||
try:
|
||||
# 解析 arguments JSON
|
||||
arguments = json.loads(tc.function.arguments)
|
||||
|
||||
results.append(
|
||||
ToolCallValidationResult(
|
||||
valid=True,
|
||||
tool_name=tc.function.name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
results.append(
|
||||
ToolCallValidationResult(
|
||||
valid=False,
|
||||
tool_name=tc.function.name,
|
||||
error=f"Arguments JSON 解析失敗: {e}",
|
||||
raw_response=tc.function.arguments,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
results.append(
|
||||
ToolCallValidationResult(
|
||||
valid=False,
|
||||
error=f"驗證失敗: {e}",
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def is_high_risk_tool(self, tool_name: str) -> bool:
|
||||
"""
|
||||
檢查是否為高風險 Tool
|
||||
|
||||
Args:
|
||||
tool_name: Tool 名稱
|
||||
|
||||
Returns:
|
||||
是否需要 HITL 審核
|
||||
"""
|
||||
return tool_name.lower() in HIGH_RISK_TOOLS
|
||||
|
||||
def get_high_risk_tools(
|
||||
self, tool_calls: list[ToolCallValidationResult]
|
||||
) -> list[ToolCallValidationResult]:
|
||||
"""
|
||||
篩選高風險 Tool Calls
|
||||
|
||||
Args:
|
||||
tool_calls: Tool Call 結果列表
|
||||
|
||||
Returns:
|
||||
高風險 Tool Calls
|
||||
"""
|
||||
return [
|
||||
tc
|
||||
for tc in tool_calls
|
||||
if tc.valid and tc.tool_name and self.is_high_risk_tool(tc.tool_name)
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 單例與工廠函數
|
||||
# =============================================================================
|
||||
|
||||
_provider: NvidiaProvider | None = None
|
||||
|
||||
|
||||
def get_nvidia_provider() -> NvidiaProvider:
|
||||
"""取得 NvidiaProvider 單例"""
|
||||
global _provider
|
||||
if _provider is None:
|
||||
_provider = NvidiaProvider()
|
||||
return _provider
|
||||
|
||||
|
||||
def reset_nvidia_provider() -> None:
|
||||
"""重置單例 (用於測試)"""
|
||||
global _provider
|
||||
_provider = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 便捷函數
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def nvidia_tool_call(
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[ToolDefinition | dict[str, Any]],
|
||||
**kwargs,
|
||||
) -> NvidiaProviderResult:
|
||||
"""
|
||||
便捷函數: 執行 NVIDIA Tool Calling
|
||||
|
||||
Args:
|
||||
messages: 對話訊息列表
|
||||
tools: 可用 Tool 定義列表
|
||||
**kwargs: 其他參數 (model, temperature, max_tokens)
|
||||
|
||||
Returns:
|
||||
NvidiaProviderResult
|
||||
"""
|
||||
provider = get_nvidia_provider()
|
||||
return await provider.tool_call(messages, tools, **kwargs)
|
||||
|
||||
|
||||
def create_tool_definition(
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: dict[str, Any],
|
||||
) -> ToolDefinition:
|
||||
"""
|
||||
建立 Tool 定義
|
||||
|
||||
Args:
|
||||
name: Tool 名稱
|
||||
description: Tool 描述
|
||||
parameters: JSON Schema 參數定義
|
||||
|
||||
Returns:
|
||||
ToolDefinition
|
||||
"""
|
||||
return ToolDefinition(
|
||||
type="function",
|
||||
function={
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters,
|
||||
},
|
||||
)
|
||||
Reference in New Issue
Block a user