""" Runtime bootstrap guard tests. 這組測試鎖住 production rollout 曾踩到的兩個啟動序問題: - API replicas 同時執行 DB bootstrap DDL 時必須有 advisory lock。 - SignalWorker 建立 Redis Stream 背景 task 前,必須先初始化 worker Redis pool。 """ from __future__ import annotations from collections.abc import Awaitable from typing import Any import pytest class _FakeScalarResult: def __init__(self, value: bool) -> None: self.value = value def scalar(self) -> bool: return self.value class _FakeLockConnection: def __init__(self, lock_results: list[bool] | None = None) -> None: self.statements: list[str] = [] self._lock_results = lock_results or [True] async def __aenter__(self) -> _FakeLockConnection: return self async def __aexit__(self, *_exc: object) -> None: return None async def commit(self) -> None: self.statements.append("COMMIT") async def execute( self, statement: object, params: dict[str, str] | None = None, ) -> _FakeScalarResult: sql = str(statement) self.statements.append(sql) assert params == {"lock_name": "awoooi:init_db:ddl"} if "pg_try_advisory_lock" in sql: return _FakeScalarResult(self._lock_results.pop(0)) return _FakeScalarResult(True) class _FakeEngine: def __init__(self, lock_results: list[bool] | None = None) -> None: self.lock_conn = _FakeLockConnection(lock_results=lock_results) def connect(self) -> _FakeLockConnection: return self.lock_conn class _ConnectionBudgetEngine: def connect(self) -> _FakeLockConnection: raise RuntimeError('too many connections for role "awoooi"') def test_get_engine_uses_database_pool_budget(monkeypatch): from src.db import base as db_base captured: dict[str, object] = {} fake_engine = object() def fake_create_async_engine(database_url: str, **kwargs: object) -> object: captured["database_url"] = database_url captured.update(kwargs) return fake_engine monkeypatch.setattr(db_base, "_engine", None) monkeypatch.setattr(db_base, "_session_factory", None) monkeypatch.setattr(db_base.settings, "DATABASE_URL", "postgresql+asyncpg://u:p@localhost/db") monkeypatch.setattr(db_base.settings, "DATABASE_POOL_SIZE", 1) monkeypatch.setattr(db_base.settings, "DATABASE_MAX_OVERFLOW", 0) monkeypatch.setattr(db_base.settings, "DATABASE_POOL_TIMEOUT_SECONDS", 5.0) monkeypatch.setattr(db_base, "create_async_engine", fake_create_async_engine) assert db_base.get_engine() is fake_engine assert captured["database_url"] == "postgresql+asyncpg://u:p@localhost/db" assert captured["pool_size"] == 1 assert captured["max_overflow"] == 0 assert captured["pool_timeout"] == 5.0 @pytest.mark.asyncio async def test_init_db_serializes_bootstrap_ddl_with_advisory_lock(monkeypatch): from src.db import base as db_base fake_engine = _FakeEngine() calls: list[object] = [] async def fake_run_init_db_ddl(engine: object) -> None: calls.append(engine) monkeypatch.setattr(db_base, "get_engine", lambda: fake_engine) monkeypatch.setattr(db_base, "_run_init_db_ddl", fake_run_init_db_ddl) await db_base.init_db() assert calls == [fake_engine.lock_conn] assert "pg_try_advisory_lock" in fake_engine.lock_conn.statements[0] assert any("pg_advisory_unlock" in stmt for stmt in fake_engine.lock_conn.statements) assert "COMMIT" in fake_engine.lock_conn.statements @pytest.mark.asyncio async def test_init_db_skips_bootstrap_when_connection_budget_exhausted(monkeypatch): from src.db import base as db_base calls: list[object] = [] async def fake_run_init_db_ddl(engine: object) -> None: calls.append(engine) monkeypatch.setattr(db_base, "get_engine", lambda: _ConnectionBudgetEngine()) monkeypatch.setattr(db_base, "_run_init_db_ddl", fake_run_init_db_ddl) await db_base.init_db() assert calls == [] @pytest.mark.asyncio async def test_init_db_releases_bootstrap_lock_when_ddl_fails(monkeypatch): from src.db import base as db_base fake_engine = _FakeEngine() async def fake_run_init_db_ddl(_engine: object) -> None: raise RuntimeError("ddl failed") monkeypatch.setattr(db_base, "get_engine", lambda: fake_engine) monkeypatch.setattr(db_base, "_run_init_db_ddl", fake_run_init_db_ddl) with pytest.raises(RuntimeError, match="ddl failed"): await db_base.init_db() assert "pg_try_advisory_lock" in fake_engine.lock_conn.statements[0] assert any("pg_advisory_unlock" in stmt for stmt in fake_engine.lock_conn.statements) @pytest.mark.asyncio async def test_init_db_skips_bootstrap_when_advisory_lock_times_out(monkeypatch): from src.db import base as db_base fake_engine = _FakeEngine(lock_results=[False]) calls: list[object] = [] async def fake_run_init_db_ddl(engine: object) -> None: calls.append(engine) monkeypatch.setattr(db_base, "get_engine", lambda: fake_engine) monkeypatch.setattr(db_base, "_run_init_db_ddl", fake_run_init_db_ddl) monkeypatch.setattr(db_base, "_DB_BOOTSTRAP_LOCK_WAIT_SECONDS", 0.0) await db_base.init_db() assert calls == [] assert "pg_try_advisory_lock" in fake_engine.lock_conn.statements[0] assert all("pg_advisory_unlock" not in stmt for stmt in fake_engine.lock_conn.statements) @pytest.mark.asyncio async def test_init_db_releases_bootstrap_lock_when_ddl_times_out(monkeypatch): from src.db import base as db_base fake_engine = _FakeEngine() async def fake_run_init_db_ddl(_engine: object) -> None: raise AssertionError("wait_for should own ddl execution in this test") async def fake_wait_for(coro: Awaitable[Any], timeout: float) -> None: assert timeout == 120.0 coro.close() raise TimeoutError monkeypatch.setattr(db_base, "get_engine", lambda: fake_engine) monkeypatch.setattr(db_base, "_run_init_db_ddl", fake_run_init_db_ddl) monkeypatch.setattr(db_base.asyncio, "wait_for", fake_wait_for) await db_base.init_db() assert "pg_try_advisory_lock" in fake_engine.lock_conn.statements[0] assert any("pg_advisory_unlock" in stmt for stmt in fake_engine.lock_conn.statements) @pytest.mark.asyncio async def test_signal_worker_initializes_worker_redis_pool_before_tasks(monkeypatch): from src.workers import signal_worker events: list[str] = [] worker = signal_worker.SignalWorker() async def fake_ensure_consumer_group() -> None: events.append("consumer_group") async def fake_init_worker_redis_pool() -> None: events.append("worker_redis_pool") def fake_create_task(coro: Awaitable[Any]) -> object: events.append("task") coro.close() class _FakeTask: def cancel(self) -> None: return None return _FakeTask() monkeypatch.setattr(worker, "_ensure_consumer_group", fake_ensure_consumer_group) monkeypatch.setattr(signal_worker, "init_worker_redis_pool", fake_init_worker_redis_pool) monkeypatch.setattr(signal_worker.asyncio, "create_task", fake_create_task) await worker.start() assert events == ["consumer_group", "worker_redis_pool", "task", "task"] def test_api_lifespan_closes_worker_redis_pool_after_signal_worker() -> None: import inspect from src import main as api_main source = inspect.getsource(api_main.lifespan) assert "close_worker_redis_pool" in source assert source.index("close_signal_worker") < source.index("close_worker_redis_pool") assert source.index("close_worker_redis_pool") < source.index("close_redis_pool")