Files
ewoooc/services/rag_service.py
OoO 353e565e52
All checks were successful
CD Pipeline / deploy (push) Successful in 1m4s
V10.417 protect embedding fallback routing
2026-05-24 14:53:43 +08:00

671 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
services/rag_service.py
Operation Ollama-First v5.0 / Phase 11 — RAG 查詢服務
設計原則(憲法級):
1. 純讀 ai_insights + 寫 rag_query_log不動 ai_insights schema
2. cosine similarity threshold 預設 0.85,太低不採用避免幻覺
3. embedding 走 bge-m3:latest與 ai_insights 一致簽名 — Phase 11.0 護欄 #3
4. feature flag RAG_ENABLED 預設 OFF避免影響戰前行為
5. 失敗安全DB 掛 / embedding 失敗 / threshold 不到 → 回 RAGResult(hits=[])
caller 自行 fallback LLM
6. fire-and-forget lograg_query_log INSERT async daemon thread不阻塞主流程
7. PII 保護query_text 寫入時截 4KBCHECK constraint 也會擋)
對應:
- migrations/027_create_rag_query_log.sql
- migrations/026_add_embedding_signature.sql
- docs/adr/ADR-029-hermes-first-twin-tower.md
- docs/phase0_audit_report_20260503.md Section 3 (BGE-M3 一致性護欄)
主入口:
- rag_service.query(...) : 主查詢介面
- rag_service.feedback(log_id, ...) : Telegram 👍/👎 反饋寫回
- rag_service.invalidate_by_caller(...) : 預留快取失效鉤v5.0 暫無 cache 層)
"""
from __future__ import annotations
import hashlib
import logging
import os
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
# ─────────────────────────────────────────────────────────────────────────────
# Feature flag + 預設參數
# ─────────────────────────────────────────────────────────────────────────────
def is_rag_enabled() -> bool:
"""環境變數即時讀取(允許 runtime toggle與 ai_call_logger 一致風格)。
feature flag 預設 OFF — 戰前部署後行為與 v4.x 完全相同。
"""
val = os.environ.get('RAG_ENABLED', 'false').strip().lower()
return val in ('true', '1', 'yes', 'on')
# 內部別名(沿用既有 _is_rag_enabled 命名相容性)
_is_rag_enabled = is_rag_enabled
RAG_DEFAULT_THRESHOLD = float(os.getenv('RAG_DEFAULT_THRESHOLD', '0.85'))
RAG_DEFAULT_TOP_K = int(os.getenv('RAG_DEFAULT_TOP_K', '5'))
# bge-m3 一致性參數(與 ai_insights 簽名計算同源)
RAG_EMBED_MODEL = os.getenv('RAG_EMBED_MODEL', 'bge-m3:latest')
RAG_EMBED_DIM = int(os.getenv('RAG_EMBED_DIM', '1024'))
RAG_EMBED_NORMALIZE = os.getenv('RAG_EMBED_NORMALIZE', 'true').strip().lower() in (
'true', '1', 'yes', 'on',
)
# query_text 寫入長度上限(與 027 CHECK octet_length<=4096 對齊;中文 1 字 3 byte → ~1300 字)
_QUERY_TEXT_MAX_BYTES = 4096
# 連續失敗門檻(與 ai_call_logger 同模式)
_MAX_CONSECUTIVE_FAILURES = 10
_failure_counter_lock = threading.Lock()
_failure_state = {'count': 0, 'killed': False}
def _record_failure() -> None:
with _failure_counter_lock:
_failure_state['count'] += 1
if _failure_state['count'] >= _MAX_CONSECUTIVE_FAILURES and not _failure_state['killed']:
_failure_state['killed'] = True
logger.error(
"[RAGService] consecutive write failures hit %d — kill-switch ON, "
"downgrading rag_query_log writes to logger.info",
_MAX_CONSECUTIVE_FAILURES,
)
def _record_success() -> None:
with _failure_counter_lock:
if _failure_state['count'] > 0:
_failure_state['count'] = 0
def _is_killed() -> bool:
with _failure_counter_lock:
return _failure_state['killed']
def _reset_kill_switch() -> None:
"""測試專用:重置 kill-switch 狀態。"""
with _failure_counter_lock:
_failure_state['count'] = 0
_failure_state['killed'] = False
# ─────────────────────────────────────────────────────────────────────────────
# BGE-M3 一致性簽名v5.0 護欄 #3
# 與 migration 026 註解一致SHA1({model}|{normalize}|{dim}|{ollama_digest})[:12]
# Python 端不查 ollama digest避免每次 query 都 GET /api/show
# 改用 model+normalize+dim 三元組已足以擋住「升級 bge-m3 / 改 normalize」雙寫漂移。
# ─────────────────────────────────────────────────────────────────────────────
def get_embedding_signature(
model: str = RAG_EMBED_MODEL,
dim: int = RAG_EMBED_DIM,
normalize: bool = RAG_EMBED_NORMALIZE,
) -> str:
"""產生 12 碼 BGE-M3 一致性簽名。
與 ai_insights.embedding_signature 比對;不一致 → log warning + 不採該筆。
"""
raw = f"{model}|{str(normalize).lower()}|{dim}"
return hashlib.sha1(raw.encode('utf-8')).hexdigest()[:12]
# ─────────────────────────────────────────────────────────────────────────────
# Phase 11.0 護欄 #3BGE-M3 跨主機一致性啟動驗證ADR-033
# ─────────────────────────────────────────────────────────────────────────────
EMBED_CONSISTENCY_TEST_TEXT = "momo電商競品分析測試向量一致性檢查"
EMBED_CONSISTENCY_MAX_DIFF = 1e-4 # cosine 距離上限(浮點誤差容忍)
EMBED_CONSISTENCY_TIMEOUT_SEC = 10.0 # 各主機 embedding 探測 timeout
def _cosine_distance(vec_a: List[float], vec_b: List[float]) -> float:
"""純 Python cosine distance不依賴 numpy 避免額外 import"""
if not vec_a or not vec_b or len(vec_a) != len(vec_b):
return 1.0
dot = sum(a * b for a, b in zip(vec_a, vec_b))
norm_a = sum(a * a for a in vec_a) ** 0.5
norm_b = sum(b * b for b in vec_b) ** 0.5
if norm_a == 0 or norm_b == 0:
return 1.0
return max(0.0, 1.0 - dot / (norm_a * norm_b))
def verify_embedding_consistency(
test_text: str = EMBED_CONSISTENCY_TEST_TEXT,
max_diff: float = EMBED_CONSISTENCY_MAX_DIFF,
) -> Dict[str, Any]:
"""跨三主機GCP Primary / Secondary / 111BGE-M3 embedding 一致性驗證。
Owen v5.0 護欄 #3ADR-033— RAG 啟動時驗證;不一致則 log warning。
fail-safe任何主機失敗連線、超時都跳過只比對能拿到的 embeddings。
最少 2 個主機可達才能比對;只有 1 個 → 回 ok=True + warning「無法比對」。
回傳:
{
'ok': bool,
'signature': str,
'reachable': [...], # ['gcp_ollama', 'ollama_secondary', 'ollama_111']
'max_diff': float, # 跨主機最大 cosine 距離
'errors': [...],
}
"""
import time
from services.ollama_service import (
OLLAMA_HOST_PRIMARY, OLLAMA_HOST_SECONDARY, OLLAMA_HOST_FALLBACK,
ollama_service,
)
hosts = {
'gcp_ollama': OLLAMA_HOST_PRIMARY,
'ollama_secondary': OLLAMA_HOST_SECONDARY,
'ollama_111': OLLAMA_HOST_FALLBACK,
}
embeddings: Dict[str, List[float]] = {}
errors: List[str] = []
for label, host in hosts.items():
try:
t0 = time.monotonic()
vec = ollama_service.generate_embedding(
text=test_text,
model=RAG_EMBED_MODEL,
host=host, # 顯式指定(避免 retry 鏈干擾驗證)
timeout=int(EMBED_CONSISTENCY_TIMEOUT_SEC),
)
elapsed = time.monotonic() - t0
if vec and len(vec) == RAG_EMBED_DIM:
embeddings[label] = vec
logger.info(f"[EmbedVerify] {label} ({host}) ok in {elapsed:.2f}s, dim={len(vec)}")
else:
errors.append(f"{label}: empty or wrong dim ({len(vec) if vec else 0})")
logger.warning(f"[EmbedVerify] {label} returned empty/wrong-dim vector")
except Exception as exc:
errors.append(f"{label}: {type(exc).__name__}: {str(exc)[:200]}")
logger.warning(f"[EmbedVerify] {label} failed: {exc}")
signature = get_embedding_signature()
reachable = list(embeddings.keys())
if len(embeddings) < 2:
msg = f"only {len(embeddings)} host reachable, cannot cross-verify"
logger.warning(f"[EmbedVerify] {msg}")
return {
'ok': True, # fail-safe1 主機可達不算錯(戰時可能 2 主機暫斷)
'signature': signature,
'reachable': reachable,
'max_diff': 0.0,
'errors': errors + [msg],
}
# 兩兩比對 cosine 距離
import itertools
max_diff_observed = 0.0
for label_a, label_b in itertools.combinations(embeddings, 2):
d = _cosine_distance(embeddings[label_a], embeddings[label_b])
max_diff_observed = max(max_diff_observed, d)
logger.debug(f"[EmbedVerify] {label_a} vs {label_b}: cosine_distance={d:.6f}")
consistent = max_diff_observed <= max_diff
if not consistent:
logger.error(
f"[EmbedVerify] ⚠️ INCONSISTENT! max cosine distance {max_diff_observed:.6f} > {max_diff} "
f"(signature={signature}, reachable={reachable}). "
f"模型版本可能漂移RAG 召回率將下降。"
)
else:
logger.info(
f"[EmbedVerify] ✅ consistent across {len(reachable)} hosts "
f"(max_diff={max_diff_observed:.2e}, signature={signature})"
)
return {
'ok': consistent,
'signature': signature,
'reachable': reachable,
'max_diff': max_diff_observed,
'errors': errors,
}
# ─────────────────────────────────────────────────────────────────────────────
# 結果容器
# ─────────────────────────────────────────────────────────────────────────────
@dataclass
class RAGResult:
"""RAG 查詢結果。caller 透過 has_high_confidence / synthesize() 決定是否走 LLM。"""
query: str
embedding_signature: str
hits: List[Dict[str, Any]] = field(default_factory=list)
threshold: float = RAG_DEFAULT_THRESHOLD
saved_call: bool = False # 是否成功避免 LLM 呼叫caller 確認後設定)
duration_ms: int = 0
log_id: Optional[int] = None # rag_query_log.idfire-and-forget可能為 None
@property
def has_high_confidence(self) -> bool:
"""有至少 1 個 hit 且 top-1 score >= threshold。"""
if not self.hits:
return False
top_score = self.hits[0].get('score', 0.0) or 0.0
return float(top_score) >= self.threshold
def synthesize(self) -> str:
"""組合前 3 筆 hits.content\\n\\n---\\n\\n 分隔,與 OCLearn 既有風格一致)。
caller 拿到後可直接當 LLM 回覆呈現給用戶;避免再次 LLM 呼叫。
"""
if not self.hits:
return ""
parts = []
for h in self.hits[:3]:
content = h.get('content') or ""
if content:
parts.append(content)
return "\n\n---\n\n".join(parts)
# ─────────────────────────────────────────────────────────────────────────────
# 主類別
# ─────────────────────────────────────────────────────────────────────────────
class RAGService:
"""RAG 查詢主入口 — 雙寫 rag_query_log + 回傳 hits。
使用範例:
from services.rag_service import rag_service
result = rag_service.query("本週業績趨勢", caller='openclaw_qa')
if result.has_high_confidence:
return result.synthesize()
# 否則走既有 LLM 路徑
"""
def query(
self,
text: str,
caller: str,
top_k: int = RAG_DEFAULT_TOP_K,
threshold: float = RAG_DEFAULT_THRESHOLD,
request_id: Optional[str] = None,
insight_type: Optional[str] = None,
mark_saved_call: bool = False,
) -> RAGResult:
"""執行 RAG 召回。
Args:
text: 查詢文本(用戶問題或 LLM prompt
caller: 與 ai_calls.caller 同白名單hermes_qa / openclaw_qa / ...
top_k: 召回筆數1-50
threshold: cosine similarity 門檻0-1預設 0.85
request_id: 與 ai_calls.request_id 串鏈
insight_type: 限制 ai_insights.insight_typeNone = 全類型)
mark_saved_call:
True 時,若本次結果 has_high_confidencerag_query_log.saved_call
會寫 True。只給「命中後直接跳過 LLM」的 RAG-first caller 使用,
觀測台相似案例查詢等輔助用途應保留 False。
Returns:
RAGResult。失敗時 hits=[] + duration_ms 仍記錄。
"""
signature = get_embedding_signature()
start = time.monotonic()
# ── 路徑 1feature flag OFF → 短路(不查 DB / 不寫 log──
if not _is_rag_enabled():
return RAGResult(
query=text or "",
embedding_signature=signature,
threshold=threshold,
duration_ms=0,
)
# ── 路徑 2empty text → 早退(避免無謂 embedding 呼叫)──
if not text or not text.strip():
return RAGResult(
query="",
embedding_signature=signature,
threshold=threshold,
duration_ms=int((time.monotonic() - start) * 1000),
)
# 護欄top_k / threshold 範圍夾擠(與 027 CHECK 對齊)
top_k = max(1, min(int(top_k or RAG_DEFAULT_TOP_K), 50))
threshold = max(0.0, min(float(threshold or RAG_DEFAULT_THRESHOLD), 1.0))
# ── 路徑 3embedding ──
query_vec: Optional[List[float]] = None
try:
from services.ollama_service import ollama_service
query_vec = ollama_service.generate_embedding(
text,
model=RAG_EMBED_MODEL,
allow_111_fallback=False,
)
if not query_vec:
logger.warning(
"[RAGService] embedding empty (caller=%s, len=%d) — fallback LLM",
caller, len(text),
)
except Exception as exc:
logger.warning(
"[RAGService] embedding failed (caller=%s): %s — fallback LLM",
caller, exc,
)
hits: List[Dict[str, Any]] = []
# ── 路徑 4DB 召回(只在 embedding 成功時)──
if query_vec:
try:
hits = self._select_hits(
query_vec=query_vec,
threshold=threshold,
top_k=top_k,
insight_type=insight_type,
expected_signature=signature,
)
except Exception as exc:
logger.warning(
"[RAGService] DB select failed (caller=%s): %s — fallback LLM",
caller, exc,
)
hits = []
duration_ms = int((time.monotonic() - start) * 1000)
result = RAGResult(
query=text,
embedding_signature=signature,
hits=hits,
threshold=threshold,
duration_ms=duration_ms,
)
result.saved_call = bool(mark_saved_call and result.has_high_confidence)
# ── 路徑 5fire-and-forget rag_query_log ──
self._async_log(
caller=caller,
text=text,
query_vec=query_vec,
top_k=top_k,
threshold=threshold,
hits=hits,
saved_call=result.saved_call,
request_id=request_id,
)
return result
# ──────────────────────────────────────────────────────────────────────
# DB 召回
# ──────────────────────────────────────────────────────────────────────
def _select_hits(
self,
query_vec: List[float],
threshold: float,
top_k: int,
insight_type: Optional[str],
expected_signature: str,
) -> List[Dict[str, Any]]:
"""從 ai_insights 召回 top_k 筆cosine similarity >= threshold
embedding_signature 不一致的列log warning + 不採該筆v5.0 護欄 #3
"""
from sqlalchemy import text as sa_text
from database.manager import get_session
# cosine_distance = embedding <=> qvec similarity = 1 - distance
# 多取 top_k * 2 緩衝給簽名漂移過濾,最終裁回 top_k
fetch_limit = max(top_k * 2, top_k)
filters = [
"embedding IS NOT NULL",
"status IN ('approved', 'active', 'executed')",
]
params: Dict[str, Any] = {
'qvec': str(query_vec),
'lim': fetch_limit,
'max_distance': 1.0 - threshold,
}
if insight_type:
filters.append("insight_type = :insight_type")
params['insight_type'] = insight_type
sql = sa_text(f"""
SELECT id, insight_type, period, content,
embedding_signature,
embedding <=> CAST(:qvec AS vector) AS distance
FROM ai_insights
WHERE {' AND '.join(filters)}
AND (embedding <=> CAST(:qvec AS vector)) <= :max_distance
ORDER BY distance ASC
LIMIT :lim
""")
session = get_session()
try:
rows = session.execute(sql, params).fetchall()
finally:
session.close()
hits: List[Dict[str, Any]] = []
signature_mismatch = 0
for row in rows:
if len(hits) >= top_k:
break
row_signature = getattr(row, 'embedding_signature', None)
# v5.0 護欄 #3簽名漂移檢查NULL = 既有未回填資料,暫時放行避免戰前資料完全失效)
if row_signature and row_signature != expected_signature:
signature_mismatch += 1
continue
distance = float(row.distance or 1.0)
similarity = 1.0 - distance
hits.append({
'id': int(row.id),
'insight_type': row.insight_type,
'period': row.period,
'content': row.content or '',
'score': round(similarity, 4),
'distance': round(distance, 4),
'embedding_signature': row_signature,
})
if signature_mismatch:
logger.warning(
"[RAGService] %d hits skipped due to embedding_signature mismatch "
"(expected=%s); 建議跑批次回填腳本",
signature_mismatch, expected_signature,
)
return hits
# ──────────────────────────────────────────────────────────────────────
# 反饋Telegram 👍/👎)
# ──────────────────────────────────────────────────────────────────────
def feedback(self, rag_query_log_id: int, score: int) -> bool:
"""寫回 rag_query_log.feedback_score。
Args:
rag_query_log_id: rag_query_log.id
score: 1-51=很沒用5=非常有用常用5=👍1=👎)
Returns:
True 寫入成功False 寫入失敗(不 raise靜默 log warning
"""
if not rag_query_log_id or not isinstance(rag_query_log_id, int):
return False
score = max(1, min(int(score or 0), 5))
try:
from sqlalchemy import text as sa_text
from database.manager import get_session
session = get_session()
try:
session.execute(
sa_text("""
UPDATE rag_query_log
SET feedback_score = :score
WHERE id = :id
"""),
{'score': score, 'id': rag_query_log_id},
)
session.commit()
return True
except Exception:
session.rollback()
raise
finally:
session.close()
except Exception as exc:
logger.warning(
"[RAGService] feedback write failed (id=%s, score=%s): %s",
rag_query_log_id, score, exc,
)
return False
# ──────────────────────────────────────────────────────────────────────
# 預留caller 級失效v5.0 暫無 in-memory cache 層)
# ──────────────────────────────────────────────────────────────────────
def invalidate_by_caller(self, caller: str) -> None:
"""預留鉤caller 的 prompt 模板更新時呼叫。
v5.0 RAG 主要靠 ai_insights 寫入時的 embedding 自動更新,
無 in-memory cache 層 → 此函式為 no-op留 API 一致性給後續 cache layer 啟用。
"""
if caller:
logger.debug("[RAGService] invalidate_by_caller(%s) — no-op (no cache layer yet)", caller)
# ──────────────────────────────────────────────────────────────────────
# fire-and-forget log
# ──────────────────────────────────────────────────────────────────────
def _async_log(
self,
caller: str,
text: str,
query_vec: Optional[List[float]],
top_k: int,
threshold: float,
hits: List[Dict[str, Any]],
request_id: Optional[str] = None,
saved_call: bool = False,
) -> None:
"""放到 daemon thread 寫入 rag_query_log主流程不阻塞。
kill-switch 觸發 → 退化為 logger.info。
"""
if _is_killed():
logger.info(
"[RAGQuery|killed] caller=%s hits=%d threshold=%.3f request_id=%s",
caller, len(hits), threshold, request_id,
)
return
threading.Thread(
target=self._write_log,
args=(caller, text, query_vec, top_k, threshold, hits, request_id, saved_call),
name=f"rag-query-log-{caller}",
daemon=True,
).start()
def _write_log(
self,
caller: str,
text: str,
query_vec: Optional[List[float]],
top_k: int,
threshold: float,
hits: List[Dict[str, Any]],
request_id: Optional[str] = None,
saved_call: bool = False,
) -> None:
"""try/except 全包DB 掛了只 log warning 不爆炸。"""
try:
from sqlalchemy import text as sa_text
from database.manager import get_session
# PII 保護query_text 截 4KB與 027 CHECK 對齊)
safe_text = (text or '')
encoded = safe_text.encode('utf-8', errors='replace')
if len(encoded) > _QUERY_TEXT_MAX_BYTES:
# 截到 byte 邊界後 decode 容錯errors='ignore' 避免 UTF-8 multi-byte 截斷)
safe_text = encoded[:_QUERY_TEXT_MAX_BYTES].decode('utf-8', errors='ignore')
used_results = [int(h['id']) for h in hits if h.get('id')]
embedding_str = str(query_vec) if query_vec else None
embedding_signature = (
get_embedding_signature(dim=len(query_vec))
if query_vec else None
)
session = get_session()
try:
session.execute(
sa_text("""
INSERT INTO rag_query_log (
caller, query_text, query_embedding,
embedding_signature,
top_k, threshold,
hit_count, used_results,
saved_call, request_id
) VALUES (
:caller, :query_text,
CAST(:embedding AS vector),
:embedding_signature,
:top_k, :threshold,
:hit_count, CAST(:used_results AS BIGINT[]),
:saved_call, :request_id
)
"""),
{
'caller': (caller or 'unknown')[:64],
'query_text': safe_text,
'embedding': embedding_str,
'embedding_signature': embedding_signature,
'top_k': int(top_k),
'threshold': round(float(threshold), 3),
'hit_count': len(hits),
'used_results': used_results if used_results else None,
'saved_call': bool(saved_call),
'request_id': (request_id or None),
},
)
session.commit()
_record_success()
except Exception:
session.rollback()
raise
finally:
session.close()
except Exception as exc:
_record_failure()
logger.warning(
"[RAGService] rag_query_log write failed (caller=%s): %s",
caller, exc,
)
# 全域單例(與 ollama_service / ai_call_logger 同模式)
rag_service = RAGService()
__all__ = [
'RAGService',
'RAGResult',
'rag_service',
'get_embedding_signature',
'verify_embedding_consistency',
'is_rag_enabled',
]