218 lines
8.3 KiB
Python
218 lines
8.3 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
Elephant Alpha AI 服務模組
|
||
負責與 OpenRouter / Elephant Alpha API 互動,提供高效率、長上下文的 Worker AI 功能
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import json
|
||
import logging
|
||
import requests
|
||
from typing import Optional, Dict, Any, List
|
||
from dataclasses import dataclass
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Elephant Alpha 設定(NVIDIA NIM API)
|
||
NVIDIA_API_KEY = os.getenv('NVIDIA_API_KEY', '')
|
||
ELEPHANT_ALPHA_BASE_URL = os.getenv(
|
||
'ELEPHANT_ALPHA_NEMOTRON_NIM_ENDPOINT',
|
||
'https://integrate.api.nvidia.com/v1',
|
||
).rstrip('/')
|
||
ELEPHANT_ALPHA_URL = os.getenv(
|
||
'ELEPHANT_ALPHA_URL',
|
||
f"{ELEPHANT_ALPHA_BASE_URL}/chat/completions",
|
||
)
|
||
DEFAULT_ELEPHANT_MODEL = os.getenv(
|
||
'ELEPHANT_ALPHA_MODEL',
|
||
'nvidia/llama-3.3-nemotron-super-49b-v1.5',
|
||
)
|
||
ELEPHANT_FALLBACK_MODELS = [
|
||
model.strip()
|
||
for model in os.getenv(
|
||
'ELEPHANT_ALPHA_FALLBACK_MODELS',
|
||
'nvidia/llama-3.3-nemotron-super-49b-v1.5,nvidia/llama-3.1-nemotron-70b-instruct,meta/llama-3.1-8b-instruct',
|
||
).split(',')
|
||
if model.strip()
|
||
]
|
||
ELEPHANT_TIMEOUT = int(os.getenv('ELEPHANT_TIMEOUT', '120')) # 預設 2 分鐘
|
||
ELEPHANT_FALLBACK_HTTP_STATUS_CODES = {403, 404, 408, 409, 425, 429, 500, 502, 503, 504}
|
||
|
||
# Elephant Alpha 定價 (USD per 1M tokens) - NVIDIA NIM 定價
|
||
ELEPHANT_PRICING = {
|
||
'nvidia/llama-3.1-nemotron-ultra-253b-v1': {'input': 0.10, 'output': 0.40},
|
||
'nvidia/llama-3.3-nemotron-super-49b-v1.5': {'input': 0.10, 'output': 0.40},
|
||
}
|
||
|
||
@dataclass
|
||
class ElephantResponse:
|
||
"""Elephant Alpha 回應結構"""
|
||
success: bool
|
||
content: str
|
||
model: str
|
||
error: Optional[str] = None
|
||
total_duration: Optional[float] = None
|
||
input_tokens: int = 0
|
||
output_tokens: int = 0
|
||
total_tokens: int = 0
|
||
input_cost: float = 0.0
|
||
output_cost: float = 0.0
|
||
total_cost: float = 0.0
|
||
|
||
class ElephantService:
|
||
"""Elephant Alpha AI 服務 - 100B 效率型 Worker"""
|
||
|
||
def __init__(self, api_key: str = None, model: str = None):
|
||
"""
|
||
初始化 Elephant 服務
|
||
"""
|
||
self.api_key = api_key or NVIDIA_API_KEY
|
||
self.model = model or DEFAULT_ELEPHANT_MODEL
|
||
|
||
# W3-A: 護欄 2 — 斷線降級 cache (300s TTL,不每次 ping OpenRouter)
|
||
_connection_cache: Dict[str, Any] = {"ok": False, "checked_at": 0.0}
|
||
|
||
def check_connection(self, cache_seconds: int = 300) -> bool:
|
||
"""
|
||
檢查 API 是否可用(結果快取 300s)。
|
||
cache_seconds=300 與 Anthropic prompt cache TTL 對齊,避免每分鐘 EA loop 都打 API。
|
||
"""
|
||
if not self.api_key:
|
||
return False
|
||
|
||
now = time.time()
|
||
cache = ElephantService._connection_cache
|
||
if cache["checked_at"] and (now - cache["checked_at"]) < cache_seconds:
|
||
return cache["ok"]
|
||
|
||
try:
|
||
response = self.generate("hi", timeout=10)
|
||
result = response.success
|
||
except Exception:
|
||
result = False
|
||
|
||
ElephantService._connection_cache = {"ok": result, "checked_at": now}
|
||
return result
|
||
|
||
@staticmethod
|
||
def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> Dict[str, float]:
|
||
"""計算費用"""
|
||
pricing = ELEPHANT_PRICING.get(model, ELEPHANT_PRICING['nvidia/llama-3.1-nemotron-ultra-253b-v1'])
|
||
input_cost = (input_tokens / 1_000_000) * pricing['input']
|
||
output_cost = (output_tokens / 1_000_000) * pricing['output']
|
||
return {
|
||
'input_cost': round(input_cost, 6),
|
||
'output_cost': round(output_cost, 6),
|
||
'total_cost': round(input_cost + output_cost, 6)
|
||
}
|
||
|
||
@staticmethod
|
||
def _model_candidates(primary_model: str) -> List[str]:
|
||
candidates = []
|
||
for model_name in [primary_model, *ELEPHANT_FALLBACK_MODELS]:
|
||
if model_name and model_name not in candidates:
|
||
candidates.append(model_name)
|
||
return candidates
|
||
|
||
@staticmethod
|
||
def _has_next_model(model_name: str, model_candidates: List[str]) -> bool:
|
||
return bool(model_candidates) and model_name != model_candidates[-1]
|
||
|
||
def generate(self, prompt: str, model: str = None,
|
||
system_prompt: str = None, temperature: float = 0.3,
|
||
json_mode: bool = False, timeout: int = None) -> ElephantResponse:
|
||
"""
|
||
生成文字(主介面)
|
||
"""
|
||
primary_model = model or self.model
|
||
request_timeout = timeout or ELEPHANT_TIMEOUT
|
||
|
||
if not self.api_key:
|
||
return ElephantResponse(success=False, content='', model=primary_model, error="API Key 未設定")
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
last_error = ""
|
||
model_candidates = self._model_candidates(primary_model)
|
||
for model_name in model_candidates:
|
||
payload = {
|
||
"model": model_name,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"max_tokens": 8000
|
||
}
|
||
|
||
if json_mode:
|
||
payload["response_format"] = {"type": "json_object"}
|
||
|
||
try:
|
||
start_time = time.time()
|
||
response = requests.post(
|
||
ELEPHANT_ALPHA_URL,
|
||
json=payload,
|
||
headers=headers,
|
||
timeout=request_timeout
|
||
)
|
||
response.raise_for_status()
|
||
end_time = time.time()
|
||
|
||
data = response.json()
|
||
message = data["choices"][0]["message"]
|
||
content = message.get("content") or message.get("reasoning_content") or message.get("reasoning") or ""
|
||
|
||
# Token 用量
|
||
usage = data.get("usage", {})
|
||
input_tokens = usage.get("prompt_tokens", 0)
|
||
output_tokens = usage.get("completion_tokens", 0)
|
||
|
||
costs = self.calculate_cost(model_name, input_tokens, output_tokens)
|
||
|
||
return ElephantResponse(
|
||
success=True,
|
||
content=content,
|
||
model=model_name,
|
||
total_duration=end_time - start_time,
|
||
input_tokens=input_tokens,
|
||
output_tokens=output_tokens,
|
||
total_tokens=input_tokens + output_tokens,
|
||
input_cost=costs['input_cost'],
|
||
output_cost=costs['output_cost'],
|
||
total_cost=costs['total_cost']
|
||
)
|
||
|
||
except requests.HTTPError as e:
|
||
status_code = e.response.status_code if e.response is not None else None
|
||
last_error = f"{model_name}: {e}"
|
||
if status_code in ELEPHANT_FALLBACK_HTTP_STATUS_CODES and self._has_next_model(model_name, model_candidates):
|
||
logger.warning(f"[Elephant] NIM 模型/API 暫時不可用,改用 fallback: {model_name} ({status_code})")
|
||
continue
|
||
logger.error(f"[Elephant] 生成失敗: {e}")
|
||
return ElephantResponse(success=False, content='', model=model_name, error=last_error)
|
||
except (requests.Timeout, requests.ConnectionError) as e:
|
||
last_error = f"{model_name}: {e}"
|
||
if self._has_next_model(model_name, model_candidates):
|
||
logger.warning(f"[Elephant] NIM 暫時性連線錯誤,改用 fallback: {model_name} ({e})")
|
||
continue
|
||
logger.error(f"[Elephant] 生成失敗: {e}")
|
||
return ElephantResponse(success=False, content='', model=model_name, error=last_error)
|
||
except Exception as e:
|
||
last_error = f"{model_name}: {e}"
|
||
logger.error(f"[Elephant] 生成失敗: {e}")
|
||
return ElephantResponse(success=False, content='', model=model_name, error=last_error)
|
||
|
||
failed_model = model_candidates[-1] if model_candidates else primary_model
|
||
return ElephantResponse(success=False, content='', model=failed_model, error=last_error or "所有 Elephant fallback model 均不可用")
|
||
|
||
# 單例實例
|
||
elephant_service = ElephantService()
|