552 lines
23 KiB
Python
552 lines
23 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
tests/test_rag_service.py
|
||
Operation Ollama-First v5.0 / Phase 11 — RAGService 單元測試
|
||
|
||
涵蓋:
|
||
- feature flag RAG_ENABLED 預設 OFF → skip 不查 DB / 不寫 log
|
||
- RAG_ENABLED=true 時的 hit / miss 分支
|
||
- embedding signature 不一致 → log warning + 不採該筆
|
||
- fire-and-forget rag_query_log 失敗不影響主流程
|
||
- feedback() / invalidate_by_caller() / get_embedding_signature()
|
||
|
||
Mock 策略:
|
||
- ollama_service.generate_embedding 用 monkeypatch
|
||
- database.manager.get_session 用 MagicMock 回 session(不真的查 DB)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
import time
|
||
from typing import Any, Dict, List
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# 共用 fixture
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
@pytest.fixture(autouse=True)
|
||
def _reset_rag_kill_switch():
|
||
"""每個測試重置 kill-switch,避免互相污染。"""
|
||
from services.rag_service import _reset_kill_switch
|
||
_reset_kill_switch()
|
||
yield
|
||
_reset_kill_switch()
|
||
|
||
|
||
@pytest.fixture
|
||
def rag_disabled(monkeypatch):
|
||
monkeypatch.setenv('RAG_ENABLED', 'false')
|
||
|
||
|
||
@pytest.fixture
|
||
def rag_enabled(monkeypatch):
|
||
monkeypatch.setenv('RAG_ENABLED', 'true')
|
||
|
||
|
||
def _fake_embedding(dim: int = 1024) -> List[float]:
|
||
"""產生穩定的假 embedding(避免每次 random)。"""
|
||
return [0.01 * (i % 100) for i in range(dim)]
|
||
|
||
|
||
def _make_row_obj(**fields):
|
||
"""模擬 SQLAlchemy row(支援屬性與 row.id 風格)。"""
|
||
obj = MagicMock()
|
||
for k, v in fields.items():
|
||
setattr(obj, k, v)
|
||
return obj
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Test 1: feature flag 預設 OFF → 短路
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
class TestFeatureFlagOff:
|
||
def test_query_skips_when_disabled(self, rag_disabled):
|
||
"""RAG_ENABLED=false → 不查 DB / 不 embed / 回空 hits。"""
|
||
from services.rag_service import rag_service
|
||
|
||
with patch('services.ollama_service.ollama_service.generate_embedding') as mock_embed:
|
||
result = rag_service.query("本週業績", caller='test_caller')
|
||
|
||
assert result.hits == []
|
||
assert result.has_high_confidence is False
|
||
mock_embed.assert_not_called()
|
||
|
||
def test_synthesize_returns_empty_when_no_hits(self, rag_disabled):
|
||
from services.rag_service import rag_service
|
||
|
||
result = rag_service.query("本週業績", caller='test_caller')
|
||
assert result.synthesize() == ""
|
||
|
||
def test_default_flag_is_off(self, monkeypatch):
|
||
"""RAG_ENABLED 未設 → 預設 OFF。"""
|
||
monkeypatch.delenv('RAG_ENABLED', raising=False)
|
||
from services.rag_service import is_rag_enabled
|
||
assert is_rag_enabled() is False
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Test 2: RAG_ENABLED=true 時的 hit / miss
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
class TestRagEnabledHits:
|
||
def test_high_confidence_hit_returns_synthesized(self, rag_enabled, monkeypatch):
|
||
"""top-1 distance=0.05 → similarity=0.95 >= threshold 0.85 → has_high_confidence。"""
|
||
from services import rag_service as rs
|
||
|
||
monkeypatch.setattr(
|
||
'services.ollama_service.ollama_service.generate_embedding',
|
||
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
|
||
)
|
||
|
||
sig = rs.get_embedding_signature()
|
||
fake_rows = [
|
||
_make_row_obj(
|
||
id=101, insight_type='weekly_strategy', period='2026-W17',
|
||
content='本週業績漲 5%,建議聚焦保濕品類。',
|
||
embedding_signature=sig, distance=0.05,
|
||
),
|
||
_make_row_obj(
|
||
id=102, insight_type='weekly_strategy', period='2026-W17',
|
||
content='競品 PChome 同類降價 8%。',
|
||
embedding_signature=sig, distance=0.10,
|
||
),
|
||
]
|
||
fake_session = MagicMock()
|
||
fake_session.execute.return_value.fetchall.return_value = fake_rows
|
||
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
|
||
|
||
result = rs.rag_service.query("本週業績趨勢", caller='openclaw_qa')
|
||
|
||
assert len(result.hits) == 2
|
||
assert result.hits[0]['score'] == pytest.approx(0.95, abs=1e-3)
|
||
assert result.has_high_confidence is True
|
||
assert '本週業績漲' in result.synthesize()
|
||
|
||
def test_mark_saved_call_only_when_requested_and_confident(self, rag_enabled, monkeypatch):
|
||
"""RAG-first caller 明確標記時,高信心命中才計入 saved_call。"""
|
||
from services import rag_service as rs
|
||
|
||
monkeypatch.setattr(
|
||
'services.ollama_service.ollama_service.generate_embedding',
|
||
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
|
||
)
|
||
monkeypatch.setattr(
|
||
rs.rag_service,
|
||
'_select_hits',
|
||
lambda **_kw: [{'id': 101, 'content': '命中內容', 'score': 0.95}],
|
||
)
|
||
captured = {}
|
||
|
||
def _capture_async_log(**kwargs):
|
||
captured.update(kwargs)
|
||
|
||
monkeypatch.setattr(rs.rag_service, '_async_log', _capture_async_log)
|
||
|
||
result = rs.rag_service.query(
|
||
"本週業績趨勢",
|
||
caller='openclaw_qa',
|
||
mark_saved_call=True,
|
||
)
|
||
|
||
assert result.saved_call is True
|
||
assert captured['saved_call'] is True
|
||
|
||
result = rs.rag_service.query(
|
||
"本週業績趨勢",
|
||
caller='admin_quality_trend',
|
||
mark_saved_call=False,
|
||
)
|
||
|
||
assert result.saved_call is False
|
||
assert captured['saved_call'] is False
|
||
|
||
def test_low_confidence_no_hit(self, rag_enabled, monkeypatch):
|
||
"""所有結果 distance>0.15 → similarity<0.85 → SQL 已過濾,回空 → has_high_confidence False。"""
|
||
from services import rag_service as rs
|
||
|
||
monkeypatch.setattr(
|
||
'services.ollama_service.ollama_service.generate_embedding',
|
||
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
|
||
)
|
||
fake_session = MagicMock()
|
||
fake_session.execute.return_value.fetchall.return_value = []
|
||
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
|
||
|
||
result = rs.rag_service.query("非常冷僻的問題", caller='openclaw_qa')
|
||
|
||
assert result.hits == []
|
||
assert result.has_high_confidence is False
|
||
assert result.synthesize() == ""
|
||
|
||
def test_empty_query_short_circuits(self, rag_enabled, monkeypatch):
|
||
from services import rag_service as rs
|
||
|
||
# generate_embedding 不應被呼叫
|
||
embed_called = {'count': 0}
|
||
|
||
def _spy(*a, **kw):
|
||
embed_called['count'] += 1
|
||
return _fake_embedding()
|
||
|
||
monkeypatch.setattr(
|
||
'services.ollama_service.ollama_service.generate_embedding', _spy,
|
||
)
|
||
|
||
result = rs.rag_service.query("", caller='openclaw_qa')
|
||
assert result.hits == []
|
||
assert embed_called['count'] == 0
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Test 3: embedding signature 不一致 → 不採用
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
class TestEmbeddingSignature:
|
||
def test_signature_format(self):
|
||
from services.rag_service import get_embedding_signature
|
||
sig = get_embedding_signature()
|
||
# 12 碼 hex
|
||
assert len(sig) == 12
|
||
assert all(c in '0123456789abcdef' for c in sig)
|
||
|
||
def test_signature_changes_with_model(self):
|
||
from services.rag_service import get_embedding_signature
|
||
a = get_embedding_signature(model='bge-m3:latest', dim=1024, normalize=True)
|
||
b = get_embedding_signature(model='bge-m3:v2', dim=1024, normalize=True)
|
||
c = get_embedding_signature(model='bge-m3:latest', dim=512, normalize=True)
|
||
d = get_embedding_signature(model='bge-m3:latest', dim=1024, normalize=False)
|
||
assert a != b and a != c and a != d
|
||
|
||
def test_mismatched_signature_rows_skipped(self, rag_enabled, monkeypatch, caplog):
|
||
"""row.embedding_signature != expected → 不採用 + log warning。"""
|
||
from services import rag_service as rs
|
||
import logging
|
||
|
||
monkeypatch.setattr(
|
||
'services.ollama_service.ollama_service.generate_embedding',
|
||
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
|
||
)
|
||
sig = rs.get_embedding_signature()
|
||
fake_rows = [
|
||
_make_row_obj(
|
||
id=201, insight_type='x', period=None, content='舊簽名資料',
|
||
embedding_signature='deadbeef0001', # 不同
|
||
distance=0.05,
|
||
),
|
||
_make_row_obj(
|
||
id=202, insight_type='x', period=None, content='新簽名資料',
|
||
embedding_signature=sig,
|
||
distance=0.06,
|
||
),
|
||
]
|
||
fake_session = MagicMock()
|
||
fake_session.execute.return_value.fetchall.return_value = fake_rows
|
||
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
|
||
|
||
with caplog.at_level(logging.WARNING, logger='services.rag_service'):
|
||
result = rs.rag_service.query("query", caller='openclaw_qa')
|
||
|
||
# 只剩 1 筆(簽名一致的)
|
||
assert len(result.hits) == 1
|
||
assert result.hits[0]['id'] == 202
|
||
# warning 有寫
|
||
assert any('signature mismatch' in r.message for r in caplog.records)
|
||
|
||
def test_null_signature_passes_through(self, rag_enabled, monkeypatch):
|
||
"""既有未回填的 row(signature=NULL)暫時放行避免戰前資料完全失效。"""
|
||
from services import rag_service as rs
|
||
|
||
monkeypatch.setattr(
|
||
'services.ollama_service.ollama_service.generate_embedding',
|
||
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
|
||
)
|
||
fake_rows = [
|
||
_make_row_obj(
|
||
id=301, insight_type='x', period=None, content='legacy data',
|
||
embedding_signature=None,
|
||
distance=0.1,
|
||
),
|
||
]
|
||
fake_session = MagicMock()
|
||
fake_session.execute.return_value.fetchall.return_value = fake_rows
|
||
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
|
||
|
||
result = rs.rag_service.query("query", caller='openclaw_qa')
|
||
assert len(result.hits) == 1
|
||
|
||
|
||
class TestEmbeddingConsistencyRouting:
|
||
def test_consistency_check_skips_111_by_default(self, monkeypatch):
|
||
from services import rag_service as rs
|
||
from services import ollama_service as oss
|
||
|
||
calls = []
|
||
|
||
def fake_embed(text, model, host, timeout, **kwargs):
|
||
calls.append((host, kwargs))
|
||
return _fake_embedding()
|
||
|
||
monkeypatch.setattr(rs, 'EMBED_CONSISTENCY_INCLUDE_111', False)
|
||
monkeypatch.setattr(oss.ollama_service, 'generate_embedding', fake_embed)
|
||
|
||
result = rs.verify_embedding_consistency()
|
||
|
||
hosts = [host for host, _kwargs in calls]
|
||
assert result['ok'] is True
|
||
assert hosts == [oss.OLLAMA_HOST_PRIMARY, oss.OLLAMA_HOST_SECONDARY]
|
||
assert oss.OLLAMA_HOST_FALLBACK not in hosts
|
||
assert all(kwargs.get('allow_111_fallback') is False for _host, kwargs in calls)
|
||
|
||
def test_consistency_check_can_include_111_when_explicitly_enabled(self, monkeypatch):
|
||
from services import rag_service as rs
|
||
from services import ollama_service as oss
|
||
|
||
calls = []
|
||
|
||
def fake_embed(text, model, host, timeout, **kwargs):
|
||
calls.append((host, kwargs))
|
||
return _fake_embedding()
|
||
|
||
monkeypatch.setattr(rs, 'EMBED_CONSISTENCY_INCLUDE_111', True)
|
||
monkeypatch.setattr(oss.ollama_service, 'generate_embedding', fake_embed)
|
||
|
||
result = rs.verify_embedding_consistency()
|
||
|
||
hosts = [host for host, _kwargs in calls]
|
||
assert result['ok'] is True
|
||
assert hosts == [
|
||
oss.OLLAMA_HOST_PRIMARY,
|
||
oss.OLLAMA_HOST_SECONDARY,
|
||
oss.OLLAMA_HOST_FALLBACK,
|
||
]
|
||
assert calls[-1][1].get('allow_111_fallback') is True
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Test 4: fire-and-forget log 失敗不影響主流程
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
class TestFireAndForgetLog:
|
||
def test_log_write_failure_does_not_crash(self, rag_enabled, monkeypatch):
|
||
"""rag_query_log INSERT 失敗 → main 仍回 RAGResult。"""
|
||
from services import rag_service as rs
|
||
|
||
monkeypatch.setattr(
|
||
'services.ollama_service.ollama_service.generate_embedding',
|
||
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
|
||
)
|
||
# SELECT session 正常;INSERT session 故意 raise
|
||
select_session = MagicMock()
|
||
select_session.execute.return_value.fetchall.return_value = []
|
||
insert_session = MagicMock()
|
||
insert_session.execute.side_effect = RuntimeError("DB unavailable")
|
||
|
||
sessions_iter = iter([select_session, insert_session])
|
||
monkeypatch.setattr(
|
||
'database.manager.get_session', lambda: next(sessions_iter),
|
||
)
|
||
|
||
# 應不 raise(fire-and-forget thread 內部包 try/except)
|
||
result = rs.rag_service.query("query", caller='openclaw_qa')
|
||
assert result.hits == []
|
||
# 等 daemon thread 跑完
|
||
time.sleep(0.2)
|
||
|
||
def test_log_write_includes_embedding_signature(self, monkeypatch):
|
||
"""rag_query_log 寫入 query_embedding 時同步保存 BGE-M3 signature。"""
|
||
from services import rag_service as rs
|
||
|
||
captured = {}
|
||
fake_session = MagicMock()
|
||
|
||
def _exec(stmt, params):
|
||
captured["sql"] = str(stmt)
|
||
captured["params"] = params
|
||
return MagicMock()
|
||
|
||
fake_session.execute.side_effect = _exec
|
||
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
|
||
|
||
rs.rag_service._write_log(
|
||
caller='openclaw_qa',
|
||
text='query',
|
||
query_vec=_fake_embedding(),
|
||
top_k=5,
|
||
threshold=0.85,
|
||
hits=[{'id': 101}],
|
||
request_id='req-1',
|
||
saved_call=True,
|
||
)
|
||
|
||
assert "embedding_signature" in captured["sql"]
|
||
assert captured["params"]["embedding_signature"] == rs.get_embedding_signature()
|
||
assert captured["params"]["saved_call"] is True
|
||
fake_session.commit.assert_called_once()
|
||
|
||
def test_embedding_failure_falls_back_to_empty(self, rag_enabled, monkeypatch):
|
||
"""embedding 回 [] → 不查 DB → 回空 hits 給 caller fallback LLM。"""
|
||
from services import rag_service as rs
|
||
|
||
monkeypatch.setattr(
|
||
'services.ollama_service.ollama_service.generate_embedding',
|
||
lambda text, model="bge-m3:latest", **_kwargs: [],
|
||
)
|
||
# DB 不應被呼叫
|
||
called = {'count': 0}
|
||
|
||
def _spy_session():
|
||
called['count'] += 1
|
||
return MagicMock()
|
||
|
||
monkeypatch.setattr('database.manager.get_session', _spy_session)
|
||
|
||
result = rs.rag_service.query("query", caller='openclaw_qa')
|
||
assert result.hits == []
|
||
# SELECT 不應發生(log 寫入仍會用 session)
|
||
# rag_query_log 寫入也許會跑(embedding=None 仍 log),所以 called 可能是 0 或 1
|
||
time.sleep(0.2)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Test 5: feedback() / invalidate_by_caller()
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
class TestFeedback:
|
||
def test_feedback_writes_score(self, monkeypatch):
|
||
from services.rag_service import rag_service
|
||
|
||
fake_session = MagicMock()
|
||
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
|
||
|
||
ok = rag_service.feedback(123, 5)
|
||
assert ok is True
|
||
fake_session.execute.assert_called_once()
|
||
fake_session.commit.assert_called_once()
|
||
|
||
def test_feedback_clamps_score(self, monkeypatch):
|
||
"""score=99 應被夾擠到 5;score=0 應被夾擠到 1。"""
|
||
from services.rag_service import rag_service
|
||
|
||
captured = {}
|
||
fake_session = MagicMock()
|
||
|
||
def _exec(stmt, params):
|
||
captured.update(params)
|
||
return MagicMock()
|
||
|
||
fake_session.execute.side_effect = _exec
|
||
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
|
||
|
||
rag_service.feedback(1, 99)
|
||
assert captured['score'] == 5
|
||
|
||
captured.clear()
|
||
rag_service.feedback(1, 0)
|
||
assert captured['score'] == 1
|
||
|
||
def test_feedback_invalid_id_returns_false(self):
|
||
from services.rag_service import rag_service
|
||
assert rag_service.feedback(0, 5) is False
|
||
assert rag_service.feedback(None, 5) is False # type: ignore
|
||
|
||
def test_feedback_db_failure_returns_false(self, monkeypatch):
|
||
from services.rag_service import rag_service
|
||
|
||
fake_session = MagicMock()
|
||
fake_session.execute.side_effect = RuntimeError("db fail")
|
||
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
|
||
|
||
ok = rag_service.feedback(123, 5)
|
||
assert ok is False
|
||
|
||
def test_invalidate_by_caller_is_noop(self):
|
||
"""v5.0 暫無 cache 層 → no-op,不 raise。"""
|
||
from services.rag_service import rag_service
|
||
rag_service.invalidate_by_caller('openclaw_qa') # 不應 raise
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Test 6: RAGResult 屬性與 synthesize
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
class TestRAGResultStructure:
|
||
def test_synthesize_takes_top_3(self):
|
||
from services.rag_service import RAGResult
|
||
result = RAGResult(
|
||
query='q', embedding_signature='abc',
|
||
hits=[{'content': f'第 {i} 筆內容'} for i in range(5)],
|
||
)
|
||
out = result.synthesize()
|
||
assert '第 0 筆' in out
|
||
assert '第 1 筆' in out
|
||
assert '第 2 筆' in out
|
||
assert '第 3 筆' not in out
|
||
assert '第 4 筆' not in out
|
||
|
||
def test_has_high_confidence_threshold(self):
|
||
from services.rag_service import RAGResult
|
||
# top-1 score=0.86 >= 0.85 → True
|
||
r1 = RAGResult(
|
||
query='q', embedding_signature='abc',
|
||
hits=[{'score': 0.86}], threshold=0.85,
|
||
)
|
||
assert r1.has_high_confidence is True
|
||
|
||
# top-1 score=0.84 < 0.85 → False
|
||
r2 = RAGResult(
|
||
query='q', embedding_signature='abc',
|
||
hits=[{'score': 0.84}], threshold=0.85,
|
||
)
|
||
assert r2.has_high_confidence is False
|
||
|
||
# 無 hits → False
|
||
r3 = RAGResult(query='q', embedding_signature='abc', hits=[])
|
||
assert r3.has_high_confidence is False
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Test 7: top_k / threshold 護欄夾擠
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
class TestParamGuards:
|
||
def test_top_k_clamped_to_50(self, rag_enabled, monkeypatch):
|
||
from services import rag_service as rs
|
||
|
||
monkeypatch.setattr(
|
||
'services.ollama_service.ollama_service.generate_embedding',
|
||
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
|
||
)
|
||
fake_session = MagicMock()
|
||
fake_session.execute.return_value.fetchall.return_value = []
|
||
captured = {}
|
||
|
||
def _exec(stmt, params):
|
||
captured.update(params)
|
||
return MagicMock(fetchall=lambda: [])
|
||
|
||
fake_session.execute.side_effect = _exec
|
||
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
|
||
|
||
rs.rag_service.query("query", caller='openclaw_qa', top_k=999)
|
||
assert captured.get('lim', 0) <= 100 # top_k=50, fetch_limit = 100
|
||
|
||
def test_threshold_clamped_to_unit_interval(self, rag_enabled, monkeypatch):
|
||
from services import rag_service as rs
|
||
|
||
monkeypatch.setattr(
|
||
'services.ollama_service.ollama_service.generate_embedding',
|
||
lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(),
|
||
)
|
||
fake_session = MagicMock()
|
||
captured = {}
|
||
|
||
def _exec(stmt, params):
|
||
captured.update(params)
|
||
return MagicMock(fetchall=lambda: [])
|
||
|
||
fake_session.execute.side_effect = _exec
|
||
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
|
||
|
||
rs.rag_service.query("query", caller='openclaw_qa', threshold=2.0)
|
||
# 2.0 → clamp 1.0;max_distance = 1.0 - 1.0 = 0.0
|
||
assert captured.get('max_distance', 1.0) == pytest.approx(0.0)
|