Files
ewoooc/tests/test_learning_pipeline.py
OoO 84a8c07e4a
All checks were successful
CD Pipeline / deploy (push) Successful in 2m57s
feat(p11.5): learning_episodes embedding 寫入 — 解鎖 Stage 3 dedup
Operation Ollama-First v5.0 / Phase 11.5 收尾(A4 已知 limitation 補完)

問題:Phase 11 A4 完成時揭露:
> Stage 3 dedup 需 episode 先 embed:目前 LearningPipeline.enqueue 寫入時
> embedding 為 NULL,所有 episode 都會略過 Stage 3 dedup

修補:
- learning_pipeline.enqueue 內 episode INSERT commit 後 enqueue embedding worker
- 用既有 _enqueue_embedding('learning_episodes', episode_id, distilled_text)
- ADR-007 retry queue worker 自動處理(_process_one_embedding 已動態 UPDATE
  {target_table},已支援 learning_episodes 表)
- distilled_text 截 4000 字避免 retry queue 表膨脹
- 失敗 swallow,僅 log debug(不阻擋 episode_id 回傳)

落地 ADR-033 護欄 #1 完整版:
  Stage 1: quality_score >= 0.7         既有
  Stage 2: 無幻覺檢測(規則引擎)        既有
  Stage 3: 與既有 insight cosine < 0.95  解鎖 
  Stage 4: weight >= 0.8 必經 👍/👎      既有

regression: 70 unit tests 全綠(含修正 test_enqueue_returns_id_on_success
配合新增 _enqueue_embedding 的 commit 計數變化)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-04 09:16:39 +08:00

278 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
tests/test_learning_pipeline.py
Operation Ollama-First v5.0 / Phase 11 — Distiller + LearningPipeline 單元測試
涵蓋:
- Distiller 各 quality_score 規則mcp / llm_response / user_feedback / manual_curated
- LearningPipeline.enqueue() DB 寫入路徑
- expire_stale_reviews() 24h 自動降級
- hash_human_approver() PII 保護
"""
from __future__ import annotations
import json
from unittest.mock import MagicMock, patch
import pytest
# ─────────────────────────────────────────────────────────────────────────────
# Distiller 各規則
# ─────────────────────────────────────────────────────────────────────────────
class TestDistillerMcpResult:
def test_long_with_keywords_high_quality(self):
from services.learning_pipeline import Distiller
d = Distiller()
text = "本週業績分析顯示,建議聚焦保濕品類。" + "詳細說明 " * 80 # > 200 字
result = d.distill(episode_type='mcp_result', raw_content=text)
assert result is not None
assert result.quality_score == 0.8
assert result.episode_type == 'mcp_result'
def test_long_no_keywords_medium_quality(self):
from services.learning_pipeline import Distiller
d = Distiller()
text = "啦啦啦" * 100 # > 200 字但無關鍵字
result = d.distill(episode_type='mcp_result', raw_content=text)
assert result.quality_score == 0.65
def test_short_low_quality(self):
from services.learning_pipeline import Distiller
d = Distiller()
text = "短內容"
result = d.distill(episode_type='mcp_result', raw_content=text)
assert result.quality_score == 0.5
def test_empty_returns_none(self):
from services.learning_pipeline import Distiller
d = Distiller()
assert d.distill(episode_type='mcp_result', raw_content='') is None
assert d.distill(episode_type='mcp_result', raw_content=' ') is None
class TestDistillerLlmResponse:
def test_json_structured_high_quality(self):
from services.learning_pipeline import Distiller
d = Distiller()
text = json.dumps({"status": "ok", "summary": "本週重點"})
result = d.distill(episode_type='llm_response', raw_content=text)
assert result.quality_score == 0.9
def test_json_array_non_empty_high(self):
from services.learning_pipeline import Distiller
d = Distiller()
text = json.dumps([{"sku": "A001", "risk": "HIGH"}])
result = d.distill(episode_type='llm_response', raw_content=text)
assert result.quality_score == 0.9
def test_json_dict_no_status_lower(self):
from services.learning_pipeline import Distiller
d = Distiller()
text = json.dumps({"some_field": "value"})
result = d.distill(episode_type='llm_response', raw_content=text)
# dict 非空 → 0.9 (status_ok 條件含 "len(obj)>0")
assert result.quality_score == 0.9
def test_free_text_long_with_numbers(self):
from services.learning_pipeline import Distiller
d = Distiller()
text = "本週業績漲了 15.3%" + "詳細說明 " * 100 # > 500 字 + 數字
result = d.distill(episode_type='llm_response', raw_content=text)
assert result.quality_score == 0.65
def test_free_text_long_no_numbers(self):
from services.learning_pipeline import Distiller
d = Distiller()
text = "本週業績趨勢上升。" + "詳細說明 " * 100 # > 500 字無數字
result = d.distill(episode_type='llm_response', raw_content=text)
assert result.quality_score == 0.55
def test_free_text_short_below_quality_gate(self):
from services.learning_pipeline import Distiller
d = Distiller()
text = "本週業績有變化" # 短文本
result = d.distill(episode_type='llm_response', raw_content=text)
# 0.4 → Stage 1 會 reject
assert result.quality_score == 0.4
class TestDistillerUserFeedback:
def test_score_5_high_quality(self):
from services.learning_pipeline import Distiller
d = Distiller()
result = d.distill(
episode_type='user_feedback',
raw_content='這個建議幫我增加了 8% 銷量',
user_feedback_score=5,
)
assert result.quality_score == 1.0
assert result.weight == 0.9 # 高權重 → Stage 4 人工驗收
def test_score_1_negative_sample(self):
from services.learning_pipeline import Distiller
d = Distiller()
result = d.distill(
episode_type='user_feedback',
raw_content='完全沒幫助',
user_feedback_score=1,
)
assert result.quality_score == 0.0 # Stage 1 reject
def test_default_score_3_mid(self):
from services.learning_pipeline import Distiller
d = Distiller()
result = d.distill(
episode_type='user_feedback',
raw_content='普通',
user_feedback_score=None,
)
# 預設 3 → (3-1)/4 = 0.5
assert result.quality_score == 0.5
class TestDistillerManualCurated:
def test_max_quality_and_weight(self):
from services.learning_pipeline import Distiller
d = Distiller()
result = d.distill(episode_type='manual_curated', raw_content='手動入庫')
assert result.quality_score == 1.0
assert result.weight == 1.0
class TestDistillerInvalidType:
def test_unknown_type_returns_none(self):
from services.learning_pipeline import Distiller
d = Distiller()
result = d.distill(episode_type='garbage', raw_content='whatever')
assert result is None
class TestDistillerLengthGuard:
def test_distilled_text_truncated_to_16kb(self):
from services.learning_pipeline import Distiller, DISTILLED_TEXT_MAX_BYTES
d = Distiller()
text = '建議分析 ' * 5000 # 遠超 16KB
result = d.distill(episode_type='mcp_result', raw_content=text)
encoded = result.distilled_text.encode('utf-8')
assert len(encoded) <= DISTILLED_TEXT_MAX_BYTES
# ─────────────────────────────────────────────────────────────────────────────
# LearningPipeline.enqueue
# ─────────────────────────────────────────────────────────────────────────────
class TestLearningPipelineEnqueue:
def test_enqueue_returns_id_on_success(self, monkeypatch):
from services.learning_pipeline import learning_pipeline
fake_session = MagicMock()
fake_row = MagicMock()
fake_row.__getitem__.return_value = 42
fake_session.execute.return_value.fetchone.return_value = fake_row
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
new_id = learning_pipeline.enqueue(
episode_type='manual_curated',
raw_content='手動入庫測試內容',
)
assert new_id == 42
# Phase 11.5: enqueue 後 _enqueue_embedding 也用同一個 fake session → commit 2 次
# 1: episode INSERT, 2: embedding_retry_queue INSERT
# 失敗安全_enqueue_embedding 失敗會 swallow 不影響 episode_id 回傳
assert fake_session.commit.call_count >= 1
def test_enqueue_returns_none_when_distill_fails(self):
from services.learning_pipeline import learning_pipeline
# 空內容 → distill 回 None → enqueue 回 None
result = learning_pipeline.enqueue(
episode_type='mcp_result',
raw_content='',
)
assert result is None
def test_enqueue_db_failure_returns_none(self, monkeypatch):
from services.learning_pipeline import learning_pipeline
fake_session = MagicMock()
fake_session.execute.side_effect = RuntimeError("db down")
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
result = learning_pipeline.enqueue(
episode_type='manual_curated',
raw_content='測試內容',
)
assert result is None
# ─────────────────────────────────────────────────────────────────────────────
# expire_stale_reviews
# ─────────────────────────────────────────────────────────────────────────────
class TestExpireStaleReviews:
def test_expire_uses_correct_sql(self, monkeypatch):
from services.learning_pipeline import expire_stale_reviews
fake_session = MagicMock()
fake_result = MagicMock()
fake_result.rowcount = 3
fake_session.execute.return_value = fake_result
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
count = expire_stale_reviews(hours=24)
assert count == 3
# 確認 commit 跑了
fake_session.commit.assert_called_once()
def test_expire_db_failure_returns_zero(self, monkeypatch):
from services.learning_pipeline import expire_stale_reviews
fake_session = MagicMock()
fake_session.execute.side_effect = RuntimeError("db down")
monkeypatch.setattr('database.manager.get_session', lambda: fake_session)
count = expire_stale_reviews(hours=24)
assert count == 0
# ─────────────────────────────────────────────────────────────────────────────
# hash_human_approver
# ─────────────────────────────────────────────────────────────────────────────
class TestHashHumanApprover:
def test_returns_8_char_hex(self):
from services.learning_pipeline import hash_human_approver
h = hash_human_approver('owen.tsai')
assert len(h) == 8
assert all(c in '0123456789abcdef' for c in h)
def test_empty_returns_empty(self):
from services.learning_pipeline import hash_human_approver
assert hash_human_approver('') == ''
assert hash_human_approver(None) == '' # type: ignore
def test_deterministic(self):
from services.learning_pipeline import hash_human_approver
a = hash_human_approver('alice')
b = hash_human_approver('alice')
c = hash_human_approver('bob')
assert a == b
assert a != c
# ─────────────────────────────────────────────────────────────────────────────
# 工具函式_detect_simple_contradiction
# ─────────────────────────────────────────────────────────────────────────────
class TestContradictionDetector:
def test_no_contradiction_returns_none(self):
from services.learning_pipeline import _detect_simple_contradiction
text = "業績是上升。市場是競爭。"
# subject=業績→上升, subject=市場→競爭,沒矛盾
assert _detect_simple_contradiction(text) is None
def test_contradiction_detected(self):
from services.learning_pipeline import _detect_simple_contradiction
text = "A是黑色。A是白色。"
result = _detect_simple_contradiction(text)
assert result is not None
assert 'A' in result