Files
awoooi/apps/api/src/services/embedding_service.py
Your Name 2dcd214156
All checks were successful
CD Pipeline / tests (push) Successful in 58s
Code Review / ai-code-review (push) Successful in 11s
CD Pipeline / build-and-deploy (push) Successful in 3m45s
CD Pipeline / post-deploy-checks (push) Successful in 1m17s
fix(ollama): cooldown noisy failed endpoints
2026-05-25 12:11:48 +08:00

308 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Embedding Service - Ollama bge-m3:latest 專用向量化
===================================================
使用 Ollama bge-m3:latest 提供文本向量化功能1024 維)。
bge-m3 為專用多語言 embedding 模型,支援繁中/英文語義搜尋。
Phase 13.2 #84 - RAG Tool 基礎設施
ADR-110 2026-05-04: GCP-A Primary 升級 bge-m3768→1024 維遷移)
版本: v1.2
建立日期: 2026-03-26 20:30 (台北時區)
更新日期: 2026-05-04 (台北時區) — ADR-110 bge-m3 升級
建立者: Claude Code
更新者: ogt + Claude Sonnet 4.6 (ADR-110 GCP-A Primary)
"""
import asyncio
from typing import Protocol
import httpx
import structlog
from src.services.model_registry import get_model as _get_model
from src.services.ollama_endpoint_circuit_breaker import (
filter_ollama_urls_with_cooldown,
record_ollama_endpoint_failure,
record_ollama_endpoint_success,
)
from src.services.ollama_endpoint_resolver import resolve_ollama_order
logger = structlog.get_logger(__name__)
# =============================================================================
# Interface (DI Protocol)
# =============================================================================
class IEmbeddingService(Protocol):
"""Embedding 服務介面"""
async def embed_text(self, text: str) -> list[float]:
"""將單一文本轉換為向量"""
...
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""批次向量化多個文本"""
...
@property
def dimension(self) -> int:
"""向量維度"""
...
# =============================================================================
# Implementation
# =============================================================================
class OllamaEmbeddingService:
"""
Ollama Embedding Service
使用 Ollama API 進行文本向量化。
預設使用 bge-m3:latest (1024 維向量),來自 GCP-A (34.143.170.20)。
Usage:
service = OllamaEmbeddingService()
vector = await service.embed_text("維運手冊")
"""
# 已知模型維度 (P1 修復: 避免硬編碼)
MODEL_DIMENSIONS: dict[str, int] = {
"qwen2.5:7b-instruct": 3584,
"qwen2.5:3b-instruct": 2048,
"llama3.2:3b": 3072,
"nomic-embed-text": 768,
# 2026-05-04 ogt + Claude Sonnet 4.6: ADR-110 GCP-A Primary — bge-m3 專用 embedding 模型
# bge-m3 產生 1024 維向量pgvector schema 已遷移至 vector(1024)(見 embedding_bge_m3_1024.sql
"bge-m3:latest": 1024,
"bge-m3": 1024,
}
DEFAULT_DIMENSION = 3584 # 未知模型的預設值
def __init__(
self,
model: str = "bge-m3:latest",
ollama_url: str | None = None,
timeout: float = 30.0,
default_dimension: int | None = None,
) -> None:
"""
初始化 Embedding Service
Args:
model: Ollama 模型名稱 (必須支援 embedding)
ollama_url: Ollama API URL (預設從 config 讀取)
timeout: 請求超時 (秒)
default_dimension: 預設向量維度 (可選,未提供則從 MODEL_DIMENSIONS 查詢)
P1 修復 (2026-03-29): 維度配置化,支援更多模型
"""
self._model = model
if ollama_url:
self._ollama_endpoints = ((ollama_url, "custom"),)
else:
self._ollama_endpoints = tuple(
(endpoint.url, endpoint.provider_name)
for endpoint in resolve_ollama_order("embedding")
if endpoint.url
)
self._ollama_url = self._ollama_endpoints[0][0] if self._ollama_endpoints else ""
self._timeout = timeout
self._default_dimension = default_dimension or self.MODEL_DIMENSIONS.get(
model, self.DEFAULT_DIMENSION
)
self._dimension: int | None = None
self._client: httpx.AsyncClient | None = None
@property
def dimension(self) -> int:
"""
向量維度
首次 embed 呼叫會自動偵測實際維度並快取。
偵測前返回 MODEL_DIMENSIONS 中的預設值。
"""
if self._dimension is None:
return self._default_dimension
return self._dimension
async def _get_client(self) -> httpx.AsyncClient:
"""取得 HTTP Client (連線池共用)"""
if self._client is None:
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(self._timeout),
limits=httpx.Limits(max_connections=10),
)
return self._client
def _active_ollama_endpoints(self) -> tuple[tuple[str, str], ...]:
active_urls = filter_ollama_urls_with_cooldown(
url for url, _provider_name in self._ollama_endpoints
)
active_url_set = set(active_urls)
return tuple(
(url, provider_name)
for url, provider_name in self._ollama_endpoints
if url in active_url_set
)
async def embed_text(self, text: str) -> list[float]:
"""
將單一文本轉換為向量
Args:
text: 要向量化的文本
Returns:
list[float]: 向量 (3584 維)
Raises:
EmbeddingError: 向量化失敗
"""
client = await self._get_client()
last_error: Exception | None = None
for endpoint_url, provider_name in self._active_ollama_endpoints():
try:
response = await client.post(
f"{endpoint_url}/api/embeddings",
json={
"model": self._model,
"prompt": text,
},
)
response.raise_for_status()
record_ollama_endpoint_success(endpoint_url)
data = response.json()
embedding = data.get("embedding", [])
# 更新維度快取
if self._dimension is None and embedding:
self._dimension = len(embedding)
logger.info(
"embedding_dimension_detected",
model=self._model,
dimension=self._dimension,
provider=provider_name,
)
return embedding
except httpx.TimeoutException as e:
last_error = e
record_ollama_endpoint_failure(endpoint_url)
logger.error(
"embedding_timeout",
model=self._model,
text_len=len(text),
provider=provider_name,
)
except httpx.HTTPStatusError as e:
last_error = e
if e.response.status_code >= 500:
record_ollama_endpoint_failure(endpoint_url)
logger.error(
"embedding_http_error",
status=e.response.status_code,
model=self._model,
provider=provider_name,
)
except Exception as e:
last_error = e
if isinstance(e, httpx.TransportError):
record_ollama_endpoint_failure(endpoint_url)
logger.error(
"embedding_error",
error=str(e),
model=self._model,
provider=provider_name,
)
if isinstance(last_error, httpx.TimeoutException):
raise EmbeddingError(f"Embedding timeout after {self._timeout}s") from last_error
if isinstance(last_error, httpx.HTTPStatusError):
raise EmbeddingError(
f"Ollama API error: {last_error.response.status_code}"
) from last_error
raise EmbeddingError("Embedding failed on all Ollama endpoints") from last_error
async def embed_batch(
self,
texts: list[str],
concurrency: int = 5,
) -> list[list[float]]:
"""
批次向量化多個文本
Args:
texts: 文本列表
concurrency: 同時並行數 (避免過載 Ollama)
Returns:
list[list[float]]: 向量列表 (與輸入順序對應)
"""
if not texts:
return []
results: list[list[float]] = []
semaphore = asyncio.Semaphore(concurrency)
async def embed_with_semaphore(text: str) -> list[float]:
async with semaphore:
return await self.embed_text(text)
tasks = [embed_with_semaphore(text) for text in texts]
results = await asyncio.gather(*tasks, return_exceptions=False)
logger.info(
"batch_embedding_complete",
count=len(texts),
model=self._model,
)
return results
async def close(self) -> None:
"""關閉連線"""
if self._client:
await self._client.aclose()
self._client = None
# =============================================================================
# Errors
# =============================================================================
class EmbeddingError(Exception):
"""Embedding 操作錯誤"""
pass
# =============================================================================
# Singleton Factory
# =============================================================================
_embedding_service: OllamaEmbeddingService | None = None
def get_embedding_service() -> OllamaEmbeddingService:
"""
取得 Embedding Service 單例
D1 集中化 2026-04-11: 預設模型從 models.json providers.ollama.models.embedding 讀取
Returns:
OllamaEmbeddingService: 共用實例
"""
global _embedding_service
if _embedding_service is None:
_embedding_service = OllamaEmbeddingService(model=_get_model("ollama", "embedding"))
return _embedding_service