98 lines
3.1 KiB
Python
98 lines
3.1 KiB
Python
# apps/api/tests/test_db_context_guard.py
|
||
from __future__ import annotations
|
||
|
||
from contextlib import asynccontextmanager
|
||
from fastapi import HTTPException
|
||
|
||
import pytest
|
||
from fastapi import FastAPI
|
||
from fastapi.testclient import TestClient
|
||
from unittest.mock import patch
|
||
|
||
from src.db.base import get_db_context
|
||
from src.main import db_context_guard, app, http_exception_handler
|
||
|
||
|
||
def test_db_context_guard_without_project_id_is_unauthorized():
|
||
"""未提供 project_id 時,DB context 取得應 fail-closed。"""
|
||
with pytest.raises(HTTPException) as exc:
|
||
async def _run():
|
||
async with get_db_context():
|
||
pass
|
||
|
||
import asyncio
|
||
|
||
asyncio.run(_run())
|
||
|
||
assert exc.value.status_code == 401
|
||
|
||
|
||
@asynccontextmanager
|
||
async def _fake_db_context():
|
||
"""避免真實 DB 連線的可驗證 success mock。"""
|
||
yield
|
||
|
||
|
||
class _UnauthorizedDbContext:
|
||
"""Simulate get_db_context() entering a failure path."""
|
||
|
||
async def __aenter__(self):
|
||
raise HTTPException(
|
||
status_code=401, detail="Missing tenant context: project_id is required"
|
||
)
|
||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb): # noqa: ARG001
|
||
return False
|
||
|
||
|
||
def _build_guard_app() -> FastAPI:
|
||
app = FastAPI()
|
||
|
||
@app.middleware("http")
|
||
async def _project_ctx_middleware(request, call_next):
|
||
project_id = (
|
||
request.headers.get("X-Project-ID")
|
||
or request.headers.get("X-Tenant-ID")
|
||
or request.query_params.get("project_id")
|
||
)
|
||
from src.core.context import clear_project_context, set_project_context
|
||
|
||
tokens = set_project_context(project_id=project_id, source="test.guard", request_id="test-request")
|
||
try:
|
||
response = await call_next(request)
|
||
return response
|
||
finally:
|
||
clear_project_context(tokens)
|
||
|
||
app.add_api_route("/api/v1/security/db-context-guard", db_context_guard, methods=["GET"])
|
||
return app
|
||
|
||
|
||
def test_db_context_guard_with_project_id_returns_snapshot():
|
||
"""有 project_id 時,應回傳可追溯的 context snapshot。"""
|
||
app = _build_guard_app()
|
||
with patch("src.db.base.get_db_context", _fake_db_context):
|
||
client = TestClient(app)
|
||
response = client.get("/api/v1/security/db-context-guard", headers={"X-Project-ID": "awoooi"})
|
||
|
||
assert response.status_code == 200
|
||
body = response.json()
|
||
assert body["status"] == "ok"
|
||
assert body["project_context"]["project_id"] == "awoooi"
|
||
assert body["project_context"]["source"] == "test.guard"
|
||
|
||
|
||
def test_http_exception_handler_is_registered():
|
||
assert app.exception_handlers[HTTPException] is http_exception_handler
|
||
|
||
|
||
def test_db_context_guard_endpoint_without_project_id_returns_401():
|
||
"""端點缺少 project context 時應回傳 401(fail-closed)。"""
|
||
|
||
with patch("src.db.base.get_db_context", return_value=_UnauthorizedDbContext()):
|
||
test_client = TestClient(app)
|
||
response = test_client.get("/api/v1/security/db-context-guard")
|
||
|
||
assert response.status_code == 401
|
||
assert response.json()["detail"] == "Missing tenant context: project_id is required"
|