Files
awoooi/apps/api/tests/test_db_context_guard.py
Your Name cfb866d055
Some checks failed
Ansible Lint / lint (push) Successful in 35s
CD Pipeline / tests (push) Failing after 13s
CD Pipeline / build-and-deploy (push) Has been skipped
CD Pipeline / post-deploy-checks (push) Has been skipped
Code Review / ai-code-review (push) Failing after 11s
feat(governance): add agent market automation surfaces
2026-06-04 21:50:55 +08:00

98 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# apps/api/tests/test_db_context_guard.py
from __future__ import annotations
from contextlib import asynccontextmanager
from fastapi import HTTPException
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from unittest.mock import patch
from src.db.base import get_db_context
from src.main import db_context_guard, app, http_exception_handler
def test_db_context_guard_without_project_id_is_unauthorized():
"""未提供 project_id 時DB context 取得應 fail-closed。"""
with pytest.raises(HTTPException) as exc:
async def _run():
async with get_db_context():
pass
import asyncio
asyncio.run(_run())
assert exc.value.status_code == 401
@asynccontextmanager
async def _fake_db_context():
"""避免真實 DB 連線的可驗證 success mock。"""
yield
class _UnauthorizedDbContext:
"""Simulate get_db_context() entering a failure path."""
async def __aenter__(self):
raise HTTPException(
status_code=401, detail="Missing tenant context: project_id is required"
)
async def __aexit__(self, exc_type, exc_val, exc_tb): # noqa: ARG001
return False
def _build_guard_app() -> FastAPI:
app = FastAPI()
@app.middleware("http")
async def _project_ctx_middleware(request, call_next):
project_id = (
request.headers.get("X-Project-ID")
or request.headers.get("X-Tenant-ID")
or request.query_params.get("project_id")
)
from src.core.context import clear_project_context, set_project_context
tokens = set_project_context(project_id=project_id, source="test.guard", request_id="test-request")
try:
response = await call_next(request)
return response
finally:
clear_project_context(tokens)
app.add_api_route("/api/v1/security/db-context-guard", db_context_guard, methods=["GET"])
return app
def test_db_context_guard_with_project_id_returns_snapshot():
"""有 project_id 時,應回傳可追溯的 context snapshot。"""
app = _build_guard_app()
with patch("src.db.base.get_db_context", _fake_db_context):
client = TestClient(app)
response = client.get("/api/v1/security/db-context-guard", headers={"X-Project-ID": "awoooi"})
assert response.status_code == 200
body = response.json()
assert body["status"] == "ok"
assert body["project_context"]["project_id"] == "awoooi"
assert body["project_context"]["source"] == "test.guard"
def test_http_exception_handler_is_registered():
assert app.exception_handlers[HTTPException] is http_exception_handler
def test_db_context_guard_endpoint_without_project_id_returns_401():
"""端點缺少 project context 時應回傳 401fail-closed"""
with patch("src.db.base.get_db_context", return_value=_UnauthorizedDbContext()):
test_client = TestClient(app)
response = test_client.get("/api/v1/security/db-context-guard")
assert response.status_code == 401
assert response.json()["detail"] == "Missing tenant context: project_id is required"