181 lines
6.0 KiB
Python
181 lines
6.0 KiB
Python
# apps/api/tests/test_model_version_tracker.py
|
||
# 2026-04-27 P3.2.2 by Claude
|
||
"""
|
||
ModelVersionTracker 單元測試
|
||
==============================
|
||
測試覆蓋:
|
||
- 第一次寫入:5 row,全部 changed=True(prev_version=None)
|
||
- 同樣資料重入:5 row,全部 changed=False
|
||
- digest 變更:該 provider changed=True,其餘 changed=False
|
||
- run_probe_cycle 回傳 dict 格式正確
|
||
- probe_all_providers 回傳空列表 → tracker 不 crash,probed=0
|
||
|
||
patch 策略:
|
||
- probe_all_providers → patch 原始模組 src.services.model_version_probe.probe_all_providers
|
||
- get_db_context → patch 原始模組 src.db.base.get_db_context
|
||
(tracker 用 lazy import,每次 import 拿到 same module object,patch 原始模組有效)
|
||
|
||
測試分類:unit(無 DB / Redis 依賴)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime, timedelta, timezone
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from src.services.model_version_probe import ProviderVersionInfo
|
||
from src.services.model_version_tracker import ModelVersionTracker
|
||
|
||
TAIPEI_TZ = timezone(timedelta(hours=8))
|
||
|
||
|
||
# =============================================================================
|
||
# Helpers
|
||
# =============================================================================
|
||
|
||
def _make_info(provider: str, version: str = "v1", digest: str | None = "sha256:abc") -> ProviderVersionInfo:
|
||
return ProviderVersionInfo(
|
||
provider=provider,
|
||
model=f"model-{provider}",
|
||
version=version,
|
||
digest=digest,
|
||
captured_at=datetime.now(TAIPEI_TZ),
|
||
)
|
||
|
||
|
||
def _make_five() -> list[ProviderVersionInfo]:
|
||
return [
|
||
_make_info("ollama"),
|
||
_make_info("ollama_local"),
|
||
_make_info("gemini", digest=None),
|
||
_make_info("claude", digest=None),
|
||
_make_info("openclaw_nemo"),
|
||
]
|
||
|
||
|
||
def _make_fake_ctx(last_fn):
|
||
"""回傳一個 asynccontextmanager,db.execute 呼叫 last_fn(index) 取 last record"""
|
||
added_rows: list = []
|
||
call_idx = [0]
|
||
|
||
@asynccontextmanager
|
||
async def fake_ctx():
|
||
class FakeDB:
|
||
async def execute(self, stmt):
|
||
result = MagicMock()
|
||
last = last_fn(call_idx[0])
|
||
call_idx[0] += 1
|
||
result.scalar_one_or_none = MagicMock(return_value=last)
|
||
return result
|
||
|
||
def add(self, obj):
|
||
added_rows.append(obj)
|
||
|
||
async def commit(self):
|
||
pass
|
||
|
||
yield FakeDB()
|
||
|
||
return fake_ctx, added_rows
|
||
|
||
|
||
# =============================================================================
|
||
# Test Cases
|
||
# =============================================================================
|
||
|
||
class TestModelVersionTracker:
|
||
@pytest.mark.asyncio
|
||
async def test_first_write_all_changed(self):
|
||
"""第一次寫入(DB 無歷史)→ 5 row 全部 changed=True"""
|
||
five = _make_five()
|
||
tracker = ModelVersionTracker()
|
||
fake_ctx, added_rows = _make_fake_ctx(lambda _: None) # last=None
|
||
|
||
async def _probe():
|
||
return five
|
||
|
||
with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \
|
||
patch("src.db.base.get_db_context", fake_ctx):
|
||
result = await tracker.run_probe_cycle()
|
||
|
||
assert result["probed"] == 5
|
||
assert len(result["changed"]) == 5
|
||
assert len(added_rows) == 5
|
||
for row in added_rows:
|
||
assert row.changed is True
|
||
assert row.prev_version is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_same_data_no_change(self):
|
||
"""DB 有相同版本記錄 → changed=False"""
|
||
five = _make_five()
|
||
tracker = ModelVersionTracker()
|
||
|
||
def _make_last_same(idx):
|
||
info = five[idx % len(five)]
|
||
last = MagicMock()
|
||
last.version = info.version
|
||
last.digest = info.digest
|
||
return last
|
||
|
||
fake_ctx, added_rows = _make_fake_ctx(_make_last_same)
|
||
|
||
async def _probe():
|
||
return five
|
||
|
||
with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \
|
||
patch("src.db.base.get_db_context", fake_ctx):
|
||
result = await tracker.run_probe_cycle()
|
||
|
||
assert result["probed"] == 5
|
||
assert len(result["changed"]) == 0
|
||
assert len(added_rows) == 5
|
||
for row in added_rows:
|
||
assert row.changed is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_digest_change_detected(self):
|
||
"""ollama provider digest 改變 → changed=True,其餘 changed=False"""
|
||
five = _make_five()
|
||
tracker = ModelVersionTracker()
|
||
changed_provider = "ollama"
|
||
|
||
def _make_last(idx):
|
||
info = five[idx % len(five)]
|
||
last = MagicMock()
|
||
last.version = info.version
|
||
last.digest = "sha256:OLD_DIGEST" if info.provider == changed_provider else info.digest
|
||
return last
|
||
|
||
fake_ctx, added_rows = _make_fake_ctx(_make_last)
|
||
|
||
async def _probe():
|
||
return five
|
||
|
||
with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \
|
||
patch("src.db.base.get_db_context", fake_ctx):
|
||
result = await tracker.run_probe_cycle()
|
||
|
||
assert result["probed"] == 5
|
||
assert changed_provider in result["changed"]
|
||
assert len(result["changed"]) == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_empty_probe_results_no_crash(self):
|
||
"""probe_all_providers 回傳空列表 → tracker 不 crash,probed=0"""
|
||
tracker = ModelVersionTracker()
|
||
fake_ctx, added_rows = _make_fake_ctx(lambda _: None)
|
||
|
||
async def _probe():
|
||
return []
|
||
|
||
with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \
|
||
patch("src.db.base.get_db_context", fake_ctx):
|
||
result = await tracker.run_probe_cycle()
|
||
|
||
assert result["probed"] == 0
|
||
assert result["changed"] == []
|
||
assert len(added_rows) == 0
|