308 lines
9.9 KiB
Python
308 lines
9.9 KiB
Python
"""
|
||
Embedding Service - Ollama bge-m3:latest 專用向量化
|
||
===================================================
|
||
|
||
使用 Ollama bge-m3:latest 提供文本向量化功能(1024 維)。
|
||
bge-m3 為專用多語言 embedding 模型,支援繁中/英文語義搜尋。
|
||
|
||
Phase 13.2 #84 - RAG Tool 基礎設施
|
||
ADR-110 2026-05-04: GCP-A Primary 升級 bge-m3(768→1024 維遷移)
|
||
|
||
版本: v1.2
|
||
建立日期: 2026-03-26 20:30 (台北時區)
|
||
更新日期: 2026-05-04 (台北時區) — ADR-110 bge-m3 升級
|
||
建立者: Claude Code
|
||
更新者: ogt + Claude Sonnet 4.6 (ADR-110 GCP-A Primary)
|
||
"""
|
||
|
||
import asyncio
|
||
from typing import Protocol
|
||
|
||
import httpx
|
||
import structlog
|
||
|
||
from src.services.model_registry import get_model as _get_model
|
||
from src.services.ollama_endpoint_circuit_breaker import (
|
||
filter_ollama_urls_with_cooldown,
|
||
record_ollama_endpoint_failure,
|
||
record_ollama_endpoint_success,
|
||
)
|
||
from src.services.ollama_endpoint_resolver import resolve_ollama_order
|
||
|
||
logger = structlog.get_logger(__name__)
|
||
|
||
|
||
# =============================================================================
|
||
# Interface (DI Protocol)
|
||
# =============================================================================
|
||
|
||
|
||
class IEmbeddingService(Protocol):
|
||
"""Embedding 服務介面"""
|
||
|
||
async def embed_text(self, text: str) -> list[float]:
|
||
"""將單一文本轉換為向量"""
|
||
...
|
||
|
||
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||
"""批次向量化多個文本"""
|
||
...
|
||
|
||
@property
|
||
def dimension(self) -> int:
|
||
"""向量維度"""
|
||
...
|
||
|
||
|
||
# =============================================================================
|
||
# Implementation
|
||
# =============================================================================
|
||
|
||
|
||
class OllamaEmbeddingService:
|
||
"""
|
||
Ollama Embedding Service
|
||
|
||
使用 Ollama API 進行文本向量化。
|
||
預設使用 bge-m3:latest (1024 維向量),來自 GCP-A (34.143.170.20)。
|
||
|
||
Usage:
|
||
service = OllamaEmbeddingService()
|
||
vector = await service.embed_text("維運手冊")
|
||
"""
|
||
|
||
# 已知模型維度 (P1 修復: 避免硬編碼)
|
||
MODEL_DIMENSIONS: dict[str, int] = {
|
||
"qwen2.5:7b-instruct": 3584,
|
||
"qwen2.5:3b-instruct": 2048,
|
||
"llama3.2:3b": 3072,
|
||
"nomic-embed-text": 768,
|
||
# 2026-05-04 ogt + Claude Sonnet 4.6: ADR-110 GCP-A Primary — bge-m3 專用 embedding 模型
|
||
# bge-m3 產生 1024 維向量;pgvector schema 已遷移至 vector(1024)(見 embedding_bge_m3_1024.sql)
|
||
"bge-m3:latest": 1024,
|
||
"bge-m3": 1024,
|
||
}
|
||
DEFAULT_DIMENSION = 3584 # 未知模型的預設值
|
||
|
||
def __init__(
|
||
self,
|
||
model: str = "bge-m3:latest",
|
||
ollama_url: str | None = None,
|
||
timeout: float = 30.0,
|
||
default_dimension: int | None = None,
|
||
) -> None:
|
||
"""
|
||
初始化 Embedding Service
|
||
|
||
Args:
|
||
model: Ollama 模型名稱 (必須支援 embedding)
|
||
ollama_url: Ollama API URL (預設從 config 讀取)
|
||
timeout: 請求超時 (秒)
|
||
default_dimension: 預設向量維度 (可選,未提供則從 MODEL_DIMENSIONS 查詢)
|
||
|
||
P1 修復 (2026-03-29): 維度配置化,支援更多模型
|
||
"""
|
||
self._model = model
|
||
if ollama_url:
|
||
self._ollama_endpoints = ((ollama_url, "custom"),)
|
||
else:
|
||
self._ollama_endpoints = tuple(
|
||
(endpoint.url, endpoint.provider_name)
|
||
for endpoint in resolve_ollama_order("embedding")
|
||
if endpoint.url
|
||
)
|
||
self._ollama_url = self._ollama_endpoints[0][0] if self._ollama_endpoints else ""
|
||
self._timeout = timeout
|
||
self._default_dimension = default_dimension or self.MODEL_DIMENSIONS.get(
|
||
model, self.DEFAULT_DIMENSION
|
||
)
|
||
self._dimension: int | None = None
|
||
self._client: httpx.AsyncClient | None = None
|
||
|
||
@property
|
||
def dimension(self) -> int:
|
||
"""
|
||
向量維度
|
||
|
||
首次 embed 呼叫會自動偵測實際維度並快取。
|
||
偵測前返回 MODEL_DIMENSIONS 中的預設值。
|
||
"""
|
||
if self._dimension is None:
|
||
return self._default_dimension
|
||
return self._dimension
|
||
|
||
async def _get_client(self) -> httpx.AsyncClient:
|
||
"""取得 HTTP Client (連線池共用)"""
|
||
if self._client is None:
|
||
self._client = httpx.AsyncClient(
|
||
timeout=httpx.Timeout(self._timeout),
|
||
limits=httpx.Limits(max_connections=10),
|
||
)
|
||
return self._client
|
||
|
||
def _active_ollama_endpoints(self) -> tuple[tuple[str, str], ...]:
|
||
active_urls = filter_ollama_urls_with_cooldown(
|
||
url for url, _provider_name in self._ollama_endpoints
|
||
)
|
||
active_url_set = set(active_urls)
|
||
return tuple(
|
||
(url, provider_name)
|
||
for url, provider_name in self._ollama_endpoints
|
||
if url in active_url_set
|
||
)
|
||
|
||
async def embed_text(self, text: str) -> list[float]:
|
||
"""
|
||
將單一文本轉換為向量
|
||
|
||
Args:
|
||
text: 要向量化的文本
|
||
|
||
Returns:
|
||
list[float]: 向量 (3584 維)
|
||
|
||
Raises:
|
||
EmbeddingError: 向量化失敗
|
||
"""
|
||
client = await self._get_client()
|
||
|
||
last_error: Exception | None = None
|
||
for endpoint_url, provider_name in self._active_ollama_endpoints():
|
||
try:
|
||
response = await client.post(
|
||
f"{endpoint_url}/api/embeddings",
|
||
json={
|
||
"model": self._model,
|
||
"prompt": text,
|
||
},
|
||
)
|
||
response.raise_for_status()
|
||
record_ollama_endpoint_success(endpoint_url)
|
||
|
||
data = response.json()
|
||
embedding = data.get("embedding", [])
|
||
|
||
# 更新維度快取
|
||
if self._dimension is None and embedding:
|
||
self._dimension = len(embedding)
|
||
logger.info(
|
||
"embedding_dimension_detected",
|
||
model=self._model,
|
||
dimension=self._dimension,
|
||
provider=provider_name,
|
||
)
|
||
|
||
return embedding
|
||
|
||
except httpx.TimeoutException as e:
|
||
last_error = e
|
||
record_ollama_endpoint_failure(endpoint_url)
|
||
logger.error(
|
||
"embedding_timeout",
|
||
model=self._model,
|
||
text_len=len(text),
|
||
provider=provider_name,
|
||
)
|
||
except httpx.HTTPStatusError as e:
|
||
last_error = e
|
||
if e.response.status_code >= 500:
|
||
record_ollama_endpoint_failure(endpoint_url)
|
||
logger.error(
|
||
"embedding_http_error",
|
||
status=e.response.status_code,
|
||
model=self._model,
|
||
provider=provider_name,
|
||
)
|
||
except Exception as e:
|
||
last_error = e
|
||
if isinstance(e, httpx.TransportError):
|
||
record_ollama_endpoint_failure(endpoint_url)
|
||
logger.error(
|
||
"embedding_error",
|
||
error=str(e),
|
||
model=self._model,
|
||
provider=provider_name,
|
||
)
|
||
|
||
if isinstance(last_error, httpx.TimeoutException):
|
||
raise EmbeddingError(f"Embedding timeout after {self._timeout}s") from last_error
|
||
if isinstance(last_error, httpx.HTTPStatusError):
|
||
raise EmbeddingError(
|
||
f"Ollama API error: {last_error.response.status_code}"
|
||
) from last_error
|
||
raise EmbeddingError("Embedding failed on all Ollama endpoints") from last_error
|
||
|
||
async def embed_batch(
|
||
self,
|
||
texts: list[str],
|
||
concurrency: int = 5,
|
||
) -> list[list[float]]:
|
||
"""
|
||
批次向量化多個文本
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
concurrency: 同時並行數 (避免過載 Ollama)
|
||
|
||
Returns:
|
||
list[list[float]]: 向量列表 (與輸入順序對應)
|
||
"""
|
||
if not texts:
|
||
return []
|
||
|
||
results: list[list[float]] = []
|
||
semaphore = asyncio.Semaphore(concurrency)
|
||
|
||
async def embed_with_semaphore(text: str) -> list[float]:
|
||
async with semaphore:
|
||
return await self.embed_text(text)
|
||
|
||
tasks = [embed_with_semaphore(text) for text in texts]
|
||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||
|
||
logger.info(
|
||
"batch_embedding_complete",
|
||
count=len(texts),
|
||
model=self._model,
|
||
)
|
||
|
||
return results
|
||
|
||
async def close(self) -> None:
|
||
"""關閉連線"""
|
||
if self._client:
|
||
await self._client.aclose()
|
||
self._client = None
|
||
|
||
|
||
# =============================================================================
|
||
# Errors
|
||
# =============================================================================
|
||
|
||
|
||
class EmbeddingError(Exception):
|
||
"""Embedding 操作錯誤"""
|
||
|
||
pass
|
||
|
||
|
||
# =============================================================================
|
||
# Singleton Factory
|
||
# =============================================================================
|
||
|
||
_embedding_service: OllamaEmbeddingService | None = None
|
||
|
||
|
||
def get_embedding_service() -> OllamaEmbeddingService:
|
||
"""
|
||
取得 Embedding Service 單例
|
||
|
||
D1 集中化 2026-04-11: 預設模型從 models.json providers.ollama.models.embedding 讀取
|
||
Returns:
|
||
OllamaEmbeddingService: 共用實例
|
||
"""
|
||
global _embedding_service
|
||
if _embedding_service is None:
|
||
_embedding_service = OllamaEmbeddingService(model=_get_model("ollama", "embedding"))
|
||
return _embedding_service
|