feat(api): Phase 13.2 AI Rate Limiter + RAG 基礎設施 (#84)
Rate Limiter (防止 Gemini 用量暴衝): - ai_rate_limiter.py: RPM/Daily/Token 三層閥值 - openclaw.py: 整合 rate limit 檢查,超限自動降級 - health.py: /health/ai-usage 監控端點 RAG Tool 基礎 (#84 進行中): - embedding_service.py: Ollama embedding 封裝 - rag_service.py: Redis vector search 服務 閥值設定: - Gemini: 10 RPM, 500/day, 100K tokens/day - Claude: 5 RPM, 200/day, 50K tokens/day Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -240,3 +240,28 @@ async def get_liveness() -> dict[str, str]:
|
||||
"""
|
||||
logger.debug("liveness_probe")
|
||||
return {"status": "alive"}
|
||||
|
||||
|
||||
@router.get("/health/ai-usage")
|
||||
async def get_ai_usage() -> dict:
|
||||
"""
|
||||
AI API 用量監控
|
||||
|
||||
Phase 13.2: Rate Limiter 整合 (2026-03-26)
|
||||
監控 Gemini/Claude API 使用量,防止暴衝。
|
||||
|
||||
Returns:
|
||||
dict: 各 provider 用量統計
|
||||
"""
|
||||
from src.services.ai_rate_limiter import get_ai_rate_limiter
|
||||
|
||||
rate_limiter = get_ai_rate_limiter()
|
||||
|
||||
gemini_stats = await rate_limiter.get_usage_stats("gemini")
|
||||
claude_stats = await rate_limiter.get_usage_stats("claude")
|
||||
|
||||
return {
|
||||
"gemini": gemini_stats,
|
||||
"claude": claude_stats,
|
||||
"fallback_order": settings.AI_FALLBACK_ORDER,
|
||||
}
|
||||
|
||||
272
apps/api/src/services/ai_rate_limiter.py
Normal file
272
apps/api/src/services/ai_rate_limiter.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
AI Rate Limiter - Gemini API 用量閥值控制
|
||||
=========================================
|
||||
|
||||
防止 API 用量暴衝,超過閥值自動降級回 Ollama。
|
||||
|
||||
功能:
|
||||
- 每分鐘請求限制 (RPM)
|
||||
- 每日請求限制
|
||||
- 每日 Token 限制
|
||||
- 超限自動降級
|
||||
|
||||
版本: v1.0
|
||||
建立日期: 2026-03-26 21:00 (台北時區)
|
||||
建立者: Claude Code
|
||||
"""
|
||||
|
||||
import structlog
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration - 閥值設定
|
||||
# =============================================================================
|
||||
|
||||
RATE_LIMITS = {
|
||||
"gemini": {
|
||||
"rpm": 10, # 每分鐘請求數
|
||||
"daily_requests": 500, # 每日請求數
|
||||
"daily_tokens": 100_000, # 每日 Token 數
|
||||
},
|
||||
"claude": {
|
||||
"rpm": 5,
|
||||
"daily_requests": 200,
|
||||
"daily_tokens": 50_000,
|
||||
},
|
||||
}
|
||||
|
||||
# Redis Keys
|
||||
REDIS_KEY_PREFIX = "ai_rate:"
|
||||
RPM_KEY = f"{REDIS_KEY_PREFIX}rpm:{{provider}}"
|
||||
DAILY_REQ_KEY = f"{REDIS_KEY_PREFIX}daily_req:{{provider}}:{{date}}"
|
||||
DAILY_TOKEN_KEY = f"{REDIS_KEY_PREFIX}daily_token:{{provider}}:{{date}}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Rate Limiter
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AIRateLimiter:
|
||||
"""
|
||||
AI API 用量限制器
|
||||
|
||||
使用 Redis 計數器追蹤用量,超限時返回降級建議。
|
||||
|
||||
Usage:
|
||||
limiter = AIRateLimiter()
|
||||
allowed, reason = await limiter.check_and_increment("gemini")
|
||||
if not allowed:
|
||||
# 降級到 Ollama
|
||||
provider = "ollama"
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._redis = None
|
||||
|
||||
async def _get_redis(self):
|
||||
"""Lazy load Redis"""
|
||||
if self._redis is None:
|
||||
from src.core.redis_client import get_redis
|
||||
self._redis = get_redis()
|
||||
return self._redis
|
||||
|
||||
def _get_today(self) -> str:
|
||||
"""取得今日日期 (台北時區)"""
|
||||
from src.utils.timezone import now_taipei
|
||||
return now_taipei().strftime("%Y-%m-%d")
|
||||
|
||||
async def check_and_increment(
|
||||
self,
|
||||
provider: str,
|
||||
tokens: int = 0,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
檢查並遞增計數器
|
||||
|
||||
Args:
|
||||
provider: AI 提供者 (gemini, claude)
|
||||
tokens: 本次使用的 token 數 (事後更新用)
|
||||
|
||||
Returns:
|
||||
tuple[bool, str | None]: (是否允許, 拒絕原因)
|
||||
"""
|
||||
if provider not in RATE_LIMITS:
|
||||
return True, None # 無限制的 provider (如 ollama)
|
||||
|
||||
limits = RATE_LIMITS[provider]
|
||||
r = await self._get_redis()
|
||||
today = self._get_today()
|
||||
|
||||
# 1. 檢查 RPM
|
||||
rpm_key = RPM_KEY.format(provider=provider)
|
||||
current_rpm = await r.get(rpm_key)
|
||||
current_rpm = int(current_rpm) if current_rpm else 0
|
||||
|
||||
if current_rpm >= limits["rpm"]:
|
||||
logger.warning(
|
||||
"ai_rate_limit_rpm",
|
||||
provider=provider,
|
||||
current=current_rpm,
|
||||
limit=limits["rpm"],
|
||||
)
|
||||
return False, f"RPM limit exceeded ({current_rpm}/{limits['rpm']})"
|
||||
|
||||
# 2. 檢查每日請求數
|
||||
daily_req_key = DAILY_REQ_KEY.format(provider=provider, date=today)
|
||||
current_daily = await r.get(daily_req_key)
|
||||
current_daily = int(current_daily) if current_daily else 0
|
||||
|
||||
if current_daily >= limits["daily_requests"]:
|
||||
logger.warning(
|
||||
"ai_rate_limit_daily",
|
||||
provider=provider,
|
||||
current=current_daily,
|
||||
limit=limits["daily_requests"],
|
||||
)
|
||||
return False, f"Daily request limit exceeded ({current_daily}/{limits['daily_requests']})"
|
||||
|
||||
# 3. 檢查每日 Token (如果有追蹤)
|
||||
daily_token_key = DAILY_TOKEN_KEY.format(provider=provider, date=today)
|
||||
current_tokens = await r.get(daily_token_key)
|
||||
current_tokens = int(current_tokens) if current_tokens else 0
|
||||
|
||||
if current_tokens >= limits["daily_tokens"]:
|
||||
logger.warning(
|
||||
"ai_rate_limit_tokens",
|
||||
provider=provider,
|
||||
current=current_tokens,
|
||||
limit=limits["daily_tokens"],
|
||||
)
|
||||
return False, f"Daily token limit exceeded ({current_tokens}/{limits['daily_tokens']})"
|
||||
|
||||
# 4. 遞增計數器
|
||||
pipe = r.pipeline()
|
||||
|
||||
# RPM: 60 秒過期
|
||||
pipe.incr(rpm_key)
|
||||
pipe.expire(rpm_key, 60)
|
||||
|
||||
# Daily requests: 明天過期
|
||||
pipe.incr(daily_req_key)
|
||||
pipe.expire(daily_req_key, 86400)
|
||||
|
||||
# Daily tokens
|
||||
if tokens > 0:
|
||||
pipe.incrby(daily_token_key, tokens)
|
||||
pipe.expire(daily_token_key, 86400)
|
||||
|
||||
await pipe.execute()
|
||||
|
||||
logger.debug(
|
||||
"ai_rate_check_passed",
|
||||
provider=provider,
|
||||
rpm=current_rpm + 1,
|
||||
daily=current_daily + 1,
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
async def record_tokens(self, provider: str, tokens: int) -> None:
|
||||
"""
|
||||
記錄 Token 用量 (回應後呼叫)
|
||||
|
||||
Args:
|
||||
provider: AI 提供者
|
||||
tokens: 使用的 token 數
|
||||
"""
|
||||
if provider not in RATE_LIMITS or tokens <= 0:
|
||||
return
|
||||
|
||||
r = await self._get_redis()
|
||||
today = self._get_today()
|
||||
daily_token_key = DAILY_TOKEN_KEY.format(provider=provider, date=today)
|
||||
|
||||
await r.incrby(daily_token_key, tokens)
|
||||
await r.expire(daily_token_key, 86400)
|
||||
|
||||
logger.debug(
|
||||
"ai_tokens_recorded",
|
||||
provider=provider,
|
||||
tokens=tokens,
|
||||
)
|
||||
|
||||
async def get_usage_stats(self, provider: str) -> dict:
|
||||
"""
|
||||
取得用量統計
|
||||
|
||||
Args:
|
||||
provider: AI 提供者
|
||||
|
||||
Returns:
|
||||
dict: 用量統計
|
||||
"""
|
||||
if provider not in RATE_LIMITS:
|
||||
return {"provider": provider, "limited": False}
|
||||
|
||||
limits = RATE_LIMITS[provider]
|
||||
r = await self._get_redis()
|
||||
today = self._get_today()
|
||||
|
||||
rpm_key = RPM_KEY.format(provider=provider)
|
||||
daily_req_key = DAILY_REQ_KEY.format(provider=provider, date=today)
|
||||
daily_token_key = DAILY_TOKEN_KEY.format(provider=provider, date=today)
|
||||
|
||||
current_rpm = await r.get(rpm_key)
|
||||
current_daily = await r.get(daily_req_key)
|
||||
current_tokens = await r.get(daily_token_key)
|
||||
|
||||
return {
|
||||
"provider": provider,
|
||||
"date": today,
|
||||
"rpm": {
|
||||
"current": int(current_rpm) if current_rpm else 0,
|
||||
"limit": limits["rpm"],
|
||||
},
|
||||
"daily_requests": {
|
||||
"current": int(current_daily) if current_daily else 0,
|
||||
"limit": limits["daily_requests"],
|
||||
},
|
||||
"daily_tokens": {
|
||||
"current": int(current_tokens) if current_tokens else 0,
|
||||
"limit": limits["daily_tokens"],
|
||||
},
|
||||
}
|
||||
|
||||
async def reset_limits(self, provider: str) -> None:
|
||||
"""
|
||||
重置限制 (緊急用)
|
||||
|
||||
Args:
|
||||
provider: AI 提供者
|
||||
"""
|
||||
r = await self._get_redis()
|
||||
today = self._get_today()
|
||||
|
||||
keys = [
|
||||
RPM_KEY.format(provider=provider),
|
||||
DAILY_REQ_KEY.format(provider=provider, date=today),
|
||||
DAILY_TOKEN_KEY.format(provider=provider, date=today),
|
||||
]
|
||||
|
||||
await r.delete(*keys)
|
||||
logger.info("ai_rate_limits_reset", provider=provider)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Singleton
|
||||
# =============================================================================
|
||||
|
||||
_rate_limiter: AIRateLimiter | None = None
|
||||
|
||||
|
||||
def get_ai_rate_limiter() -> AIRateLimiter:
|
||||
"""取得 Rate Limiter 單例"""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = AIRateLimiter()
|
||||
return _rate_limiter
|
||||
231
apps/api/src/services/embedding_service.py
Normal file
231
apps/api/src/services/embedding_service.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
Embedding Service - Ollama BGE-M3 替代方案
|
||||
==========================================
|
||||
|
||||
使用 Ollama qwen2.5:7b-instruct 提供文本向量化功能。
|
||||
雖非專用 embedding 模型,但支援多語言 (繁中/英文)。
|
||||
|
||||
Phase 13.2 #84 - RAG Tool 基礎設施
|
||||
|
||||
版本: v1.0
|
||||
建立日期: 2026-03-26 20:30 (台北時區)
|
||||
建立者: Claude Code
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Protocol
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Interface (DI Protocol)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class IEmbeddingService(Protocol):
|
||||
"""Embedding 服務介面"""
|
||||
|
||||
async def embed_text(self, text: str) -> list[float]:
|
||||
"""將單一文本轉換為向量"""
|
||||
...
|
||||
|
||||
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
"""批次向量化多個文本"""
|
||||
...
|
||||
|
||||
@property
|
||||
def dimension(self) -> int:
|
||||
"""向量維度"""
|
||||
...
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Implementation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class OllamaEmbeddingService:
|
||||
"""
|
||||
Ollama Embedding Service
|
||||
|
||||
使用 Ollama API 進行文本向量化。
|
||||
預設使用 qwen2.5:7b-instruct (3584 維向量)。
|
||||
|
||||
Usage:
|
||||
service = OllamaEmbeddingService()
|
||||
vector = await service.embed_text("維運手冊")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "qwen2.5:7b-instruct",
|
||||
ollama_url: str | None = None,
|
||||
timeout: float = 30.0,
|
||||
) -> None:
|
||||
"""
|
||||
初始化 Embedding Service
|
||||
|
||||
Args:
|
||||
model: Ollama 模型名稱 (必須支援 embedding)
|
||||
ollama_url: Ollama API URL (預設從 config 讀取)
|
||||
timeout: 請求超時 (秒)
|
||||
"""
|
||||
self._model = model
|
||||
self._ollama_url = ollama_url or settings.OLLAMA_URL
|
||||
self._timeout = timeout
|
||||
self._dimension: int | None = None
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
@property
|
||||
def dimension(self) -> int:
|
||||
"""
|
||||
向量維度
|
||||
|
||||
首次呼叫會自動偵測,之後快取。
|
||||
qwen2.5:7b-instruct = 3584 維
|
||||
"""
|
||||
if self._dimension is None:
|
||||
# 預設值,實際會在第一次 embed 時更新
|
||||
return 3584
|
||||
return self._dimension
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""取得 HTTP Client (連線池共用)"""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(self._timeout),
|
||||
limits=httpx.Limits(max_connections=10),
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def embed_text(self, text: str) -> list[float]:
|
||||
"""
|
||||
將單一文本轉換為向量
|
||||
|
||||
Args:
|
||||
text: 要向量化的文本
|
||||
|
||||
Returns:
|
||||
list[float]: 向量 (3584 維)
|
||||
|
||||
Raises:
|
||||
EmbeddingError: 向量化失敗
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self._ollama_url}/api/embeddings",
|
||||
json={
|
||||
"model": self._model,
|
||||
"prompt": text,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
embedding = data.get("embedding", [])
|
||||
|
||||
# 更新維度快取
|
||||
if self._dimension is None and embedding:
|
||||
self._dimension = len(embedding)
|
||||
logger.info(
|
||||
"embedding_dimension_detected",
|
||||
model=self._model,
|
||||
dimension=self._dimension,
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("embedding_timeout", model=self._model, text_len=len(text))
|
||||
raise EmbeddingError(f"Embedding timeout after {self._timeout}s")
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
"embedding_http_error",
|
||||
status=e.response.status_code,
|
||||
model=self._model,
|
||||
)
|
||||
raise EmbeddingError(f"Ollama API error: {e.response.status_code}")
|
||||
except Exception as e:
|
||||
logger.error("embedding_error", error=str(e), model=self._model)
|
||||
raise EmbeddingError(f"Embedding failed: {e}")
|
||||
|
||||
async def embed_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
concurrency: int = 5,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
批次向量化多個文本
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
concurrency: 同時並行數 (避免過載 Ollama)
|
||||
|
||||
Returns:
|
||||
list[list[float]]: 向量列表 (與輸入順序對應)
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
results: list[list[float]] = []
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
async def embed_with_semaphore(text: str) -> list[float]:
|
||||
async with semaphore:
|
||||
return await self.embed_text(text)
|
||||
|
||||
tasks = [embed_with_semaphore(text) for text in texts]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
logger.info(
|
||||
"batch_embedding_complete",
|
||||
count=len(texts),
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def close(self) -> None:
|
||||
"""關閉連線"""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Errors
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class EmbeddingError(Exception):
|
||||
"""Embedding 操作錯誤"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Singleton Factory
|
||||
# =============================================================================
|
||||
|
||||
_embedding_service: OllamaEmbeddingService | None = None
|
||||
|
||||
|
||||
def get_embedding_service() -> OllamaEmbeddingService:
|
||||
"""
|
||||
取得 Embedding Service 單例
|
||||
|
||||
Returns:
|
||||
OllamaEmbeddingService: 共用實例
|
||||
"""
|
||||
global _embedding_service
|
||||
if _embedding_service is None:
|
||||
_embedding_service = OllamaEmbeddingService()
|
||||
return _embedding_service
|
||||
@@ -35,6 +35,7 @@ from src.models.ai import (
|
||||
OpenClawDecision,
|
||||
)
|
||||
from src.services.langfuse_client import langfuse_trace
|
||||
from src.services.model_registry import get_model_registry
|
||||
from src.services.signoz_client import GoldMetrics, get_signoz_client
|
||||
from src.utils.k8s_naming import normalize_resource_name
|
||||
from src.utils.timezone import now_taipei_iso
|
||||
@@ -270,17 +271,22 @@ class OpenClawService:
|
||||
prompt_length=len(prompt),
|
||||
)
|
||||
|
||||
# 從 ModelRegistry 取得模型配置
|
||||
registry = get_model_registry()
|
||||
model_name = registry.get_model("ollama", "rca")
|
||||
options = registry.get_provider_options("ollama")
|
||||
|
||||
response = await client.post(
|
||||
f"{settings.OLLAMA_URL}/api/generate",
|
||||
json={
|
||||
"model": "qwen2.5:7b-instruct", # 使用更大的模型提高品質
|
||||
"model": model_name,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"format": "json", # 強制 JSON 輸出
|
||||
"options": {
|
||||
"num_predict": 1024, # 增加輸出長度
|
||||
"temperature": 0.1, # 低溫度確保穩定輸出
|
||||
"top_p": 0.9,
|
||||
"num_predict": options.get("num_predict", 1024),
|
||||
"temperature": options.get("temperature", 0.1),
|
||||
"top_p": options.get("top_p", 0.9),
|
||||
},
|
||||
},
|
||||
timeout=httpx.Timeout(float(settings.OPENCLAW_TIMEOUT), connect=10.0),
|
||||
@@ -324,9 +330,12 @@ class OpenClawService:
|
||||
try:
|
||||
client = await self._get_client()
|
||||
|
||||
# Gemini 1.5 Flash 支援 JSON Mode
|
||||
# 從 ModelRegistry 取得模型配置
|
||||
registry = get_model_registry()
|
||||
model_name = registry.get_model("gemini", "rca")
|
||||
|
||||
response = await client.post(
|
||||
f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={settings.GEMINI_API_KEY}",
|
||||
f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent?key={settings.GEMINI_API_KEY}",
|
||||
json={
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
@@ -737,6 +746,24 @@ class OpenClawService:
|
||||
|
||||
return response, provider, success, False # from_cache=False
|
||||
|
||||
# =========================================================================
|
||||
# Public LLM Interface (ILLMProvider Protocol)
|
||||
# =========================================================================
|
||||
|
||||
async def call(self, prompt: str) -> tuple[str, str, bool]:
|
||||
"""
|
||||
呼叫 LLM (ILLMProvider Protocol 實作)
|
||||
|
||||
#39 Error Analyzer Agent 使用此方法
|
||||
|
||||
Args:
|
||||
prompt: 完整的 prompt
|
||||
|
||||
Returns:
|
||||
(response, provider, success)
|
||||
"""
|
||||
return await self._call_with_fallback(prompt)
|
||||
|
||||
# =========================================================================
|
||||
# Fallback Chain
|
||||
# =========================================================================
|
||||
@@ -784,7 +811,23 @@ class OpenClawService:
|
||||
DeepLinking.langfuse_trace_url(trace.langfuse_trace_id),
|
||||
)
|
||||
|
||||
# Phase 13.2: Rate Limiter 整合 (2026-03-26)
|
||||
# 防止雲端 API 用量暴衝,超限自動降級
|
||||
from src.services.ai_rate_limiter import get_ai_rate_limiter
|
||||
rate_limiter = get_ai_rate_limiter()
|
||||
|
||||
for provider in settings.AI_FALLBACK_ORDER:
|
||||
# Rate Limit 檢查 (gemini/claude 需檢查,ollama 不限)
|
||||
if provider in ("gemini", "claude"):
|
||||
allowed, reason = await rate_limiter.check_and_increment(provider)
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
"ai_rate_limit_skip",
|
||||
provider=provider,
|
||||
reason=reason,
|
||||
)
|
||||
continue # 跳過此 provider,嘗試下一個
|
||||
|
||||
logger.info("ai_provider_attempt", provider=provider)
|
||||
|
||||
start_time = time.time()
|
||||
@@ -829,13 +872,9 @@ class OpenClawService:
|
||||
return self._generate_mock_response(alert_context or {}, signoz_metrics), "mock_fallback", True
|
||||
|
||||
def _get_model_name(self, provider: str) -> str:
|
||||
"""取得 provider 對應的模型名稱"""
|
||||
model_map = {
|
||||
"ollama": "qwen2.5:7b-instruct",
|
||||
"gemini": "gemini-1.5-flash",
|
||||
"claude": "claude-3-haiku-20240307",
|
||||
}
|
||||
return model_map.get(provider, provider)
|
||||
"""取得 provider 對應的模型名稱 (從 ModelRegistry)"""
|
||||
registry = get_model_registry()
|
||||
return registry.get_model(provider, "rca")
|
||||
|
||||
# =========================================================================
|
||||
# Response Parsing (防禦性解析)
|
||||
|
||||
468
apps/api/src/services/rag_service.py
Normal file
468
apps/api/src/services/rag_service.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""
|
||||
RAG Service - 維運手冊向量搜尋
|
||||
==============================
|
||||
|
||||
Phase 13.2 #84 - Runbook RAG Tool
|
||||
|
||||
功能:
|
||||
- 文檔分段 (Chunking)
|
||||
- 向量索引 (Redis Stack FT.CREATE)
|
||||
- 語義搜尋 (KNN Vector Search)
|
||||
|
||||
版本: v1.0
|
||||
建立日期: 2026-03-26 20:45 (台北時區)
|
||||
建立者: Claude Code
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
import redis.asyncio as redis
|
||||
import structlog
|
||||
|
||||
from src.core.config import settings
|
||||
from src.services.embedding_service import IEmbeddingService, get_embedding_service
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
RAG_CONFIG = {
|
||||
"chunk_size": 500, # 每段字數
|
||||
"chunk_overlap": 50, # 重疊字數
|
||||
"index_name": "idx:runbooks", # Redis index 名稱
|
||||
"prefix": "runbook:", # Key prefix
|
||||
"ttl_days": 30, # 文檔 TTL (天)
|
||||
}
|
||||
|
||||
# 維運手冊來源目錄 (相對於專案根目錄)
|
||||
RUNBOOK_SOURCES = [
|
||||
"docs/operations/*.md",
|
||||
"docs/troubleshooting/*.md",
|
||||
"docs/adr/*.md",
|
||||
".agents/skills/*.md",
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Interface
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class IRAGService(Protocol):
|
||||
"""RAG 服務介面"""
|
||||
|
||||
async def index_documents(self, base_path: Path) -> int:
|
||||
"""索引文檔,回傳索引數量"""
|
||||
...
|
||||
|
||||
async def search(self, query: str, top_k: int = 5) -> list[dict]:
|
||||
"""語義搜尋,回傳相關段落"""
|
||||
...
|
||||
|
||||
async def get_index_stats(self) -> dict:
|
||||
"""取得索引統計"""
|
||||
...
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Implementation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class RAGService:
|
||||
"""
|
||||
RAG Service 實作
|
||||
|
||||
使用 Redis Stack 進行向量索引與搜尋。
|
||||
支援維運手冊的語義搜尋。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: redis.Redis | None = None,
|
||||
embedding_service: IEmbeddingService | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
初始化 RAG Service
|
||||
|
||||
Args:
|
||||
redis_client: Redis 連線 (DI 注入)
|
||||
embedding_service: Embedding 服務 (DI 注入)
|
||||
"""
|
||||
self._redis = redis_client
|
||||
self._embedding_service = embedding_service
|
||||
self._index_created = False
|
||||
|
||||
async def _get_redis(self) -> redis.Redis:
|
||||
"""Lazy load Redis client"""
|
||||
if self._redis is None:
|
||||
from src.core.redis_client import get_redis
|
||||
self._redis = get_redis()
|
||||
return self._redis
|
||||
|
||||
async def _get_embedding_service(self) -> IEmbeddingService:
|
||||
"""Lazy load Embedding service"""
|
||||
if self._embedding_service is None:
|
||||
self._embedding_service = get_embedding_service()
|
||||
return self._embedding_service
|
||||
|
||||
# =========================================================================
|
||||
# Document Processing
|
||||
# =========================================================================
|
||||
|
||||
def _chunk_text(self, text: str, source: str) -> list[dict]:
|
||||
"""
|
||||
將文本分段
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
source: 來源檔案路徑
|
||||
|
||||
Returns:
|
||||
list[dict]: 分段列表,每個包含 content, source, chunk_id
|
||||
"""
|
||||
chunk_size = RAG_CONFIG["chunk_size"]
|
||||
overlap = RAG_CONFIG["chunk_overlap"]
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
chunk_idx = 0
|
||||
|
||||
while start < len(text):
|
||||
end = start + chunk_size
|
||||
|
||||
# 嘗試在句號/換行處斷開 (避免截斷句子)
|
||||
if end < len(text):
|
||||
# 往後找到最近的句號或換行
|
||||
for sep in ["\n\n", "。", "\n", ". ", ","]:
|
||||
sep_pos = text.rfind(sep, start + chunk_size // 2, end + 50)
|
||||
if sep_pos > start:
|
||||
end = sep_pos + len(sep)
|
||||
break
|
||||
|
||||
chunk_content = text[start:end].strip()
|
||||
|
||||
if chunk_content:
|
||||
chunk_id = self._generate_chunk_id(source, chunk_idx)
|
||||
chunks.append({
|
||||
"chunk_id": chunk_id,
|
||||
"content": chunk_content,
|
||||
"source": source,
|
||||
"chunk_index": chunk_idx,
|
||||
})
|
||||
chunk_idx += 1
|
||||
|
||||
start = end - overlap if end < len(text) else len(text)
|
||||
|
||||
return chunks
|
||||
|
||||
def _generate_chunk_id(self, source: str, chunk_idx: int) -> str:
|
||||
"""生成唯一 Chunk ID"""
|
||||
content = f"{source}:{chunk_idx}"
|
||||
return hashlib.md5(content.encode()).hexdigest()[:12]
|
||||
|
||||
# =========================================================================
|
||||
# Redis Vector Index
|
||||
# =========================================================================
|
||||
|
||||
async def _ensure_index(self) -> None:
|
||||
"""
|
||||
確保向量索引存在
|
||||
|
||||
使用 FT.CREATE 建立 HNSW 向量索引
|
||||
"""
|
||||
if self._index_created:
|
||||
return
|
||||
|
||||
r = await self._get_redis()
|
||||
embedding_service = await self._get_embedding_service()
|
||||
dim = embedding_service.dimension
|
||||
|
||||
index_name = RAG_CONFIG["index_name"]
|
||||
prefix = RAG_CONFIG["prefix"]
|
||||
|
||||
try:
|
||||
# 檢查索引是否存在
|
||||
await r.execute_command("FT.INFO", index_name)
|
||||
logger.info("rag_index_exists", index=index_name)
|
||||
self._index_created = True
|
||||
return
|
||||
except redis.ResponseError as e:
|
||||
if "Unknown index name" not in str(e):
|
||||
raise
|
||||
|
||||
# 建立向量索引
|
||||
# Schema: content (TEXT), source (TAG), embedding (VECTOR HNSW)
|
||||
try:
|
||||
await r.execute_command(
|
||||
"FT.CREATE", index_name,
|
||||
"ON", "HASH",
|
||||
"PREFIX", "1", prefix,
|
||||
"SCHEMA",
|
||||
"content", "TEXT", "WEIGHT", "1.0",
|
||||
"source", "TAG",
|
||||
"chunk_index", "NUMERIC",
|
||||
"embedding", "VECTOR", "HNSW", "6",
|
||||
"TYPE", "FLOAT32",
|
||||
"DIM", str(dim),
|
||||
"DISTANCE_METRIC", "COSINE",
|
||||
)
|
||||
logger.info(
|
||||
"rag_index_created",
|
||||
index=index_name,
|
||||
dimension=dim,
|
||||
)
|
||||
self._index_created = True
|
||||
except redis.ResponseError as e:
|
||||
if "Index already exists" in str(e):
|
||||
self._index_created = True
|
||||
else:
|
||||
logger.error("rag_index_create_failed", error=str(e))
|
||||
raise
|
||||
|
||||
async def _store_chunk(
|
||||
self,
|
||||
chunk: dict,
|
||||
embedding: list[float],
|
||||
) -> None:
|
||||
"""
|
||||
儲存分段到 Redis
|
||||
|
||||
Args:
|
||||
chunk: 分段資料
|
||||
embedding: 向量
|
||||
"""
|
||||
r = await self._get_redis()
|
||||
prefix = RAG_CONFIG["prefix"]
|
||||
key = f"{prefix}{chunk['chunk_id']}"
|
||||
|
||||
# 將 float list 轉換為 bytes (FLOAT32)
|
||||
embedding_bytes = struct.pack(f"{len(embedding)}f", *embedding)
|
||||
|
||||
await r.hset(
|
||||
key,
|
||||
mapping={
|
||||
"content": chunk["content"],
|
||||
"source": chunk["source"],
|
||||
"chunk_index": chunk["chunk_index"],
|
||||
"embedding": embedding_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
# 設定 TTL
|
||||
ttl_seconds = RAG_CONFIG["ttl_days"] * 24 * 60 * 60
|
||||
await r.expire(key, ttl_seconds)
|
||||
|
||||
# =========================================================================
|
||||
# Public API
|
||||
# =========================================================================
|
||||
|
||||
async def index_documents(self, base_path: Path) -> int:
|
||||
"""
|
||||
索引維運手冊
|
||||
|
||||
Args:
|
||||
base_path: 專案根目錄
|
||||
|
||||
Returns:
|
||||
int: 索引的分段數量
|
||||
"""
|
||||
await self._ensure_index()
|
||||
embedding_service = await self._get_embedding_service()
|
||||
|
||||
total_chunks = 0
|
||||
all_chunks: list[dict] = []
|
||||
|
||||
# 收集所有文檔
|
||||
for pattern in RUNBOOK_SOURCES:
|
||||
for file_path in base_path.glob(pattern):
|
||||
if file_path.is_file():
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
relative_path = str(file_path.relative_to(base_path))
|
||||
chunks = self._chunk_text(content, relative_path)
|
||||
all_chunks.extend(chunks)
|
||||
logger.debug(
|
||||
"rag_file_chunked",
|
||||
file=relative_path,
|
||||
chunks=len(chunks),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"rag_file_read_error",
|
||||
file=str(file_path),
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
if not all_chunks:
|
||||
logger.warning("rag_no_documents_found", patterns=RUNBOOK_SOURCES)
|
||||
return 0
|
||||
|
||||
# 批次向量化
|
||||
logger.info("rag_embedding_start", chunks=len(all_chunks))
|
||||
texts = [c["content"] for c in all_chunks]
|
||||
embeddings = await embedding_service.embed_batch(texts, concurrency=3)
|
||||
|
||||
# 儲存到 Redis
|
||||
for chunk, embedding in zip(all_chunks, embeddings):
|
||||
await self._store_chunk(chunk, embedding)
|
||||
total_chunks += 1
|
||||
|
||||
logger.info(
|
||||
"rag_index_complete",
|
||||
total_chunks=total_chunks,
|
||||
sources=len(RUNBOOK_SOURCES),
|
||||
)
|
||||
|
||||
return total_chunks
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
語義搜尋維運手冊
|
||||
|
||||
Args:
|
||||
query: 自然語言查詢
|
||||
top_k: 回傳數量 (預設 5)
|
||||
|
||||
Returns:
|
||||
list[dict]: 相關段落列表
|
||||
- content: 段落內容
|
||||
- source: 來源檔案
|
||||
- score: 相似度分數
|
||||
"""
|
||||
await self._ensure_index()
|
||||
r = await self._get_redis()
|
||||
embedding_service = await self._get_embedding_service()
|
||||
index_name = RAG_CONFIG["index_name"]
|
||||
|
||||
# 向量化查詢
|
||||
query_embedding = await embedding_service.embed_text(query)
|
||||
query_bytes = struct.pack(f"{len(query_embedding)}f", *query_embedding)
|
||||
|
||||
# KNN 向量搜尋
|
||||
# *=>[KNN 5 @embedding $vec AS score]
|
||||
try:
|
||||
results = await r.execute_command(
|
||||
"FT.SEARCH", index_name,
|
||||
f"*=>[KNN {top_k} @embedding $vec AS score]",
|
||||
"PARAMS", "2", "vec", query_bytes,
|
||||
"SORTBY", "score",
|
||||
"RETURN", "3", "content", "source", "score",
|
||||
"DIALECT", "2",
|
||||
)
|
||||
except redis.ResponseError as e:
|
||||
logger.error("rag_search_error", error=str(e), query=query[:50])
|
||||
return []
|
||||
|
||||
# 解析結果
|
||||
# Results format: [total, key1, [field1, value1, ...], key2, ...]
|
||||
if not results or results[0] == 0:
|
||||
return []
|
||||
|
||||
parsed = []
|
||||
i = 1
|
||||
while i < len(results):
|
||||
key = results[i]
|
||||
fields = results[i + 1] if i + 1 < len(results) else []
|
||||
|
||||
# 將 fields list 轉為 dict
|
||||
field_dict = {}
|
||||
for j in range(0, len(fields), 2):
|
||||
if j + 1 < len(fields):
|
||||
field_dict[fields[j]] = fields[j + 1]
|
||||
|
||||
parsed.append({
|
||||
"content": field_dict.get("content", ""),
|
||||
"source": field_dict.get("source", ""),
|
||||
"score": float(field_dict.get("score", 0)),
|
||||
})
|
||||
|
||||
i += 2
|
||||
|
||||
logger.info(
|
||||
"rag_search_complete",
|
||||
query=query[:30],
|
||||
results=len(parsed),
|
||||
)
|
||||
|
||||
return parsed
|
||||
|
||||
async def get_index_stats(self) -> dict:
|
||||
"""
|
||||
取得索引統計
|
||||
|
||||
Returns:
|
||||
dict: 索引資訊
|
||||
"""
|
||||
r = await self._get_redis()
|
||||
index_name = RAG_CONFIG["index_name"]
|
||||
|
||||
try:
|
||||
info = await r.execute_command("FT.INFO", index_name)
|
||||
# 將 list 轉為 dict
|
||||
info_dict = {}
|
||||
for i in range(0, len(info), 2):
|
||||
if i + 1 < len(info):
|
||||
info_dict[info[i]] = info[i + 1]
|
||||
|
||||
return {
|
||||
"index_name": index_name,
|
||||
"num_docs": info_dict.get("num_docs", 0),
|
||||
"num_terms": info_dict.get("num_terms", 0),
|
||||
"indexing": info_dict.get("indexing", 0),
|
||||
}
|
||||
except redis.ResponseError:
|
||||
return {
|
||||
"index_name": index_name,
|
||||
"num_docs": 0,
|
||||
"error": "Index not found",
|
||||
}
|
||||
|
||||
async def clear_index(self) -> bool:
|
||||
"""
|
||||
清除索引 (重建用)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
r = await self._get_redis()
|
||||
index_name = RAG_CONFIG["index_name"]
|
||||
|
||||
try:
|
||||
await r.execute_command("FT.DROPINDEX", index_name, "DD")
|
||||
self._index_created = False
|
||||
logger.info("rag_index_cleared", index=index_name)
|
||||
return True
|
||||
except redis.ResponseError:
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Singleton Factory
|
||||
# =============================================================================
|
||||
|
||||
_rag_service: RAGService | None = None
|
||||
|
||||
|
||||
def get_rag_service() -> RAGService:
|
||||
"""
|
||||
取得 RAG Service 單例
|
||||
|
||||
Returns:
|
||||
RAGService: 共用實例
|
||||
"""
|
||||
global _rag_service
|
||||
if _rag_service is None:
|
||||
_rag_service = RAGService()
|
||||
return _rag_service
|
||||
Reference in New Issue
Block a user