Files
ewoooc/services/ai_provider.py
OoO e3da4ffbb3
All checks were successful
CD Pipeline / deploy (push) Successful in 1m6s
阻止 Gemini 成為推薦主路徑
2026-05-21 16:18:35 +08:00

497 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AI 提供者抽象層
統一 Ollama 和 Gemini 的介面;通用生成一律 Ollama-firstGemini 僅作備援
"""
import os
import logging
from typing import Optional, Dict, Any, List, Union
from dataclasses import dataclass
from datetime import date, datetime
logger = logging.getLogger(__name__)
# AI 提供者設定Gemini 不可作為預設,只能在 Ollama 失敗後備援
AI_PROVIDER = os.getenv('AI_PROVIDER', 'ollama')
# 引入服務
from .ollama_service import OllamaService, OllamaResponse
from .gemini_service import GeminiService, GeminiResponse, AVAILABLE_GEMINI_MODELS
from .elephant_service import ElephantService, ElephantResponse
from .gemini_guard import gemini_disabled_message, is_gemini_fallback_enabled
@dataclass
class AIResponse:
"""統一的 AI 回應結構"""
success: bool
content: str
model: str
provider: str # 'ollama', 'gemini' 或 'elephant'
error: Optional[str] = None
total_duration: Optional[float] = None
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
input_cost: float = 0.0
output_cost: float = 0.0
total_cost: float = 0.0
def to_dict(self) -> Dict[str, Any]:
"""轉換為字典格式"""
return {
'success': self.success,
'content': self.content,
'model': self.model,
'provider': self.provider,
'error': self.error,
'total_duration': self.total_duration,
'input_tokens': self.input_tokens,
'output_tokens': self.output_tokens,
'total_tokens': self.total_tokens,
'input_cost': self.input_cost,
'output_cost': self.output_cost,
'total_cost': self.total_cost,
}
class AIProviderService:
"""
AI 提供者服務 - 統一的 AI 介面
通用文案 / 關鍵字 / 洞察生成一律先走 Ollama 三主機級聯。
Gemini 只在 Ollama 主路徑失敗時作為 fallback不允許被設為預設主路徑。
"""
def __init__(self, default_provider: str = None):
"""
初始化 AI 提供者服務
Args:
default_provider: 預設提供者Gemini 會被降回 Ollama
"""
self._default_provider = self._sanitize_default_provider(default_provider or AI_PROVIDER)
self._ollama = OllamaService()
self._gemini = GeminiService()
self._elephant = ElephantService()
# 狀態快取
self._status_cache = {'timestamp': 0, 'data': None}
self._CACHE_TTL = 60 # 60 秒
@property
def default_provider(self) -> str:
"""取得預設提供者"""
return self._default_provider
@default_provider.setter
def default_provider(self, value: str):
"""設定預設提供者"""
self._default_provider = self._sanitize_default_provider(value)
logger.info(f"AI 預設提供者已切換至: {value}")
@staticmethod
def _sanitize_default_provider(value: str) -> str:
"""Gemini 是 fallback-only任何預設 Gemini 設定都強制回 Ollama。"""
normalized = (value or 'ollama').strip().lower()
if normalized == 'gemini':
logger.warning("AI_PROVIDER=gemini 已被拒絕Gemini 僅作 Ollama 失敗備援")
return 'ollama'
if normalized not in ('ollama', 'elephant'):
raise ValueError("Provider must be 'ollama' or 'elephant'; gemini is fallback-only")
return normalized
def get_status(self, force_refresh: bool = False) -> Dict[str, Any]:
"""
取得所有 AI 服務的狀態
Args:
force_refresh: 是否強制刷新快取
Returns:
dict: 包含 Ollama 和 Gemini 狀態的字典
"""
import time
now = time.time()
# 使用快取
if not force_refresh and self._status_cache['data'] is not None:
if now - self._status_cache['timestamp'] < self._CACHE_TTL:
return self._status_cache['data']
# 檢查各服務狀態
ollama_connected = self._ollama.check_connection()
gemini_connected = self._gemini.check_connection()
elephant_connected = self._elephant.check_connection()
status = {
'default_provider': self._default_provider,
'ollama': {
'connected': ollama_connected,
'model': self._ollama.model if ollama_connected else None,
'available_models': self._ollama.available_models if ollama_connected else [],
'type': 'local',
'cost': 'free'
},
'gemini': {
'connected': gemini_connected,
'model': self._gemini.model if gemini_connected else None,
'available_models': AVAILABLE_GEMINI_MODELS,
'type': 'cloud',
'cost': 'paid'
},
'elephant': {
'connected': elephant_connected,
'model': self._elephant.model if elephant_connected else None,
'available_models': [{'id': 'nvidia/llama-3.1-nemotron-ultra-253b-v1', 'name': 'Nemotron Ultra 253B'}],
'type': 'cloud',
'cost': 'efficient'
},
'recommended_provider': self._get_recommended_provider(ollama_connected, gemini_connected, elephant_connected),
'timestamp': datetime.now().isoformat()
}
# 更新快取
self._status_cache = {'timestamp': now, 'data': status}
return status
def _get_recommended_provider(self, ollama_ok: bool, gemini_ok: bool, elephant_ok: bool) -> str:
"""根據可用性推薦主提供者Gemini 僅能備援,不可被推薦為主路徑。"""
if ollama_ok:
return 'ollama'
if self._default_provider == 'elephant' and elephant_ok:
return 'elephant'
if elephant_ok:
return 'elephant'
return 'none'
def _convert_response(self, response: Union[OllamaResponse, GeminiResponse],
provider: str) -> AIResponse:
"""將服務回應轉換為統一格式"""
if isinstance(response, OllamaResponse):
return AIResponse(
success=response.success,
content=response.content,
model=response.model,
provider='ollama',
error=response.error,
total_duration=response.total_duration,
input_tokens=0,
output_tokens=0,
total_tokens=0,
input_cost=0.0,
output_cost=0.0,
total_cost=0.0 # Ollama 是免費的
)
elif isinstance(response, GeminiResponse):
return AIResponse(
success=response.success,
content=response.content,
model=response.model,
provider='gemini',
error=response.error,
total_duration=response.total_duration,
input_tokens=response.input_tokens,
output_tokens=response.output_tokens,
total_tokens=response.total_tokens,
input_cost=response.input_cost,
output_cost=response.output_cost,
total_cost=response.total_cost
)
elif isinstance(response, ElephantResponse):
return AIResponse(
success=response.success,
content=response.content,
model=response.model,
provider='elephant',
error=response.error,
total_duration=response.total_duration,
input_tokens=response.input_tokens,
output_tokens=response.output_tokens,
total_tokens=response.total_tokens,
input_cost=response.input_cost,
output_cost=response.output_cost,
total_cost=response.total_cost
)
else:
return AIResponse(
success=False,
content='',
model='unknown',
provider=provider,
error='Unknown response type'
)
def _gemini_fallback(self, ollama_result: AIResponse, fallback_call) -> AIResponse:
"""Ollama 失敗時才呼叫 Gemini讓通用 AI 入口符合 Ollama-first。"""
if ollama_result.success:
return ollama_result
if not is_gemini_fallback_enabled("ai_provider"):
disabled = gemini_disabled_message("ai_provider")
logger.warning("Ollama 主路徑失敗,但 %s", disabled)
ollama_result.error = (
f"{ollama_result.error}; {disabled}"
if ollama_result.error else disabled
)
return ollama_result
logger.warning("Ollama 主路徑失敗,啟用 Gemini 備援:%s", ollama_result.error)
try:
gemini_response = fallback_call()
gemini_result = self._convert_response(gemini_response, 'gemini')
if not gemini_result.success and ollama_result.error:
gemini_result.error = f"Ollama 失敗:{ollama_result.error}Gemini 備援失敗:{gemini_result.error}"
return gemini_result
except Exception as exc:
ollama_result.error = f"{ollama_result.error}; Gemini 備援例外:{exc}"
return ollama_result
def generate(self, prompt: str, provider: str = None, model: str = None,
system_prompt: str = None, temperature: float = 0.7,
timeout: int = None) -> AIResponse:
"""
生成文字(統一介面)
Args:
prompt: 使用者提示
provider: 指定提供者Gemini 只會在 Ollama 失敗後備援
model: 指定模型
system_prompt: 系統提示
temperature: 創意度
timeout: 超時時間(秒)
Returns:
AIResponse
"""
provider = (provider or self._default_provider or 'ollama').strip().lower()
if provider == 'elephant':
response = self._elephant.generate(
prompt=prompt,
model=model,
system_prompt=system_prompt,
temperature=temperature,
timeout=timeout
)
return self._convert_response(response, provider)
ollama_response = self._ollama.generate(
prompt=prompt,
model=model,
system_prompt=system_prompt,
temperature=temperature,
timeout=timeout
)
ollama_result = self._convert_response(ollama_response, 'ollama')
return self._gemini_fallback(
ollama_result,
lambda: self._gemini.generate(
prompt=prompt,
model=model,
system_prompt=system_prompt,
temperature=temperature,
timeout=timeout
),
)
def generate_sales_copy(self, product_name: str, provider: str = None,
model: str = None, trend_keywords: List[str] = None,
style: str = "吸睛", upcoming_holidays: List[Dict] = None,
bestseller_products: List[Dict] = None) -> AIResponse:
"""
生成銷售文案(統一介面)
Args:
product_name: 商品名稱
provider: 指定提供者
model: 指定模型
trend_keywords: 趨勢關鍵字
style: 文案風格
upcoming_holidays: 即將到來的假期
bestseller_products: 競品熱銷商品
Returns:
AIResponse
"""
provider = (provider or self._default_provider or 'ollama').strip().lower()
ollama_response = self._ollama.generate_sales_copy(
product_name=product_name,
trend_keywords=trend_keywords,
style=style,
upcoming_holidays=upcoming_holidays,
bestseller_products=bestseller_products
)
ollama_result = self._convert_response(ollama_response, 'ollama')
return self._gemini_fallback(
ollama_result,
lambda: self._gemini.generate_sales_copy(
product_name=product_name,
trend_keywords=trend_keywords,
style=style,
upcoming_holidays=upcoming_holidays,
bestseller_products=bestseller_products,
model=model
),
)
def extract_keywords(self, text: str, provider: str = None,
model: str = None, max_keywords: int = 10) -> AIResponse:
"""
提取關鍵字(統一介面)
Args:
text: 要分析的文字
provider: 指定提供者
model: 指定模型
max_keywords: 最大關鍵字數量
Returns:
AIResponse
"""
ollama_response = self._ollama.extract_keywords(text, max_keywords)
ollama_result = self._convert_response(ollama_response, 'ollama')
return self._gemini_fallback(
ollama_result,
lambda: self._gemini.extract_keywords(text, max_keywords, model),
)
def search_product_insights(self, product_name: str, provider: str = None,
model: str = None, include_competitors: bool = True,
include_trends: bool = True,
web_context: str = "") -> AIResponse:
"""
搜尋商品市場洞察(統一介面)
Args:
product_name: 商品名稱
provider: 指定提供者
model: 指定模型
include_competitors: 是否包含競品分析
include_trends: 是否包含趨勢分析
web_context: 網路搜尋結果(用於 Ollama
Returns:
AIResponse
"""
ollama_response = self._ollama.search_product_insights(
product_name=product_name,
include_competitors=include_competitors,
include_trends=include_trends,
web_context=web_context
)
ollama_result = self._convert_response(ollama_response, 'ollama')
return self._gemini_fallback(
ollama_result,
lambda: self._gemini.search_product_insights(
product_name=product_name,
include_competitors=include_competitors,
include_trends=include_trends,
model=model
),
)
def web_search(self, query: str, provider: str = None, model: str = None,
num_results: int = 5, search_type: str = "general") -> AIResponse:
"""
網路搜尋(統一介面)- 僅 Ollama 支援
Args:
query: 搜尋關鍵字
provider: 指定提供者(會忽略 gemini
model: 指定模型
num_results: 返回結果數量
search_type: 搜尋類型
Returns:
AIResponse
"""
# 網路搜尋目前僅 Ollama 支援
response = self._ollama.web_search(query, num_results, search_type)
return self._convert_response(response, 'ollama')
def search_trend_keywords(self, category: str, provider: str = None,
model: str = None, time_range: str = "week") -> AIResponse:
"""
搜尋趨勢關鍵字(統一介面)- 僅 Ollama 支援
Args:
category: 商品分類
provider: 指定提供者(會忽略 gemini
model: 指定模型
time_range: 時間範圍
Returns:
AIResponse
"""
# 趨勢關鍵字目前僅 Ollama 支援
response = self._ollama.search_trend_keywords(category, time_range)
return self._convert_response(response, 'ollama')
# 建立全域服務實例
ai_provider_service = AIProviderService()
# 便捷函數
def get_ai_status(force_refresh: bool = False) -> Dict[str, Any]:
"""取得 AI 服務狀態"""
return ai_provider_service.get_status(force_refresh)
def set_ai_provider(provider: str) -> bool:
"""設定預設 AI 提供者"""
try:
ai_provider_service.default_provider = provider
return True
except ValueError:
return False
def generate_copy(product_name: str, provider: str = None, model: str = None,
**kwargs) -> AIResponse:
"""生成文案(便捷函數)"""
return ai_provider_service.generate_sales_copy(
product_name=product_name,
provider=provider,
model=model,
**kwargs
)
if __name__ == "__main__":
# 測試程式碼
logging.basicConfig(level=logging.INFO)
service = AIProviderService()
# 檢查狀態
print("檢查 AI 服務狀態...")
status = service.get_status()
print(f"Ollama: {'✅ 已連線' if status['ollama']['connected'] else '❌ 未連線'}")
print(f"Gemini: {'✅ 已連線' if status['gemini']['connected'] else '❌ 未連線'}")
print(f"預設提供者: {status['default_provider']}")
print(f"推薦提供者: {status['recommended_provider']}")
# 測試文案生成
if status['recommended_provider'] != 'none':
print(f"\n使用 {status['recommended_provider']} 測試文案生成...")
result = service.generate_sales_copy(
product_name="玻尿酸保濕面膜",
provider=status['recommended_provider'],
trend_keywords=["換季保養", "敏感肌"],
style="吸睛"
)
if result.success:
print(f"\n生成結果:\n{result.content}")
print(f"\n提供者: {result.provider}")
print(f"模型: {result.model}")
print(f"耗時: {result.total_duration:.2f}")
if result.total_cost > 0:
print(f"費用: ${result.total_cost:.6f} USD")
else:
print(f"生成失敗: {result.error}")