Files
awoooi/apps/api/tests/test_model_version_tracker.py
Your Name 4111ea4f9f
All checks were successful
Code Review / ai-code-review (push) Successful in 12s
CD Pipeline / tests (push) Successful in 1m13s
CD Pipeline / build-and-deploy (push) Successful in 3m36s
CD Pipeline / post-deploy-checks (push) Successful in 1m20s
fix(ai): remove 188 ollama provider
2026-05-06 14:34:48 +08:00

181 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# apps/api/tests/test_model_version_tracker.py
# 2026-04-27 P3.2.2 by Claude
"""
ModelVersionTracker 單元測試
==============================
測試覆蓋:
- 第一次寫入5 row全部 changed=Trueprev_version=None
- 同樣資料重入5 row全部 changed=False
- digest 變更:該 provider changed=True其餘 changed=False
- run_probe_cycle 回傳 dict 格式正確
- probe_all_providers 回傳空列表 → tracker 不 crashprobed=0
patch 策略:
- probe_all_providers → patch 原始模組 src.services.model_version_probe.probe_all_providers
- get_db_context → patch 原始模組 src.db.base.get_db_context
tracker 用 lazy import每次 import 拿到 same module objectpatch 原始模組有效)
測試分類unit無 DB / Redis 依賴)
"""
from __future__ import annotations
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
import pytest
from src.services.model_version_probe import ProviderVersionInfo
from src.services.model_version_tracker import ModelVersionTracker
TAIPEI_TZ = timezone(timedelta(hours=8))
# =============================================================================
# Helpers
# =============================================================================
def _make_info(provider: str, version: str = "v1", digest: str | None = "sha256:abc") -> ProviderVersionInfo:
return ProviderVersionInfo(
provider=provider,
model=f"model-{provider}",
version=version,
digest=digest,
captured_at=datetime.now(TAIPEI_TZ),
)
def _make_five() -> list[ProviderVersionInfo]:
return [
_make_info("ollama"),
_make_info("ollama_local"),
_make_info("gemini", digest=None),
_make_info("claude", digest=None),
_make_info("openclaw_nemo"),
]
def _make_fake_ctx(last_fn):
"""回傳一個 asynccontextmanagerdb.execute 呼叫 last_fn(index) 取 last record"""
added_rows: list = []
call_idx = [0]
@asynccontextmanager
async def fake_ctx():
class FakeDB:
async def execute(self, stmt):
result = MagicMock()
last = last_fn(call_idx[0])
call_idx[0] += 1
result.scalar_one_or_none = MagicMock(return_value=last)
return result
def add(self, obj):
added_rows.append(obj)
async def commit(self):
pass
yield FakeDB()
return fake_ctx, added_rows
# =============================================================================
# Test Cases
# =============================================================================
class TestModelVersionTracker:
@pytest.mark.asyncio
async def test_first_write_all_changed(self):
"""第一次寫入DB 無歷史)→ 5 row 全部 changed=True"""
five = _make_five()
tracker = ModelVersionTracker()
fake_ctx, added_rows = _make_fake_ctx(lambda _: None) # last=None
async def _probe():
return five
with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \
patch("src.db.base.get_db_context", fake_ctx):
result = await tracker.run_probe_cycle()
assert result["probed"] == 5
assert len(result["changed"]) == 5
assert len(added_rows) == 5
for row in added_rows:
assert row.changed is True
assert row.prev_version is None
@pytest.mark.asyncio
async def test_same_data_no_change(self):
"""DB 有相同版本記錄 → changed=False"""
five = _make_five()
tracker = ModelVersionTracker()
def _make_last_same(idx):
info = five[idx % len(five)]
last = MagicMock()
last.version = info.version
last.digest = info.digest
return last
fake_ctx, added_rows = _make_fake_ctx(_make_last_same)
async def _probe():
return five
with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \
patch("src.db.base.get_db_context", fake_ctx):
result = await tracker.run_probe_cycle()
assert result["probed"] == 5
assert len(result["changed"]) == 0
assert len(added_rows) == 5
for row in added_rows:
assert row.changed is False
@pytest.mark.asyncio
async def test_digest_change_detected(self):
"""ollama provider digest 改變 → changed=True其餘 changed=False"""
five = _make_five()
tracker = ModelVersionTracker()
changed_provider = "ollama"
def _make_last(idx):
info = five[idx % len(five)]
last = MagicMock()
last.version = info.version
last.digest = "sha256:OLD_DIGEST" if info.provider == changed_provider else info.digest
return last
fake_ctx, added_rows = _make_fake_ctx(_make_last)
async def _probe():
return five
with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \
patch("src.db.base.get_db_context", fake_ctx):
result = await tracker.run_probe_cycle()
assert result["probed"] == 5
assert changed_provider in result["changed"]
assert len(result["changed"]) == 1
@pytest.mark.asyncio
async def test_empty_probe_results_no_crash(self):
"""probe_all_providers 回傳空列表 → tracker 不 crashprobed=0"""
tracker = ModelVersionTracker()
fake_ctx, added_rows = _make_fake_ctx(lambda _: None)
async def _probe():
return []
with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \
patch("src.db.base.get_db_context", fake_ctx):
result = await tracker.run_probe_cycle()
assert result["probed"] == 0
assert result["changed"] == []
assert len(added_rows) == 0