#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ tests/test_rag_service.py Operation Ollama-First v5.0 / Phase 11 — RAGService 單元測試 涵蓋: - feature flag RAG_ENABLED 預設 OFF → skip 不查 DB / 不寫 log - RAG_ENABLED=true 時的 hit / miss 分支 - embedding signature 不一致 → log warning + 不採該筆 - fire-and-forget rag_query_log 失敗不影響主流程 - feedback() / invalidate_by_caller() / get_embedding_signature() Mock 策略: - ollama_service.generate_embedding 用 monkeypatch - database.manager.get_session 用 MagicMock 回 session(不真的查 DB) """ from __future__ import annotations import os import time from typing import Any, Dict, List from unittest.mock import MagicMock, patch import pytest # ───────────────────────────────────────────────────────────────────────────── # 共用 fixture # ───────────────────────────────────────────────────────────────────────────── @pytest.fixture(autouse=True) def _reset_rag_kill_switch(): """每個測試重置 kill-switch,避免互相污染。""" from services.rag_service import _reset_kill_switch _reset_kill_switch() yield _reset_kill_switch() @pytest.fixture def rag_disabled(monkeypatch): monkeypatch.setenv('RAG_ENABLED', 'false') @pytest.fixture def rag_enabled(monkeypatch): monkeypatch.setenv('RAG_ENABLED', 'true') def _fake_embedding(dim: int = 1024) -> List[float]: """產生穩定的假 embedding(避免每次 random)。""" return [0.01 * (i % 100) for i in range(dim)] def _make_row_obj(**fields): """模擬 SQLAlchemy row(支援屬性與 row.id 風格)。""" obj = MagicMock() for k, v in fields.items(): setattr(obj, k, v) return obj # ───────────────────────────────────────────────────────────────────────────── # Test 1: feature flag 預設 OFF → 短路 # ───────────────────────────────────────────────────────────────────────────── class TestFeatureFlagOff: def test_query_skips_when_disabled(self, rag_disabled): """RAG_ENABLED=false → 不查 DB / 不 embed / 回空 hits。""" from services.rag_service import rag_service with patch('services.ollama_service.ollama_service.generate_embedding') as mock_embed: result = rag_service.query("本週業績", caller='test_caller') assert result.hits == [] assert result.has_high_confidence is False mock_embed.assert_not_called() def test_synthesize_returns_empty_when_no_hits(self, rag_disabled): from services.rag_service import rag_service result = rag_service.query("本週業績", caller='test_caller') assert result.synthesize() == "" def test_default_flag_is_off(self, monkeypatch): """RAG_ENABLED 未設 → 預設 OFF。""" monkeypatch.delenv('RAG_ENABLED', raising=False) from services.rag_service import is_rag_enabled assert is_rag_enabled() is False # ───────────────────────────────────────────────────────────────────────────── # Test 2: RAG_ENABLED=true 時的 hit / miss # ───────────────────────────────────────────────────────────────────────────── class TestRagEnabledHits: def test_high_confidence_hit_returns_synthesized(self, rag_enabled, monkeypatch): """top-1 distance=0.05 → similarity=0.95 >= threshold 0.85 → has_high_confidence。""" from services import rag_service as rs monkeypatch.setattr( 'services.ollama_service.ollama_service.generate_embedding', lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(), ) sig = rs.get_embedding_signature() fake_rows = [ _make_row_obj( id=101, insight_type='weekly_strategy', period='2026-W17', content='本週業績漲 5%,建議聚焦保濕品類。', embedding_signature=sig, distance=0.05, ), _make_row_obj( id=102, insight_type='weekly_strategy', period='2026-W17', content='競品 PChome 同類降價 8%。', embedding_signature=sig, distance=0.10, ), ] fake_session = MagicMock() fake_session.execute.return_value.fetchall.return_value = fake_rows monkeypatch.setattr('database.manager.get_session', lambda: fake_session) result = rs.rag_service.query("本週業績趨勢", caller='openclaw_qa') assert len(result.hits) == 2 assert result.hits[0]['score'] == pytest.approx(0.95, abs=1e-3) assert result.has_high_confidence is True assert '本週業績漲' in result.synthesize() def test_mark_saved_call_only_when_requested_and_confident(self, rag_enabled, monkeypatch): """RAG-first caller 明確標記時,高信心命中才計入 saved_call。""" from services import rag_service as rs monkeypatch.setattr( 'services.ollama_service.ollama_service.generate_embedding', lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(), ) monkeypatch.setattr( rs.rag_service, '_select_hits', lambda **_kw: [{'id': 101, 'content': '命中內容', 'score': 0.95}], ) captured = {} def _capture_async_log(**kwargs): captured.update(kwargs) monkeypatch.setattr(rs.rag_service, '_async_log', _capture_async_log) result = rs.rag_service.query( "本週業績趨勢", caller='openclaw_qa', mark_saved_call=True, ) assert result.saved_call is True assert captured['saved_call'] is True result = rs.rag_service.query( "本週業績趨勢", caller='admin_quality_trend', mark_saved_call=False, ) assert result.saved_call is False assert captured['saved_call'] is False def test_low_confidence_no_hit(self, rag_enabled, monkeypatch): """所有結果 distance>0.15 → similarity<0.85 → SQL 已過濾,回空 → has_high_confidence False。""" from services import rag_service as rs monkeypatch.setattr( 'services.ollama_service.ollama_service.generate_embedding', lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(), ) fake_session = MagicMock() fake_session.execute.return_value.fetchall.return_value = [] monkeypatch.setattr('database.manager.get_session', lambda: fake_session) result = rs.rag_service.query("非常冷僻的問題", caller='openclaw_qa') assert result.hits == [] assert result.has_high_confidence is False assert result.synthesize() == "" def test_empty_query_short_circuits(self, rag_enabled, monkeypatch): from services import rag_service as rs # generate_embedding 不應被呼叫 embed_called = {'count': 0} def _spy(*a, **kw): embed_called['count'] += 1 return _fake_embedding() monkeypatch.setattr( 'services.ollama_service.ollama_service.generate_embedding', _spy, ) result = rs.rag_service.query("", caller='openclaw_qa') assert result.hits == [] assert embed_called['count'] == 0 # ───────────────────────────────────────────────────────────────────────────── # Test 3: embedding signature 不一致 → 不採用 # ───────────────────────────────────────────────────────────────────────────── class TestEmbeddingSignature: def test_signature_format(self): from services.rag_service import get_embedding_signature sig = get_embedding_signature() # 12 碼 hex assert len(sig) == 12 assert all(c in '0123456789abcdef' for c in sig) def test_signature_changes_with_model(self): from services.rag_service import get_embedding_signature a = get_embedding_signature(model='bge-m3:latest', dim=1024, normalize=True) b = get_embedding_signature(model='bge-m3:v2', dim=1024, normalize=True) c = get_embedding_signature(model='bge-m3:latest', dim=512, normalize=True) d = get_embedding_signature(model='bge-m3:latest', dim=1024, normalize=False) assert a != b and a != c and a != d def test_mismatched_signature_rows_skipped(self, rag_enabled, monkeypatch, caplog): """row.embedding_signature != expected → 不採用 + log warning。""" from services import rag_service as rs import logging monkeypatch.setattr( 'services.ollama_service.ollama_service.generate_embedding', lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(), ) sig = rs.get_embedding_signature() fake_rows = [ _make_row_obj( id=201, insight_type='x', period=None, content='舊簽名資料', embedding_signature='deadbeef0001', # 不同 distance=0.05, ), _make_row_obj( id=202, insight_type='x', period=None, content='新簽名資料', embedding_signature=sig, distance=0.06, ), ] fake_session = MagicMock() fake_session.execute.return_value.fetchall.return_value = fake_rows monkeypatch.setattr('database.manager.get_session', lambda: fake_session) with caplog.at_level(logging.WARNING, logger='services.rag_service'): result = rs.rag_service.query("query", caller='openclaw_qa') # 只剩 1 筆(簽名一致的) assert len(result.hits) == 1 assert result.hits[0]['id'] == 202 # warning 有寫 assert any('signature mismatch' in r.message for r in caplog.records) def test_null_signature_passes_through(self, rag_enabled, monkeypatch): """既有未回填的 row(signature=NULL)暫時放行避免戰前資料完全失效。""" from services import rag_service as rs monkeypatch.setattr( 'services.ollama_service.ollama_service.generate_embedding', lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(), ) fake_rows = [ _make_row_obj( id=301, insight_type='x', period=None, content='legacy data', embedding_signature=None, distance=0.1, ), ] fake_session = MagicMock() fake_session.execute.return_value.fetchall.return_value = fake_rows monkeypatch.setattr('database.manager.get_session', lambda: fake_session) result = rs.rag_service.query("query", caller='openclaw_qa') assert len(result.hits) == 1 class TestEmbeddingConsistencyRouting: def test_consistency_check_skips_111_by_default(self, monkeypatch): from services import rag_service as rs from services import ollama_service as oss calls = [] def fake_embed(text, model, host, timeout, **kwargs): calls.append((host, kwargs)) return _fake_embedding() monkeypatch.setattr(rs, 'EMBED_CONSISTENCY_INCLUDE_111', False) monkeypatch.setattr(oss.ollama_service, 'generate_embedding', fake_embed) result = rs.verify_embedding_consistency() hosts = [host for host, _kwargs in calls] assert result['ok'] is True assert hosts == [oss.OLLAMA_HOST_PRIMARY, oss.OLLAMA_HOST_SECONDARY] assert oss.OLLAMA_HOST_FALLBACK not in hosts assert all(kwargs.get('allow_111_fallback') is False for _host, kwargs in calls) def test_consistency_check_can_include_111_when_explicitly_enabled(self, monkeypatch): from services import rag_service as rs from services import ollama_service as oss calls = [] def fake_embed(text, model, host, timeout, **kwargs): calls.append((host, kwargs)) return _fake_embedding() monkeypatch.setattr(rs, 'EMBED_CONSISTENCY_INCLUDE_111', True) monkeypatch.setattr(oss.ollama_service, 'generate_embedding', fake_embed) result = rs.verify_embedding_consistency() hosts = [host for host, _kwargs in calls] assert result['ok'] is True assert hosts == [ oss.OLLAMA_HOST_PRIMARY, oss.OLLAMA_HOST_SECONDARY, oss.OLLAMA_HOST_FALLBACK, ] assert calls[-1][1].get('allow_111_fallback') is True # ───────────────────────────────────────────────────────────────────────────── # Test 4: fire-and-forget log 失敗不影響主流程 # ───────────────────────────────────────────────────────────────────────────── class TestFireAndForgetLog: def test_log_write_failure_does_not_crash(self, rag_enabled, monkeypatch): """rag_query_log INSERT 失敗 → main 仍回 RAGResult。""" from services import rag_service as rs monkeypatch.setattr( 'services.ollama_service.ollama_service.generate_embedding', lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(), ) # SELECT session 正常;INSERT session 故意 raise select_session = MagicMock() select_session.execute.return_value.fetchall.return_value = [] insert_session = MagicMock() insert_session.execute.side_effect = RuntimeError("DB unavailable") sessions_iter = iter([select_session, insert_session]) monkeypatch.setattr( 'database.manager.get_session', lambda: next(sessions_iter), ) # 應不 raise(fire-and-forget thread 內部包 try/except) result = rs.rag_service.query("query", caller='openclaw_qa') assert result.hits == [] # 等 daemon thread 跑完 time.sleep(0.2) def test_log_write_includes_embedding_signature(self, monkeypatch): """rag_query_log 寫入 query_embedding 時同步保存 BGE-M3 signature。""" from services import rag_service as rs captured = {} fake_session = MagicMock() def _exec(stmt, params): captured["sql"] = str(stmt) captured["params"] = params return MagicMock() fake_session.execute.side_effect = _exec monkeypatch.setattr('database.manager.get_session', lambda: fake_session) rs.rag_service._write_log( caller='openclaw_qa', text='query', query_vec=_fake_embedding(), top_k=5, threshold=0.85, hits=[{'id': 101}], request_id='req-1', saved_call=True, ) assert "embedding_signature" in captured["sql"] assert captured["params"]["embedding_signature"] == rs.get_embedding_signature() assert captured["params"]["saved_call"] is True fake_session.commit.assert_called_once() def test_embedding_failure_falls_back_to_empty(self, rag_enabled, monkeypatch): """embedding 回 [] → 不查 DB → 回空 hits 給 caller fallback LLM。""" from services import rag_service as rs monkeypatch.setattr( 'services.ollama_service.ollama_service.generate_embedding', lambda text, model="bge-m3:latest", **_kwargs: [], ) # DB 不應被呼叫 called = {'count': 0} def _spy_session(): called['count'] += 1 return MagicMock() monkeypatch.setattr('database.manager.get_session', _spy_session) result = rs.rag_service.query("query", caller='openclaw_qa') assert result.hits == [] # SELECT 不應發生(log 寫入仍會用 session) # rag_query_log 寫入也許會跑(embedding=None 仍 log),所以 called 可能是 0 或 1 time.sleep(0.2) # ───────────────────────────────────────────────────────────────────────────── # Test 5: feedback() / invalidate_by_caller() # ───────────────────────────────────────────────────────────────────────────── class TestFeedback: def test_feedback_writes_score(self, monkeypatch): from services.rag_service import rag_service fake_session = MagicMock() monkeypatch.setattr('database.manager.get_session', lambda: fake_session) ok = rag_service.feedback(123, 5) assert ok is True fake_session.execute.assert_called_once() fake_session.commit.assert_called_once() def test_feedback_clamps_score(self, monkeypatch): """score=99 應被夾擠到 5;score=0 應被夾擠到 1。""" from services.rag_service import rag_service captured = {} fake_session = MagicMock() def _exec(stmt, params): captured.update(params) return MagicMock() fake_session.execute.side_effect = _exec monkeypatch.setattr('database.manager.get_session', lambda: fake_session) rag_service.feedback(1, 99) assert captured['score'] == 5 captured.clear() rag_service.feedback(1, 0) assert captured['score'] == 1 def test_feedback_invalid_id_returns_false(self): from services.rag_service import rag_service assert rag_service.feedback(0, 5) is False assert rag_service.feedback(None, 5) is False # type: ignore def test_feedback_db_failure_returns_false(self, monkeypatch): from services.rag_service import rag_service fake_session = MagicMock() fake_session.execute.side_effect = RuntimeError("db fail") monkeypatch.setattr('database.manager.get_session', lambda: fake_session) ok = rag_service.feedback(123, 5) assert ok is False def test_invalidate_by_caller_is_noop(self): """v5.0 暫無 cache 層 → no-op,不 raise。""" from services.rag_service import rag_service rag_service.invalidate_by_caller('openclaw_qa') # 不應 raise # ───────────────────────────────────────────────────────────────────────────── # Test 6: RAGResult 屬性與 synthesize # ───────────────────────────────────────────────────────────────────────────── class TestRAGResultStructure: def test_synthesize_takes_top_3(self): from services.rag_service import RAGResult result = RAGResult( query='q', embedding_signature='abc', hits=[{'content': f'第 {i} 筆內容'} for i in range(5)], ) out = result.synthesize() assert '第 0 筆' in out assert '第 1 筆' in out assert '第 2 筆' in out assert '第 3 筆' not in out assert '第 4 筆' not in out def test_has_high_confidence_threshold(self): from services.rag_service import RAGResult # top-1 score=0.86 >= 0.85 → True r1 = RAGResult( query='q', embedding_signature='abc', hits=[{'score': 0.86}], threshold=0.85, ) assert r1.has_high_confidence is True # top-1 score=0.84 < 0.85 → False r2 = RAGResult( query='q', embedding_signature='abc', hits=[{'score': 0.84}], threshold=0.85, ) assert r2.has_high_confidence is False # 無 hits → False r3 = RAGResult(query='q', embedding_signature='abc', hits=[]) assert r3.has_high_confidence is False # ───────────────────────────────────────────────────────────────────────────── # Test 7: top_k / threshold 護欄夾擠 # ───────────────────────────────────────────────────────────────────────────── class TestParamGuards: def test_top_k_clamped_to_50(self, rag_enabled, monkeypatch): from services import rag_service as rs monkeypatch.setattr( 'services.ollama_service.ollama_service.generate_embedding', lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(), ) fake_session = MagicMock() fake_session.execute.return_value.fetchall.return_value = [] captured = {} def _exec(stmt, params): captured.update(params) return MagicMock(fetchall=lambda: []) fake_session.execute.side_effect = _exec monkeypatch.setattr('database.manager.get_session', lambda: fake_session) rs.rag_service.query("query", caller='openclaw_qa', top_k=999) assert captured.get('lim', 0) <= 100 # top_k=50, fetch_limit = 100 def test_threshold_clamped_to_unit_interval(self, rag_enabled, monkeypatch): from services import rag_service as rs monkeypatch.setattr( 'services.ollama_service.ollama_service.generate_embedding', lambda text, model="bge-m3:latest", **_kwargs: _fake_embedding(), ) fake_session = MagicMock() captured = {} def _exec(stmt, params): captured.update(params) return MagicMock(fetchall=lambda: []) fake_session.execute.side_effect = _exec monkeypatch.setattr('database.manager.get_session', lambda: fake_session) rs.rag_service.query("query", caller='openclaw_qa', threshold=2.0) # 2.0 → clamp 1.0;max_distance = 1.0 - 1.0 = 0.0 assert captured.get('max_distance', 1.0) == pytest.approx(0.0)