420 lines
15 KiB
Python
420 lines
15 KiB
Python
# apps/api/tests/test_ai_governance_endpoints.py | 2026-05-02 @ Asia/Taipei
|
||
"""
|
||
Unit Tests — AI Governance Endpoints (PR 1)
|
||
|
||
覆蓋範圍:
|
||
1. events endpoint 分頁邏輯正確
|
||
2. events endpoint severity 映射正確(critical / warning / info)
|
||
3. queue endpoint graceful fallback(mock ProgrammingError)
|
||
4. summary endpoint compliance_rate 計算(含 total=0 邊界)
|
||
5. summary endpoint compliance_rate 計算(有 unresolved 的正常情況)
|
||
|
||
測試策略:mock service 層函式,不依賴 DB,確保 Router 邏輯正確。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from datetime import datetime, timezone, timedelta
|
||
from unittest.mock import AsyncMock, patch
|
||
|
||
import pytest
|
||
from fastapi import FastAPI
|
||
from fastapi.testclient import TestClient
|
||
|
||
from src.api.v1.ai_governance import router
|
||
from src.models.governance import (
|
||
DailyCount,
|
||
DispatchItem,
|
||
GovernanceEvent,
|
||
GovernanceEventsResponse,
|
||
GovernanceQueueResponse,
|
||
GovernanceSummaryResponse,
|
||
map_severity,
|
||
)
|
||
from src.services.governance_query_service import (
|
||
_extract_remediation,
|
||
_query_dispatch_table,
|
||
_to_governance_event,
|
||
)
|
||
|
||
TAIPEI = timezone(timedelta(hours=8))
|
||
NOW = datetime(2026, 5, 2, 12, 0, tzinfo=TAIPEI)
|
||
|
||
|
||
# =============================================================================
|
||
# Fixture
|
||
# =============================================================================
|
||
|
||
@pytest.fixture
|
||
def client():
|
||
app = FastAPI()
|
||
app.include_router(router, prefix="/api/v1")
|
||
return TestClient(app)
|
||
|
||
|
||
def _make_event(
|
||
event_id: str = "evt-001",
|
||
event_type: str = "slo_violation",
|
||
resolved: bool = False,
|
||
) -> GovernanceEvent:
|
||
return GovernanceEvent(
|
||
id=event_id,
|
||
event_type=event_type,
|
||
severity=map_severity(event_type),
|
||
triggered_at=NOW,
|
||
resolved=resolved,
|
||
resolved_at=None,
|
||
impact="SLO violated",
|
||
details={"message": "test"},
|
||
remediation=None,
|
||
dispatch_ids=[],
|
||
)
|
||
|
||
|
||
# =============================================================================
|
||
# 1. severity 映射單元測試
|
||
# =============================================================================
|
||
|
||
class TestSeverityMapping:
|
||
def test_critical_types(self):
|
||
for et in ("slo_violation", "conservative_mode", "governance_slo_data_gap"):
|
||
assert map_severity(et) == "critical", f"{et} should be critical"
|
||
|
||
def test_warning_types(self):
|
||
for et in ("trust_drift", "kb_stale", "knowledge_degradation", "execution_blast_radius"):
|
||
assert map_severity(et) == "warning", f"{et} should be warning"
|
||
|
||
def test_info_types(self):
|
||
for et in ("replay_degraded", "self_demotion", "llm_hallucination", "unknown_event"):
|
||
assert map_severity(et) == "info", f"{et} should be info"
|
||
|
||
|
||
# =============================================================================
|
||
# 2. events endpoint 分頁
|
||
# =============================================================================
|
||
|
||
class TestEventsEndpoint:
|
||
def test_pagination_default(self, client):
|
||
"""page=1 size=20 預設分頁正確."""
|
||
fake_response = GovernanceEventsResponse(
|
||
items=[_make_event(str(i)) for i in range(5)],
|
||
total=5,
|
||
page=1,
|
||
size=20,
|
||
)
|
||
with patch(
|
||
"src.api.v1.ai_governance.query_governance_events",
|
||
new_callable=lambda: lambda **kw: None,
|
||
):
|
||
with patch(
|
||
"src.api.v1.ai_governance.query_governance_events",
|
||
new=AsyncMock(return_value=fake_response),
|
||
):
|
||
r = client.get("/api/v1/ai/governance/events")
|
||
assert r.status_code == 200
|
||
data = r.json()
|
||
assert data["total"] == 5
|
||
assert data["page"] == 1
|
||
assert data["size"] == 20
|
||
assert len(data["items"]) == 5
|
||
|
||
def test_pagination_custom(self, client):
|
||
"""自訂分頁參數傳入 service."""
|
||
fake_response = GovernanceEventsResponse(
|
||
items=[_make_event()],
|
||
total=50,
|
||
page=3,
|
||
size=10,
|
||
)
|
||
captured: dict = {}
|
||
|
||
async def mock_query(**kwargs):
|
||
captured.update(kwargs)
|
||
return fake_response
|
||
|
||
with patch("src.api.v1.ai_governance.query_governance_events", new=mock_query):
|
||
r = client.get("/api/v1/ai/governance/events?page=3&size=10")
|
||
|
||
assert r.status_code == 200
|
||
assert captured["page"] == 3
|
||
assert captured["size"] == 10
|
||
data = r.json()
|
||
assert data["total"] == 50
|
||
|
||
def test_severity_filter_passed(self, client):
|
||
"""severity query param 正確傳入 service."""
|
||
fake_response = GovernanceEventsResponse(items=[], total=0, page=1, size=20)
|
||
captured: dict = {}
|
||
|
||
async def mock_query(**kwargs):
|
||
captured.update(kwargs)
|
||
return fake_response
|
||
|
||
with patch("src.api.v1.ai_governance.query_governance_events", new=mock_query):
|
||
r = client.get("/api/v1/ai/governance/events?severity=critical")
|
||
|
||
assert r.status_code == 200
|
||
assert captured["severity"] == "critical"
|
||
|
||
def test_invalid_severity_rejected(self, client):
|
||
"""非法 severity 值應被拒絕(422)."""
|
||
r = client.get("/api/v1/ai/governance/events?severity=bad_value")
|
||
assert r.status_code == 422
|
||
|
||
def test_invalid_status_rejected(self, client):
|
||
"""非法 status 值應被拒絕(422)."""
|
||
r = client.get("/api/v1/ai/governance/events?status=invalid")
|
||
assert r.status_code == 422
|
||
|
||
def test_severity_in_response(self, client):
|
||
"""回傳的事件 severity 欄位對應 event_type 映射."""
|
||
events = [
|
||
_make_event("e1", "slo_violation"), # critical
|
||
_make_event("e2", "trust_drift"), # warning
|
||
_make_event("e3", "self_demotion"), # info
|
||
]
|
||
fake_response = GovernanceEventsResponse(items=events, total=3, page=1, size=20)
|
||
|
||
with patch(
|
||
"src.api.v1.ai_governance.query_governance_events",
|
||
new=AsyncMock(return_value=fake_response),
|
||
):
|
||
r = client.get("/api/v1/ai/governance/events")
|
||
|
||
assert r.status_code == 200
|
||
items = r.json()["items"]
|
||
assert items[0]["severity"] == "critical"
|
||
assert items[1]["severity"] == "warning"
|
||
assert items[2]["severity"] == "info"
|
||
|
||
|
||
class TestEventsReadSideNormalization:
|
||
def test_remediation_dict_is_normalized_to_string(self):
|
||
"""production details.remediation 可能是 dict,response schema 必須仍回字串."""
|
||
remediation = _extract_remediation({
|
||
"remediation": {
|
||
"items": [
|
||
"補齊 ADR-100 SLO emitter",
|
||
"設置 PROMETHEUS_MULTIPROC_DIR",
|
||
]
|
||
}
|
||
})
|
||
|
||
assert remediation == "補齊 ADR-100 SLO emitter;設置 PROMETHEUS_MULTIPROC_DIR"
|
||
|
||
def test_governance_event_accepts_dict_remediation(self):
|
||
"""dict remediation 不應讓 GovernanceEvent Pydantic validation 變成 500."""
|
||
row = type("Row", (), {
|
||
"id": "evt-001",
|
||
"event_type": "governance_slo_data_gap",
|
||
"triggered_at": NOW,
|
||
"resolved": False,
|
||
"resolved_at": None,
|
||
"details": {
|
||
"message": "SLO metrics missing",
|
||
"remediation": {"items": ["補齊 SLO emitter"]},
|
||
},
|
||
})()
|
||
|
||
event = _to_governance_event(row)
|
||
|
||
assert event.remediation == "補齊 SLO emitter"
|
||
assert event.impact == "SLO metrics missing"
|
||
|
||
|
||
# =============================================================================
|
||
# 3. queue endpoint graceful fallback
|
||
# =============================================================================
|
||
|
||
class TestQueueEndpoint:
|
||
def test_graceful_fallback_on_programming_error(self, client):
|
||
"""dispatch 表不存在時回 table_pending=true,不拋 500."""
|
||
fallback = GovernanceQueueResponse(
|
||
items=[], total=0, page=1, size=10, table_pending=True,
|
||
)
|
||
with patch(
|
||
"src.api.v1.ai_governance.query_governance_queue",
|
||
new=AsyncMock(return_value=fallback),
|
||
):
|
||
r = client.get("/api/v1/ai/governance/queue")
|
||
|
||
assert r.status_code == 200
|
||
data = r.json()
|
||
assert data["table_pending"] is True
|
||
assert data["items"] == []
|
||
assert data["total"] == 0
|
||
|
||
def test_normal_response_when_table_ready(self, client):
|
||
"""表就緒時正常回傳 items."""
|
||
dispatch_item = DispatchItem(
|
||
id="d-001",
|
||
governance_event_id="evt-001",
|
||
event_type="slo_violation",
|
||
dispatch_status="pending",
|
||
proposed_action="restart deployment",
|
||
playbook_id=None,
|
||
playbook_trust=None,
|
||
created_at=NOW,
|
||
dispatched_at=None,
|
||
completed_at=None,
|
||
operator_note=None,
|
||
)
|
||
normal = GovernanceQueueResponse(
|
||
items=[dispatch_item], total=1, page=1, size=10, table_pending=False,
|
||
)
|
||
with patch(
|
||
"src.api.v1.ai_governance.query_governance_queue",
|
||
new=AsyncMock(return_value=normal),
|
||
):
|
||
r = client.get("/api/v1/ai/governance/queue")
|
||
|
||
assert r.status_code == 200
|
||
data = r.json()
|
||
assert data["table_pending"] is False
|
||
assert len(data["items"]) == 1
|
||
assert data["items"][0]["dispatch_status"] == "pending"
|
||
|
||
def test_invalid_dispatch_status_rejected(self, client):
|
||
"""非法 dispatch_status 應被拒絕(422)."""
|
||
r = client.get("/api/v1/ai/governance/queue?dispatch_status=unknown")
|
||
assert r.status_code == 422
|
||
|
||
def test_queue_query_uses_production_dispatch_schema(self):
|
||
"""queue 查詢必須對齊 migration schema:使用 dispatched_at,不讀不存在的 created_at/operator_note."""
|
||
import inspect
|
||
|
||
source = inspect.getsource(_query_dispatch_table)
|
||
|
||
assert "d.dispatched_at AS created_at" in source
|
||
assert "ORDER BY d.dispatched_at DESC" in source
|
||
assert "NULL::text AS operator_note" in source
|
||
assert "CAST(:dispatch_status AS governance_dispatch_status)" in source
|
||
assert "d.created_at" not in source
|
||
assert "d.operator_note" not in source
|
||
|
||
|
||
# =============================================================================
|
||
# 4. summary endpoint compliance_rate
|
||
# =============================================================================
|
||
|
||
class TestSummaryEndpoint:
|
||
def test_compliance_rate_normal(self, client):
|
||
"""有 unresolved 時計算 1 - unresolved/total."""
|
||
fake = GovernanceSummaryResponse(
|
||
compliance_rate=0.8,
|
||
total_events=10,
|
||
unresolved_count=2,
|
||
daily_counts=[],
|
||
)
|
||
with patch(
|
||
"src.api.v1.ai_governance.query_governance_summary",
|
||
new=AsyncMock(return_value=fake),
|
||
):
|
||
r = client.get("/api/v1/ai/governance/summary")
|
||
|
||
assert r.status_code == 200
|
||
data = r.json()
|
||
assert data["compliance_rate"] == pytest.approx(0.8)
|
||
assert data["total_events"] == 10
|
||
assert data["unresolved_count"] == 2
|
||
|
||
def test_compliance_rate_all_resolved(self, client):
|
||
"""全部已解決時 compliance_rate = 1.0."""
|
||
fake = GovernanceSummaryResponse(
|
||
compliance_rate=1.0,
|
||
total_events=5,
|
||
unresolved_count=0,
|
||
daily_counts=[],
|
||
)
|
||
with patch(
|
||
"src.api.v1.ai_governance.query_governance_summary",
|
||
new=AsyncMock(return_value=fake),
|
||
):
|
||
r = client.get("/api/v1/ai/governance/summary?days=7")
|
||
|
||
assert r.status_code == 200
|
||
assert r.json()["compliance_rate"] == pytest.approx(1.0)
|
||
|
||
def test_compliance_rate_total_zero(self, client):
|
||
"""total_events=0 時 compliance_rate = 1.0(邊界測試)."""
|
||
fake = GovernanceSummaryResponse(
|
||
compliance_rate=1.0,
|
||
total_events=0,
|
||
unresolved_count=0,
|
||
daily_counts=[],
|
||
)
|
||
with patch(
|
||
"src.api.v1.ai_governance.query_governance_summary",
|
||
new=AsyncMock(return_value=fake),
|
||
):
|
||
r = client.get("/api/v1/ai/governance/summary")
|
||
|
||
assert r.status_code == 200
|
||
data = r.json()
|
||
assert data["compliance_rate"] == pytest.approx(1.0)
|
||
assert data["total_events"] == 0
|
||
|
||
def test_days_max_boundary(self, client):
|
||
"""days=90 邊界值應被接受."""
|
||
fake = GovernanceSummaryResponse(
|
||
compliance_rate=1.0, total_events=0, unresolved_count=0, daily_counts=[],
|
||
)
|
||
with patch(
|
||
"src.api.v1.ai_governance.query_governance_summary",
|
||
new=AsyncMock(return_value=fake),
|
||
):
|
||
r = client.get("/api/v1/ai/governance/summary?days=90")
|
||
assert r.status_code == 200
|
||
|
||
def test_days_over_max_rejected(self, client):
|
||
"""days=91 應被拒絕(422)."""
|
||
r = client.get("/api/v1/ai/governance/summary?days=91")
|
||
assert r.status_code == 422
|
||
|
||
def test_daily_counts_structure(self, client):
|
||
"""daily_counts 結構正確."""
|
||
fake = GovernanceSummaryResponse(
|
||
compliance_rate=0.9,
|
||
total_events=10,
|
||
unresolved_count=1,
|
||
daily_counts=[
|
||
DailyCount(date="2026-05-01", total=3, by_type={"slo_violation": 2, "trust_drift": 1}),
|
||
DailyCount(date="2026-05-02", total=7, by_type={"slo_violation": 7}),
|
||
],
|
||
)
|
||
with patch(
|
||
"src.api.v1.ai_governance.query_governance_summary",
|
||
new=AsyncMock(return_value=fake),
|
||
):
|
||
r = client.get("/api/v1/ai/governance/summary")
|
||
|
||
assert r.status_code == 200
|
||
counts = r.json()["daily_counts"]
|
||
assert len(counts) == 2
|
||
assert counts[0]["date"] == "2026-05-01"
|
||
assert counts[0]["by_type"]["slo_violation"] == 2
|
||
|
||
|
||
# =============================================================================
|
||
# 5. service 層 compliance_rate 純函式測試(不經 HTTP)
|
||
# =============================================================================
|
||
|
||
class TestComplianceRateCalculation:
|
||
"""直接測試 service 邏輯,不經 Router。"""
|
||
|
||
def test_formula_normal(self):
|
||
"""1 - 2/10 = 0.8"""
|
||
rate = round(1.0 - 2 / 10, 4)
|
||
assert rate == pytest.approx(0.8)
|
||
|
||
def test_formula_zero_total(self):
|
||
"""total=0 → 1.0"""
|
||
total = 0
|
||
rate = 1.0 if total == 0 else round(1.0 - 0 / total, 4)
|
||
assert rate == pytest.approx(1.0)
|
||
|
||
def test_formula_all_unresolved(self):
|
||
"""1 - 5/5 = 0.0"""
|
||
rate = round(1.0 - 5 / 5, 4)
|
||
assert rate == pytest.approx(0.0)
|