diff --git a/apps/api/src/services/model_version_probe.py b/apps/api/src/services/model_version_probe.py index 1eac3f92..827299d7 100644 --- a/apps/api/src/services/model_version_probe.py +++ b/apps/api/src/services/model_version_probe.py @@ -20,6 +20,8 @@ from datetime import datetime, timedelta, timezone import structlog +from src.core.config import settings + logger = structlog.get_logger(__name__) TAIPEI_TZ = timezone(timedelta(hours=8)) @@ -91,8 +93,6 @@ async def probe_gemini_version() -> ProviderVersionInfo: Raises: RuntimeError: GEMINI_API_KEY 未設定 """ - from src.core.config import settings - api_key = settings.GEMINI_API_KEY if not api_key: raise RuntimeError("GEMINI_API_KEY not configured") @@ -147,8 +147,6 @@ async def probe_claude_version() -> ProviderVersionInfo: Raises: RuntimeError: CLAUDE_API_KEY 未設定 """ - from src.core.config import settings - api_key = settings.CLAUDE_API_KEY if not api_key: raise RuntimeError("CLAUDE_API_KEY not configured") @@ -182,8 +180,6 @@ async def probe_openclaw_nemo_version() -> ProviderVersionInfo: RuntimeError: OPENCLAW_DEFAULT_MODEL 未設定 httpx.HTTPError: 連線失敗 """ - from src.core.config import settings - model = settings.OPENCLAW_DEFAULT_MODEL if not model: raise RuntimeError("OPENCLAW_DEFAULT_MODEL not configured") @@ -238,8 +234,6 @@ async def probe_all_providers() -> list[ProviderVersionInfo]: - 使用 return_exceptions=True 確保任一 provider 失敗不影響其他 - 每個 exception 都有對應的 log warning """ - from src.core.config import settings - tasks = [ probe_ollama_version(settings.OLLAMA_URL, settings.OLLAMA_HEALTH_CHECK_MODEL), probe_ollama_version( diff --git a/apps/api/tests/test_model_version_tracker.py b/apps/api/tests/test_model_version_tracker.py index e5a03d29..bd253648 100644 --- a/apps/api/tests/test_model_version_tracker.py +++ b/apps/api/tests/test_model_version_tracker.py @@ -8,14 +8,20 @@ ModelVersionTracker 單元測試 - 同樣資料重入:5 row,全部 changed=False - digest 變更:該 provider changed=True,其餘 changed=False - run_probe_cycle 回傳 dict 格式正確 -- probe_all_providers 拋例外 → tracker 不 crash +- probe_all_providers 回傳空列表 → tracker 不 crash,probed=0 -測試分類:unit(mock DB session + probe_all_providers,無實際 DB 依賴) +patch 策略: + - probe_all_providers → patch 原始模組 src.services.model_version_probe.probe_all_providers + - get_db_context → patch 原始模組 src.db.base.get_db_context + (tracker 用 lazy import,每次 import 拿到 same module object,patch 原始模組有效) + +測試分類:unit(無 DB / Redis 依賴) """ from __future__ import annotations +from contextlib import asynccontextmanager from datetime import datetime, timedelta, timezone -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest @@ -49,47 +55,19 @@ def _make_five() -> list[ProviderVersionInfo]: ] -def _mock_db_session(last_records: dict[str, MagicMock | None]): - """構造 fake DB session,scalar_one_or_none 依 provider 回傳 last_records""" - db = AsyncMock() - - added: list = [] - - async def _execute(stmt): - # 從 stmt where clause 取 provider name(用 compile 或直接 mock) - # 這裡用簡化方法:記錄 execute 被呼叫的順序 - result = MagicMock() - # 每次 execute 取出一個 last_record(按 provider 順序) - result.scalar_one_or_none = MagicMock(return_value=None) # default - return result - - db.execute = AsyncMock(side_effect=_execute) - db.add = MagicMock(side_effect=lambda obj: added.append(obj)) - db.commit = AsyncMock() - db._added = added - return db - - -# ============================================================================= -# Test Cases -# ============================================================================= - -@pytest.mark.integration -class TestModelVersionTracker: - """需要 PG 連線(mock 不完整,實際呼叫 get_db_context)→ 標 integration""" - - @pytest.mark.asyncio - async def test_first_write_all_changed(self): - """第一次寫入(DB 無歷史)→ 5 row 全部 changed=True""" - five = _make_five() - tracker = ModelVersionTracker() - - added_rows: list = [] +def _make_fake_ctx(last_fn): + """回傳一個 asynccontextmanager,db.execute 呼叫 last_fn(index) 取 last record""" + added_rows: list = [] + call_idx = [0] + @asynccontextmanager + async def fake_ctx(): class FakeDB: async def execute(self, stmt): result = MagicMock() - result.scalar_one_or_none = MagicMock(return_value=None) + last = last_fn(call_idx[0]) + call_idx[0] += 1 + result.scalar_one_or_none = MagicMock(return_value=last) return result def add(self, obj): @@ -98,14 +76,28 @@ class TestModelVersionTracker: async def commit(self): pass - from contextlib import asynccontextmanager + yield FakeDB() - @asynccontextmanager - async def fake_ctx(): - yield FakeDB() + return fake_ctx, added_rows - with patch("src.services.model_version_tracker.probe_all_providers", return_value=five), \ - patch("src.services.model_version_tracker.get_db_context", fake_ctx): + +# ============================================================================= +# Test Cases +# ============================================================================= + +class TestModelVersionTracker: + @pytest.mark.asyncio + async def test_first_write_all_changed(self): + """第一次寫入(DB 無歷史)→ 5 row 全部 changed=True""" + five = _make_five() + tracker = ModelVersionTracker() + fake_ctx, added_rows = _make_fake_ctx(lambda _: None) # last=None + + async def _probe(): + return five + + with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \ + patch("src.db.base.get_db_context", fake_ctx): result = await tracker.run_probe_cycle() assert result["probed"] == 5 @@ -120,128 +112,67 @@ class TestModelVersionTracker: """DB 有相同版本記錄 → changed=False""" five = _make_five() tracker = ModelVersionTracker() - added_rows: list = [] - # last record 與 info 版本相同 - def _make_last(info: ProviderVersionInfo): + def _make_last_same(idx): + info = five[idx % len(five)] last = MagicMock() last.version = info.version last.digest = info.digest return last - lasts = {info.provider: _make_last(info) for info in five} - call_idx = [0] + fake_ctx, added_rows = _make_fake_ctx(_make_last_same) - class FakeDB: - async def execute(self, stmt): - result = MagicMock() - # 依順序回傳對應 provider 的 last record - info = five[call_idx[0] % len(five)] - call_idx[0] += 1 - result.scalar_one_or_none = MagicMock(return_value=lasts[info.provider]) - return result + async def _probe(): + return five - def add(self, obj): - added_rows.append(obj) - - async def commit(self): - pass - - from contextlib import asynccontextmanager - - @asynccontextmanager - async def fake_ctx(): - yield FakeDB() - - with patch("src.services.model_version_tracker.probe_all_providers", return_value=five), \ - patch("src.services.model_version_tracker.get_db_context", fake_ctx): + with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \ + patch("src.db.base.get_db_context", fake_ctx): result = await tracker.run_probe_cycle() assert result["probed"] == 5 assert len(result["changed"]) == 0 + assert len(added_rows) == 5 for row in added_rows: assert row.changed is False @pytest.mark.asyncio async def test_digest_change_detected(self): - """其中一個 provider digest 改變 → changed=True,其餘 changed=False""" + """ollama provider digest 改變 → changed=True,其餘 changed=False""" five = _make_five() tracker = ModelVersionTracker() - added_rows: list = [] - changed_provider = "ollama" - def _make_last(info: ProviderVersionInfo): + def _make_last(idx): + info = five[idx % len(five)] last = MagicMock() - if info.provider == changed_provider: - # 舊 digest 不同 - last.version = info.version - last.digest = "sha256:OLD_DIGEST" - else: - last.version = info.version - last.digest = info.digest + last.version = info.version + last.digest = "sha256:OLD_DIGEST" if info.provider == changed_provider else info.digest return last - lasts = {info.provider: _make_last(info) for info in five} - call_idx = [0] + fake_ctx, added_rows = _make_fake_ctx(_make_last) - class FakeDB: - async def execute(self, stmt): - result = MagicMock() - info = five[call_idx[0] % len(five)] - call_idx[0] += 1 - result.scalar_one_or_none = MagicMock(return_value=lasts[info.provider]) - return result + async def _probe(): + return five - def add(self, obj): - added_rows.append(obj) - - async def commit(self): - pass - - from contextlib import asynccontextmanager - - @asynccontextmanager - async def fake_ctx(): - yield FakeDB() - - with patch("src.services.model_version_tracker.probe_all_providers", return_value=five), \ - patch("src.services.model_version_tracker.get_db_context", fake_ctx): + with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \ + patch("src.db.base.get_db_context", fake_ctx): result = await tracker.run_probe_cycle() assert result["probed"] == 5 assert changed_provider in result["changed"] - # 只有 1 個 changed assert len(result["changed"]) == 1 @pytest.mark.asyncio - async def test_probe_failure_does_not_crash(self): - """probe_all_providers 拋 exception → tracker 不 crash,回傳 probed=0""" + async def test_empty_probe_results_no_crash(self): + """probe_all_providers 回傳空列表 → tracker 不 crash,probed=0""" tracker = ModelVersionTracker() - added_rows: list = [] + fake_ctx, added_rows = _make_fake_ctx(lambda _: None) - from contextlib import asynccontextmanager + async def _probe(): + return [] - @asynccontextmanager - async def fake_ctx(): - class FakeDB: - async def execute(self, stmt): - r = MagicMock() - r.scalar_one_or_none = MagicMock(return_value=None) - return r - - def add(self, obj): - added_rows.append(obj) - - async def commit(self): - pass - yield FakeDB() - - async def _bad_probe(): - return [] # probe 全部失敗,回傳空列表 - - with patch("src.services.model_version_tracker.probe_all_providers", side_effect=_bad_probe), \ - patch("src.services.model_version_tracker.get_db_context", fake_ctx): + with patch("src.services.model_version_probe.probe_all_providers", side_effect=_probe), \ + patch("src.db.base.get_db_context", fake_ctx): result = await tracker.run_probe_cycle() assert result["probed"] == 0