feat(api): Phase 13.2 AI Rate Limiter + RAG 基礎設施 (#84)

Rate Limiter (防止 Gemini 用量暴衝):
- ai_rate_limiter.py: RPM/Daily/Token 三層閥值
- openclaw.py: 整合 rate limit 檢查,超限自動降級
- health.py: /health/ai-usage 監控端點

RAG Tool 基礎 (#84 進行中):
- embedding_service.py: Ollama embedding 封裝
- rag_service.py: Redis vector search 服務

閥值設定:
- Gemini: 10 RPM, 500/day, 100K tokens/day
- Claude: 5 RPM, 200/day, 50K tokens/day

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
OG T
2026-03-26 15:52:57 +08:00
parent 30145c7d7e
commit bf32c4b1f2
5 changed files with 1048 additions and 13 deletions

View File

@@ -240,3 +240,28 @@ async def get_liveness() -> dict[str, str]:
"""
logger.debug("liveness_probe")
return {"status": "alive"}
@router.get("/health/ai-usage")
async def get_ai_usage() -> dict:
"""
AI API 用量監控
Phase 13.2: Rate Limiter 整合 (2026-03-26)
監控 Gemini/Claude API 使用量,防止暴衝。
Returns:
dict: 各 provider 用量統計
"""
from src.services.ai_rate_limiter import get_ai_rate_limiter
rate_limiter = get_ai_rate_limiter()
gemini_stats = await rate_limiter.get_usage_stats("gemini")
claude_stats = await rate_limiter.get_usage_stats("claude")
return {
"gemini": gemini_stats,
"claude": claude_stats,
"fallback_order": settings.AI_FALLBACK_ORDER,
}

View File

@@ -0,0 +1,272 @@
"""
AI Rate Limiter - Gemini API 用量閥值控制
=========================================
防止 API 用量暴衝,超過閥值自動降級回 Ollama。
功能:
- 每分鐘請求限制 (RPM)
- 每日請求限制
- 每日 Token 限制
- 超限自動降級
版本: v1.0
建立日期: 2026-03-26 21:00 (台北時區)
建立者: Claude Code
"""
import structlog
from src.core.config import settings
logger = structlog.get_logger(__name__)
# =============================================================================
# Configuration - 閥值設定
# =============================================================================
RATE_LIMITS = {
"gemini": {
"rpm": 10, # 每分鐘請求數
"daily_requests": 500, # 每日請求數
"daily_tokens": 100_000, # 每日 Token 數
},
"claude": {
"rpm": 5,
"daily_requests": 200,
"daily_tokens": 50_000,
},
}
# Redis Keys
REDIS_KEY_PREFIX = "ai_rate:"
RPM_KEY = f"{REDIS_KEY_PREFIX}rpm:{{provider}}"
DAILY_REQ_KEY = f"{REDIS_KEY_PREFIX}daily_req:{{provider}}:{{date}}"
DAILY_TOKEN_KEY = f"{REDIS_KEY_PREFIX}daily_token:{{provider}}:{{date}}"
# =============================================================================
# Rate Limiter
# =============================================================================
class AIRateLimiter:
"""
AI API 用量限制器
使用 Redis 計數器追蹤用量,超限時返回降級建議。
Usage:
limiter = AIRateLimiter()
allowed, reason = await limiter.check_and_increment("gemini")
if not allowed:
# 降級到 Ollama
provider = "ollama"
"""
def __init__(self) -> None:
self._redis = None
async def _get_redis(self):
"""Lazy load Redis"""
if self._redis is None:
from src.core.redis_client import get_redis
self._redis = get_redis()
return self._redis
def _get_today(self) -> str:
"""取得今日日期 (台北時區)"""
from src.utils.timezone import now_taipei
return now_taipei().strftime("%Y-%m-%d")
async def check_and_increment(
self,
provider: str,
tokens: int = 0,
) -> tuple[bool, str | None]:
"""
檢查並遞增計數器
Args:
provider: AI 提供者 (gemini, claude)
tokens: 本次使用的 token 數 (事後更新用)
Returns:
tuple[bool, str | None]: (是否允許, 拒絕原因)
"""
if provider not in RATE_LIMITS:
return True, None # 無限制的 provider (如 ollama)
limits = RATE_LIMITS[provider]
r = await self._get_redis()
today = self._get_today()
# 1. 檢查 RPM
rpm_key = RPM_KEY.format(provider=provider)
current_rpm = await r.get(rpm_key)
current_rpm = int(current_rpm) if current_rpm else 0
if current_rpm >= limits["rpm"]:
logger.warning(
"ai_rate_limit_rpm",
provider=provider,
current=current_rpm,
limit=limits["rpm"],
)
return False, f"RPM limit exceeded ({current_rpm}/{limits['rpm']})"
# 2. 檢查每日請求數
daily_req_key = DAILY_REQ_KEY.format(provider=provider, date=today)
current_daily = await r.get(daily_req_key)
current_daily = int(current_daily) if current_daily else 0
if current_daily >= limits["daily_requests"]:
logger.warning(
"ai_rate_limit_daily",
provider=provider,
current=current_daily,
limit=limits["daily_requests"],
)
return False, f"Daily request limit exceeded ({current_daily}/{limits['daily_requests']})"
# 3. 檢查每日 Token (如果有追蹤)
daily_token_key = DAILY_TOKEN_KEY.format(provider=provider, date=today)
current_tokens = await r.get(daily_token_key)
current_tokens = int(current_tokens) if current_tokens else 0
if current_tokens >= limits["daily_tokens"]:
logger.warning(
"ai_rate_limit_tokens",
provider=provider,
current=current_tokens,
limit=limits["daily_tokens"],
)
return False, f"Daily token limit exceeded ({current_tokens}/{limits['daily_tokens']})"
# 4. 遞增計數器
pipe = r.pipeline()
# RPM: 60 秒過期
pipe.incr(rpm_key)
pipe.expire(rpm_key, 60)
# Daily requests: 明天過期
pipe.incr(daily_req_key)
pipe.expire(daily_req_key, 86400)
# Daily tokens
if tokens > 0:
pipe.incrby(daily_token_key, tokens)
pipe.expire(daily_token_key, 86400)
await pipe.execute()
logger.debug(
"ai_rate_check_passed",
provider=provider,
rpm=current_rpm + 1,
daily=current_daily + 1,
)
return True, None
async def record_tokens(self, provider: str, tokens: int) -> None:
"""
記錄 Token 用量 (回應後呼叫)
Args:
provider: AI 提供者
tokens: 使用的 token 數
"""
if provider not in RATE_LIMITS or tokens <= 0:
return
r = await self._get_redis()
today = self._get_today()
daily_token_key = DAILY_TOKEN_KEY.format(provider=provider, date=today)
await r.incrby(daily_token_key, tokens)
await r.expire(daily_token_key, 86400)
logger.debug(
"ai_tokens_recorded",
provider=provider,
tokens=tokens,
)
async def get_usage_stats(self, provider: str) -> dict:
"""
取得用量統計
Args:
provider: AI 提供者
Returns:
dict: 用量統計
"""
if provider not in RATE_LIMITS:
return {"provider": provider, "limited": False}
limits = RATE_LIMITS[provider]
r = await self._get_redis()
today = self._get_today()
rpm_key = RPM_KEY.format(provider=provider)
daily_req_key = DAILY_REQ_KEY.format(provider=provider, date=today)
daily_token_key = DAILY_TOKEN_KEY.format(provider=provider, date=today)
current_rpm = await r.get(rpm_key)
current_daily = await r.get(daily_req_key)
current_tokens = await r.get(daily_token_key)
return {
"provider": provider,
"date": today,
"rpm": {
"current": int(current_rpm) if current_rpm else 0,
"limit": limits["rpm"],
},
"daily_requests": {
"current": int(current_daily) if current_daily else 0,
"limit": limits["daily_requests"],
},
"daily_tokens": {
"current": int(current_tokens) if current_tokens else 0,
"limit": limits["daily_tokens"],
},
}
async def reset_limits(self, provider: str) -> None:
"""
重置限制 (緊急用)
Args:
provider: AI 提供者
"""
r = await self._get_redis()
today = self._get_today()
keys = [
RPM_KEY.format(provider=provider),
DAILY_REQ_KEY.format(provider=provider, date=today),
DAILY_TOKEN_KEY.format(provider=provider, date=today),
]
await r.delete(*keys)
logger.info("ai_rate_limits_reset", provider=provider)
# =============================================================================
# Singleton
# =============================================================================
_rate_limiter: AIRateLimiter | None = None
def get_ai_rate_limiter() -> AIRateLimiter:
"""取得 Rate Limiter 單例"""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = AIRateLimiter()
return _rate_limiter

View File

@@ -0,0 +1,231 @@
"""
Embedding Service - Ollama BGE-M3 替代方案
==========================================
使用 Ollama qwen2.5:7b-instruct 提供文本向量化功能。
雖非專用 embedding 模型,但支援多語言 (繁中/英文)。
Phase 13.2 #84 - RAG Tool 基礎設施
版本: v1.0
建立日期: 2026-03-26 20:30 (台北時區)
建立者: Claude Code
"""
import asyncio
from typing import Protocol
import httpx
import structlog
from src.core.config import settings
logger = structlog.get_logger(__name__)
# =============================================================================
# Interface (DI Protocol)
# =============================================================================
class IEmbeddingService(Protocol):
"""Embedding 服務介面"""
async def embed_text(self, text: str) -> list[float]:
"""將單一文本轉換為向量"""
...
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""批次向量化多個文本"""
...
@property
def dimension(self) -> int:
"""向量維度"""
...
# =============================================================================
# Implementation
# =============================================================================
class OllamaEmbeddingService:
"""
Ollama Embedding Service
使用 Ollama API 進行文本向量化。
預設使用 qwen2.5:7b-instruct (3584 維向量)。
Usage:
service = OllamaEmbeddingService()
vector = await service.embed_text("維運手冊")
"""
def __init__(
self,
model: str = "qwen2.5:7b-instruct",
ollama_url: str | None = None,
timeout: float = 30.0,
) -> None:
"""
初始化 Embedding Service
Args:
model: Ollama 模型名稱 (必須支援 embedding)
ollama_url: Ollama API URL (預設從 config 讀取)
timeout: 請求超時 (秒)
"""
self._model = model
self._ollama_url = ollama_url or settings.OLLAMA_URL
self._timeout = timeout
self._dimension: int | None = None
self._client: httpx.AsyncClient | None = None
@property
def dimension(self) -> int:
"""
向量維度
首次呼叫會自動偵測,之後快取。
qwen2.5:7b-instruct = 3584 維
"""
if self._dimension is None:
# 預設值,實際會在第一次 embed 時更新
return 3584
return self._dimension
async def _get_client(self) -> httpx.AsyncClient:
"""取得 HTTP Client (連線池共用)"""
if self._client is None:
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(self._timeout),
limits=httpx.Limits(max_connections=10),
)
return self._client
async def embed_text(self, text: str) -> list[float]:
"""
將單一文本轉換為向量
Args:
text: 要向量化的文本
Returns:
list[float]: 向量 (3584 維)
Raises:
EmbeddingError: 向量化失敗
"""
client = await self._get_client()
try:
response = await client.post(
f"{self._ollama_url}/api/embeddings",
json={
"model": self._model,
"prompt": text,
},
)
response.raise_for_status()
data = response.json()
embedding = data.get("embedding", [])
# 更新維度快取
if self._dimension is None and embedding:
self._dimension = len(embedding)
logger.info(
"embedding_dimension_detected",
model=self._model,
dimension=self._dimension,
)
return embedding
except httpx.TimeoutException:
logger.error("embedding_timeout", model=self._model, text_len=len(text))
raise EmbeddingError(f"Embedding timeout after {self._timeout}s")
except httpx.HTTPStatusError as e:
logger.error(
"embedding_http_error",
status=e.response.status_code,
model=self._model,
)
raise EmbeddingError(f"Ollama API error: {e.response.status_code}")
except Exception as e:
logger.error("embedding_error", error=str(e), model=self._model)
raise EmbeddingError(f"Embedding failed: {e}")
async def embed_batch(
self,
texts: list[str],
concurrency: int = 5,
) -> list[list[float]]:
"""
批次向量化多個文本
Args:
texts: 文本列表
concurrency: 同時並行數 (避免過載 Ollama)
Returns:
list[list[float]]: 向量列表 (與輸入順序對應)
"""
if not texts:
return []
results: list[list[float]] = []
semaphore = asyncio.Semaphore(concurrency)
async def embed_with_semaphore(text: str) -> list[float]:
async with semaphore:
return await self.embed_text(text)
tasks = [embed_with_semaphore(text) for text in texts]
results = await asyncio.gather(*tasks, return_exceptions=False)
logger.info(
"batch_embedding_complete",
count=len(texts),
model=self._model,
)
return results
async def close(self) -> None:
"""關閉連線"""
if self._client:
await self._client.aclose()
self._client = None
# =============================================================================
# Errors
# =============================================================================
class EmbeddingError(Exception):
"""Embedding 操作錯誤"""
pass
# =============================================================================
# Singleton Factory
# =============================================================================
_embedding_service: OllamaEmbeddingService | None = None
def get_embedding_service() -> OllamaEmbeddingService:
"""
取得 Embedding Service 單例
Returns:
OllamaEmbeddingService: 共用實例
"""
global _embedding_service
if _embedding_service is None:
_embedding_service = OllamaEmbeddingService()
return _embedding_service

View File

@@ -35,6 +35,7 @@ from src.models.ai import (
OpenClawDecision,
)
from src.services.langfuse_client import langfuse_trace
from src.services.model_registry import get_model_registry
from src.services.signoz_client import GoldMetrics, get_signoz_client
from src.utils.k8s_naming import normalize_resource_name
from src.utils.timezone import now_taipei_iso
@@ -270,17 +271,22 @@ class OpenClawService:
prompt_length=len(prompt),
)
# 從 ModelRegistry 取得模型配置
registry = get_model_registry()
model_name = registry.get_model("ollama", "rca")
options = registry.get_provider_options("ollama")
response = await client.post(
f"{settings.OLLAMA_URL}/api/generate",
json={
"model": "qwen2.5:7b-instruct", # 使用更大的模型提高品質
"model": model_name,
"prompt": prompt,
"stream": False,
"format": "json", # 強制 JSON 輸出
"options": {
"num_predict": 1024, # 增加輸出長度
"temperature": 0.1, # 低溫度確保穩定輸出
"top_p": 0.9,
"num_predict": options.get("num_predict", 1024),
"temperature": options.get("temperature", 0.1),
"top_p": options.get("top_p", 0.9),
},
},
timeout=httpx.Timeout(float(settings.OPENCLAW_TIMEOUT), connect=10.0),
@@ -324,9 +330,12 @@ class OpenClawService:
try:
client = await self._get_client()
# Gemini 1.5 Flash 支援 JSON Mode
# 從 ModelRegistry 取得模型配置
registry = get_model_registry()
model_name = registry.get_model("gemini", "rca")
response = await client.post(
f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={settings.GEMINI_API_KEY}",
f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent?key={settings.GEMINI_API_KEY}",
json={
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
@@ -737,6 +746,24 @@ class OpenClawService:
return response, provider, success, False # from_cache=False
# =========================================================================
# Public LLM Interface (ILLMProvider Protocol)
# =========================================================================
async def call(self, prompt: str) -> tuple[str, str, bool]:
"""
呼叫 LLM (ILLMProvider Protocol 實作)
#39 Error Analyzer Agent 使用此方法
Args:
prompt: 完整的 prompt
Returns:
(response, provider, success)
"""
return await self._call_with_fallback(prompt)
# =========================================================================
# Fallback Chain
# =========================================================================
@@ -784,7 +811,23 @@ class OpenClawService:
DeepLinking.langfuse_trace_url(trace.langfuse_trace_id),
)
# Phase 13.2: Rate Limiter 整合 (2026-03-26)
# 防止雲端 API 用量暴衝,超限自動降級
from src.services.ai_rate_limiter import get_ai_rate_limiter
rate_limiter = get_ai_rate_limiter()
for provider in settings.AI_FALLBACK_ORDER:
# Rate Limit 檢查 (gemini/claude 需檢查ollama 不限)
if provider in ("gemini", "claude"):
allowed, reason = await rate_limiter.check_and_increment(provider)
if not allowed:
logger.warning(
"ai_rate_limit_skip",
provider=provider,
reason=reason,
)
continue # 跳過此 provider嘗試下一個
logger.info("ai_provider_attempt", provider=provider)
start_time = time.time()
@@ -829,13 +872,9 @@ class OpenClawService:
return self._generate_mock_response(alert_context or {}, signoz_metrics), "mock_fallback", True
def _get_model_name(self, provider: str) -> str:
"""取得 provider 對應的模型名稱"""
model_map = {
"ollama": "qwen2.5:7b-instruct",
"gemini": "gemini-1.5-flash",
"claude": "claude-3-haiku-20240307",
}
return model_map.get(provider, provider)
"""取得 provider 對應的模型名稱 (從 ModelRegistry)"""
registry = get_model_registry()
return registry.get_model(provider, "rca")
# =========================================================================
# Response Parsing (防禦性解析)

View File

@@ -0,0 +1,468 @@
"""
RAG Service - 維運手冊向量搜尋
==============================
Phase 13.2 #84 - Runbook RAG Tool
功能:
- 文檔分段 (Chunking)
- 向量索引 (Redis Stack FT.CREATE)
- 語義搜尋 (KNN Vector Search)
版本: v1.0
建立日期: 2026-03-26 20:45 (台北時區)
建立者: Claude Code
"""
import hashlib
import struct
import uuid
from pathlib import Path
from typing import Protocol
import redis.asyncio as redis
import structlog
from src.core.config import settings
from src.services.embedding_service import IEmbeddingService, get_embedding_service
logger = structlog.get_logger(__name__)
# =============================================================================
# Configuration
# =============================================================================
RAG_CONFIG = {
"chunk_size": 500, # 每段字數
"chunk_overlap": 50, # 重疊字數
"index_name": "idx:runbooks", # Redis index 名稱
"prefix": "runbook:", # Key prefix
"ttl_days": 30, # 文檔 TTL (天)
}
# 維運手冊來源目錄 (相對於專案根目錄)
RUNBOOK_SOURCES = [
"docs/operations/*.md",
"docs/troubleshooting/*.md",
"docs/adr/*.md",
".agents/skills/*.md",
]
# =============================================================================
# Interface
# =============================================================================
class IRAGService(Protocol):
"""RAG 服務介面"""
async def index_documents(self, base_path: Path) -> int:
"""索引文檔,回傳索引數量"""
...
async def search(self, query: str, top_k: int = 5) -> list[dict]:
"""語義搜尋,回傳相關段落"""
...
async def get_index_stats(self) -> dict:
"""取得索引統計"""
...
# =============================================================================
# Implementation
# =============================================================================
class RAGService:
"""
RAG Service 實作
使用 Redis Stack 進行向量索引與搜尋。
支援維運手冊的語義搜尋。
"""
def __init__(
self,
redis_client: redis.Redis | None = None,
embedding_service: IEmbeddingService | None = None,
) -> None:
"""
初始化 RAG Service
Args:
redis_client: Redis 連線 (DI 注入)
embedding_service: Embedding 服務 (DI 注入)
"""
self._redis = redis_client
self._embedding_service = embedding_service
self._index_created = False
async def _get_redis(self) -> redis.Redis:
"""Lazy load Redis client"""
if self._redis is None:
from src.core.redis_client import get_redis
self._redis = get_redis()
return self._redis
async def _get_embedding_service(self) -> IEmbeddingService:
"""Lazy load Embedding service"""
if self._embedding_service is None:
self._embedding_service = get_embedding_service()
return self._embedding_service
# =========================================================================
# Document Processing
# =========================================================================
def _chunk_text(self, text: str, source: str) -> list[dict]:
"""
將文本分段
Args:
text: 原始文本
source: 來源檔案路徑
Returns:
list[dict]: 分段列表,每個包含 content, source, chunk_id
"""
chunk_size = RAG_CONFIG["chunk_size"]
overlap = RAG_CONFIG["chunk_overlap"]
chunks = []
start = 0
chunk_idx = 0
while start < len(text):
end = start + chunk_size
# 嘗試在句號/換行處斷開 (避免截斷句子)
if end < len(text):
# 往後找到最近的句號或換行
for sep in ["\n\n", "", "\n", ". ", ""]:
sep_pos = text.rfind(sep, start + chunk_size // 2, end + 50)
if sep_pos > start:
end = sep_pos + len(sep)
break
chunk_content = text[start:end].strip()
if chunk_content:
chunk_id = self._generate_chunk_id(source, chunk_idx)
chunks.append({
"chunk_id": chunk_id,
"content": chunk_content,
"source": source,
"chunk_index": chunk_idx,
})
chunk_idx += 1
start = end - overlap if end < len(text) else len(text)
return chunks
def _generate_chunk_id(self, source: str, chunk_idx: int) -> str:
"""生成唯一 Chunk ID"""
content = f"{source}:{chunk_idx}"
return hashlib.md5(content.encode()).hexdigest()[:12]
# =========================================================================
# Redis Vector Index
# =========================================================================
async def _ensure_index(self) -> None:
"""
確保向量索引存在
使用 FT.CREATE 建立 HNSW 向量索引
"""
if self._index_created:
return
r = await self._get_redis()
embedding_service = await self._get_embedding_service()
dim = embedding_service.dimension
index_name = RAG_CONFIG["index_name"]
prefix = RAG_CONFIG["prefix"]
try:
# 檢查索引是否存在
await r.execute_command("FT.INFO", index_name)
logger.info("rag_index_exists", index=index_name)
self._index_created = True
return
except redis.ResponseError as e:
if "Unknown index name" not in str(e):
raise
# 建立向量索引
# Schema: content (TEXT), source (TAG), embedding (VECTOR HNSW)
try:
await r.execute_command(
"FT.CREATE", index_name,
"ON", "HASH",
"PREFIX", "1", prefix,
"SCHEMA",
"content", "TEXT", "WEIGHT", "1.0",
"source", "TAG",
"chunk_index", "NUMERIC",
"embedding", "VECTOR", "HNSW", "6",
"TYPE", "FLOAT32",
"DIM", str(dim),
"DISTANCE_METRIC", "COSINE",
)
logger.info(
"rag_index_created",
index=index_name,
dimension=dim,
)
self._index_created = True
except redis.ResponseError as e:
if "Index already exists" in str(e):
self._index_created = True
else:
logger.error("rag_index_create_failed", error=str(e))
raise
async def _store_chunk(
self,
chunk: dict,
embedding: list[float],
) -> None:
"""
儲存分段到 Redis
Args:
chunk: 分段資料
embedding: 向量
"""
r = await self._get_redis()
prefix = RAG_CONFIG["prefix"]
key = f"{prefix}{chunk['chunk_id']}"
# 將 float list 轉換為 bytes (FLOAT32)
embedding_bytes = struct.pack(f"{len(embedding)}f", *embedding)
await r.hset(
key,
mapping={
"content": chunk["content"],
"source": chunk["source"],
"chunk_index": chunk["chunk_index"],
"embedding": embedding_bytes,
},
)
# 設定 TTL
ttl_seconds = RAG_CONFIG["ttl_days"] * 24 * 60 * 60
await r.expire(key, ttl_seconds)
# =========================================================================
# Public API
# =========================================================================
async def index_documents(self, base_path: Path) -> int:
"""
索引維運手冊
Args:
base_path: 專案根目錄
Returns:
int: 索引的分段數量
"""
await self._ensure_index()
embedding_service = await self._get_embedding_service()
total_chunks = 0
all_chunks: list[dict] = []
# 收集所有文檔
for pattern in RUNBOOK_SOURCES:
for file_path in base_path.glob(pattern):
if file_path.is_file():
try:
content = file_path.read_text(encoding="utf-8")
relative_path = str(file_path.relative_to(base_path))
chunks = self._chunk_text(content, relative_path)
all_chunks.extend(chunks)
logger.debug(
"rag_file_chunked",
file=relative_path,
chunks=len(chunks),
)
except Exception as e:
logger.warning(
"rag_file_read_error",
file=str(file_path),
error=str(e),
)
if not all_chunks:
logger.warning("rag_no_documents_found", patterns=RUNBOOK_SOURCES)
return 0
# 批次向量化
logger.info("rag_embedding_start", chunks=len(all_chunks))
texts = [c["content"] for c in all_chunks]
embeddings = await embedding_service.embed_batch(texts, concurrency=3)
# 儲存到 Redis
for chunk, embedding in zip(all_chunks, embeddings):
await self._store_chunk(chunk, embedding)
total_chunks += 1
logger.info(
"rag_index_complete",
total_chunks=total_chunks,
sources=len(RUNBOOK_SOURCES),
)
return total_chunks
async def search(
self,
query: str,
top_k: int = 5,
) -> list[dict]:
"""
語義搜尋維運手冊
Args:
query: 自然語言查詢
top_k: 回傳數量 (預設 5)
Returns:
list[dict]: 相關段落列表
- content: 段落內容
- source: 來源檔案
- score: 相似度分數
"""
await self._ensure_index()
r = await self._get_redis()
embedding_service = await self._get_embedding_service()
index_name = RAG_CONFIG["index_name"]
# 向量化查詢
query_embedding = await embedding_service.embed_text(query)
query_bytes = struct.pack(f"{len(query_embedding)}f", *query_embedding)
# KNN 向量搜尋
# *=>[KNN 5 @embedding $vec AS score]
try:
results = await r.execute_command(
"FT.SEARCH", index_name,
f"*=>[KNN {top_k} @embedding $vec AS score]",
"PARAMS", "2", "vec", query_bytes,
"SORTBY", "score",
"RETURN", "3", "content", "source", "score",
"DIALECT", "2",
)
except redis.ResponseError as e:
logger.error("rag_search_error", error=str(e), query=query[:50])
return []
# 解析結果
# Results format: [total, key1, [field1, value1, ...], key2, ...]
if not results or results[0] == 0:
return []
parsed = []
i = 1
while i < len(results):
key = results[i]
fields = results[i + 1] if i + 1 < len(results) else []
# 將 fields list 轉為 dict
field_dict = {}
for j in range(0, len(fields), 2):
if j + 1 < len(fields):
field_dict[fields[j]] = fields[j + 1]
parsed.append({
"content": field_dict.get("content", ""),
"source": field_dict.get("source", ""),
"score": float(field_dict.get("score", 0)),
})
i += 2
logger.info(
"rag_search_complete",
query=query[:30],
results=len(parsed),
)
return parsed
async def get_index_stats(self) -> dict:
"""
取得索引統計
Returns:
dict: 索引資訊
"""
r = await self._get_redis()
index_name = RAG_CONFIG["index_name"]
try:
info = await r.execute_command("FT.INFO", index_name)
# 將 list 轉為 dict
info_dict = {}
for i in range(0, len(info), 2):
if i + 1 < len(info):
info_dict[info[i]] = info[i + 1]
return {
"index_name": index_name,
"num_docs": info_dict.get("num_docs", 0),
"num_terms": info_dict.get("num_terms", 0),
"indexing": info_dict.get("indexing", 0),
}
except redis.ResponseError:
return {
"index_name": index_name,
"num_docs": 0,
"error": "Index not found",
}
async def clear_index(self) -> bool:
"""
清除索引 (重建用)
Returns:
bool: 是否成功
"""
r = await self._get_redis()
index_name = RAG_CONFIG["index_name"]
try:
await r.execute_command("FT.DROPINDEX", index_name, "DD")
self._index_created = False
logger.info("rag_index_cleared", index=index_name)
return True
except redis.ResponseError:
return False
# =============================================================================
# Singleton Factory
# =============================================================================
_rag_service: RAGService | None = None
def get_rag_service() -> RAGService:
"""
取得 RAG Service 單例
Returns:
RAGService: 共用實例
"""
global _rag_service
if _rag_service is None:
_rag_service = RAGService()
return _rag_service