250 lines
9.1 KiB
Python
250 lines
9.1 KiB
Python
"""
|
||
AWOOOI — Knowledge RAG Service (Phase 33, ADR-067)
|
||
==================================================
|
||
本地 RAG 知識庫:bge-m3 1024維向量 + pgvector
|
||
|
||
索引策略:
|
||
- 初期 < 100 筆: 線性搜尋
|
||
- 超過 100 筆: 執行 CREATE INDEX ivfflat (手動觸發)
|
||
|
||
向量模型: bge-m3 (GCP-A/GCP-B/111 Ollama lane, 1024維)
|
||
生成模型: qwen2.5:7b-instruct (Ollama GCP-A/GCP-B/111)
|
||
|
||
leWOOOgo: Service 層只處理業務邏輯,DB 存取委派 rag_chunk_repository
|
||
架構審查 C1 修正: 2026-04-10 Claude Sonnet 4.6 Asia/Taipei
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
from pathlib import Path
|
||
|
||
import httpx
|
||
import structlog
|
||
|
||
import src.repositories.rag_chunk_repository as rag_repo
|
||
from src.core.config import settings
|
||
from src.services.ollama_endpoint_circuit_breaker import (
|
||
record_ollama_endpoint_failure,
|
||
record_ollama_endpoint_success,
|
||
resolve_ollama_order_with_cooldown,
|
||
)
|
||
|
||
logger = structlog.get_logger(__name__)
|
||
|
||
_EMBED_MODEL = "bge-m3:latest"
|
||
_GEN_MODEL = "qwen2.5:7b-instruct"
|
||
_TOP_K = 5
|
||
|
||
_INDEX_SOURCES = [
|
||
Path("docs/runbooks"),
|
||
Path("docs/adr"),
|
||
Path("docs"),
|
||
Path(".agents/skills"),
|
||
]
|
||
|
||
|
||
class KnowledgeRAGService:
|
||
"""RAG 知識庫服務 — leWOOOgo 合規: DB 存取全部委派 rag_chunk_repository"""
|
||
|
||
def __init__(self) -> None:
|
||
self._http: httpx.AsyncClient | None = None
|
||
|
||
async def _get_http(self) -> httpx.AsyncClient:
|
||
if self._http is None or self._http.is_closed:
|
||
self._http = httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=10.0))
|
||
return self._http
|
||
|
||
# ------------------------------------------------------------------
|
||
# 公開 API
|
||
# ------------------------------------------------------------------
|
||
|
||
async def query(self, question: str, top_k: int = _TOP_K) -> str:
|
||
"""RAG 查詢:embedding → pgvector knn → 生成回答"""
|
||
embedding = await self._embed(question)
|
||
if embedding is None:
|
||
return "⚠️ 無法生成向量,RAG 查詢失敗"
|
||
|
||
chunks = await rag_repo.search_chunks(embedding, top_k)
|
||
if not chunks:
|
||
return "📭 知識庫尚無相關資料,請先執行索引建立"
|
||
|
||
context = "\n\n---\n\n".join(
|
||
f"[{c.get('source','?')}] {c.get('title','')}\n{c.get('chunk_text','')}"
|
||
for c in chunks
|
||
)
|
||
return await self._generate_answer(question, context)
|
||
|
||
async def index_document(
|
||
self,
|
||
source: str,
|
||
source_id: str,
|
||
title: str,
|
||
text: str,
|
||
metadata: dict | None = None,
|
||
) -> bool:
|
||
"""
|
||
將文件向量化並儲存到 pgvector
|
||
自動分段 (每段 500 字, overlap 100),索引前先刪舊版本去重
|
||
"""
|
||
await rag_repo.delete_by_source_id(source_id)
|
||
chunks = self._chunk_text(text, chunk_size=500, overlap=100)
|
||
success = 0
|
||
for chunk in chunks:
|
||
emb = await self._embed(chunk)
|
||
if emb:
|
||
ok = await rag_repo.insert_chunk(source, source_id, title, chunk, emb, metadata or {})
|
||
if ok:
|
||
success += 1
|
||
logger.info("rag_indexed", source=source, source_id=source_id, chunks=success)
|
||
return success > 0
|
||
|
||
async def index_all_sources(self) -> int:
|
||
"""
|
||
掃描所有知識來源並向量化(供 /rag/index 端點呼叫)
|
||
來源: docs/runbooks/, docs/adr/, docs/, .agents/skills/
|
||
"""
|
||
total = 0
|
||
for source_dir in _INDEX_SOURCES:
|
||
if not source_dir.exists():
|
||
continue
|
||
for md_file in source_dir.rglob("*.md"):
|
||
try:
|
||
text = md_file.read_text(encoding="utf-8", errors="ignore")
|
||
ok = await self.index_document(
|
||
source=source_dir.parts[-1],
|
||
source_id=str(md_file),
|
||
title=md_file.stem,
|
||
text=text,
|
||
)
|
||
if ok:
|
||
total += 1
|
||
logger.debug("rag_source_indexed", file=str(md_file))
|
||
except Exception as e:
|
||
logger.warning("rag_source_index_failed", file=str(md_file), error=str(e))
|
||
logger.info("rag_index_all_complete", total_docs=total)
|
||
return total
|
||
|
||
async def get_stats(self) -> dict:
|
||
"""RAG 知識庫統計"""
|
||
return await rag_repo.get_stats()
|
||
|
||
# ------------------------------------------------------------------
|
||
# 向量化 + 生成
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _embed(self, text: str) -> list[float] | None:
|
||
http = await self._get_http()
|
||
for endpoint in resolve_ollama_order_with_cooldown("embedding"):
|
||
if not endpoint.url:
|
||
continue
|
||
try:
|
||
resp = await http.post(
|
||
f"{endpoint.url}/api/embeddings",
|
||
json={
|
||
"model": getattr(settings, "OLLAMA_EMBEDDING_MODEL", _EMBED_MODEL),
|
||
"prompt": text,
|
||
},
|
||
)
|
||
if resp.status_code == 200:
|
||
record_ollama_endpoint_success(endpoint.url)
|
||
logger.debug(
|
||
"rag_embed_success",
|
||
provider=endpoint.provider_name,
|
||
)
|
||
return resp.json().get("embedding")
|
||
if resp.status_code >= 500:
|
||
record_ollama_endpoint_failure(endpoint.url)
|
||
logger.warning(
|
||
"rag_embed_http_error",
|
||
provider=endpoint.provider_name,
|
||
status=resp.status_code,
|
||
)
|
||
except Exception as e:
|
||
if isinstance(e, (httpx.TimeoutException, httpx.TransportError)):
|
||
record_ollama_endpoint_failure(endpoint.url)
|
||
logger.warning(
|
||
"rag_embed_failed",
|
||
provider=endpoint.provider_name,
|
||
error=str(e),
|
||
)
|
||
return None
|
||
|
||
async def _generate_answer(self, question: str, context: str) -> str:
|
||
prompt = (
|
||
"你是 AWOOOI AIOps 知識庫助手,請用繁體中文根據以下資料回答問題。\n"
|
||
"如果資料不足以回答,明確說「資料不足」,不要猜測。\n\n"
|
||
f"=== 相關資料 ===\n{context[:6000]}\n\n"
|
||
f"=== 問題 ===\n{question}"
|
||
)
|
||
http = await self._get_http()
|
||
for endpoint in resolve_ollama_order_with_cooldown("rag"):
|
||
if not endpoint.url:
|
||
continue
|
||
try:
|
||
resp = await http.post(
|
||
f"{endpoint.url}/api/generate",
|
||
json={
|
||
"model": _GEN_MODEL,
|
||
"prompt": prompt,
|
||
"stream": False,
|
||
"options": {"num_predict": 512, "temperature": 0.2},
|
||
},
|
||
timeout=httpx.Timeout(90.0, connect=10.0),
|
||
)
|
||
if resp.status_code == 200:
|
||
record_ollama_endpoint_success(endpoint.url)
|
||
logger.debug(
|
||
"rag_generate_success",
|
||
provider=endpoint.provider_name,
|
||
)
|
||
return resp.json().get("response", "").strip()
|
||
if resp.status_code >= 500:
|
||
record_ollama_endpoint_failure(endpoint.url)
|
||
logger.warning(
|
||
"rag_generate_http_error",
|
||
provider=endpoint.provider_name,
|
||
status=resp.status_code,
|
||
)
|
||
except Exception as e:
|
||
if isinstance(e, (httpx.TimeoutException, httpx.TransportError)):
|
||
record_ollama_endpoint_failure(endpoint.url)
|
||
logger.error(
|
||
"rag_generate_failed",
|
||
provider=endpoint.provider_name,
|
||
error=str(e),
|
||
)
|
||
return "⚠️ RAG 生成失敗,請稍後再試"
|
||
|
||
# ------------------------------------------------------------------
|
||
# 工具
|
||
# ------------------------------------------------------------------
|
||
|
||
@staticmethod
|
||
def _chunk_text(text: str, chunk_size: int = 500, overlap: int = 100) -> list[str]:
|
||
"""簡單字元分段,帶 overlap"""
|
||
chunks = []
|
||
start = 0
|
||
while start < len(text):
|
||
end = start + chunk_size
|
||
chunks.append(text[start:end])
|
||
start += chunk_size - overlap
|
||
return [c for c in chunks if c.strip()]
|
||
|
||
async def close(self) -> None:
|
||
if self._http and not self._http.is_closed:
|
||
await self._http.aclose()
|
||
|
||
|
||
_instance: KnowledgeRAGService | None = None
|
||
|
||
|
||
def get_knowledge_rag_service() -> KnowledgeRAGService:
|
||
global _instance
|
||
if _instance is None:
|
||
_instance = KnowledgeRAGService()
|
||
return _instance
|
||
|
||
|
||
def set_knowledge_rag_service(svc: KnowledgeRAGService) -> None:
|
||
global _instance
|
||
_instance = svc
|