490 lines
18 KiB
Python
490 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
AI 提供者抽象層
|
||
統一 Ollama 和 Gemini 的介面;通用生成一律 Ollama-first,Gemini 僅作備援
|
||
"""
|
||
|
||
import os
|
||
import logging
|
||
from typing import Optional, Dict, Any, List, Union
|
||
from dataclasses import dataclass
|
||
from datetime import date, datetime
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# AI 提供者設定:Gemini 不可作為預設,只能在 Ollama 失敗後備援
|
||
AI_PROVIDER = os.getenv('AI_PROVIDER', 'ollama')
|
||
|
||
# 引入服務
|
||
from .ollama_service import OllamaService, OllamaResponse
|
||
from .gemini_service import GeminiService, GeminiResponse, AVAILABLE_GEMINI_MODELS
|
||
from .elephant_service import ElephantService, ElephantResponse
|
||
|
||
|
||
@dataclass
|
||
class AIResponse:
|
||
"""統一的 AI 回應結構"""
|
||
success: bool
|
||
content: str
|
||
model: str
|
||
provider: str # 'ollama', 'gemini' 或 'elephant'
|
||
error: Optional[str] = None
|
||
total_duration: Optional[float] = None
|
||
input_tokens: int = 0
|
||
output_tokens: int = 0
|
||
total_tokens: int = 0
|
||
input_cost: float = 0.0
|
||
output_cost: float = 0.0
|
||
total_cost: float = 0.0
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""轉換為字典格式"""
|
||
return {
|
||
'success': self.success,
|
||
'content': self.content,
|
||
'model': self.model,
|
||
'provider': self.provider,
|
||
'error': self.error,
|
||
'total_duration': self.total_duration,
|
||
'input_tokens': self.input_tokens,
|
||
'output_tokens': self.output_tokens,
|
||
'total_tokens': self.total_tokens,
|
||
'input_cost': self.input_cost,
|
||
'output_cost': self.output_cost,
|
||
'total_cost': self.total_cost,
|
||
}
|
||
|
||
|
||
class AIProviderService:
|
||
"""
|
||
AI 提供者服務 - 統一的 AI 介面
|
||
|
||
通用文案 / 關鍵字 / 洞察生成一律先走 Ollama 三主機級聯。
|
||
Gemini 只在 Ollama 主路徑失敗時作為 fallback,不允許被設為預設主路徑。
|
||
"""
|
||
|
||
def __init__(self, default_provider: str = None):
|
||
"""
|
||
初始化 AI 提供者服務
|
||
|
||
Args:
|
||
default_provider: 預設提供者(Gemini 會被降回 Ollama)
|
||
"""
|
||
self._default_provider = self._sanitize_default_provider(default_provider or AI_PROVIDER)
|
||
self._ollama = OllamaService()
|
||
self._gemini = GeminiService()
|
||
self._elephant = ElephantService()
|
||
|
||
# 狀態快取
|
||
self._status_cache = {'timestamp': 0, 'data': None}
|
||
self._CACHE_TTL = 60 # 60 秒
|
||
|
||
@property
|
||
def default_provider(self) -> str:
|
||
"""取得預設提供者"""
|
||
return self._default_provider
|
||
|
||
@default_provider.setter
|
||
def default_provider(self, value: str):
|
||
"""設定預設提供者"""
|
||
self._default_provider = self._sanitize_default_provider(value)
|
||
logger.info(f"AI 預設提供者已切換至: {value}")
|
||
|
||
@staticmethod
|
||
def _sanitize_default_provider(value: str) -> str:
|
||
"""Gemini 是 fallback-only;任何預設 Gemini 設定都強制回 Ollama。"""
|
||
normalized = (value or 'ollama').strip().lower()
|
||
if normalized == 'gemini':
|
||
logger.warning("AI_PROVIDER=gemini 已被拒絕;Gemini 僅作 Ollama 失敗備援")
|
||
return 'ollama'
|
||
if normalized not in ('ollama', 'elephant'):
|
||
raise ValueError("Provider must be 'ollama' or 'elephant'; gemini is fallback-only")
|
||
return normalized
|
||
|
||
def get_status(self, force_refresh: bool = False) -> Dict[str, Any]:
|
||
"""
|
||
取得所有 AI 服務的狀態
|
||
|
||
Args:
|
||
force_refresh: 是否強制刷新快取
|
||
|
||
Returns:
|
||
dict: 包含 Ollama 和 Gemini 狀態的字典
|
||
"""
|
||
import time
|
||
now = time.time()
|
||
|
||
# 使用快取
|
||
if not force_refresh and self._status_cache['data'] is not None:
|
||
if now - self._status_cache['timestamp'] < self._CACHE_TTL:
|
||
return self._status_cache['data']
|
||
|
||
# 檢查各服務狀態
|
||
ollama_connected = self._ollama.check_connection()
|
||
gemini_connected = self._gemini.check_connection()
|
||
elephant_connected = self._elephant.check_connection()
|
||
|
||
status = {
|
||
'default_provider': self._default_provider,
|
||
'ollama': {
|
||
'connected': ollama_connected,
|
||
'model': self._ollama.model if ollama_connected else None,
|
||
'available_models': self._ollama.available_models if ollama_connected else [],
|
||
'type': 'local',
|
||
'cost': 'free'
|
||
},
|
||
'gemini': {
|
||
'connected': gemini_connected,
|
||
'model': self._gemini.model if gemini_connected else None,
|
||
'available_models': AVAILABLE_GEMINI_MODELS,
|
||
'type': 'cloud',
|
||
'cost': 'paid'
|
||
},
|
||
'elephant': {
|
||
'connected': elephant_connected,
|
||
'model': self._elephant.model if elephant_connected else None,
|
||
'available_models': [{'id': 'nvidia/llama-3.1-nemotron-ultra-253b-v1', 'name': 'Nemotron Ultra 253B'}],
|
||
'type': 'cloud',
|
||
'cost': 'efficient'
|
||
},
|
||
'recommended_provider': self._get_recommended_provider(ollama_connected, gemini_connected, elephant_connected),
|
||
'timestamp': datetime.now().isoformat()
|
||
}
|
||
|
||
# 更新快取
|
||
self._status_cache = {'timestamp': now, 'data': status}
|
||
return status
|
||
|
||
def _get_recommended_provider(self, ollama_ok: bool, gemini_ok: bool, elephant_ok: bool) -> str:
|
||
"""根據可用性推薦提供者"""
|
||
if ollama_ok:
|
||
return 'ollama'
|
||
if self._default_provider == 'elephant' and elephant_ok:
|
||
return 'elephant'
|
||
if gemini_ok:
|
||
return 'gemini'
|
||
if elephant_ok:
|
||
return 'elephant'
|
||
return 'none'
|
||
|
||
def _convert_response(self, response: Union[OllamaResponse, GeminiResponse],
|
||
provider: str) -> AIResponse:
|
||
"""將服務回應轉換為統一格式"""
|
||
if isinstance(response, OllamaResponse):
|
||
return AIResponse(
|
||
success=response.success,
|
||
content=response.content,
|
||
model=response.model,
|
||
provider='ollama',
|
||
error=response.error,
|
||
total_duration=response.total_duration,
|
||
input_tokens=0,
|
||
output_tokens=0,
|
||
total_tokens=0,
|
||
input_cost=0.0,
|
||
output_cost=0.0,
|
||
total_cost=0.0 # Ollama 是免費的
|
||
)
|
||
elif isinstance(response, GeminiResponse):
|
||
return AIResponse(
|
||
success=response.success,
|
||
content=response.content,
|
||
model=response.model,
|
||
provider='gemini',
|
||
error=response.error,
|
||
total_duration=response.total_duration,
|
||
input_tokens=response.input_tokens,
|
||
output_tokens=response.output_tokens,
|
||
total_tokens=response.total_tokens,
|
||
input_cost=response.input_cost,
|
||
output_cost=response.output_cost,
|
||
total_cost=response.total_cost
|
||
)
|
||
elif isinstance(response, ElephantResponse):
|
||
return AIResponse(
|
||
success=response.success,
|
||
content=response.content,
|
||
model=response.model,
|
||
provider='elephant',
|
||
error=response.error,
|
||
total_duration=response.total_duration,
|
||
input_tokens=response.input_tokens,
|
||
output_tokens=response.output_tokens,
|
||
total_tokens=response.total_tokens,
|
||
input_cost=response.input_cost,
|
||
output_cost=response.output_cost,
|
||
total_cost=response.total_cost
|
||
)
|
||
else:
|
||
return AIResponse(
|
||
success=False,
|
||
content='',
|
||
model='unknown',
|
||
provider=provider,
|
||
error='Unknown response type'
|
||
)
|
||
|
||
def _gemini_fallback(self, ollama_result: AIResponse, fallback_call) -> AIResponse:
|
||
"""Ollama 失敗時才呼叫 Gemini,讓通用 AI 入口符合 Ollama-first。"""
|
||
if ollama_result.success:
|
||
return ollama_result
|
||
logger.warning("Ollama 主路徑失敗,啟用 Gemini 備援:%s", ollama_result.error)
|
||
try:
|
||
gemini_response = fallback_call()
|
||
gemini_result = self._convert_response(gemini_response, 'gemini')
|
||
if not gemini_result.success and ollama_result.error:
|
||
gemini_result.error = f"Ollama 失敗:{ollama_result.error};Gemini 備援失敗:{gemini_result.error}"
|
||
return gemini_result
|
||
except Exception as exc:
|
||
ollama_result.error = f"{ollama_result.error}; Gemini 備援例外:{exc}"
|
||
return ollama_result
|
||
|
||
def generate(self, prompt: str, provider: str = None, model: str = None,
|
||
system_prompt: str = None, temperature: float = 0.7,
|
||
timeout: int = None) -> AIResponse:
|
||
"""
|
||
生成文字(統一介面)
|
||
|
||
Args:
|
||
prompt: 使用者提示
|
||
provider: 指定提供者;Gemini 只會在 Ollama 失敗後備援
|
||
model: 指定模型
|
||
system_prompt: 系統提示
|
||
temperature: 創意度
|
||
timeout: 超時時間(秒)
|
||
|
||
Returns:
|
||
AIResponse
|
||
"""
|
||
provider = (provider or self._default_provider or 'ollama').strip().lower()
|
||
|
||
if provider == 'elephant':
|
||
response = self._elephant.generate(
|
||
prompt=prompt,
|
||
model=model,
|
||
system_prompt=system_prompt,
|
||
temperature=temperature,
|
||
timeout=timeout
|
||
)
|
||
|
||
return self._convert_response(response, provider)
|
||
|
||
ollama_response = self._ollama.generate(
|
||
prompt=prompt,
|
||
model=model,
|
||
system_prompt=system_prompt,
|
||
temperature=temperature,
|
||
timeout=timeout
|
||
)
|
||
ollama_result = self._convert_response(ollama_response, 'ollama')
|
||
return self._gemini_fallback(
|
||
ollama_result,
|
||
lambda: self._gemini.generate(
|
||
prompt=prompt,
|
||
model=model,
|
||
system_prompt=system_prompt,
|
||
temperature=temperature,
|
||
timeout=timeout
|
||
),
|
||
)
|
||
|
||
def generate_sales_copy(self, product_name: str, provider: str = None,
|
||
model: str = None, trend_keywords: List[str] = None,
|
||
style: str = "吸睛", upcoming_holidays: List[Dict] = None,
|
||
bestseller_products: List[Dict] = None) -> AIResponse:
|
||
"""
|
||
生成銷售文案(統一介面)
|
||
|
||
Args:
|
||
product_name: 商品名稱
|
||
provider: 指定提供者
|
||
model: 指定模型
|
||
trend_keywords: 趨勢關鍵字
|
||
style: 文案風格
|
||
upcoming_holidays: 即將到來的假期
|
||
bestseller_products: 競品熱銷商品
|
||
|
||
Returns:
|
||
AIResponse
|
||
"""
|
||
provider = (provider or self._default_provider or 'ollama').strip().lower()
|
||
|
||
ollama_response = self._ollama.generate_sales_copy(
|
||
product_name=product_name,
|
||
trend_keywords=trend_keywords,
|
||
style=style,
|
||
upcoming_holidays=upcoming_holidays,
|
||
bestseller_products=bestseller_products
|
||
)
|
||
ollama_result = self._convert_response(ollama_response, 'ollama')
|
||
return self._gemini_fallback(
|
||
ollama_result,
|
||
lambda: self._gemini.generate_sales_copy(
|
||
product_name=product_name,
|
||
trend_keywords=trend_keywords,
|
||
style=style,
|
||
upcoming_holidays=upcoming_holidays,
|
||
bestseller_products=bestseller_products,
|
||
model=model
|
||
),
|
||
)
|
||
|
||
def extract_keywords(self, text: str, provider: str = None,
|
||
model: str = None, max_keywords: int = 10) -> AIResponse:
|
||
"""
|
||
提取關鍵字(統一介面)
|
||
|
||
Args:
|
||
text: 要分析的文字
|
||
provider: 指定提供者
|
||
model: 指定模型
|
||
max_keywords: 最大關鍵字數量
|
||
|
||
Returns:
|
||
AIResponse
|
||
"""
|
||
ollama_response = self._ollama.extract_keywords(text, max_keywords)
|
||
ollama_result = self._convert_response(ollama_response, 'ollama')
|
||
return self._gemini_fallback(
|
||
ollama_result,
|
||
lambda: self._gemini.extract_keywords(text, max_keywords, model),
|
||
)
|
||
|
||
def search_product_insights(self, product_name: str, provider: str = None,
|
||
model: str = None, include_competitors: bool = True,
|
||
include_trends: bool = True,
|
||
web_context: str = "") -> AIResponse:
|
||
"""
|
||
搜尋商品市場洞察(統一介面)
|
||
|
||
Args:
|
||
product_name: 商品名稱
|
||
provider: 指定提供者
|
||
model: 指定模型
|
||
include_competitors: 是否包含競品分析
|
||
include_trends: 是否包含趨勢分析
|
||
web_context: 網路搜尋結果(用於 Ollama)
|
||
|
||
Returns:
|
||
AIResponse
|
||
"""
|
||
ollama_response = self._ollama.search_product_insights(
|
||
product_name=product_name,
|
||
include_competitors=include_competitors,
|
||
include_trends=include_trends,
|
||
web_context=web_context
|
||
)
|
||
ollama_result = self._convert_response(ollama_response, 'ollama')
|
||
return self._gemini_fallback(
|
||
ollama_result,
|
||
lambda: self._gemini.search_product_insights(
|
||
product_name=product_name,
|
||
include_competitors=include_competitors,
|
||
include_trends=include_trends,
|
||
model=model
|
||
),
|
||
)
|
||
|
||
def web_search(self, query: str, provider: str = None, model: str = None,
|
||
num_results: int = 5, search_type: str = "general") -> AIResponse:
|
||
"""
|
||
網路搜尋(統一介面)- 僅 Ollama 支援
|
||
|
||
Args:
|
||
query: 搜尋關鍵字
|
||
provider: 指定提供者(會忽略 gemini)
|
||
model: 指定模型
|
||
num_results: 返回結果數量
|
||
search_type: 搜尋類型
|
||
|
||
Returns:
|
||
AIResponse
|
||
"""
|
||
# 網路搜尋目前僅 Ollama 支援
|
||
response = self._ollama.web_search(query, num_results, search_type)
|
||
return self._convert_response(response, 'ollama')
|
||
|
||
def search_trend_keywords(self, category: str, provider: str = None,
|
||
model: str = None, time_range: str = "week") -> AIResponse:
|
||
"""
|
||
搜尋趨勢關鍵字(統一介面)- 僅 Ollama 支援
|
||
|
||
Args:
|
||
category: 商品分類
|
||
provider: 指定提供者(會忽略 gemini)
|
||
model: 指定模型
|
||
time_range: 時間範圍
|
||
|
||
Returns:
|
||
AIResponse
|
||
"""
|
||
# 趨勢關鍵字目前僅 Ollama 支援
|
||
response = self._ollama.search_trend_keywords(category, time_range)
|
||
return self._convert_response(response, 'ollama')
|
||
|
||
|
||
# 建立全域服務實例
|
||
ai_provider_service = AIProviderService()
|
||
|
||
|
||
# 便捷函數
|
||
def get_ai_status(force_refresh: bool = False) -> Dict[str, Any]:
|
||
"""取得 AI 服務狀態"""
|
||
return ai_provider_service.get_status(force_refresh)
|
||
|
||
|
||
def set_ai_provider(provider: str) -> bool:
|
||
"""設定預設 AI 提供者"""
|
||
try:
|
||
ai_provider_service.default_provider = provider
|
||
return True
|
||
except ValueError:
|
||
return False
|
||
|
||
|
||
def generate_copy(product_name: str, provider: str = None, model: str = None,
|
||
**kwargs) -> AIResponse:
|
||
"""生成文案(便捷函數)"""
|
||
return ai_provider_service.generate_sales_copy(
|
||
product_name=product_name,
|
||
provider=provider,
|
||
model=model,
|
||
**kwargs
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 測試程式碼
|
||
logging.basicConfig(level=logging.INFO)
|
||
|
||
service = AIProviderService()
|
||
|
||
# 檢查狀態
|
||
print("檢查 AI 服務狀態...")
|
||
status = service.get_status()
|
||
print(f"Ollama: {'✅ 已連線' if status['ollama']['connected'] else '❌ 未連線'}")
|
||
print(f"Gemini: {'✅ 已連線' if status['gemini']['connected'] else '❌ 未連線'}")
|
||
print(f"預設提供者: {status['default_provider']}")
|
||
print(f"推薦提供者: {status['recommended_provider']}")
|
||
|
||
# 測試文案生成
|
||
if status['recommended_provider'] != 'none':
|
||
print(f"\n使用 {status['recommended_provider']} 測試文案生成...")
|
||
result = service.generate_sales_copy(
|
||
product_name="玻尿酸保濕面膜",
|
||
provider=status['recommended_provider'],
|
||
trend_keywords=["換季保養", "敏感肌"],
|
||
style="吸睛"
|
||
)
|
||
|
||
if result.success:
|
||
print(f"\n生成結果:\n{result.content}")
|
||
print(f"\n提供者: {result.provider}")
|
||
print(f"模型: {result.model}")
|
||
print(f"耗時: {result.total_duration:.2f} 秒")
|
||
if result.total_cost > 0:
|
||
print(f"費用: ${result.total_cost:.6f} USD")
|
||
else:
|
||
print(f"生成失敗: {result.error}")
|