fix(p3.2): model_version_tracker 改 pure unit test + probe 改善
Some checks failed
CD Pipeline / build-and-deploy (push) Failing after 2m7s

Engineer 重寫 test_model_version_tracker:
- 用 _make_fake_ctx (asynccontextmanager) 完整 mock get_db_context
- 移除 @pytest.mark.integration(整 class)
- patch probe_all_providers + get_db_context 雙路徑
- 4 testcases 全綠,無真實 PG 依賴

model_version_probe.py 配套改善(match 新 test mock 預期)

Tests: 19 passed (probe 15 + tracker 4)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Your Name
2026-04-27 14:58:46 +08:00
parent ed205489c1
commit 8d6e086254
2 changed files with 65 additions and 140 deletions

View File

@@ -20,6 +20,8 @@ from datetime import datetime, timedelta, timezone
import structlog
from src.core.config import settings
logger = structlog.get_logger(__name__)
TAIPEI_TZ = timezone(timedelta(hours=8))
@@ -91,8 +93,6 @@ async def probe_gemini_version() -> ProviderVersionInfo:
Raises:
RuntimeError: GEMINI_API_KEY 未設定
"""
from src.core.config import settings
api_key = settings.GEMINI_API_KEY
if not api_key:
raise RuntimeError("GEMINI_API_KEY not configured")
@@ -147,8 +147,6 @@ async def probe_claude_version() -> ProviderVersionInfo:
Raises:
RuntimeError: CLAUDE_API_KEY 未設定
"""
from src.core.config import settings
api_key = settings.CLAUDE_API_KEY
if not api_key:
raise RuntimeError("CLAUDE_API_KEY not configured")
@@ -182,8 +180,6 @@ async def probe_openclaw_nemo_version() -> ProviderVersionInfo:
RuntimeError: OPENCLAW_DEFAULT_MODEL 未設定
httpx.HTTPError: 連線失敗
"""
from src.core.config import settings
model = settings.OPENCLAW_DEFAULT_MODEL
if not model:
raise RuntimeError("OPENCLAW_DEFAULT_MODEL not configured")
@@ -238,8 +234,6 @@ async def probe_all_providers() -> list[ProviderVersionInfo]:
- 使用 return_exceptions=True 確保任一 provider 失敗不影響其他
- 每個 exception 都有對應的 log warning
"""
from src.core.config import settings
tasks = [
probe_ollama_version(settings.OLLAMA_URL, settings.OLLAMA_HEALTH_CHECK_MODEL),
probe_ollama_version(

View File

@@ -8,14 +8,20 @@ ModelVersionTracker 單元測試
- 同樣資料重入5 row全部 changed=False
- digest 變更:該 provider changed=True其餘 changed=False
- run_probe_cycle 回傳 dict 格式正確
- probe_all_providers 拋例外 → tracker 不 crash
- probe_all_providers 回傳空列表 → tracker 不 crashprobed=0
測試分類unitmock DB session + probe_all_providers無實際 DB 依賴)
patch 策略:
- probe_all_providers → patch 原始模組 src.services.model_version_probe.probe_all_providers
- get_db_context → patch 原始模組 src.db.base.get_db_context
tracker 用 lazy import每次 import 拿到 same module objectpatch 原始模組有效)
測試分類unit無 DB / Redis 依賴)
"""
from __future__ import annotations
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import MagicMock, patch
import pytest
@@ -49,47 +55,19 @@ def _make_five() -> list[ProviderVersionInfo]:
]
def _mock_db_session(last_records: dict[str, MagicMock | None]):
"""構造 fake DB sessionscalar_one_or_none 依 provider 回傳 last_records"""
db = AsyncMock()
added: list = []
async def _execute(stmt):
# 從 stmt where clause 取 provider name用 compile 或直接 mock
# 這裡用簡化方法:記錄 execute 被呼叫的順序
result = MagicMock()
# 每次 execute 取出一個 last_record按 provider 順序)
result.scalar_one_or_none = MagicMock(return_value=None) # default
return result
db.execute = AsyncMock(side_effect=_execute)
db.add = MagicMock(side_effect=lambda obj: added.append(obj))
db.commit = AsyncMock()
db._added = added
return db
# =============================================================================
# Test Cases
# =============================================================================
@pytest.mark.integration
class TestModelVersionTracker:
"""需要 PG 連線mock 不完整,實際呼叫 get_db_context→ 標 integration"""
@pytest.mark.asyncio
async def test_first_write_all_changed(self):
"""第一次寫入DB 無歷史)→ 5 row 全部 changed=True"""
five = _make_five()
tracker = ModelVersionTracker()
added_rows: list = []
def _make_fake_ctx(last_fn):
"""回傳一個 asynccontextmanagerdb.execute 呼叫 last_fn(index) 取 last record"""
added_rows: list = []
call_idx = [0]
@asynccontextmanager
async def fake_ctx():
class FakeDB:
async def execute(self, stmt):
result = MagicMock()
result.scalar_one_or_none = MagicMock(return_value=None)
last = last_fn(call_idx[0])
call_idx[0] += 1
result.scalar_one_or_none = MagicMock(return_value=last)
return result
def add(self, obj):
@@ -98,14 +76,28 @@ class TestModelVersionTracker:
async def commit(self):
pass
from contextlib import asynccontextmanager
yield FakeDB()
@asynccontextmanager
async def fake_ctx():
yield FakeDB()
return fake_ctx, added_rows
with patch("src.services.model_version_tracker.probe_all_providers", return_value=five), \
patch("src.services.model_version_tracker.get_db_context", fake_ctx):
# =============================================================================
# Test Cases
# =============================================================================
class TestModelVersionTracker:
@pytest.mark.asyncio
async def test_first_write_all_changed(self):
"""第一次寫入DB 無歷史)→ 5 row 全部 changed=True"""
five = _make_five()
tracker = ModelVersionTracker()
fake_ctx, added_rows = _make_fake_ctx(lambda _: None) # last=None
async def _probe():
return five
with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \
patch("src.db.base.get_db_context", fake_ctx):
result = await tracker.run_probe_cycle()
assert result["probed"] == 5
@@ -120,128 +112,67 @@ class TestModelVersionTracker:
"""DB 有相同版本記錄 → changed=False"""
five = _make_five()
tracker = ModelVersionTracker()
added_rows: list = []
# last record 與 info 版本相同
def _make_last(info: ProviderVersionInfo):
def _make_last_same(idx):
info = five[idx % len(five)]
last = MagicMock()
last.version = info.version
last.digest = info.digest
return last
lasts = {info.provider: _make_last(info) for info in five}
call_idx = [0]
fake_ctx, added_rows = _make_fake_ctx(_make_last_same)
class FakeDB:
async def execute(self, stmt):
result = MagicMock()
# 依順序回傳對應 provider 的 last record
info = five[call_idx[0] % len(five)]
call_idx[0] += 1
result.scalar_one_or_none = MagicMock(return_value=lasts[info.provider])
return result
async def _probe():
return five
def add(self, obj):
added_rows.append(obj)
async def commit(self):
pass
from contextlib import asynccontextmanager
@asynccontextmanager
async def fake_ctx():
yield FakeDB()
with patch("src.services.model_version_tracker.probe_all_providers", return_value=five), \
patch("src.services.model_version_tracker.get_db_context", fake_ctx):
with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \
patch("src.db.base.get_db_context", fake_ctx):
result = await tracker.run_probe_cycle()
assert result["probed"] == 5
assert len(result["changed"]) == 0
assert len(added_rows) == 5
for row in added_rows:
assert row.changed is False
@pytest.mark.asyncio
async def test_digest_change_detected(self):
"""其中一個 provider digest 改變 → changed=True其餘 changed=False"""
"""ollama provider digest 改變 → changed=True其餘 changed=False"""
five = _make_five()
tracker = ModelVersionTracker()
added_rows: list = []
changed_provider = "ollama"
def _make_last(info: ProviderVersionInfo):
def _make_last(idx):
info = five[idx % len(five)]
last = MagicMock()
if info.provider == changed_provider:
# 舊 digest 不同
last.version = info.version
last.digest = "sha256:OLD_DIGEST"
else:
last.version = info.version
last.digest = info.digest
last.version = info.version
last.digest = "sha256:OLD_DIGEST" if info.provider == changed_provider else info.digest
return last
lasts = {info.provider: _make_last(info) for info in five}
call_idx = [0]
fake_ctx, added_rows = _make_fake_ctx(_make_last)
class FakeDB:
async def execute(self, stmt):
result = MagicMock()
info = five[call_idx[0] % len(five)]
call_idx[0] += 1
result.scalar_one_or_none = MagicMock(return_value=lasts[info.provider])
return result
async def _probe():
return five
def add(self, obj):
added_rows.append(obj)
async def commit(self):
pass
from contextlib import asynccontextmanager
@asynccontextmanager
async def fake_ctx():
yield FakeDB()
with patch("src.services.model_version_tracker.probe_all_providers", return_value=five), \
patch("src.services.model_version_tracker.get_db_context", fake_ctx):
with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \
patch("src.db.base.get_db_context", fake_ctx):
result = await tracker.run_probe_cycle()
assert result["probed"] == 5
assert changed_provider in result["changed"]
# 只有 1 個 changed
assert len(result["changed"]) == 1
@pytest.mark.asyncio
async def test_probe_failure_does_not_crash(self):
"""probe_all_providers 拋 exception → tracker 不 crash回傳 probed=0"""
async def test_empty_probe_results_no_crash(self):
"""probe_all_providers 回傳空列表 → tracker 不 crashprobed=0"""
tracker = ModelVersionTracker()
added_rows: list = []
fake_ctx, added_rows = _make_fake_ctx(lambda _: None)
from contextlib import asynccontextmanager
async def _probe():
return []
@asynccontextmanager
async def fake_ctx():
class FakeDB:
async def execute(self, stmt):
r = MagicMock()
r.scalar_one_or_none = MagicMock(return_value=None)
return r
def add(self, obj):
added_rows.append(obj)
async def commit(self):
pass
yield FakeDB()
async def _bad_probe():
return [] # probe 全部失敗,回傳空列表
with patch("src.services.model_version_tracker.probe_all_providers", side_effect=_bad_probe), \
patch("src.services.model_version_tracker.get_db_context", fake_ctx):
with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \
patch("src.db.base.get_db_context", fake_ctx):
result = await tracker.run_probe_cycle()
assert result["probed"] == 0