Files
ewoooc/services/elephant_service.py
OoO 89e7f2ccd2
All checks were successful
CD Pipeline / deploy (push) Successful in 1m46s
fix(ai): 擴大 ElephantAlpha 暫時性 fallback
2026-04-30 13:59:12 +08:00

218 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Elephant Alpha AI 服務模組
負責與 OpenRouter / Elephant Alpha API 互動,提供高效率、長上下文的 Worker AI 功能
"""
import os
import time
import json
import logging
import requests
from typing import Optional, Dict, Any, List
from dataclasses import dataclass
logger = logging.getLogger(__name__)
# Elephant Alpha 設定NVIDIA NIM API
NVIDIA_API_KEY = os.getenv('NVIDIA_API_KEY', '')
ELEPHANT_ALPHA_BASE_URL = os.getenv(
'ELEPHANT_ALPHA_NEMOTRON_NIM_ENDPOINT',
'https://integrate.api.nvidia.com/v1',
).rstrip('/')
ELEPHANT_ALPHA_URL = os.getenv(
'ELEPHANT_ALPHA_URL',
f"{ELEPHANT_ALPHA_BASE_URL}/chat/completions",
)
DEFAULT_ELEPHANT_MODEL = os.getenv(
'ELEPHANT_ALPHA_MODEL',
'nvidia/llama-3.3-nemotron-super-49b-v1.5',
)
ELEPHANT_FALLBACK_MODELS = [
model.strip()
for model in os.getenv(
'ELEPHANT_ALPHA_FALLBACK_MODELS',
'nvidia/llama-3.3-nemotron-super-49b-v1.5,nvidia/llama-3.1-nemotron-70b-instruct,meta/llama-3.1-8b-instruct',
).split(',')
if model.strip()
]
ELEPHANT_TIMEOUT = int(os.getenv('ELEPHANT_TIMEOUT', '120')) # 預設 2 分鐘
ELEPHANT_FALLBACK_HTTP_STATUS_CODES = {403, 404, 408, 409, 425, 429, 500, 502, 503, 504}
# Elephant Alpha 定價 (USD per 1M tokens) - NVIDIA NIM 定價
ELEPHANT_PRICING = {
'nvidia/llama-3.1-nemotron-ultra-253b-v1': {'input': 0.10, 'output': 0.40},
'nvidia/llama-3.3-nemotron-super-49b-v1.5': {'input': 0.10, 'output': 0.40},
}
@dataclass
class ElephantResponse:
"""Elephant Alpha 回應結構"""
success: bool
content: str
model: str
error: Optional[str] = None
total_duration: Optional[float] = None
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
input_cost: float = 0.0
output_cost: float = 0.0
total_cost: float = 0.0
class ElephantService:
"""Elephant Alpha AI 服務 - 100B 效率型 Worker"""
def __init__(self, api_key: str = None, model: str = None):
"""
初始化 Elephant 服務
"""
self.api_key = api_key or NVIDIA_API_KEY
self.model = model or DEFAULT_ELEPHANT_MODEL
# W3-A: 護欄 2 — 斷線降級 cache (300s TTL不每次 ping OpenRouter)
_connection_cache: Dict[str, Any] = {"ok": False, "checked_at": 0.0}
def check_connection(self, cache_seconds: int = 300) -> bool:
"""
檢查 API 是否可用(結果快取 300s
cache_seconds=300 與 Anthropic prompt cache TTL 對齊,避免每分鐘 EA loop 都打 API。
"""
if not self.api_key:
return False
now = time.time()
cache = ElephantService._connection_cache
if cache["checked_at"] and (now - cache["checked_at"]) < cache_seconds:
return cache["ok"]
try:
response = self.generate("hi", timeout=10)
result = response.success
except Exception:
result = False
ElephantService._connection_cache = {"ok": result, "checked_at": now}
return result
@staticmethod
def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> Dict[str, float]:
"""計算費用"""
pricing = ELEPHANT_PRICING.get(model, ELEPHANT_PRICING['nvidia/llama-3.1-nemotron-ultra-253b-v1'])
input_cost = (input_tokens / 1_000_000) * pricing['input']
output_cost = (output_tokens / 1_000_000) * pricing['output']
return {
'input_cost': round(input_cost, 6),
'output_cost': round(output_cost, 6),
'total_cost': round(input_cost + output_cost, 6)
}
@staticmethod
def _model_candidates(primary_model: str) -> List[str]:
candidates = []
for model_name in [primary_model, *ELEPHANT_FALLBACK_MODELS]:
if model_name and model_name not in candidates:
candidates.append(model_name)
return candidates
@staticmethod
def _has_next_model(model_name: str, model_candidates: List[str]) -> bool:
return bool(model_candidates) and model_name != model_candidates[-1]
def generate(self, prompt: str, model: str = None,
system_prompt: str = None, temperature: float = 0.3,
json_mode: bool = False, timeout: int = None) -> ElephantResponse:
"""
生成文字(主介面)
"""
primary_model = model or self.model
request_timeout = timeout or ELEPHANT_TIMEOUT
if not self.api_key:
return ElephantResponse(success=False, content='', model=primary_model, error="API Key 未設定")
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
last_error = ""
model_candidates = self._model_candidates(primary_model)
for model_name in model_candidates:
payload = {
"model": model_name,
"messages": messages,
"temperature": temperature,
"max_tokens": 8000
}
if json_mode:
payload["response_format"] = {"type": "json_object"}
try:
start_time = time.time()
response = requests.post(
ELEPHANT_ALPHA_URL,
json=payload,
headers=headers,
timeout=request_timeout
)
response.raise_for_status()
end_time = time.time()
data = response.json()
message = data["choices"][0]["message"]
content = message.get("content") or message.get("reasoning_content") or message.get("reasoning") or ""
# Token 用量
usage = data.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
costs = self.calculate_cost(model_name, input_tokens, output_tokens)
return ElephantResponse(
success=True,
content=content,
model=model_name,
total_duration=end_time - start_time,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
input_cost=costs['input_cost'],
output_cost=costs['output_cost'],
total_cost=costs['total_cost']
)
except requests.HTTPError as e:
status_code = e.response.status_code if e.response is not None else None
last_error = f"{model_name}: {e}"
if status_code in ELEPHANT_FALLBACK_HTTP_STATUS_CODES and self._has_next_model(model_name, model_candidates):
logger.warning(f"[Elephant] NIM 模型/API 暫時不可用,改用 fallback: {model_name} ({status_code})")
continue
logger.error(f"[Elephant] 生成失敗: {e}")
return ElephantResponse(success=False, content='', model=model_name, error=last_error)
except (requests.Timeout, requests.ConnectionError) as e:
last_error = f"{model_name}: {e}"
if self._has_next_model(model_name, model_candidates):
logger.warning(f"[Elephant] NIM 暫時性連線錯誤,改用 fallback: {model_name} ({e})")
continue
logger.error(f"[Elephant] 生成失敗: {e}")
return ElephantResponse(success=False, content='', model=model_name, error=last_error)
except Exception as e:
last_error = f"{model_name}: {e}"
logger.error(f"[Elephant] 生成失敗: {e}")
return ElephantResponse(success=False, content='', model=model_name, error=last_error)
failed_model = model_candidates[-1] if model_candidates else primary_model
return ElephantResponse(success=False, content='', model=failed_model, error=last_error or "所有 Elephant fallback model 均不可用")
# 單例實例
elephant_service = ElephantService()