Some checks failed
CD Pipeline / workflow-shape (push) Successful in 0s
CD Pipeline / cancel-stale-cd (push) Has been skipped
CD Pipeline / tests (push) Failing after 1m46s
CD Pipeline / build-and-deploy (push) Has been skipped
CD Pipeline / post-deploy-checks (push) Has been skipped
234 lines
7.6 KiB
Python
234 lines
7.6 KiB
Python
"""
|
|
Runtime bootstrap guard tests.
|
|
|
|
這組測試鎖住 production rollout 曾踩到的兩個啟動序問題:
|
|
- API replicas 同時執行 DB bootstrap DDL 時必須有 advisory lock。
|
|
- SignalWorker 建立 Redis Stream 背景 task 前,必須先初始化 worker Redis pool。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Awaitable
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
|
|
class _FakeScalarResult:
|
|
def __init__(self, value: bool) -> None:
|
|
self.value = value
|
|
|
|
def scalar(self) -> bool:
|
|
return self.value
|
|
|
|
|
|
class _FakeLockConnection:
|
|
def __init__(self, lock_results: list[bool] | None = None) -> None:
|
|
self.statements: list[str] = []
|
|
self._lock_results = lock_results or [True]
|
|
|
|
async def __aenter__(self) -> _FakeLockConnection:
|
|
return self
|
|
|
|
async def __aexit__(self, *_exc: object) -> None:
|
|
return None
|
|
|
|
async def commit(self) -> None:
|
|
self.statements.append("COMMIT")
|
|
|
|
async def execute(
|
|
self,
|
|
statement: object,
|
|
params: dict[str, str] | None = None,
|
|
) -> _FakeScalarResult:
|
|
sql = str(statement)
|
|
self.statements.append(sql)
|
|
assert params == {"lock_name": "awoooi:init_db:ddl"}
|
|
if "pg_try_advisory_lock" in sql:
|
|
return _FakeScalarResult(self._lock_results.pop(0))
|
|
return _FakeScalarResult(True)
|
|
|
|
|
|
class _FakeEngine:
|
|
def __init__(self, lock_results: list[bool] | None = None) -> None:
|
|
self.lock_conn = _FakeLockConnection(lock_results=lock_results)
|
|
|
|
def connect(self) -> _FakeLockConnection:
|
|
return self.lock_conn
|
|
|
|
|
|
class _ConnectionBudgetEngine:
|
|
def connect(self) -> _FakeLockConnection:
|
|
raise RuntimeError('too many connections for role "awoooi"')
|
|
|
|
|
|
def test_get_engine_uses_database_pool_budget(monkeypatch):
|
|
from src.db import base as db_base
|
|
|
|
captured: dict[str, object] = {}
|
|
fake_engine = object()
|
|
|
|
def fake_create_async_engine(database_url: str, **kwargs: object) -> object:
|
|
captured["database_url"] = database_url
|
|
captured.update(kwargs)
|
|
return fake_engine
|
|
|
|
monkeypatch.setattr(db_base, "_engine", None)
|
|
monkeypatch.setattr(db_base, "_session_factory", None)
|
|
monkeypatch.setattr(db_base.settings, "DATABASE_URL", "postgresql+asyncpg://u:p@localhost/db")
|
|
monkeypatch.setattr(db_base.settings, "DATABASE_POOL_SIZE", 1)
|
|
monkeypatch.setattr(db_base.settings, "DATABASE_MAX_OVERFLOW", 0)
|
|
monkeypatch.setattr(db_base.settings, "DATABASE_POOL_TIMEOUT_SECONDS", 5.0)
|
|
monkeypatch.setattr(db_base, "create_async_engine", fake_create_async_engine)
|
|
|
|
assert db_base.get_engine() is fake_engine
|
|
assert captured["database_url"] == "postgresql+asyncpg://u:p@localhost/db"
|
|
assert captured["pool_size"] == 1
|
|
assert captured["max_overflow"] == 0
|
|
assert captured["pool_timeout"] == 5.0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_init_db_serializes_bootstrap_ddl_with_advisory_lock(monkeypatch):
|
|
from src.db import base as db_base
|
|
|
|
fake_engine = _FakeEngine()
|
|
calls: list[object] = []
|
|
|
|
async def fake_run_init_db_ddl(engine: object) -> None:
|
|
calls.append(engine)
|
|
|
|
monkeypatch.setattr(db_base, "get_engine", lambda: fake_engine)
|
|
monkeypatch.setattr(db_base, "_run_init_db_ddl", fake_run_init_db_ddl)
|
|
|
|
await db_base.init_db()
|
|
|
|
assert calls == [fake_engine.lock_conn]
|
|
assert "pg_try_advisory_lock" in fake_engine.lock_conn.statements[0]
|
|
assert any("pg_advisory_unlock" in stmt for stmt in fake_engine.lock_conn.statements)
|
|
assert "COMMIT" in fake_engine.lock_conn.statements
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_init_db_skips_bootstrap_when_connection_budget_exhausted(monkeypatch):
|
|
from src.db import base as db_base
|
|
|
|
calls: list[object] = []
|
|
|
|
async def fake_run_init_db_ddl(engine: object) -> None:
|
|
calls.append(engine)
|
|
|
|
monkeypatch.setattr(db_base, "get_engine", lambda: _ConnectionBudgetEngine())
|
|
monkeypatch.setattr(db_base, "_run_init_db_ddl", fake_run_init_db_ddl)
|
|
|
|
await db_base.init_db()
|
|
|
|
assert calls == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_init_db_releases_bootstrap_lock_when_ddl_fails(monkeypatch):
|
|
from src.db import base as db_base
|
|
|
|
fake_engine = _FakeEngine()
|
|
|
|
async def fake_run_init_db_ddl(_engine: object) -> None:
|
|
raise RuntimeError("ddl failed")
|
|
|
|
monkeypatch.setattr(db_base, "get_engine", lambda: fake_engine)
|
|
monkeypatch.setattr(db_base, "_run_init_db_ddl", fake_run_init_db_ddl)
|
|
|
|
with pytest.raises(RuntimeError, match="ddl failed"):
|
|
await db_base.init_db()
|
|
|
|
assert "pg_try_advisory_lock" in fake_engine.lock_conn.statements[0]
|
|
assert any("pg_advisory_unlock" in stmt for stmt in fake_engine.lock_conn.statements)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_init_db_skips_bootstrap_when_advisory_lock_times_out(monkeypatch):
|
|
from src.db import base as db_base
|
|
|
|
fake_engine = _FakeEngine(lock_results=[False])
|
|
calls: list[object] = []
|
|
|
|
async def fake_run_init_db_ddl(engine: object) -> None:
|
|
calls.append(engine)
|
|
|
|
monkeypatch.setattr(db_base, "get_engine", lambda: fake_engine)
|
|
monkeypatch.setattr(db_base, "_run_init_db_ddl", fake_run_init_db_ddl)
|
|
monkeypatch.setattr(db_base, "_DB_BOOTSTRAP_LOCK_WAIT_SECONDS", 0.0)
|
|
|
|
await db_base.init_db()
|
|
|
|
assert calls == []
|
|
assert "pg_try_advisory_lock" in fake_engine.lock_conn.statements[0]
|
|
assert all("pg_advisory_unlock" not in stmt for stmt in fake_engine.lock_conn.statements)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_init_db_releases_bootstrap_lock_when_ddl_times_out(monkeypatch):
|
|
from src.db import base as db_base
|
|
|
|
fake_engine = _FakeEngine()
|
|
|
|
async def fake_run_init_db_ddl(_engine: object) -> None:
|
|
raise AssertionError("wait_for should own ddl execution in this test")
|
|
|
|
async def fake_wait_for(coro: Awaitable[Any], timeout: float) -> None:
|
|
assert timeout == 120.0
|
|
coro.close()
|
|
raise TimeoutError
|
|
|
|
monkeypatch.setattr(db_base, "get_engine", lambda: fake_engine)
|
|
monkeypatch.setattr(db_base, "_run_init_db_ddl", fake_run_init_db_ddl)
|
|
monkeypatch.setattr(db_base.asyncio, "wait_for", fake_wait_for)
|
|
|
|
await db_base.init_db()
|
|
|
|
assert "pg_try_advisory_lock" in fake_engine.lock_conn.statements[0]
|
|
assert any("pg_advisory_unlock" in stmt for stmt in fake_engine.lock_conn.statements)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_signal_worker_initializes_worker_redis_pool_before_tasks(monkeypatch):
|
|
from src.workers import signal_worker
|
|
|
|
events: list[str] = []
|
|
worker = signal_worker.SignalWorker()
|
|
|
|
async def fake_ensure_consumer_group() -> None:
|
|
events.append("consumer_group")
|
|
|
|
async def fake_init_worker_redis_pool() -> None:
|
|
events.append("worker_redis_pool")
|
|
|
|
def fake_create_task(coro: Awaitable[Any]) -> object:
|
|
events.append("task")
|
|
coro.close()
|
|
|
|
class _FakeTask:
|
|
def cancel(self) -> None:
|
|
return None
|
|
|
|
return _FakeTask()
|
|
|
|
monkeypatch.setattr(worker, "_ensure_consumer_group", fake_ensure_consumer_group)
|
|
monkeypatch.setattr(signal_worker, "init_worker_redis_pool", fake_init_worker_redis_pool)
|
|
monkeypatch.setattr(signal_worker.asyncio, "create_task", fake_create_task)
|
|
|
|
await worker.start()
|
|
|
|
assert events == ["consumer_group", "worker_redis_pool", "task", "task"]
|
|
|
|
|
|
def test_api_lifespan_closes_worker_redis_pool_after_signal_worker() -> None:
|
|
import inspect
|
|
|
|
from src import main as api_main
|
|
|
|
source = inspect.getsource(api_main.lifespan)
|
|
assert "close_worker_redis_pool" in source
|
|
assert source.index("close_signal_worker") < source.index("close_worker_redis_pool")
|
|
assert source.index("close_worker_redis_pool") < source.index("close_redis_pool")
|