Files
awoooi/apps/api/src/services/embedding_service.py
OG T 8724ed7dcf fix(mcp): P1 修復 - DI 一致性 + 測試補充 + 配置優化
首席架構師審查 P1 修復清單:

P1-1 RAG Provider DI 模式一致性:
- 支援 rag_service 參數注入
- 新增 close() 方法
- TYPE_CHECKING 延遲導入

P1-3 RAG 測試補充:
- test_rag_provider.py (9 tests)
- DI 注入/Lazy Load/Tool Schema/驗證/Close

P1-4 Grafana Config 快取優化:
- URL/Key 首次查詢後快取
- 減少重複 settings 存取

P1-5 Embedding 維度配置化:
- MODEL_DIMENSIONS 字典 (qwen/llama/nomic)
- default_dimension 參數
- 支援更多模型

測試: 9/9 PASSED

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-03-29 16:23:30 +08:00

249 lines
7.1 KiB
Python

"""
Embedding Service - Ollama BGE-M3 替代方案
==========================================
使用 Ollama qwen2.5:7b-instruct 提供文本向量化功能。
雖非專用 embedding 模型,但支援多語言 (繁中/英文)。
Phase 13.2 #84 - RAG Tool 基礎設施
版本: v1.1
建立日期: 2026-03-26 20:30 (台北時區)
更新日期: 2026-03-29 20:50 (台北時區)
建立者: Claude Code
更新者: Claude Code (P1 修復: 維度配置化)
"""
import asyncio
from typing import Protocol
import httpx
import structlog
from src.core.config import settings
logger = structlog.get_logger(__name__)
# =============================================================================
# Interface (DI Protocol)
# =============================================================================
class IEmbeddingService(Protocol):
"""Embedding 服務介面"""
async def embed_text(self, text: str) -> list[float]:
"""將單一文本轉換為向量"""
...
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""批次向量化多個文本"""
...
@property
def dimension(self) -> int:
"""向量維度"""
...
# =============================================================================
# Implementation
# =============================================================================
class OllamaEmbeddingService:
"""
Ollama Embedding Service
使用 Ollama API 進行文本向量化。
預設使用 qwen2.5:7b-instruct (3584 維向量)。
Usage:
service = OllamaEmbeddingService()
vector = await service.embed_text("維運手冊")
"""
# 已知模型維度 (P1 修復: 避免硬編碼)
MODEL_DIMENSIONS: dict[str, int] = {
"qwen2.5:7b-instruct": 3584,
"qwen2.5:3b-instruct": 2048,
"llama3.2:3b": 3072,
"nomic-embed-text": 768,
}
DEFAULT_DIMENSION = 3584 # 未知模型的預設值
def __init__(
self,
model: str = "qwen2.5:7b-instruct",
ollama_url: str | None = None,
timeout: float = 30.0,
default_dimension: int | None = None,
) -> None:
"""
初始化 Embedding Service
Args:
model: Ollama 模型名稱 (必須支援 embedding)
ollama_url: Ollama API URL (預設從 config 讀取)
timeout: 請求超時 (秒)
default_dimension: 預設向量維度 (可選,未提供則從 MODEL_DIMENSIONS 查詢)
P1 修復 (2026-03-29): 維度配置化,支援更多模型
"""
self._model = model
self._ollama_url = ollama_url or settings.OLLAMA_URL
self._timeout = timeout
self._default_dimension = default_dimension or self.MODEL_DIMENSIONS.get(
model, self.DEFAULT_DIMENSION
)
self._dimension: int | None = None
self._client: httpx.AsyncClient | None = None
@property
def dimension(self) -> int:
"""
向量維度
首次 embed 呼叫會自動偵測實際維度並快取。
偵測前返回 MODEL_DIMENSIONS 中的預設值。
"""
if self._dimension is None:
return self._default_dimension
return self._dimension
async def _get_client(self) -> httpx.AsyncClient:
"""取得 HTTP Client (連線池共用)"""
if self._client is None:
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(self._timeout),
limits=httpx.Limits(max_connections=10),
)
return self._client
async def embed_text(self, text: str) -> list[float]:
"""
將單一文本轉換為向量
Args:
text: 要向量化的文本
Returns:
list[float]: 向量 (3584 維)
Raises:
EmbeddingError: 向量化失敗
"""
client = await self._get_client()
try:
response = await client.post(
f"{self._ollama_url}/api/embeddings",
json={
"model": self._model,
"prompt": text,
},
)
response.raise_for_status()
data = response.json()
embedding = data.get("embedding", [])
# 更新維度快取
if self._dimension is None and embedding:
self._dimension = len(embedding)
logger.info(
"embedding_dimension_detected",
model=self._model,
dimension=self._dimension,
)
return embedding
except httpx.TimeoutException as e:
logger.error("embedding_timeout", model=self._model, text_len=len(text))
raise EmbeddingError(f"Embedding timeout after {self._timeout}s") from e
except httpx.HTTPStatusError as e:
logger.error(
"embedding_http_error",
status=e.response.status_code,
model=self._model,
)
raise EmbeddingError(f"Ollama API error: {e.response.status_code}") from e
except Exception as e:
logger.error("embedding_error", error=str(e), model=self._model)
raise EmbeddingError(f"Embedding failed: {e}") from e
async def embed_batch(
self,
texts: list[str],
concurrency: int = 5,
) -> list[list[float]]:
"""
批次向量化多個文本
Args:
texts: 文本列表
concurrency: 同時並行數 (避免過載 Ollama)
Returns:
list[list[float]]: 向量列表 (與輸入順序對應)
"""
if not texts:
return []
results: list[list[float]] = []
semaphore = asyncio.Semaphore(concurrency)
async def embed_with_semaphore(text: str) -> list[float]:
async with semaphore:
return await self.embed_text(text)
tasks = [embed_with_semaphore(text) for text in texts]
results = await asyncio.gather(*tasks, return_exceptions=False)
logger.info(
"batch_embedding_complete",
count=len(texts),
model=self._model,
)
return results
async def close(self) -> None:
"""關閉連線"""
if self._client:
await self._client.aclose()
self._client = None
# =============================================================================
# Errors
# =============================================================================
class EmbeddingError(Exception):
"""Embedding 操作錯誤"""
pass
# =============================================================================
# Singleton Factory
# =============================================================================
_embedding_service: OllamaEmbeddingService | None = None
def get_embedding_service() -> OllamaEmbeddingService:
"""
取得 Embedding Service 單例
Returns:
OllamaEmbeddingService: 共用實例
"""
global _embedding_service
if _embedding_service is None:
_embedding_service = OllamaEmbeddingService()
return _embedding_service