From bf32c4b1f25fe41d337b5deb7dc5a0e5f5fa1a53 Mon Sep 17 00:00:00 2001 From: OG T Date: Thu, 26 Mar 2026 15:52:57 +0800 Subject: [PATCH] =?UTF-8?q?feat(api):=20Phase=2013.2=20AI=20Rate=20Limiter?= =?UTF-8?q?=20+=20RAG=20=E5=9F=BA=E7=A4=8E=E8=A8=AD=E6=96=BD=20(#84)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rate Limiter (防止 Gemini 用量暴衝): - ai_rate_limiter.py: RPM/Daily/Token 三層閥值 - openclaw.py: 整合 rate limit 檢查,超限自動降級 - health.py: /health/ai-usage 監控端點 RAG Tool 基礎 (#84 進行中): - embedding_service.py: Ollama embedding 封裝 - rag_service.py: Redis vector search 服務 閥值設定: - Gemini: 10 RPM, 500/day, 100K tokens/day - Claude: 5 RPM, 200/day, 50K tokens/day Co-Authored-By: Claude Opus 4.5 --- apps/api/src/api/v1/health.py | 25 ++ apps/api/src/services/ai_rate_limiter.py | 272 ++++++++++++ apps/api/src/services/embedding_service.py | 231 ++++++++++ apps/api/src/services/openclaw.py | 65 ++- apps/api/src/services/rag_service.py | 468 +++++++++++++++++++++ 5 files changed, 1048 insertions(+), 13 deletions(-) create mode 100644 apps/api/src/services/ai_rate_limiter.py create mode 100644 apps/api/src/services/embedding_service.py create mode 100644 apps/api/src/services/rag_service.py diff --git a/apps/api/src/api/v1/health.py b/apps/api/src/api/v1/health.py index 06e69074..4620225d 100644 --- a/apps/api/src/api/v1/health.py +++ b/apps/api/src/api/v1/health.py @@ -240,3 +240,28 @@ async def get_liveness() -> dict[str, str]: """ logger.debug("liveness_probe") return {"status": "alive"} + + +@router.get("/health/ai-usage") +async def get_ai_usage() -> dict: + """ + AI API 用量監控 + + Phase 13.2: Rate Limiter 整合 (2026-03-26) + 監控 Gemini/Claude API 使用量,防止暴衝。 + + Returns: + dict: 各 provider 用量統計 + """ + from src.services.ai_rate_limiter import get_ai_rate_limiter + + rate_limiter = get_ai_rate_limiter() + + gemini_stats = await rate_limiter.get_usage_stats("gemini") + claude_stats = await rate_limiter.get_usage_stats("claude") + + return { + "gemini": gemini_stats, + "claude": claude_stats, + "fallback_order": settings.AI_FALLBACK_ORDER, + } diff --git a/apps/api/src/services/ai_rate_limiter.py b/apps/api/src/services/ai_rate_limiter.py new file mode 100644 index 00000000..3be91964 --- /dev/null +++ b/apps/api/src/services/ai_rate_limiter.py @@ -0,0 +1,272 @@ +""" +AI Rate Limiter - Gemini API 用量閥值控制 +========================================= + +防止 API 用量暴衝,超過閥值自動降級回 Ollama。 + +功能: +- 每分鐘請求限制 (RPM) +- 每日請求限制 +- 每日 Token 限制 +- 超限自動降級 + +版本: v1.0 +建立日期: 2026-03-26 21:00 (台北時區) +建立者: Claude Code +""" + +import structlog + +from src.core.config import settings + +logger = structlog.get_logger(__name__) + + +# ============================================================================= +# Configuration - 閥值設定 +# ============================================================================= + +RATE_LIMITS = { + "gemini": { + "rpm": 10, # 每分鐘請求數 + "daily_requests": 500, # 每日請求數 + "daily_tokens": 100_000, # 每日 Token 數 + }, + "claude": { + "rpm": 5, + "daily_requests": 200, + "daily_tokens": 50_000, + }, +} + +# Redis Keys +REDIS_KEY_PREFIX = "ai_rate:" +RPM_KEY = f"{REDIS_KEY_PREFIX}rpm:{{provider}}" +DAILY_REQ_KEY = f"{REDIS_KEY_PREFIX}daily_req:{{provider}}:{{date}}" +DAILY_TOKEN_KEY = f"{REDIS_KEY_PREFIX}daily_token:{{provider}}:{{date}}" + + +# ============================================================================= +# Rate Limiter +# ============================================================================= + + +class AIRateLimiter: + """ + AI API 用量限制器 + + 使用 Redis 計數器追蹤用量,超限時返回降級建議。 + + Usage: + limiter = AIRateLimiter() + allowed, reason = await limiter.check_and_increment("gemini") + if not allowed: + # 降級到 Ollama + provider = "ollama" + """ + + def __init__(self) -> None: + self._redis = None + + async def _get_redis(self): + """Lazy load Redis""" + if self._redis is None: + from src.core.redis_client import get_redis + self._redis = get_redis() + return self._redis + + def _get_today(self) -> str: + """取得今日日期 (台北時區)""" + from src.utils.timezone import now_taipei + return now_taipei().strftime("%Y-%m-%d") + + async def check_and_increment( + self, + provider: str, + tokens: int = 0, + ) -> tuple[bool, str | None]: + """ + 檢查並遞增計數器 + + Args: + provider: AI 提供者 (gemini, claude) + tokens: 本次使用的 token 數 (事後更新用) + + Returns: + tuple[bool, str | None]: (是否允許, 拒絕原因) + """ + if provider not in RATE_LIMITS: + return True, None # 無限制的 provider (如 ollama) + + limits = RATE_LIMITS[provider] + r = await self._get_redis() + today = self._get_today() + + # 1. 檢查 RPM + rpm_key = RPM_KEY.format(provider=provider) + current_rpm = await r.get(rpm_key) + current_rpm = int(current_rpm) if current_rpm else 0 + + if current_rpm >= limits["rpm"]: + logger.warning( + "ai_rate_limit_rpm", + provider=provider, + current=current_rpm, + limit=limits["rpm"], + ) + return False, f"RPM limit exceeded ({current_rpm}/{limits['rpm']})" + + # 2. 檢查每日請求數 + daily_req_key = DAILY_REQ_KEY.format(provider=provider, date=today) + current_daily = await r.get(daily_req_key) + current_daily = int(current_daily) if current_daily else 0 + + if current_daily >= limits["daily_requests"]: + logger.warning( + "ai_rate_limit_daily", + provider=provider, + current=current_daily, + limit=limits["daily_requests"], + ) + return False, f"Daily request limit exceeded ({current_daily}/{limits['daily_requests']})" + + # 3. 檢查每日 Token (如果有追蹤) + daily_token_key = DAILY_TOKEN_KEY.format(provider=provider, date=today) + current_tokens = await r.get(daily_token_key) + current_tokens = int(current_tokens) if current_tokens else 0 + + if current_tokens >= limits["daily_tokens"]: + logger.warning( + "ai_rate_limit_tokens", + provider=provider, + current=current_tokens, + limit=limits["daily_tokens"], + ) + return False, f"Daily token limit exceeded ({current_tokens}/{limits['daily_tokens']})" + + # 4. 遞增計數器 + pipe = r.pipeline() + + # RPM: 60 秒過期 + pipe.incr(rpm_key) + pipe.expire(rpm_key, 60) + + # Daily requests: 明天過期 + pipe.incr(daily_req_key) + pipe.expire(daily_req_key, 86400) + + # Daily tokens + if tokens > 0: + pipe.incrby(daily_token_key, tokens) + pipe.expire(daily_token_key, 86400) + + await pipe.execute() + + logger.debug( + "ai_rate_check_passed", + provider=provider, + rpm=current_rpm + 1, + daily=current_daily + 1, + ) + + return True, None + + async def record_tokens(self, provider: str, tokens: int) -> None: + """ + 記錄 Token 用量 (回應後呼叫) + + Args: + provider: AI 提供者 + tokens: 使用的 token 數 + """ + if provider not in RATE_LIMITS or tokens <= 0: + return + + r = await self._get_redis() + today = self._get_today() + daily_token_key = DAILY_TOKEN_KEY.format(provider=provider, date=today) + + await r.incrby(daily_token_key, tokens) + await r.expire(daily_token_key, 86400) + + logger.debug( + "ai_tokens_recorded", + provider=provider, + tokens=tokens, + ) + + async def get_usage_stats(self, provider: str) -> dict: + """ + 取得用量統計 + + Args: + provider: AI 提供者 + + Returns: + dict: 用量統計 + """ + if provider not in RATE_LIMITS: + return {"provider": provider, "limited": False} + + limits = RATE_LIMITS[provider] + r = await self._get_redis() + today = self._get_today() + + rpm_key = RPM_KEY.format(provider=provider) + daily_req_key = DAILY_REQ_KEY.format(provider=provider, date=today) + daily_token_key = DAILY_TOKEN_KEY.format(provider=provider, date=today) + + current_rpm = await r.get(rpm_key) + current_daily = await r.get(daily_req_key) + current_tokens = await r.get(daily_token_key) + + return { + "provider": provider, + "date": today, + "rpm": { + "current": int(current_rpm) if current_rpm else 0, + "limit": limits["rpm"], + }, + "daily_requests": { + "current": int(current_daily) if current_daily else 0, + "limit": limits["daily_requests"], + }, + "daily_tokens": { + "current": int(current_tokens) if current_tokens else 0, + "limit": limits["daily_tokens"], + }, + } + + async def reset_limits(self, provider: str) -> None: + """ + 重置限制 (緊急用) + + Args: + provider: AI 提供者 + """ + r = await self._get_redis() + today = self._get_today() + + keys = [ + RPM_KEY.format(provider=provider), + DAILY_REQ_KEY.format(provider=provider, date=today), + DAILY_TOKEN_KEY.format(provider=provider, date=today), + ] + + await r.delete(*keys) + logger.info("ai_rate_limits_reset", provider=provider) + + +# ============================================================================= +# Singleton +# ============================================================================= + +_rate_limiter: AIRateLimiter | None = None + + +def get_ai_rate_limiter() -> AIRateLimiter: + """取得 Rate Limiter 單例""" + global _rate_limiter + if _rate_limiter is None: + _rate_limiter = AIRateLimiter() + return _rate_limiter diff --git a/apps/api/src/services/embedding_service.py b/apps/api/src/services/embedding_service.py new file mode 100644 index 00000000..92d27f42 --- /dev/null +++ b/apps/api/src/services/embedding_service.py @@ -0,0 +1,231 @@ +""" +Embedding Service - Ollama BGE-M3 替代方案 +========================================== + +使用 Ollama qwen2.5:7b-instruct 提供文本向量化功能。 +雖非專用 embedding 模型,但支援多語言 (繁中/英文)。 + +Phase 13.2 #84 - RAG Tool 基礎設施 + +版本: v1.0 +建立日期: 2026-03-26 20:30 (台北時區) +建立者: Claude Code +""" + +import asyncio +from typing import Protocol + +import httpx +import structlog + +from src.core.config import settings + +logger = structlog.get_logger(__name__) + + +# ============================================================================= +# Interface (DI Protocol) +# ============================================================================= + + +class IEmbeddingService(Protocol): + """Embedding 服務介面""" + + async def embed_text(self, text: str) -> list[float]: + """將單一文本轉換為向量""" + ... + + async def embed_batch(self, texts: list[str]) -> list[list[float]]: + """批次向量化多個文本""" + ... + + @property + def dimension(self) -> int: + """向量維度""" + ... + + +# ============================================================================= +# Implementation +# ============================================================================= + + +class OllamaEmbeddingService: + """ + Ollama Embedding Service + + 使用 Ollama API 進行文本向量化。 + 預設使用 qwen2.5:7b-instruct (3584 維向量)。 + + Usage: + service = OllamaEmbeddingService() + vector = await service.embed_text("維運手冊") + """ + + def __init__( + self, + model: str = "qwen2.5:7b-instruct", + ollama_url: str | None = None, + timeout: float = 30.0, + ) -> None: + """ + 初始化 Embedding Service + + Args: + model: Ollama 模型名稱 (必須支援 embedding) + ollama_url: Ollama API URL (預設從 config 讀取) + timeout: 請求超時 (秒) + """ + self._model = model + self._ollama_url = ollama_url or settings.OLLAMA_URL + self._timeout = timeout + self._dimension: int | None = None + self._client: httpx.AsyncClient | None = None + + @property + def dimension(self) -> int: + """ + 向量維度 + + 首次呼叫會自動偵測,之後快取。 + qwen2.5:7b-instruct = 3584 維 + """ + if self._dimension is None: + # 預設值,實際會在第一次 embed 時更新 + return 3584 + return self._dimension + + async def _get_client(self) -> httpx.AsyncClient: + """取得 HTTP Client (連線池共用)""" + if self._client is None: + self._client = httpx.AsyncClient( + timeout=httpx.Timeout(self._timeout), + limits=httpx.Limits(max_connections=10), + ) + return self._client + + async def embed_text(self, text: str) -> list[float]: + """ + 將單一文本轉換為向量 + + Args: + text: 要向量化的文本 + + Returns: + list[float]: 向量 (3584 維) + + Raises: + EmbeddingError: 向量化失敗 + """ + client = await self._get_client() + + try: + response = await client.post( + f"{self._ollama_url}/api/embeddings", + json={ + "model": self._model, + "prompt": text, + }, + ) + response.raise_for_status() + + data = response.json() + embedding = data.get("embedding", []) + + # 更新維度快取 + if self._dimension is None and embedding: + self._dimension = len(embedding) + logger.info( + "embedding_dimension_detected", + model=self._model, + dimension=self._dimension, + ) + + return embedding + + except httpx.TimeoutException: + logger.error("embedding_timeout", model=self._model, text_len=len(text)) + raise EmbeddingError(f"Embedding timeout after {self._timeout}s") + except httpx.HTTPStatusError as e: + logger.error( + "embedding_http_error", + status=e.response.status_code, + model=self._model, + ) + raise EmbeddingError(f"Ollama API error: {e.response.status_code}") + except Exception as e: + logger.error("embedding_error", error=str(e), model=self._model) + raise EmbeddingError(f"Embedding failed: {e}") + + async def embed_batch( + self, + texts: list[str], + concurrency: int = 5, + ) -> list[list[float]]: + """ + 批次向量化多個文本 + + Args: + texts: 文本列表 + concurrency: 同時並行數 (避免過載 Ollama) + + Returns: + list[list[float]]: 向量列表 (與輸入順序對應) + """ + if not texts: + return [] + + results: list[list[float]] = [] + semaphore = asyncio.Semaphore(concurrency) + + async def embed_with_semaphore(text: str) -> list[float]: + async with semaphore: + return await self.embed_text(text) + + tasks = [embed_with_semaphore(text) for text in texts] + results = await asyncio.gather(*tasks, return_exceptions=False) + + logger.info( + "batch_embedding_complete", + count=len(texts), + model=self._model, + ) + + return results + + async def close(self) -> None: + """關閉連線""" + if self._client: + await self._client.aclose() + self._client = None + + +# ============================================================================= +# Errors +# ============================================================================= + + +class EmbeddingError(Exception): + """Embedding 操作錯誤""" + + pass + + +# ============================================================================= +# Singleton Factory +# ============================================================================= + +_embedding_service: OllamaEmbeddingService | None = None + + +def get_embedding_service() -> OllamaEmbeddingService: + """ + 取得 Embedding Service 單例 + + Returns: + OllamaEmbeddingService: 共用實例 + """ + global _embedding_service + if _embedding_service is None: + _embedding_service = OllamaEmbeddingService() + return _embedding_service diff --git a/apps/api/src/services/openclaw.py b/apps/api/src/services/openclaw.py index 98c3e266..dd7a7f31 100644 --- a/apps/api/src/services/openclaw.py +++ b/apps/api/src/services/openclaw.py @@ -35,6 +35,7 @@ from src.models.ai import ( OpenClawDecision, ) from src.services.langfuse_client import langfuse_trace +from src.services.model_registry import get_model_registry from src.services.signoz_client import GoldMetrics, get_signoz_client from src.utils.k8s_naming import normalize_resource_name from src.utils.timezone import now_taipei_iso @@ -270,17 +271,22 @@ class OpenClawService: prompt_length=len(prompt), ) + # 從 ModelRegistry 取得模型配置 + registry = get_model_registry() + model_name = registry.get_model("ollama", "rca") + options = registry.get_provider_options("ollama") + response = await client.post( f"{settings.OLLAMA_URL}/api/generate", json={ - "model": "qwen2.5:7b-instruct", # 使用更大的模型提高品質 + "model": model_name, "prompt": prompt, "stream": False, "format": "json", # 強制 JSON 輸出 "options": { - "num_predict": 1024, # 增加輸出長度 - "temperature": 0.1, # 低溫度確保穩定輸出 - "top_p": 0.9, + "num_predict": options.get("num_predict", 1024), + "temperature": options.get("temperature", 0.1), + "top_p": options.get("top_p", 0.9), }, }, timeout=httpx.Timeout(float(settings.OPENCLAW_TIMEOUT), connect=10.0), @@ -324,9 +330,12 @@ class OpenClawService: try: client = await self._get_client() - # Gemini 1.5 Flash 支援 JSON Mode + # 從 ModelRegistry 取得模型配置 + registry = get_model_registry() + model_name = registry.get_model("gemini", "rca") + response = await client.post( - f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={settings.GEMINI_API_KEY}", + f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent?key={settings.GEMINI_API_KEY}", json={ "contents": [{"parts": [{"text": prompt}]}], "generationConfig": { @@ -737,6 +746,24 @@ class OpenClawService: return response, provider, success, False # from_cache=False + # ========================================================================= + # Public LLM Interface (ILLMProvider Protocol) + # ========================================================================= + + async def call(self, prompt: str) -> tuple[str, str, bool]: + """ + 呼叫 LLM (ILLMProvider Protocol 實作) + + #39 Error Analyzer Agent 使用此方法 + + Args: + prompt: 完整的 prompt + + Returns: + (response, provider, success) + """ + return await self._call_with_fallback(prompt) + # ========================================================================= # Fallback Chain # ========================================================================= @@ -784,7 +811,23 @@ class OpenClawService: DeepLinking.langfuse_trace_url(trace.langfuse_trace_id), ) + # Phase 13.2: Rate Limiter 整合 (2026-03-26) + # 防止雲端 API 用量暴衝,超限自動降級 + from src.services.ai_rate_limiter import get_ai_rate_limiter + rate_limiter = get_ai_rate_limiter() + for provider in settings.AI_FALLBACK_ORDER: + # Rate Limit 檢查 (gemini/claude 需檢查,ollama 不限) + if provider in ("gemini", "claude"): + allowed, reason = await rate_limiter.check_and_increment(provider) + if not allowed: + logger.warning( + "ai_rate_limit_skip", + provider=provider, + reason=reason, + ) + continue # 跳過此 provider,嘗試下一個 + logger.info("ai_provider_attempt", provider=provider) start_time = time.time() @@ -829,13 +872,9 @@ class OpenClawService: return self._generate_mock_response(alert_context or {}, signoz_metrics), "mock_fallback", True def _get_model_name(self, provider: str) -> str: - """取得 provider 對應的模型名稱""" - model_map = { - "ollama": "qwen2.5:7b-instruct", - "gemini": "gemini-1.5-flash", - "claude": "claude-3-haiku-20240307", - } - return model_map.get(provider, provider) + """取得 provider 對應的模型名稱 (從 ModelRegistry)""" + registry = get_model_registry() + return registry.get_model(provider, "rca") # ========================================================================= # Response Parsing (防禦性解析) diff --git a/apps/api/src/services/rag_service.py b/apps/api/src/services/rag_service.py new file mode 100644 index 00000000..0103264a --- /dev/null +++ b/apps/api/src/services/rag_service.py @@ -0,0 +1,468 @@ +""" +RAG Service - 維運手冊向量搜尋 +============================== + +Phase 13.2 #84 - Runbook RAG Tool + +功能: +- 文檔分段 (Chunking) +- 向量索引 (Redis Stack FT.CREATE) +- 語義搜尋 (KNN Vector Search) + +版本: v1.0 +建立日期: 2026-03-26 20:45 (台北時區) +建立者: Claude Code +""" + +import hashlib +import struct +import uuid +from pathlib import Path +from typing import Protocol + +import redis.asyncio as redis +import structlog + +from src.core.config import settings +from src.services.embedding_service import IEmbeddingService, get_embedding_service + +logger = structlog.get_logger(__name__) + + +# ============================================================================= +# Configuration +# ============================================================================= + +RAG_CONFIG = { + "chunk_size": 500, # 每段字數 + "chunk_overlap": 50, # 重疊字數 + "index_name": "idx:runbooks", # Redis index 名稱 + "prefix": "runbook:", # Key prefix + "ttl_days": 30, # 文檔 TTL (天) +} + +# 維運手冊來源目錄 (相對於專案根目錄) +RUNBOOK_SOURCES = [ + "docs/operations/*.md", + "docs/troubleshooting/*.md", + "docs/adr/*.md", + ".agents/skills/*.md", +] + + +# ============================================================================= +# Interface +# ============================================================================= + + +class IRAGService(Protocol): + """RAG 服務介面""" + + async def index_documents(self, base_path: Path) -> int: + """索引文檔,回傳索引數量""" + ... + + async def search(self, query: str, top_k: int = 5) -> list[dict]: + """語義搜尋,回傳相關段落""" + ... + + async def get_index_stats(self) -> dict: + """取得索引統計""" + ... + + +# ============================================================================= +# Implementation +# ============================================================================= + + +class RAGService: + """ + RAG Service 實作 + + 使用 Redis Stack 進行向量索引與搜尋。 + 支援維運手冊的語義搜尋。 + """ + + def __init__( + self, + redis_client: redis.Redis | None = None, + embedding_service: IEmbeddingService | None = None, + ) -> None: + """ + 初始化 RAG Service + + Args: + redis_client: Redis 連線 (DI 注入) + embedding_service: Embedding 服務 (DI 注入) + """ + self._redis = redis_client + self._embedding_service = embedding_service + self._index_created = False + + async def _get_redis(self) -> redis.Redis: + """Lazy load Redis client""" + if self._redis is None: + from src.core.redis_client import get_redis + self._redis = get_redis() + return self._redis + + async def _get_embedding_service(self) -> IEmbeddingService: + """Lazy load Embedding service""" + if self._embedding_service is None: + self._embedding_service = get_embedding_service() + return self._embedding_service + + # ========================================================================= + # Document Processing + # ========================================================================= + + def _chunk_text(self, text: str, source: str) -> list[dict]: + """ + 將文本分段 + + Args: + text: 原始文本 + source: 來源檔案路徑 + + Returns: + list[dict]: 分段列表,每個包含 content, source, chunk_id + """ + chunk_size = RAG_CONFIG["chunk_size"] + overlap = RAG_CONFIG["chunk_overlap"] + + chunks = [] + start = 0 + chunk_idx = 0 + + while start < len(text): + end = start + chunk_size + + # 嘗試在句號/換行處斷開 (避免截斷句子) + if end < len(text): + # 往後找到最近的句號或換行 + for sep in ["\n\n", "。", "\n", ". ", ","]: + sep_pos = text.rfind(sep, start + chunk_size // 2, end + 50) + if sep_pos > start: + end = sep_pos + len(sep) + break + + chunk_content = text[start:end].strip() + + if chunk_content: + chunk_id = self._generate_chunk_id(source, chunk_idx) + chunks.append({ + "chunk_id": chunk_id, + "content": chunk_content, + "source": source, + "chunk_index": chunk_idx, + }) + chunk_idx += 1 + + start = end - overlap if end < len(text) else len(text) + + return chunks + + def _generate_chunk_id(self, source: str, chunk_idx: int) -> str: + """生成唯一 Chunk ID""" + content = f"{source}:{chunk_idx}" + return hashlib.md5(content.encode()).hexdigest()[:12] + + # ========================================================================= + # Redis Vector Index + # ========================================================================= + + async def _ensure_index(self) -> None: + """ + 確保向量索引存在 + + 使用 FT.CREATE 建立 HNSW 向量索引 + """ + if self._index_created: + return + + r = await self._get_redis() + embedding_service = await self._get_embedding_service() + dim = embedding_service.dimension + + index_name = RAG_CONFIG["index_name"] + prefix = RAG_CONFIG["prefix"] + + try: + # 檢查索引是否存在 + await r.execute_command("FT.INFO", index_name) + logger.info("rag_index_exists", index=index_name) + self._index_created = True + return + except redis.ResponseError as e: + if "Unknown index name" not in str(e): + raise + + # 建立向量索引 + # Schema: content (TEXT), source (TAG), embedding (VECTOR HNSW) + try: + await r.execute_command( + "FT.CREATE", index_name, + "ON", "HASH", + "PREFIX", "1", prefix, + "SCHEMA", + "content", "TEXT", "WEIGHT", "1.0", + "source", "TAG", + "chunk_index", "NUMERIC", + "embedding", "VECTOR", "HNSW", "6", + "TYPE", "FLOAT32", + "DIM", str(dim), + "DISTANCE_METRIC", "COSINE", + ) + logger.info( + "rag_index_created", + index=index_name, + dimension=dim, + ) + self._index_created = True + except redis.ResponseError as e: + if "Index already exists" in str(e): + self._index_created = True + else: + logger.error("rag_index_create_failed", error=str(e)) + raise + + async def _store_chunk( + self, + chunk: dict, + embedding: list[float], + ) -> None: + """ + 儲存分段到 Redis + + Args: + chunk: 分段資料 + embedding: 向量 + """ + r = await self._get_redis() + prefix = RAG_CONFIG["prefix"] + key = f"{prefix}{chunk['chunk_id']}" + + # 將 float list 轉換為 bytes (FLOAT32) + embedding_bytes = struct.pack(f"{len(embedding)}f", *embedding) + + await r.hset( + key, + mapping={ + "content": chunk["content"], + "source": chunk["source"], + "chunk_index": chunk["chunk_index"], + "embedding": embedding_bytes, + }, + ) + + # 設定 TTL + ttl_seconds = RAG_CONFIG["ttl_days"] * 24 * 60 * 60 + await r.expire(key, ttl_seconds) + + # ========================================================================= + # Public API + # ========================================================================= + + async def index_documents(self, base_path: Path) -> int: + """ + 索引維運手冊 + + Args: + base_path: 專案根目錄 + + Returns: + int: 索引的分段數量 + """ + await self._ensure_index() + embedding_service = await self._get_embedding_service() + + total_chunks = 0 + all_chunks: list[dict] = [] + + # 收集所有文檔 + for pattern in RUNBOOK_SOURCES: + for file_path in base_path.glob(pattern): + if file_path.is_file(): + try: + content = file_path.read_text(encoding="utf-8") + relative_path = str(file_path.relative_to(base_path)) + chunks = self._chunk_text(content, relative_path) + all_chunks.extend(chunks) + logger.debug( + "rag_file_chunked", + file=relative_path, + chunks=len(chunks), + ) + except Exception as e: + logger.warning( + "rag_file_read_error", + file=str(file_path), + error=str(e), + ) + + if not all_chunks: + logger.warning("rag_no_documents_found", patterns=RUNBOOK_SOURCES) + return 0 + + # 批次向量化 + logger.info("rag_embedding_start", chunks=len(all_chunks)) + texts = [c["content"] for c in all_chunks] + embeddings = await embedding_service.embed_batch(texts, concurrency=3) + + # 儲存到 Redis + for chunk, embedding in zip(all_chunks, embeddings): + await self._store_chunk(chunk, embedding) + total_chunks += 1 + + logger.info( + "rag_index_complete", + total_chunks=total_chunks, + sources=len(RUNBOOK_SOURCES), + ) + + return total_chunks + + async def search( + self, + query: str, + top_k: int = 5, + ) -> list[dict]: + """ + 語義搜尋維運手冊 + + Args: + query: 自然語言查詢 + top_k: 回傳數量 (預設 5) + + Returns: + list[dict]: 相關段落列表 + - content: 段落內容 + - source: 來源檔案 + - score: 相似度分數 + """ + await self._ensure_index() + r = await self._get_redis() + embedding_service = await self._get_embedding_service() + index_name = RAG_CONFIG["index_name"] + + # 向量化查詢 + query_embedding = await embedding_service.embed_text(query) + query_bytes = struct.pack(f"{len(query_embedding)}f", *query_embedding) + + # KNN 向量搜尋 + # *=>[KNN 5 @embedding $vec AS score] + try: + results = await r.execute_command( + "FT.SEARCH", index_name, + f"*=>[KNN {top_k} @embedding $vec AS score]", + "PARAMS", "2", "vec", query_bytes, + "SORTBY", "score", + "RETURN", "3", "content", "source", "score", + "DIALECT", "2", + ) + except redis.ResponseError as e: + logger.error("rag_search_error", error=str(e), query=query[:50]) + return [] + + # 解析結果 + # Results format: [total, key1, [field1, value1, ...], key2, ...] + if not results or results[0] == 0: + return [] + + parsed = [] + i = 1 + while i < len(results): + key = results[i] + fields = results[i + 1] if i + 1 < len(results) else [] + + # 將 fields list 轉為 dict + field_dict = {} + for j in range(0, len(fields), 2): + if j + 1 < len(fields): + field_dict[fields[j]] = fields[j + 1] + + parsed.append({ + "content": field_dict.get("content", ""), + "source": field_dict.get("source", ""), + "score": float(field_dict.get("score", 0)), + }) + + i += 2 + + logger.info( + "rag_search_complete", + query=query[:30], + results=len(parsed), + ) + + return parsed + + async def get_index_stats(self) -> dict: + """ + 取得索引統計 + + Returns: + dict: 索引資訊 + """ + r = await self._get_redis() + index_name = RAG_CONFIG["index_name"] + + try: + info = await r.execute_command("FT.INFO", index_name) + # 將 list 轉為 dict + info_dict = {} + for i in range(0, len(info), 2): + if i + 1 < len(info): + info_dict[info[i]] = info[i + 1] + + return { + "index_name": index_name, + "num_docs": info_dict.get("num_docs", 0), + "num_terms": info_dict.get("num_terms", 0), + "indexing": info_dict.get("indexing", 0), + } + except redis.ResponseError: + return { + "index_name": index_name, + "num_docs": 0, + "error": "Index not found", + } + + async def clear_index(self) -> bool: + """ + 清除索引 (重建用) + + Returns: + bool: 是否成功 + """ + r = await self._get_redis() + index_name = RAG_CONFIG["index_name"] + + try: + await r.execute_command("FT.DROPINDEX", index_name, "DD") + self._index_created = False + logger.info("rag_index_cleared", index=index_name) + return True + except redis.ResponseError: + return False + + +# ============================================================================= +# Singleton Factory +# ============================================================================= + +_rag_service: RAGService | None = None + + +def get_rag_service() -> RAGService: + """ + 取得 RAG Service 單例 + + Returns: + RAGService: 共用實例 + """ + global _rag_service + if _rag_service is None: + _rag_service = RAGService() + return _rag_service