659 lines
27 KiB
Python
659 lines
27 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
services/rag_service.py
|
||
Operation Ollama-First v5.0 / Phase 11 — RAG 查詢服務
|
||
|
||
設計原則(憲法級):
|
||
1. 純讀 ai_insights + 寫 rag_query_log(不動 ai_insights schema)
|
||
2. cosine similarity threshold 預設 0.85,太低不採用避免幻覺
|
||
3. embedding 走 bge-m3:latest(與 ai_insights 一致簽名 — Phase 11.0 護欄 #3)
|
||
4. feature flag RAG_ENABLED 預設 OFF(避免影響戰前行為)
|
||
5. 失敗安全:DB 掛 / embedding 失敗 / threshold 不到 → 回 RAGResult(hits=[])
|
||
caller 自行 fallback LLM
|
||
6. fire-and-forget log:rag_query_log INSERT async daemon thread,不阻塞主流程
|
||
7. PII 保護:query_text 寫入時截 4KB(CHECK constraint 也會擋)
|
||
|
||
對應:
|
||
- migrations/027_create_rag_query_log.sql
|
||
- migrations/026_add_embedding_signature.sql
|
||
- docs/adr/ADR-029-hermes-first-twin-tower.md
|
||
- docs/phase0_audit_report_20260503.md Section 3 (BGE-M3 一致性護欄)
|
||
|
||
主入口:
|
||
- rag_service.query(...) : 主查詢介面
|
||
- rag_service.feedback(log_id, ...) : Telegram 👍/👎 反饋寫回
|
||
- rag_service.invalidate_by_caller(...) : 預留快取失效鉤(v5.0 暫無 cache 層)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import hashlib
|
||
import logging
|
||
import os
|
||
import threading
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Feature flag + 預設參數
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
def is_rag_enabled() -> bool:
|
||
"""環境變數即時讀取(允許 runtime toggle,與 ai_call_logger 一致風格)。
|
||
|
||
feature flag 預設 OFF — 戰前部署後行為與 v4.x 完全相同。
|
||
"""
|
||
val = os.environ.get('RAG_ENABLED', 'false').strip().lower()
|
||
return val in ('true', '1', 'yes', 'on')
|
||
|
||
|
||
# 內部別名(沿用既有 _is_rag_enabled 命名相容性)
|
||
_is_rag_enabled = is_rag_enabled
|
||
|
||
|
||
RAG_DEFAULT_THRESHOLD = float(os.getenv('RAG_DEFAULT_THRESHOLD', '0.85'))
|
||
RAG_DEFAULT_TOP_K = int(os.getenv('RAG_DEFAULT_TOP_K', '5'))
|
||
|
||
# bge-m3 一致性參數(與 ai_insights 簽名計算同源)
|
||
RAG_EMBED_MODEL = os.getenv('RAG_EMBED_MODEL', 'bge-m3:latest')
|
||
RAG_EMBED_DIM = int(os.getenv('RAG_EMBED_DIM', '1024'))
|
||
RAG_EMBED_NORMALIZE = os.getenv('RAG_EMBED_NORMALIZE', 'true').strip().lower() in (
|
||
'true', '1', 'yes', 'on',
|
||
)
|
||
|
||
# query_text 寫入長度上限(與 027 CHECK octet_length<=4096 對齊;中文 1 字 3 byte → ~1300 字)
|
||
_QUERY_TEXT_MAX_BYTES = 4096
|
||
|
||
# 連續失敗門檻(與 ai_call_logger 同模式)
|
||
_MAX_CONSECUTIVE_FAILURES = 10
|
||
_failure_counter_lock = threading.Lock()
|
||
_failure_state = {'count': 0, 'killed': False}
|
||
|
||
|
||
def _record_failure() -> None:
|
||
with _failure_counter_lock:
|
||
_failure_state['count'] += 1
|
||
if _failure_state['count'] >= _MAX_CONSECUTIVE_FAILURES and not _failure_state['killed']:
|
||
_failure_state['killed'] = True
|
||
logger.error(
|
||
"[RAGService] consecutive write failures hit %d — kill-switch ON, "
|
||
"downgrading rag_query_log writes to logger.info",
|
||
_MAX_CONSECUTIVE_FAILURES,
|
||
)
|
||
|
||
|
||
def _record_success() -> None:
|
||
with _failure_counter_lock:
|
||
if _failure_state['count'] > 0:
|
||
_failure_state['count'] = 0
|
||
|
||
|
||
def _is_killed() -> bool:
|
||
with _failure_counter_lock:
|
||
return _failure_state['killed']
|
||
|
||
|
||
def _reset_kill_switch() -> None:
|
||
"""測試專用:重置 kill-switch 狀態。"""
|
||
with _failure_counter_lock:
|
||
_failure_state['count'] = 0
|
||
_failure_state['killed'] = False
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# BGE-M3 一致性簽名(v5.0 護欄 #3)
|
||
# 與 migration 026 註解一致:SHA1({model}|{normalize}|{dim}|{ollama_digest})[:12]
|
||
# Python 端不查 ollama digest(避免每次 query 都 GET /api/show),
|
||
# 改用 model+normalize+dim 三元組已足以擋住「升級 bge-m3 / 改 normalize」雙寫漂移。
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
def get_embedding_signature(
|
||
model: str = RAG_EMBED_MODEL,
|
||
dim: int = RAG_EMBED_DIM,
|
||
normalize: bool = RAG_EMBED_NORMALIZE,
|
||
) -> str:
|
||
"""產生 12 碼 BGE-M3 一致性簽名。
|
||
|
||
與 ai_insights.embedding_signature 比對;不一致 → log warning + 不採該筆。
|
||
"""
|
||
raw = f"{model}|{str(normalize).lower()}|{dim}"
|
||
return hashlib.sha1(raw.encode('utf-8')).hexdigest()[:12]
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Phase 11.0 護欄 #3:BGE-M3 跨主機一致性啟動驗證(ADR-033)
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
EMBED_CONSISTENCY_TEST_TEXT = "momo電商競品分析測試向量一致性檢查"
|
||
EMBED_CONSISTENCY_MAX_DIFF = 1e-4 # cosine 距離上限(浮點誤差容忍)
|
||
EMBED_CONSISTENCY_TIMEOUT_SEC = 10.0 # 各主機 embedding 探測 timeout
|
||
|
||
|
||
def _cosine_distance(vec_a: List[float], vec_b: List[float]) -> float:
|
||
"""純 Python cosine distance(不依賴 numpy 避免額外 import)"""
|
||
if not vec_a or not vec_b or len(vec_a) != len(vec_b):
|
||
return 1.0
|
||
dot = sum(a * b for a, b in zip(vec_a, vec_b))
|
||
norm_a = sum(a * a for a in vec_a) ** 0.5
|
||
norm_b = sum(b * b for b in vec_b) ** 0.5
|
||
if norm_a == 0 or norm_b == 0:
|
||
return 1.0
|
||
return max(0.0, 1.0 - dot / (norm_a * norm_b))
|
||
|
||
|
||
def verify_embedding_consistency(
|
||
test_text: str = EMBED_CONSISTENCY_TEST_TEXT,
|
||
max_diff: float = EMBED_CONSISTENCY_MAX_DIFF,
|
||
) -> Dict[str, Any]:
|
||
"""跨三主機(GCP Primary / Secondary / 111)BGE-M3 embedding 一致性驗證。
|
||
|
||
Owen v5.0 護欄 #3(ADR-033)— RAG 啟動時驗證;不一致則 log warning。
|
||
fail-safe:任何主機失敗(連線、超時)都跳過,只比對能拿到的 embeddings。
|
||
最少 2 個主機可達才能比對;只有 1 個 → 回 ok=True + warning「無法比對」。
|
||
|
||
回傳:
|
||
{
|
||
'ok': bool,
|
||
'signature': str,
|
||
'reachable': [...], # ['gcp_ollama', 'ollama_secondary', 'ollama_111']
|
||
'max_diff': float, # 跨主機最大 cosine 距離
|
||
'errors': [...],
|
||
}
|
||
"""
|
||
import time
|
||
from services.ollama_service import (
|
||
OLLAMA_HOST_PRIMARY, OLLAMA_HOST_SECONDARY, OLLAMA_HOST_FALLBACK,
|
||
ollama_service,
|
||
)
|
||
|
||
hosts = {
|
||
'gcp_ollama': OLLAMA_HOST_PRIMARY,
|
||
'ollama_secondary': OLLAMA_HOST_SECONDARY,
|
||
'ollama_111': OLLAMA_HOST_FALLBACK,
|
||
}
|
||
|
||
embeddings: Dict[str, List[float]] = {}
|
||
errors: List[str] = []
|
||
|
||
for label, host in hosts.items():
|
||
try:
|
||
t0 = time.monotonic()
|
||
vec = ollama_service.generate_embedding(
|
||
text=test_text,
|
||
model=RAG_EMBED_MODEL,
|
||
host=host, # 顯式指定(避免 retry 鏈干擾驗證)
|
||
timeout=int(EMBED_CONSISTENCY_TIMEOUT_SEC),
|
||
)
|
||
elapsed = time.monotonic() - t0
|
||
if vec and len(vec) == RAG_EMBED_DIM:
|
||
embeddings[label] = vec
|
||
logger.info(f"[EmbedVerify] {label} ({host}) ok in {elapsed:.2f}s, dim={len(vec)}")
|
||
else:
|
||
errors.append(f"{label}: empty or wrong dim ({len(vec) if vec else 0})")
|
||
logger.warning(f"[EmbedVerify] {label} returned empty/wrong-dim vector")
|
||
except Exception as exc:
|
||
errors.append(f"{label}: {type(exc).__name__}: {str(exc)[:200]}")
|
||
logger.warning(f"[EmbedVerify] {label} failed: {exc}")
|
||
|
||
signature = get_embedding_signature()
|
||
reachable = list(embeddings.keys())
|
||
|
||
if len(embeddings) < 2:
|
||
msg = f"only {len(embeddings)} host reachable, cannot cross-verify"
|
||
logger.warning(f"[EmbedVerify] {msg}")
|
||
return {
|
||
'ok': True, # fail-safe:1 主機可達不算錯(戰時可能 2 主機暫斷)
|
||
'signature': signature,
|
||
'reachable': reachable,
|
||
'max_diff': 0.0,
|
||
'errors': errors + [msg],
|
||
}
|
||
|
||
# 兩兩比對 cosine 距離
|
||
import itertools
|
||
max_diff_observed = 0.0
|
||
for label_a, label_b in itertools.combinations(embeddings, 2):
|
||
d = _cosine_distance(embeddings[label_a], embeddings[label_b])
|
||
max_diff_observed = max(max_diff_observed, d)
|
||
logger.debug(f"[EmbedVerify] {label_a} vs {label_b}: cosine_distance={d:.6f}")
|
||
|
||
consistent = max_diff_observed <= max_diff
|
||
if not consistent:
|
||
logger.error(
|
||
f"[EmbedVerify] ⚠️ INCONSISTENT! max cosine distance {max_diff_observed:.6f} > {max_diff} "
|
||
f"(signature={signature}, reachable={reachable}). "
|
||
f"模型版本可能漂移;RAG 召回率將下降。"
|
||
)
|
||
else:
|
||
logger.info(
|
||
f"[EmbedVerify] ✅ consistent across {len(reachable)} hosts "
|
||
f"(max_diff={max_diff_observed:.2e}, signature={signature})"
|
||
)
|
||
|
||
return {
|
||
'ok': consistent,
|
||
'signature': signature,
|
||
'reachable': reachable,
|
||
'max_diff': max_diff_observed,
|
||
'errors': errors,
|
||
}
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# 結果容器
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
@dataclass
|
||
class RAGResult:
|
||
"""RAG 查詢結果。caller 透過 has_high_confidence / synthesize() 決定是否走 LLM。"""
|
||
|
||
query: str
|
||
embedding_signature: str
|
||
hits: List[Dict[str, Any]] = field(default_factory=list)
|
||
threshold: float = RAG_DEFAULT_THRESHOLD
|
||
saved_call: bool = False # 是否成功避免 LLM 呼叫(caller 確認後設定)
|
||
duration_ms: int = 0
|
||
log_id: Optional[int] = None # rag_query_log.id(fire-and-forget,可能為 None)
|
||
|
||
@property
|
||
def has_high_confidence(self) -> bool:
|
||
"""有至少 1 個 hit 且 top-1 score >= threshold。"""
|
||
if not self.hits:
|
||
return False
|
||
top_score = self.hits[0].get('score', 0.0) or 0.0
|
||
return float(top_score) >= self.threshold
|
||
|
||
def synthesize(self) -> str:
|
||
"""組合前 3 筆 hits.content(用 \\n\\n---\\n\\n 分隔,與 OCLearn 既有風格一致)。
|
||
|
||
caller 拿到後可直接當 LLM 回覆呈現給用戶;避免再次 LLM 呼叫。
|
||
"""
|
||
if not self.hits:
|
||
return ""
|
||
parts = []
|
||
for h in self.hits[:3]:
|
||
content = h.get('content') or ""
|
||
if content:
|
||
parts.append(content)
|
||
return "\n\n---\n\n".join(parts)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# 主類別
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
class RAGService:
|
||
"""RAG 查詢主入口 — 雙寫 rag_query_log + 回傳 hits。
|
||
|
||
使用範例:
|
||
from services.rag_service import rag_service
|
||
result = rag_service.query("本週業績趨勢", caller='openclaw_qa')
|
||
if result.has_high_confidence:
|
||
return result.synthesize()
|
||
# 否則走既有 LLM 路徑
|
||
"""
|
||
|
||
def query(
|
||
self,
|
||
text: str,
|
||
caller: str,
|
||
top_k: int = RAG_DEFAULT_TOP_K,
|
||
threshold: float = RAG_DEFAULT_THRESHOLD,
|
||
request_id: Optional[str] = None,
|
||
insight_type: Optional[str] = None,
|
||
) -> RAGResult:
|
||
"""執行 RAG 召回。
|
||
|
||
Args:
|
||
text: 查詢文本(用戶問題或 LLM prompt)
|
||
caller: 與 ai_calls.caller 同白名單(hermes_qa / openclaw_qa / ...)
|
||
top_k: 召回筆數(1-50)
|
||
threshold: cosine similarity 門檻(0-1,預設 0.85)
|
||
request_id: 與 ai_calls.request_id 串鏈
|
||
insight_type: 限制 ai_insights.insight_type(None = 全類型)
|
||
|
||
Returns:
|
||
RAGResult。失敗時 hits=[] + duration_ms 仍記錄。
|
||
"""
|
||
signature = get_embedding_signature()
|
||
start = time.monotonic()
|
||
|
||
# ── 路徑 1:feature flag OFF → 短路(不查 DB / 不寫 log)──
|
||
if not _is_rag_enabled():
|
||
return RAGResult(
|
||
query=text or "",
|
||
embedding_signature=signature,
|
||
threshold=threshold,
|
||
duration_ms=0,
|
||
)
|
||
|
||
# ── 路徑 2:empty text → 早退(避免無謂 embedding 呼叫)──
|
||
if not text or not text.strip():
|
||
return RAGResult(
|
||
query="",
|
||
embedding_signature=signature,
|
||
threshold=threshold,
|
||
duration_ms=int((time.monotonic() - start) * 1000),
|
||
)
|
||
|
||
# 護欄:top_k / threshold 範圍夾擠(與 027 CHECK 對齊)
|
||
top_k = max(1, min(int(top_k or RAG_DEFAULT_TOP_K), 50))
|
||
threshold = max(0.0, min(float(threshold or RAG_DEFAULT_THRESHOLD), 1.0))
|
||
|
||
# ── 路徑 3:embedding ──
|
||
query_vec: Optional[List[float]] = None
|
||
try:
|
||
from services.ollama_service import ollama_service
|
||
query_vec = ollama_service.generate_embedding(text, model=RAG_EMBED_MODEL)
|
||
if not query_vec:
|
||
logger.warning(
|
||
"[RAGService] embedding empty (caller=%s, len=%d) — fallback LLM",
|
||
caller, len(text),
|
||
)
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"[RAGService] embedding failed (caller=%s): %s — fallback LLM",
|
||
caller, exc,
|
||
)
|
||
|
||
hits: List[Dict[str, Any]] = []
|
||
|
||
# ── 路徑 4:DB 召回(只在 embedding 成功時)──
|
||
if query_vec:
|
||
try:
|
||
hits = self._select_hits(
|
||
query_vec=query_vec,
|
||
threshold=threshold,
|
||
top_k=top_k,
|
||
insight_type=insight_type,
|
||
expected_signature=signature,
|
||
)
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"[RAGService] DB select failed (caller=%s): %s — fallback LLM",
|
||
caller, exc,
|
||
)
|
||
hits = []
|
||
|
||
duration_ms = int((time.monotonic() - start) * 1000)
|
||
result = RAGResult(
|
||
query=text,
|
||
embedding_signature=signature,
|
||
hits=hits,
|
||
threshold=threshold,
|
||
saved_call=False, # caller 確認後再設(_call_after_decision API 預留)
|
||
duration_ms=duration_ms,
|
||
)
|
||
|
||
# ── 路徑 5:fire-and-forget rag_query_log ──
|
||
self._async_log(
|
||
caller=caller,
|
||
text=text,
|
||
query_vec=query_vec,
|
||
top_k=top_k,
|
||
threshold=threshold,
|
||
hits=hits,
|
||
request_id=request_id,
|
||
)
|
||
|
||
return result
|
||
|
||
# ──────────────────────────────────────────────────────────────────────
|
||
# DB 召回
|
||
# ──────────────────────────────────────────────────────────────────────
|
||
def _select_hits(
|
||
self,
|
||
query_vec: List[float],
|
||
threshold: float,
|
||
top_k: int,
|
||
insight_type: Optional[str],
|
||
expected_signature: str,
|
||
) -> List[Dict[str, Any]]:
|
||
"""從 ai_insights 召回 top_k 筆(cosine similarity >= threshold)。
|
||
|
||
embedding_signature 不一致的列:log warning + 不採該筆(v5.0 護欄 #3)。
|
||
"""
|
||
from sqlalchemy import text as sa_text
|
||
from database.manager import get_session
|
||
|
||
# cosine_distance = embedding <=> qvec; similarity = 1 - distance
|
||
# 多取 top_k * 2 緩衝給簽名漂移過濾,最終裁回 top_k
|
||
fetch_limit = max(top_k * 2, top_k)
|
||
filters = [
|
||
"embedding IS NOT NULL",
|
||
"status IN ('approved', 'active', 'executed')",
|
||
]
|
||
params: Dict[str, Any] = {
|
||
'qvec': str(query_vec),
|
||
'lim': fetch_limit,
|
||
'max_distance': 1.0 - threshold,
|
||
}
|
||
if insight_type:
|
||
filters.append("insight_type = :insight_type")
|
||
params['insight_type'] = insight_type
|
||
|
||
sql = sa_text(f"""
|
||
SELECT id, insight_type, period, content,
|
||
embedding_signature,
|
||
embedding <=> CAST(:qvec AS vector) AS distance
|
||
FROM ai_insights
|
||
WHERE {' AND '.join(filters)}
|
||
AND (embedding <=> CAST(:qvec AS vector)) <= :max_distance
|
||
ORDER BY distance ASC
|
||
LIMIT :lim
|
||
""")
|
||
|
||
session = get_session()
|
||
try:
|
||
rows = session.execute(sql, params).fetchall()
|
||
finally:
|
||
session.close()
|
||
|
||
hits: List[Dict[str, Any]] = []
|
||
signature_mismatch = 0
|
||
for row in rows:
|
||
if len(hits) >= top_k:
|
||
break
|
||
row_signature = getattr(row, 'embedding_signature', None)
|
||
# v5.0 護欄 #3:簽名漂移檢查(NULL = 既有未回填資料,暫時放行避免戰前資料完全失效)
|
||
if row_signature and row_signature != expected_signature:
|
||
signature_mismatch += 1
|
||
continue
|
||
distance = float(row.distance or 1.0)
|
||
similarity = 1.0 - distance
|
||
hits.append({
|
||
'id': int(row.id),
|
||
'insight_type': row.insight_type,
|
||
'period': row.period,
|
||
'content': row.content or '',
|
||
'score': round(similarity, 4),
|
||
'distance': round(distance, 4),
|
||
'embedding_signature': row_signature,
|
||
})
|
||
|
||
if signature_mismatch:
|
||
logger.warning(
|
||
"[RAGService] %d hits skipped due to embedding_signature mismatch "
|
||
"(expected=%s); 建議跑批次回填腳本",
|
||
signature_mismatch, expected_signature,
|
||
)
|
||
|
||
return hits
|
||
|
||
# ──────────────────────────────────────────────────────────────────────
|
||
# 反饋(Telegram 👍/👎)
|
||
# ──────────────────────────────────────────────────────────────────────
|
||
def feedback(self, rag_query_log_id: int, score: int) -> bool:
|
||
"""寫回 rag_query_log.feedback_score。
|
||
|
||
Args:
|
||
rag_query_log_id: rag_query_log.id
|
||
score: 1-5(1=很沒用,5=非常有用;常用:5=👍,1=👎)
|
||
|
||
Returns:
|
||
True 寫入成功;False 寫入失敗(不 raise,靜默 log warning)。
|
||
"""
|
||
if not rag_query_log_id or not isinstance(rag_query_log_id, int):
|
||
return False
|
||
score = max(1, min(int(score or 0), 5))
|
||
|
||
try:
|
||
from sqlalchemy import text as sa_text
|
||
from database.manager import get_session
|
||
|
||
session = get_session()
|
||
try:
|
||
session.execute(
|
||
sa_text("""
|
||
UPDATE rag_query_log
|
||
SET feedback_score = :score
|
||
WHERE id = :id
|
||
"""),
|
||
{'score': score, 'id': rag_query_log_id},
|
||
)
|
||
session.commit()
|
||
return True
|
||
except Exception:
|
||
session.rollback()
|
||
raise
|
||
finally:
|
||
session.close()
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"[RAGService] feedback write failed (id=%s, score=%s): %s",
|
||
rag_query_log_id, score, exc,
|
||
)
|
||
return False
|
||
|
||
# ──────────────────────────────────────────────────────────────────────
|
||
# 預留:caller 級失效(v5.0 暫無 in-memory cache 層)
|
||
# ──────────────────────────────────────────────────────────────────────
|
||
def invalidate_by_caller(self, caller: str) -> None:
|
||
"""預留鉤:caller 的 prompt 模板更新時呼叫。
|
||
|
||
v5.0 RAG 主要靠 ai_insights 寫入時的 embedding 自動更新,
|
||
無 in-memory cache 層 → 此函式為 no-op,留 API 一致性給後續 cache layer 啟用。
|
||
"""
|
||
if caller:
|
||
logger.debug("[RAGService] invalidate_by_caller(%s) — no-op (no cache layer yet)", caller)
|
||
|
||
# ──────────────────────────────────────────────────────────────────────
|
||
# fire-and-forget log
|
||
# ──────────────────────────────────────────────────────────────────────
|
||
def _async_log(
|
||
self,
|
||
caller: str,
|
||
text: str,
|
||
query_vec: Optional[List[float]],
|
||
top_k: int,
|
||
threshold: float,
|
||
hits: List[Dict[str, Any]],
|
||
request_id: Optional[str],
|
||
) -> None:
|
||
"""放到 daemon thread 寫入 rag_query_log,主流程不阻塞。
|
||
|
||
kill-switch 觸發 → 退化為 logger.info。
|
||
"""
|
||
if _is_killed():
|
||
logger.info(
|
||
"[RAGQuery|killed] caller=%s hits=%d threshold=%.3f request_id=%s",
|
||
caller, len(hits), threshold, request_id,
|
||
)
|
||
return
|
||
|
||
threading.Thread(
|
||
target=self._write_log,
|
||
args=(caller, text, query_vec, top_k, threshold, hits, request_id),
|
||
name=f"rag-query-log-{caller}",
|
||
daemon=True,
|
||
).start()
|
||
|
||
def _write_log(
|
||
self,
|
||
caller: str,
|
||
text: str,
|
||
query_vec: Optional[List[float]],
|
||
top_k: int,
|
||
threshold: float,
|
||
hits: List[Dict[str, Any]],
|
||
request_id: Optional[str],
|
||
) -> None:
|
||
"""try/except 全包;DB 掛了只 log warning 不爆炸。"""
|
||
try:
|
||
from sqlalchemy import text as sa_text
|
||
from database.manager import get_session
|
||
|
||
# PII 保護:query_text 截 4KB(與 027 CHECK 對齊)
|
||
safe_text = (text or '')
|
||
encoded = safe_text.encode('utf-8', errors='replace')
|
||
if len(encoded) > _QUERY_TEXT_MAX_BYTES:
|
||
# 截到 byte 邊界後 decode 容錯(errors='ignore' 避免 UTF-8 multi-byte 截斷)
|
||
safe_text = encoded[:_QUERY_TEXT_MAX_BYTES].decode('utf-8', errors='ignore')
|
||
|
||
used_results = [int(h['id']) for h in hits if h.get('id')]
|
||
embedding_str = str(query_vec) if query_vec else None
|
||
embedding_signature = (
|
||
get_embedding_signature(dim=len(query_vec))
|
||
if query_vec else None
|
||
)
|
||
|
||
session = get_session()
|
||
try:
|
||
session.execute(
|
||
sa_text("""
|
||
INSERT INTO rag_query_log (
|
||
caller, query_text, query_embedding,
|
||
embedding_signature,
|
||
top_k, threshold,
|
||
hit_count, used_results,
|
||
saved_call, request_id
|
||
) VALUES (
|
||
:caller, :query_text,
|
||
CAST(:embedding AS vector),
|
||
:embedding_signature,
|
||
:top_k, :threshold,
|
||
:hit_count, CAST(:used_results AS BIGINT[]),
|
||
:saved_call, :request_id
|
||
)
|
||
"""),
|
||
{
|
||
'caller': (caller or 'unknown')[:64],
|
||
'query_text': safe_text,
|
||
'embedding': embedding_str,
|
||
'embedding_signature': embedding_signature,
|
||
'top_k': int(top_k),
|
||
'threshold': round(float(threshold), 3),
|
||
'hit_count': len(hits),
|
||
'used_results': used_results if used_results else None,
|
||
'saved_call': False, # caller 確認後另寫;INSERT 階段固定 False
|
||
'request_id': (request_id or None),
|
||
},
|
||
)
|
||
session.commit()
|
||
_record_success()
|
||
except Exception:
|
||
session.rollback()
|
||
raise
|
||
finally:
|
||
session.close()
|
||
except Exception as exc:
|
||
_record_failure()
|
||
logger.warning(
|
||
"[RAGService] rag_query_log write failed (caller=%s): %s",
|
||
caller, exc,
|
||
)
|
||
|
||
|
||
# 全域單例(與 ollama_service / ai_call_logger 同模式)
|
||
rag_service = RAGService()
|
||
|
||
|
||
__all__ = [
|
||
'RAGService',
|
||
'RAGResult',
|
||
'rag_service',
|
||
'get_embedding_signature',
|
||
'verify_embedding_consistency',
|
||
'is_rag_enabled',
|
||
]
|