Files
ewoooc/tests/test_rag_service.py
OoO 353e565e52
All checks were successful
CD Pipeline / deploy (push) Successful in 1m4s
V10.417 protect embedding fallback routing
2026-05-24 14:53:43 +08:00

505 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
tests/test_rag_service.py
Operation Ollama-First v5.0 / Phase 11 — RAGService 單元測試
涵蓋:
- feature flag RAG_ENABLED 預設 OFF → skip 不查 DB / 不寫 log
- RAG_ENABLED=true 時的 hit / miss 分支
- embedding signature 不一致 → log warning + 不採該筆
- fire-and-forget rag_query_log 失敗不影響主流程
- feedback() / invalidate_by_caller() / get_embedding_signature()
Mock 策略:
- ollama_service.generate_embedding 用 monkeypatch
- database.manager.get_session 用 MagicMock 回 session不真的查 DB
"""
from __future__ import annotations
import os
import time
from typing import Any, Dict, List
from unittest.mock import MagicMock, patch
import pytest
# ─────────────────────────────────────────────────────────────────────────────
# 共用 fixture
# ─────────────────────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _reset_rag_kill_switch():
"""每個測試重置 kill-switch避免互相污染。"""
from services.rag_service import _reset_kill_switch
_reset_kill_switch()
yield
_reset_kill_switch()
@pytest.fixture
def rag_disabled(monkeypatch):
monkeypatch.setenv('RAG_ENABLED', 'false')
@pytest.fixture
def rag_enabled(monkeypatch):
monkeypatch.setenv('RAG_ENABLED', 'true')
def _fake_embedding(dim: int = 1024) -> List[float]:
"""產生穩定的假 embedding避免每次 random"""
return [0.01 * (i % 100) for i in range(dim)]
def _make_row_obj(**fields):
"""模擬 SQLAlchemy row支援屬性與 row.id 風格)。"""
obj = MagicMock()
for k, v in fields.items():
setattr(obj, k, v)
return obj
# ─────────────────────────────────────────────────────────────────────────────
# Test 1: feature flag 預設 OFF → 短路
# ─────────────────────────────────────────────────────────────────────────────
class TestFeatureFlagOff:
def test_query_skips_when_disabled(self, rag_disabled):
"""RAG_ENABLED=false → 不查 DB / 不 embed / 回空 hits。"""
from services.rag_service import rag_service
with patch('services.ollama_service.ollama_service.generate_embedding') as mock_embed:
result = rag_service.query("本週業績", caller='test_caller')
assert result.hits == []
assert result.has_high_confidence is False
mock_embed.assert_not_called()
def test_synthesize_returns_empty_when_no_hits(self, rag_disabled):
from services.rag_service import rag_service
result = rag_service.query("本週業績", caller='test_caller')
assert result.synthesize() == ""
def test_default_flag_is_off(self, monkeypatch):
"""RAG_ENABLED 未設 → 預設 OFF。"""
monkeypatch.delenv('RAG_ENABLED', raising=False)
from services.rag_service import is_rag_enabled
assert is_rag_enabled() is False
# ─────────────────────────────────────────────────────────────────────────────
# Test 2: RAG_ENABLED=true 時的 hit / miss
# ─────────────────────────────────────────────────────────────────────────────
class TestRagEnabledHits:
def test_high_confidence_hit_returns_synthesized(self, rag_enabled, monkeypatch):
"""top-1 distance=0.05 → similarity=0.95 >= threshold 0.85 → has_high_confidence。"""
from services import rag_service as rs
monkeypatch.setattr(
'services.ollama_service.ollama_service.generate_embedding',
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
)
sig = rs.get_embedding_signature()
fake_rows = [
_make_row_obj(
id=101, insight_type='weekly_strategy', period='2026-W17',
content='本週業績漲 5%,建議聚焦保濕品類。',
embedding_signature=sig, distance=0.05,
),
_make_row_obj(
id=102, insight_type='weekly_strategy', period='2026-W17',
content='競品 PChome 同類降價 8%',
embedding_signature=sig, distance=0.10,
),
]
fake_session = MagicMock()
fake_session.execute.return_value.fetchall.return_value = fake_rows
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
result = rs.rag_service.query("本週業績趨勢", caller='openclaw_qa')
assert len(result.hits) == 2
assert result.hits[0]['score'] == pytest.approx(0.95, abs=1e-3)
assert result.has_high_confidence is True
assert '本週業績漲' in result.synthesize()
def test_mark_saved_call_only_when_requested_and_confident(self, rag_enabled, monkeypatch):
"""RAG-first caller 明確標記時,高信心命中才計入 saved_call。"""
from services import rag_service as rs
monkeypatch.setattr(
'services.ollama_service.ollama_service.generate_embedding',
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
)
monkeypatch.setattr(
rs.rag_service,
'_select_hits',
lambda **_kw: [{'id': 101, 'content': '命中內容', 'score': 0.95}],
)
captured = {}
def _capture_async_log(**kwargs):
captured.update(kwargs)
monkeypatch.setattr(rs.rag_service, '_async_log', _capture_async_log)
result = rs.rag_service.query(
"本週業績趨勢",
caller='openclaw_qa',
mark_saved_call=True,
)
assert result.saved_call is True
assert captured['saved_call'] is True
result = rs.rag_service.query(
"本週業績趨勢",
caller='admin_quality_trend',
mark_saved_call=False,
)
assert result.saved_call is False
assert captured['saved_call'] is False
def test_low_confidence_no_hit(self, rag_enabled, monkeypatch):
"""所有結果 distance>0.15 → similarity<0.85 → SQL 已過濾,回空 → has_high_confidence False。"""
from services import rag_service as rs
monkeypatch.setattr(
'services.ollama_service.ollama_service.generate_embedding',
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
)
fake_session = MagicMock()
fake_session.execute.return_value.fetchall.return_value = []
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
result = rs.rag_service.query("非常冷僻的問題", caller='openclaw_qa')
assert result.hits == []
assert result.has_high_confidence is False
assert result.synthesize() == ""
def test_empty_query_short_circuits(self, rag_enabled, monkeypatch):
from services import rag_service as rs
# generate_embedding 不應被呼叫
embed_called = {'count': 0}
def _spy(*a, **kw):
embed_called['count'] += 1
return _fake_embedding()
monkeypatch.setattr(
'services.ollama_service.ollama_service.generate_embedding', _spy,
)
result = rs.rag_service.query("", caller='openclaw_qa')
assert result.hits == []
assert embed_called['count'] == 0
# ─────────────────────────────────────────────────────────────────────────────
# Test 3: embedding signature 不一致 → 不採用
# ─────────────────────────────────────────────────────────────────────────────
class TestEmbeddingSignature:
def test_signature_format(self):
from services.rag_service import get_embedding_signature
sig = get_embedding_signature()
# 12 碼 hex
assert len(sig) == 12
assert all(c in '0123456789abcdef' for c in sig)
def test_signature_changes_with_model(self):
from services.rag_service import get_embedding_signature
a = get_embedding_signature(model='bge-m3:latest', dim=1024, normalize=True)
b = get_embedding_signature(model='bge-m3:v2', dim=1024, normalize=True)
c = get_embedding_signature(model='bge-m3:latest', dim=512, normalize=True)
d = get_embedding_signature(model='bge-m3:latest', dim=1024, normalize=False)
assert a != b and a != c and a != d
def test_mismatched_signature_rows_skipped(self, rag_enabled, monkeypatch, caplog):
"""row.embedding_signature != expected → 不採用 + log warning。"""
from services import rag_service as rs
import logging
monkeypatch.setattr(
'services.ollama_service.ollama_service.generate_embedding',
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
)
sig = rs.get_embedding_signature()
fake_rows = [
_make_row_obj(
id=201, insight_type='x', period=None, content='舊簽名資料',
embedding_signature='deadbeef0001', # 不同
distance=0.05,
),
_make_row_obj(
id=202, insight_type='x', period=None, content='新簽名資料',
embedding_signature=sig,
distance=0.06,
),
]
fake_session = MagicMock()
fake_session.execute.return_value.fetchall.return_value = fake_rows
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
with caplog.at_level(logging.WARNING, logger='services.rag_service'):
result = rs.rag_service.query("query", caller='openclaw_qa')
# 只剩 1 筆(簽名一致的)
assert len(result.hits) == 1
assert result.hits[0]['id'] == 202
# warning 有寫
assert any('signature mismatch' in r.message for r in caplog.records)
def test_null_signature_passes_through(self, rag_enabled, monkeypatch):
"""既有未回填的 rowsignature=NULL暫時放行避免戰前資料完全失效。"""
from services import rag_service as rs
monkeypatch.setattr(
'services.ollama_service.ollama_service.generate_embedding',
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
)
fake_rows = [
_make_row_obj(
id=301, insight_type='x', period=None, content='legacy data',
embedding_signature=None,
distance=0.1,
),
]
fake_session = MagicMock()
fake_session.execute.return_value.fetchall.return_value = fake_rows
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
result = rs.rag_service.query("query", caller='openclaw_qa')
assert len(result.hits) == 1
# ─────────────────────────────────────────────────────────────────────────────
# Test 4: fire-and-forget log 失敗不影響主流程
# ─────────────────────────────────────────────────────────────────────────────
class TestFireAndForgetLog:
def test_log_write_failure_does_not_crash(self, rag_enabled, monkeypatch):
"""rag_query_log INSERT 失敗 → main 仍回 RAGResult。"""
from services import rag_service as rs
monkeypatch.setattr(
'services.ollama_service.ollama_service.generate_embedding',
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
)
# SELECT session 正常INSERT session 故意 raise
select_session = MagicMock()
select_session.execute.return_value.fetchall.return_value = []
insert_session = MagicMock()
insert_session.execute.side_effect = RuntimeError("DB unavailable")
sessions_iter = iter([select_session, insert_session])
monkeypatch.setattr(
'database.manager.get_session', lambda: next(sessions_iter),
)
# 應不 raisefire-and-forget thread 內部包 try/except
result = rs.rag_service.query("query", caller='openclaw_qa')
assert result.hits == []
# 等 daemon thread 跑完
time.sleep(0.2)
def test_log_write_includes_embedding_signature(self, monkeypatch):
"""rag_query_log 寫入 query_embedding 時同步保存 BGE-M3 signature。"""
from services import rag_service as rs
captured = {}
fake_session = MagicMock()
def _exec(stmt, params):
captured["sql"] = str(stmt)
captured["params"] = params
return MagicMock()
fake_session.execute.side_effect = _exec
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
rs.rag_service._write_log(
caller='openclaw_qa',
text='query',
query_vec=_fake_embedding(),
top_k=5,
threshold=0.85,
hits=[{'id': 101}],
request_id='req-1',
saved_call=True,
)
assert "embedding_signature" in captured["sql"]
assert captured["params"]["embedding_signature"] == rs.get_embedding_signature()
assert captured["params"]["saved_call"] is True
fake_session.commit.assert_called_once()
def test_embedding_failure_falls_back_to_empty(self, rag_enabled, monkeypatch):
"""embedding 回 [] → 不查 DB → 回空 hits 給 caller fallback LLM。"""
from services import rag_service as rs
monkeypatch.setattr(
'services.ollama_service.ollama_service.generate_embedding',
lambda text, model="bge-m3:latest", **_kwargs: [],
)
# DB 不應被呼叫
called = {'count': 0}
def _spy_session():
called['count'] += 1
return MagicMock()
monkeypatch.setattr('database.manager.get_session', _spy_session)
result = rs.rag_service.query("query", caller='openclaw_qa')
assert result.hits == []
# SELECT 不應發生log 寫入仍會用 session
# rag_query_log 寫入也許會跑embedding=None 仍 log所以 called 可能是 0 或 1
time.sleep(0.2)
# ─────────────────────────────────────────────────────────────────────────────
# Test 5: feedback() / invalidate_by_caller()
# ─────────────────────────────────────────────────────────────────────────────
class TestFeedback:
def test_feedback_writes_score(self, monkeypatch):
from services.rag_service import rag_service
fake_session = MagicMock()
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
ok = rag_service.feedback(123, 5)
assert ok is True
fake_session.execute.assert_called_once()
fake_session.commit.assert_called_once()
def test_feedback_clamps_score(self, monkeypatch):
"""score=99 應被夾擠到 5score=0 應被夾擠到 1。"""
from services.rag_service import rag_service
captured = {}
fake_session = MagicMock()
def _exec(stmt, params):
captured.update(params)
return MagicMock()
fake_session.execute.side_effect = _exec
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
rag_service.feedback(1, 99)
assert captured['score'] == 5
captured.clear()
rag_service.feedback(1, 0)
assert captured['score'] == 1
def test_feedback_invalid_id_returns_false(self):
from services.rag_service import rag_service
assert rag_service.feedback(0, 5) is False
assert rag_service.feedback(None, 5) is False # type: ignore
def test_feedback_db_failure_returns_false(self, monkeypatch):
from services.rag_service import rag_service
fake_session = MagicMock()
fake_session.execute.side_effect = RuntimeError("db fail")
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
ok = rag_service.feedback(123, 5)
assert ok is False
def test_invalidate_by_caller_is_noop(self):
"""v5.0 暫無 cache 層 → no-op不 raise。"""
from services.rag_service import rag_service
rag_service.invalidate_by_caller('openclaw_qa') # 不應 raise
# ─────────────────────────────────────────────────────────────────────────────
# Test 6: RAGResult 屬性與 synthesize
# ─────────────────────────────────────────────────────────────────────────────
class TestRAGResultStructure:
def test_synthesize_takes_top_3(self):
from services.rag_service import RAGResult
result = RAGResult(
query='q', embedding_signature='abc',
hits=[{'content': f'{i} 筆內容'} for i in range(5)],
)
out = result.synthesize()
assert '第 0 筆' in out
assert '第 1 筆' in out
assert '第 2 筆' in out
assert '第 3 筆' not in out
assert '第 4 筆' not in out
def test_has_high_confidence_threshold(self):
from services.rag_service import RAGResult
# top-1 score=0.86 >= 0.85 → True
r1 = RAGResult(
query='q', embedding_signature='abc',
hits=[{'score': 0.86}], threshold=0.85,
)
assert r1.has_high_confidence is True
# top-1 score=0.84 < 0.85 → False
r2 = RAGResult(
query='q', embedding_signature='abc',
hits=[{'score': 0.84}], threshold=0.85,
)
assert r2.has_high_confidence is False
# 無 hits → False
r3 = RAGResult(query='q', embedding_signature='abc', hits=[])
assert r3.has_high_confidence is False
# ─────────────────────────────────────────────────────────────────────────────
# Test 7: top_k / threshold 護欄夾擠
# ─────────────────────────────────────────────────────────────────────────────
class TestParamGuards:
def test_top_k_clamped_to_50(self, rag_enabled, monkeypatch):
from services import rag_service as rs
monkeypatch.setattr(
'services.ollama_service.ollama_service.generate_embedding',
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
)
fake_session = MagicMock()
fake_session.execute.return_value.fetchall.return_value = []
captured = {}
def _exec(stmt, params):
captured.update(params)
return MagicMock(fetchall=lambda: [])
fake_session.execute.side_effect = _exec
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
rs.rag_service.query("query", caller='openclaw_qa', top_k=999)
assert captured.get('lim', 0) <= 100 # top_k=50, fetch_limit = 100
def test_threshold_clamped_to_unit_interval(self, rag_enabled, monkeypatch):
from services import rag_service as rs
monkeypatch.setattr(
'services.ollama_service.ollama_service.generate_embedding',
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
)
fake_session = MagicMock()
captured = {}
def _exec(stmt, params):
captured.update(params)
return MagicMock(fetchall=lambda: [])
fake_session.execute.side_effect = _exec
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
rs.rag_service.query("query", caller='openclaw_qa', threshold=2.0)
# 2.0 → clamp 1.0max_distance = 1.0 - 1.0 = 0.0
assert captured.get('max_distance', 1.0) == pytest.approx(0.0)