# apps/api/tests/test_db_context_guard.py from __future__ import annotations from contextlib import asynccontextmanager from fastapi import HTTPException import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from unittest.mock import patch from src.db.base import get_db_context from src.main import db_context_guard, app, http_exception_handler def test_db_context_guard_without_project_id_is_unauthorized(): """未提供 project_id 時,DB context 取得應 fail-closed。""" with pytest.raises(HTTPException) as exc: async def _run(): async with get_db_context(): pass import asyncio asyncio.run(_run()) assert exc.value.status_code == 401 @asynccontextmanager async def _fake_db_context(): """避免真實 DB 連線的可驗證 success mock。""" yield class _UnauthorizedDbContext: """Simulate get_db_context() entering a failure path.""" async def __aenter__(self): raise HTTPException( status_code=401, detail="Missing tenant context: project_id is required" ) async def __aexit__(self, exc_type, exc_val, exc_tb): # noqa: ARG001 return False def _build_guard_app() -> FastAPI: app = FastAPI() @app.middleware("http") async def _project_ctx_middleware(request, call_next): project_id = ( request.headers.get("X-Project-ID") or request.headers.get("X-Tenant-ID") or request.query_params.get("project_id") ) from src.core.context import clear_project_context, set_project_context tokens = set_project_context(project_id=project_id, source="test.guard", request_id="test-request") try: response = await call_next(request) return response finally: clear_project_context(tokens) app.add_api_route("/api/v1/security/db-context-guard", db_context_guard, methods=["GET"]) return app def test_db_context_guard_with_project_id_returns_snapshot(): """有 project_id 時,應回傳可追溯的 context snapshot。""" app = _build_guard_app() with patch("src.db.base.get_db_context", _fake_db_context): client = TestClient(app) response = client.get("/api/v1/security/db-context-guard", headers={"X-Project-ID": "awoooi"}) assert response.status_code == 200 body = response.json() assert body["status"] == "ok" assert body["project_context"]["project_id"] == "awoooi" assert body["project_context"]["source"] == "test.guard" def test_http_exception_handler_is_registered(): assert app.exception_handlers[HTTPException] is http_exception_handler def test_db_context_guard_endpoint_without_project_id_returns_401(): """端點缺少 project context 時應回傳 401(fail-closed)。""" with patch("src.db.base.get_db_context", return_value=_UnauthorizedDbContext()): test_client = TestClient(app) response = test_client.get("/api/v1/security/db-context-guard") assert response.status_code == 401 assert response.json()["detail"] == "Missing tenant context: project_id is required"