#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ services/rag_service.py Operation Ollama-First v5.0 / Phase 11 — RAG 查詢服務 設計原則(憲法級): 1. 純讀 ai_insights + 寫 rag_query_log(不動 ai_insights schema) 2. cosine similarity threshold 預設 0.85,太低不採用避免幻覺 3. embedding 走 bge-m3:latest(與 ai_insights 一致簽名 — Phase 11.0 護欄 #3) 4. feature flag RAG_ENABLED 預設 OFF(避免影響戰前行為) 5. 失敗安全:DB 掛 / embedding 失敗 / threshold 不到 → 回 RAGResult(hits=[]) caller 自行 fallback LLM 6. fire-and-forget log:rag_query_log INSERT async daemon thread,不阻塞主流程 7. PII 保護:query_text 寫入時截 4KB(CHECK constraint 也會擋) 對應: - migrations/027_create_rag_query_log.sql - migrations/026_add_embedding_signature.sql - docs/adr/ADR-029-hermes-first-twin-tower.md - docs/phase0_audit_report_20260503.md Section 3 (BGE-M3 一致性護欄) 主入口: - rag_service.query(...) : 主查詢介面 - rag_service.feedback(log_id, ...) : Telegram 👍/👎 反饋寫回 - rag_service.invalidate_by_caller(...) : 預留快取失效鉤(v5.0 暫無 cache 層) """ from __future__ import annotations import hashlib import logging import os import threading import time from dataclasses import dataclass, field from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) # ───────────────────────────────────────────────────────────────────────────── # Feature flag + 預設參數 # ───────────────────────────────────────────────────────────────────────────── def is_rag_enabled() -> bool: """環境變數即時讀取(允許 runtime toggle,與 ai_call_logger 一致風格)。 feature flag 預設 OFF — 戰前部署後行為與 v4.x 完全相同。 """ val = os.environ.get('RAG_ENABLED', 'false').strip().lower() return val in ('true', '1', 'yes', 'on') # 內部別名(沿用既有 _is_rag_enabled 命名相容性) _is_rag_enabled = is_rag_enabled RAG_DEFAULT_THRESHOLD = float(os.getenv('RAG_DEFAULT_THRESHOLD', '0.85')) RAG_DEFAULT_TOP_K = int(os.getenv('RAG_DEFAULT_TOP_K', '5')) # bge-m3 一致性參數(與 ai_insights 簽名計算同源) RAG_EMBED_MODEL = os.getenv('RAG_EMBED_MODEL', 'bge-m3:latest') RAG_EMBED_DIM = int(os.getenv('RAG_EMBED_DIM', '1024')) RAG_EMBED_NORMALIZE = os.getenv('RAG_EMBED_NORMALIZE', 'true').strip().lower() in ( 'true', '1', 'yes', 'on', ) # query_text 寫入長度上限(與 027 CHECK octet_length<=4096 對齊;中文 1 字 3 byte → ~1300 字) _QUERY_TEXT_MAX_BYTES = 4096 # 連續失敗門檻(與 ai_call_logger 同模式) _MAX_CONSECUTIVE_FAILURES = 10 _failure_counter_lock = threading.Lock() _failure_state = {'count': 0, 'killed': False} def _record_failure() -> None: with _failure_counter_lock: _failure_state['count'] += 1 if _failure_state['count'] >= _MAX_CONSECUTIVE_FAILURES and not _failure_state['killed']: _failure_state['killed'] = True logger.error( "[RAGService] consecutive write failures hit %d — kill-switch ON, " "downgrading rag_query_log writes to logger.info", _MAX_CONSECUTIVE_FAILURES, ) def _record_success() -> None: with _failure_counter_lock: if _failure_state['count'] > 0: _failure_state['count'] = 0 def _is_killed() -> bool: with _failure_counter_lock: return _failure_state['killed'] def _reset_kill_switch() -> None: """測試專用:重置 kill-switch 狀態。""" with _failure_counter_lock: _failure_state['count'] = 0 _failure_state['killed'] = False # ───────────────────────────────────────────────────────────────────────────── # BGE-M3 一致性簽名(v5.0 護欄 #3) # 與 migration 026 註解一致:SHA1({model}|{normalize}|{dim}|{ollama_digest})[:12] # Python 端不查 ollama digest(避免每次 query 都 GET /api/show), # 改用 model+normalize+dim 三元組已足以擋住「升級 bge-m3 / 改 normalize」雙寫漂移。 # ───────────────────────────────────────────────────────────────────────────── def get_embedding_signature( model: str = RAG_EMBED_MODEL, dim: int = RAG_EMBED_DIM, normalize: bool = RAG_EMBED_NORMALIZE, ) -> str: """產生 12 碼 BGE-M3 一致性簽名。 與 ai_insights.embedding_signature 比對;不一致 → log warning + 不採該筆。 """ raw = f"{model}|{str(normalize).lower()}|{dim}" return hashlib.sha1(raw.encode('utf-8')).hexdigest()[:12] # ───────────────────────────────────────────────────────────────────────────── # Phase 11.0 護欄 #3:BGE-M3 跨主機一致性啟動驗證(ADR-033) # ───────────────────────────────────────────────────────────────────────────── EMBED_CONSISTENCY_TEST_TEXT = "momo電商競品分析測試向量一致性檢查" EMBED_CONSISTENCY_MAX_DIFF = 1e-4 # cosine 距離上限(浮點誤差容忍) EMBED_CONSISTENCY_TIMEOUT_SEC = 10.0 # 各主機 embedding 探測 timeout EMBED_CONSISTENCY_INCLUDE_111 = os.getenv( 'EMBED_CONSISTENCY_INCLUDE_111', 'false', ).strip().lower() in ('true', '1', 'yes', 'on') def _cosine_distance(vec_a: List[float], vec_b: List[float]) -> float: """純 Python cosine distance(不依賴 numpy 避免額外 import)""" if not vec_a or not vec_b or len(vec_a) != len(vec_b): return 1.0 dot = sum(a * b for a, b in zip(vec_a, vec_b)) norm_a = sum(a * a for a in vec_a) ** 0.5 norm_b = sum(b * b for b in vec_b) ** 0.5 if norm_a == 0 or norm_b == 0: return 1.0 return max(0.0, 1.0 - dot / (norm_a * norm_b)) def verify_embedding_consistency( test_text: str = EMBED_CONSISTENCY_TEST_TEXT, max_diff: float = EMBED_CONSISTENCY_MAX_DIFF, ) -> Dict[str, Any]: """跨 GCP Ollama 節點 BGE-M3 embedding 一致性驗證。 Owen v5.0 護欄 #3(ADR-033)— RAG 啟動時驗證;不一致則 log warning。 fail-safe:任何主機失敗(連線、超時)都跳過,只比對能拿到的 embeddings。 最少 2 個主機可達才能比對;只有 1 個 → 回 ok=True + warning「無法比對」。 111 是 Mac final fallback,預設不參與背景一致性檢查;只有 EMBED_CONSISTENCY_INCLUDE_111=true 才納入救急驗證,避免載入 bge-m3 壓住 111。 回傳: { 'ok': bool, 'signature': str, 'reachable': [...], # ['gcp_ollama', 'ollama_secondary', 'ollama_111'] 'max_diff': float, # 跨主機最大 cosine 距離 'errors': [...], } """ import time from services.ollama_service import ( OLLAMA_HOST_PRIMARY, OLLAMA_HOST_SECONDARY, OLLAMA_HOST_FALLBACK, ollama_service, ) hosts = { 'gcp_ollama': OLLAMA_HOST_PRIMARY, 'ollama_secondary': OLLAMA_HOST_SECONDARY, } if EMBED_CONSISTENCY_INCLUDE_111: hosts['ollama_111'] = OLLAMA_HOST_FALLBACK embeddings: Dict[str, List[float]] = {} errors: List[str] = [] for label, host in hosts.items(): try: t0 = time.monotonic() vec = ollama_service.generate_embedding( text=test_text, model=RAG_EMBED_MODEL, host=host, # 顯式指定(避免 retry 鏈干擾驗證) timeout=int(EMBED_CONSISTENCY_TIMEOUT_SEC), allow_111_fallback=(label == 'ollama_111'), ) elapsed = time.monotonic() - t0 if vec and len(vec) == RAG_EMBED_DIM: embeddings[label] = vec logger.info(f"[EmbedVerify] {label} ({host}) ok in {elapsed:.2f}s, dim={len(vec)}") else: errors.append(f"{label}: empty or wrong dim ({len(vec) if vec else 0})") logger.warning(f"[EmbedVerify] {label} returned empty/wrong-dim vector") except Exception as exc: errors.append(f"{label}: {type(exc).__name__}: {str(exc)[:200]}") logger.warning(f"[EmbedVerify] {label} failed: {exc}") signature = get_embedding_signature() reachable = list(embeddings.keys()) if len(embeddings) < 2: msg = f"only {len(embeddings)} host reachable, cannot cross-verify" logger.warning(f"[EmbedVerify] {msg}") return { 'ok': True, # fail-safe:1 主機可達不算錯(戰時可能 2 主機暫斷) 'signature': signature, 'reachable': reachable, 'max_diff': 0.0, 'errors': errors + [msg], } # 兩兩比對 cosine 距離 import itertools max_diff_observed = 0.0 for label_a, label_b in itertools.combinations(embeddings, 2): d = _cosine_distance(embeddings[label_a], embeddings[label_b]) max_diff_observed = max(max_diff_observed, d) logger.debug(f"[EmbedVerify] {label_a} vs {label_b}: cosine_distance={d:.6f}") consistent = max_diff_observed <= max_diff if not consistent: logger.error( f"[EmbedVerify] ⚠️ INCONSISTENT! max cosine distance {max_diff_observed:.6f} > {max_diff} " f"(signature={signature}, reachable={reachable}). " f"模型版本可能漂移;RAG 召回率將下降。" ) else: logger.info( f"[EmbedVerify] ✅ consistent across {len(reachable)} hosts " f"(max_diff={max_diff_observed:.2e}, signature={signature})" ) return { 'ok': consistent, 'signature': signature, 'reachable': reachable, 'max_diff': max_diff_observed, 'errors': errors, } # ───────────────────────────────────────────────────────────────────────────── # 結果容器 # ───────────────────────────────────────────────────────────────────────────── @dataclass class RAGResult: """RAG 查詢結果。caller 透過 has_high_confidence / synthesize() 決定是否走 LLM。""" query: str embedding_signature: str hits: List[Dict[str, Any]] = field(default_factory=list) threshold: float = RAG_DEFAULT_THRESHOLD saved_call: bool = False # 是否成功避免 LLM 呼叫(caller 確認後設定) duration_ms: int = 0 log_id: Optional[int] = None # rag_query_log.id(fire-and-forget,可能為 None) @property def has_high_confidence(self) -> bool: """有至少 1 個 hit 且 top-1 score >= threshold。""" if not self.hits: return False top_score = self.hits[0].get('score', 0.0) or 0.0 return float(top_score) >= self.threshold def synthesize(self) -> str: """組合前 3 筆 hits.content(用 \\n\\n---\\n\\n 分隔,與 OCLearn 既有風格一致)。 caller 拿到後可直接當 LLM 回覆呈現給用戶;避免再次 LLM 呼叫。 """ if not self.hits: return "" parts = [] for h in self.hits[:3]: content = h.get('content') or "" if content: parts.append(content) return "\n\n---\n\n".join(parts) # ───────────────────────────────────────────────────────────────────────────── # 主類別 # ───────────────────────────────────────────────────────────────────────────── class RAGService: """RAG 查詢主入口 — 雙寫 rag_query_log + 回傳 hits。 使用範例: from services.rag_service import rag_service result = rag_service.query("本週業績趨勢", caller='openclaw_qa') if result.has_high_confidence: return result.synthesize() # 否則走既有 LLM 路徑 """ def query( self, text: str, caller: str, top_k: int = RAG_DEFAULT_TOP_K, threshold: float = RAG_DEFAULT_THRESHOLD, request_id: Optional[str] = None, insight_type: Optional[str] = None, mark_saved_call: bool = False, ) -> RAGResult: """執行 RAG 召回。 Args: text: 查詢文本(用戶問題或 LLM prompt) caller: 與 ai_calls.caller 同白名單(hermes_qa / openclaw_qa / ...) top_k: 召回筆數(1-50) threshold: cosine similarity 門檻(0-1,預設 0.85) request_id: 與 ai_calls.request_id 串鏈 insight_type: 限制 ai_insights.insight_type(None = 全類型) mark_saved_call: True 時,若本次結果 has_high_confidence,rag_query_log.saved_call 會寫 True。只給「命中後直接跳過 LLM」的 RAG-first caller 使用, 觀測台相似案例查詢等輔助用途應保留 False。 Returns: RAGResult。失敗時 hits=[] + duration_ms 仍記錄。 """ signature = get_embedding_signature() start = time.monotonic() # ── 路徑 1:feature flag OFF → 短路(不查 DB / 不寫 log)── if not _is_rag_enabled(): return RAGResult( query=text or "", embedding_signature=signature, threshold=threshold, duration_ms=0, ) # ── 路徑 2:empty text → 早退(避免無謂 embedding 呼叫)── if not text or not text.strip(): return RAGResult( query="", embedding_signature=signature, threshold=threshold, duration_ms=int((time.monotonic() - start) * 1000), ) # 護欄:top_k / threshold 範圍夾擠(與 027 CHECK 對齊) top_k = max(1, min(int(top_k or RAG_DEFAULT_TOP_K), 50)) threshold = max(0.0, min(float(threshold or RAG_DEFAULT_THRESHOLD), 1.0)) # ── 路徑 3:embedding ── query_vec: Optional[List[float]] = None try: from services.ollama_service import ollama_service query_vec = ollama_service.generate_embedding( text, model=RAG_EMBED_MODEL, allow_111_fallback=False, ) if not query_vec: logger.warning( "[RAGService] embedding empty (caller=%s, len=%d) — fallback LLM", caller, len(text), ) except Exception as exc: logger.warning( "[RAGService] embedding failed (caller=%s): %s — fallback LLM", caller, exc, ) hits: List[Dict[str, Any]] = [] # ── 路徑 4:DB 召回(只在 embedding 成功時)── if query_vec: try: hits = self._select_hits( query_vec=query_vec, threshold=threshold, top_k=top_k, insight_type=insight_type, expected_signature=signature, ) except Exception as exc: logger.warning( "[RAGService] DB select failed (caller=%s): %s — fallback LLM", caller, exc, ) hits = [] duration_ms = int((time.monotonic() - start) * 1000) result = RAGResult( query=text, embedding_signature=signature, hits=hits, threshold=threshold, duration_ms=duration_ms, ) result.saved_call = bool(mark_saved_call and result.has_high_confidence) # ── 路徑 5:fire-and-forget rag_query_log ── self._async_log( caller=caller, text=text, query_vec=query_vec, top_k=top_k, threshold=threshold, hits=hits, saved_call=result.saved_call, request_id=request_id, ) return result # ────────────────────────────────────────────────────────────────────── # DB 召回 # ────────────────────────────────────────────────────────────────────── def _select_hits( self, query_vec: List[float], threshold: float, top_k: int, insight_type: Optional[str], expected_signature: str, ) -> List[Dict[str, Any]]: """從 ai_insights 召回 top_k 筆(cosine similarity >= threshold)。 embedding_signature 不一致的列:log warning + 不採該筆(v5.0 護欄 #3)。 """ from sqlalchemy import text as sa_text from database.manager import get_session # cosine_distance = embedding <=> qvec; similarity = 1 - distance # 多取 top_k * 2 緩衝給簽名漂移過濾,最終裁回 top_k fetch_limit = max(top_k * 2, top_k) filters = [ "embedding IS NOT NULL", "status IN ('approved', 'active', 'executed')", ] params: Dict[str, Any] = { 'qvec': str(query_vec), 'lim': fetch_limit, 'max_distance': 1.0 - threshold, } if insight_type: filters.append("insight_type = :insight_type") params['insight_type'] = insight_type sql = sa_text(f""" SELECT id, insight_type, period, content, embedding_signature, embedding <=> CAST(:qvec AS vector) AS distance FROM ai_insights WHERE {' AND '.join(filters)} AND (embedding <=> CAST(:qvec AS vector)) <= :max_distance ORDER BY distance ASC LIMIT :lim """) session = get_session() try: rows = session.execute(sql, params).fetchall() finally: session.close() hits: List[Dict[str, Any]] = [] signature_mismatch = 0 for row in rows: if len(hits) >= top_k: break row_signature = getattr(row, 'embedding_signature', None) # v5.0 護欄 #3:簽名漂移檢查(NULL = 既有未回填資料,暫時放行避免戰前資料完全失效) if row_signature and row_signature != expected_signature: signature_mismatch += 1 continue distance = float(row.distance or 1.0) similarity = 1.0 - distance hits.append({ 'id': int(row.id), 'insight_type': row.insight_type, 'period': row.period, 'content': row.content or '', 'score': round(similarity, 4), 'distance': round(distance, 4), 'embedding_signature': row_signature, }) if signature_mismatch: logger.warning( "[RAGService] %d hits skipped due to embedding_signature mismatch " "(expected=%s); 建議跑批次回填腳本", signature_mismatch, expected_signature, ) return hits # ────────────────────────────────────────────────────────────────────── # 反饋(Telegram 👍/👎) # ────────────────────────────────────────────────────────────────────── def feedback(self, rag_query_log_id: int, score: int) -> bool: """寫回 rag_query_log.feedback_score。 Args: rag_query_log_id: rag_query_log.id score: 1-5(1=很沒用,5=非常有用;常用:5=👍,1=👎) Returns: True 寫入成功;False 寫入失敗(不 raise,靜默 log warning)。 """ if not rag_query_log_id or not isinstance(rag_query_log_id, int): return False score = max(1, min(int(score or 0), 5)) try: from sqlalchemy import text as sa_text from database.manager import get_session session = get_session() try: session.execute( sa_text(""" UPDATE rag_query_log SET feedback_score = :score WHERE id = :id """), {'score': score, 'id': rag_query_log_id}, ) session.commit() return True except Exception: session.rollback() raise finally: session.close() except Exception as exc: logger.warning( "[RAGService] feedback write failed (id=%s, score=%s): %s", rag_query_log_id, score, exc, ) return False # ────────────────────────────────────────────────────────────────────── # 預留:caller 級失效(v5.0 暫無 in-memory cache 層) # ────────────────────────────────────────────────────────────────────── def invalidate_by_caller(self, caller: str) -> None: """預留鉤:caller 的 prompt 模板更新時呼叫。 v5.0 RAG 主要靠 ai_insights 寫入時的 embedding 自動更新, 無 in-memory cache 層 → 此函式為 no-op,留 API 一致性給後續 cache layer 啟用。 """ if caller: logger.debug("[RAGService] invalidate_by_caller(%s) — no-op (no cache layer yet)", caller) # ────────────────────────────────────────────────────────────────────── # fire-and-forget log # ────────────────────────────────────────────────────────────────────── def _async_log( self, caller: str, text: str, query_vec: Optional[List[float]], top_k: int, threshold: float, hits: List[Dict[str, Any]], request_id: Optional[str] = None, saved_call: bool = False, ) -> None: """放到 daemon thread 寫入 rag_query_log,主流程不阻塞。 kill-switch 觸發 → 退化為 logger.info。 """ if _is_killed(): logger.info( "[RAGQuery|killed] caller=%s hits=%d threshold=%.3f request_id=%s", caller, len(hits), threshold, request_id, ) return threading.Thread( target=self._write_log, args=(caller, text, query_vec, top_k, threshold, hits, request_id, saved_call), name=f"rag-query-log-{caller}", daemon=True, ).start() def _write_log( self, caller: str, text: str, query_vec: Optional[List[float]], top_k: int, threshold: float, hits: List[Dict[str, Any]], request_id: Optional[str] = None, saved_call: bool = False, ) -> None: """try/except 全包;DB 掛了只 log warning 不爆炸。""" try: from sqlalchemy import text as sa_text from database.manager import get_session # PII 保護:query_text 截 4KB(與 027 CHECK 對齊) safe_text = (text or '') encoded = safe_text.encode('utf-8', errors='replace') if len(encoded) > _QUERY_TEXT_MAX_BYTES: # 截到 byte 邊界後 decode 容錯(errors='ignore' 避免 UTF-8 multi-byte 截斷) safe_text = encoded[:_QUERY_TEXT_MAX_BYTES].decode('utf-8', errors='ignore') used_results = [int(h['id']) for h in hits if h.get('id')] embedding_str = str(query_vec) if query_vec else None embedding_signature = ( get_embedding_signature(dim=len(query_vec)) if query_vec else None ) session = get_session() try: session.execute( sa_text(""" INSERT INTO rag_query_log ( caller, query_text, query_embedding, embedding_signature, top_k, threshold, hit_count, used_results, saved_call, request_id ) VALUES ( :caller, :query_text, CAST(:embedding AS vector), :embedding_signature, :top_k, :threshold, :hit_count, CAST(:used_results AS BIGINT[]), :saved_call, :request_id ) """), { 'caller': (caller or 'unknown')[:64], 'query_text': safe_text, 'embedding': embedding_str, 'embedding_signature': embedding_signature, 'top_k': int(top_k), 'threshold': round(float(threshold), 3), 'hit_count': len(hits), 'used_results': used_results if used_results else None, 'saved_call': bool(saved_call), 'request_id': (request_id or None), }, ) session.commit() _record_success() except Exception: session.rollback() raise finally: session.close() except Exception as exc: _record_failure() logger.warning( "[RAGService] rag_query_log write failed (caller=%s): %s", caller, exc, ) # 全域單例(與 ollama_service / ai_call_logger 同模式) rag_service = RAGService() __all__ = [ 'RAGService', 'RAGResult', 'rag_service', 'get_embedding_signature', 'verify_embedding_consistency', 'is_rag_enabled', ]