Files
awoooi/apps/api/tests/test_ai_governance_endpoints.py
Your Name 6220f52266
All checks were successful
Code Review / ai-code-review (push) Successful in 10s
CD Pipeline / tests (push) Successful in 1m25s
CD Pipeline / build-and-deploy (push) Successful in 3m46s
CD Pipeline / post-deploy-checks (push) Successful in 1m16s
fix(governance): cast dispatch status filter
2026-05-14 18:39:11 +08:00

420 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# apps/api/tests/test_ai_governance_endpoints.py | 2026-05-02 @ Asia/Taipei
"""
Unit Tests — AI Governance Endpoints (PR 1)
覆蓋範圍:
1. events endpoint 分頁邏輯正確
2. events endpoint severity 映射正確critical / warning / info
3. queue endpoint graceful fallbackmock ProgrammingError
4. summary endpoint compliance_rate 計算(含 total=0 邊界)
5. summary endpoint compliance_rate 計算(有 unresolved 的正常情況)
測試策略mock service 層函式,不依賴 DB確保 Router 邏輯正確。
"""
from __future__ import annotations
from datetime import datetime, timezone, timedelta
from unittest.mock import AsyncMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.api.v1.ai_governance import router
from src.models.governance import (
DailyCount,
DispatchItem,
GovernanceEvent,
GovernanceEventsResponse,
GovernanceQueueResponse,
GovernanceSummaryResponse,
map_severity,
)
from src.services.governance_query_service import (
_extract_remediation,
_query_dispatch_table,
_to_governance_event,
)
TAIPEI = timezone(timedelta(hours=8))
NOW = datetime(2026, 5, 2, 12, 0, tzinfo=TAIPEI)
# =============================================================================
# Fixture
# =============================================================================
@pytest.fixture
def client():
app = FastAPI()
app.include_router(router, prefix="/api/v1")
return TestClient(app)
def _make_event(
event_id: str = "evt-001",
event_type: str = "slo_violation",
resolved: bool = False,
) -> GovernanceEvent:
return GovernanceEvent(
id=event_id,
event_type=event_type,
severity=map_severity(event_type),
triggered_at=NOW,
resolved=resolved,
resolved_at=None,
impact="SLO violated",
details={"message": "test"},
remediation=None,
dispatch_ids=[],
)
# =============================================================================
# 1. severity 映射單元測試
# =============================================================================
class TestSeverityMapping:
def test_critical_types(self):
for et in ("slo_violation", "conservative_mode", "governance_slo_data_gap"):
assert map_severity(et) == "critical", f"{et} should be critical"
def test_warning_types(self):
for et in ("trust_drift", "kb_stale", "knowledge_degradation", "execution_blast_radius"):
assert map_severity(et) == "warning", f"{et} should be warning"
def test_info_types(self):
for et in ("replay_degraded", "self_demotion", "llm_hallucination", "unknown_event"):
assert map_severity(et) == "info", f"{et} should be info"
# =============================================================================
# 2. events endpoint 分頁
# =============================================================================
class TestEventsEndpoint:
def test_pagination_default(self, client):
"""page=1 size=20 預設分頁正確."""
fake_response = GovernanceEventsResponse(
items=[_make_event(str(i)) for i in range(5)],
total=5,
page=1,
size=20,
)
with patch(
"src.api.v1.ai_governance.query_governance_events",
new_callable=lambda: lambda **kw: None,
):
with patch(
"src.api.v1.ai_governance.query_governance_events",
new=AsyncMock(return_value=fake_response),
):
r = client.get("/api/v1/ai/governance/events")
assert r.status_code == 200
data = r.json()
assert data["total"] == 5
assert data["page"] == 1
assert data["size"] == 20
assert len(data["items"]) == 5
def test_pagination_custom(self, client):
"""自訂分頁參數傳入 service."""
fake_response = GovernanceEventsResponse(
items=[_make_event()],
total=50,
page=3,
size=10,
)
captured: dict = {}
async def mock_query(**kwargs):
captured.update(kwargs)
return fake_response
with patch("src.api.v1.ai_governance.query_governance_events", new=mock_query):
r = client.get("/api/v1/ai/governance/events?page=3&size=10")
assert r.status_code == 200
assert captured["page"] == 3
assert captured["size"] == 10
data = r.json()
assert data["total"] == 50
def test_severity_filter_passed(self, client):
"""severity query param 正確傳入 service."""
fake_response = GovernanceEventsResponse(items=[], total=0, page=1, size=20)
captured: dict = {}
async def mock_query(**kwargs):
captured.update(kwargs)
return fake_response
with patch("src.api.v1.ai_governance.query_governance_events", new=mock_query):
r = client.get("/api/v1/ai/governance/events?severity=critical")
assert r.status_code == 200
assert captured["severity"] == "critical"
def test_invalid_severity_rejected(self, client):
"""非法 severity 值應被拒絕422."""
r = client.get("/api/v1/ai/governance/events?severity=bad_value")
assert r.status_code == 422
def test_invalid_status_rejected(self, client):
"""非法 status 值應被拒絕422."""
r = client.get("/api/v1/ai/governance/events?status=invalid")
assert r.status_code == 422
def test_severity_in_response(self, client):
"""回傳的事件 severity 欄位對應 event_type 映射."""
events = [
_make_event("e1", "slo_violation"), # critical
_make_event("e2", "trust_drift"), # warning
_make_event("e3", "self_demotion"), # info
]
fake_response = GovernanceEventsResponse(items=events, total=3, page=1, size=20)
with patch(
"src.api.v1.ai_governance.query_governance_events",
new=AsyncMock(return_value=fake_response),
):
r = client.get("/api/v1/ai/governance/events")
assert r.status_code == 200
items = r.json()["items"]
assert items[0]["severity"] == "critical"
assert items[1]["severity"] == "warning"
assert items[2]["severity"] == "info"
class TestEventsReadSideNormalization:
def test_remediation_dict_is_normalized_to_string(self):
"""production details.remediation 可能是 dictresponse schema 必須仍回字串."""
remediation = _extract_remediation({
"remediation": {
"items": [
"補齊 ADR-100 SLO emitter",
"設置 PROMETHEUS_MULTIPROC_DIR",
]
}
})
assert remediation == "補齊 ADR-100 SLO emitter設置 PROMETHEUS_MULTIPROC_DIR"
def test_governance_event_accepts_dict_remediation(self):
"""dict remediation 不應讓 GovernanceEvent Pydantic validation 變成 500."""
row = type("Row", (), {
"id": "evt-001",
"event_type": "governance_slo_data_gap",
"triggered_at": NOW,
"resolved": False,
"resolved_at": None,
"details": {
"message": "SLO metrics missing",
"remediation": {"items": ["補齊 SLO emitter"]},
},
})()
event = _to_governance_event(row)
assert event.remediation == "補齊 SLO emitter"
assert event.impact == "SLO metrics missing"
# =============================================================================
# 3. queue endpoint graceful fallback
# =============================================================================
class TestQueueEndpoint:
def test_graceful_fallback_on_programming_error(self, client):
"""dispatch 表不存在時回 table_pending=true不拋 500."""
fallback = GovernanceQueueResponse(
items=[], total=0, page=1, size=10, table_pending=True,
)
with patch(
"src.api.v1.ai_governance.query_governance_queue",
new=AsyncMock(return_value=fallback),
):
r = client.get("/api/v1/ai/governance/queue")
assert r.status_code == 200
data = r.json()
assert data["table_pending"] is True
assert data["items"] == []
assert data["total"] == 0
def test_normal_response_when_table_ready(self, client):
"""表就緒時正常回傳 items."""
dispatch_item = DispatchItem(
id="d-001",
governance_event_id="evt-001",
event_type="slo_violation",
dispatch_status="pending",
proposed_action="restart deployment",
playbook_id=None,
playbook_trust=None,
created_at=NOW,
dispatched_at=None,
completed_at=None,
operator_note=None,
)
normal = GovernanceQueueResponse(
items=[dispatch_item], total=1, page=1, size=10, table_pending=False,
)
with patch(
"src.api.v1.ai_governance.query_governance_queue",
new=AsyncMock(return_value=normal),
):
r = client.get("/api/v1/ai/governance/queue")
assert r.status_code == 200
data = r.json()
assert data["table_pending"] is False
assert len(data["items"]) == 1
assert data["items"][0]["dispatch_status"] == "pending"
def test_invalid_dispatch_status_rejected(self, client):
"""非法 dispatch_status 應被拒絕422."""
r = client.get("/api/v1/ai/governance/queue?dispatch_status=unknown")
assert r.status_code == 422
def test_queue_query_uses_production_dispatch_schema(self):
"""queue 查詢必須對齊 migration schema使用 dispatched_at不讀不存在的 created_at/operator_note."""
import inspect
source = inspect.getsource(_query_dispatch_table)
assert "d.dispatched_at AS created_at" in source
assert "ORDER BY d.dispatched_at DESC" in source
assert "NULL::text AS operator_note" in source
assert "CAST(:dispatch_status AS governance_dispatch_status)" in source
assert "d.created_at" not in source
assert "d.operator_note" not in source
# =============================================================================
# 4. summary endpoint compliance_rate
# =============================================================================
class TestSummaryEndpoint:
def test_compliance_rate_normal(self, client):
"""有 unresolved 時計算 1 - unresolved/total."""
fake = GovernanceSummaryResponse(
compliance_rate=0.8,
total_events=10,
unresolved_count=2,
daily_counts=[],
)
with patch(
"src.api.v1.ai_governance.query_governance_summary",
new=AsyncMock(return_value=fake),
):
r = client.get("/api/v1/ai/governance/summary")
assert r.status_code == 200
data = r.json()
assert data["compliance_rate"] == pytest.approx(0.8)
assert data["total_events"] == 10
assert data["unresolved_count"] == 2
def test_compliance_rate_all_resolved(self, client):
"""全部已解決時 compliance_rate = 1.0."""
fake = GovernanceSummaryResponse(
compliance_rate=1.0,
total_events=5,
unresolved_count=0,
daily_counts=[],
)
with patch(
"src.api.v1.ai_governance.query_governance_summary",
new=AsyncMock(return_value=fake),
):
r = client.get("/api/v1/ai/governance/summary?days=7")
assert r.status_code == 200
assert r.json()["compliance_rate"] == pytest.approx(1.0)
def test_compliance_rate_total_zero(self, client):
"""total_events=0 時 compliance_rate = 1.0(邊界測試)."""
fake = GovernanceSummaryResponse(
compliance_rate=1.0,
total_events=0,
unresolved_count=0,
daily_counts=[],
)
with patch(
"src.api.v1.ai_governance.query_governance_summary",
new=AsyncMock(return_value=fake),
):
r = client.get("/api/v1/ai/governance/summary")
assert r.status_code == 200
data = r.json()
assert data["compliance_rate"] == pytest.approx(1.0)
assert data["total_events"] == 0
def test_days_max_boundary(self, client):
"""days=90 邊界值應被接受."""
fake = GovernanceSummaryResponse(
compliance_rate=1.0, total_events=0, unresolved_count=0, daily_counts=[],
)
with patch(
"src.api.v1.ai_governance.query_governance_summary",
new=AsyncMock(return_value=fake),
):
r = client.get("/api/v1/ai/governance/summary?days=90")
assert r.status_code == 200
def test_days_over_max_rejected(self, client):
"""days=91 應被拒絕422."""
r = client.get("/api/v1/ai/governance/summary?days=91")
assert r.status_code == 422
def test_daily_counts_structure(self, client):
"""daily_counts 結構正確."""
fake = GovernanceSummaryResponse(
compliance_rate=0.9,
total_events=10,
unresolved_count=1,
daily_counts=[
DailyCount(date="2026-05-01", total=3, by_type={"slo_violation": 2, "trust_drift": 1}),
DailyCount(date="2026-05-02", total=7, by_type={"slo_violation": 7}),
],
)
with patch(
"src.api.v1.ai_governance.query_governance_summary",
new=AsyncMock(return_value=fake),
):
r = client.get("/api/v1/ai/governance/summary")
assert r.status_code == 200
counts = r.json()["daily_counts"]
assert len(counts) == 2
assert counts[0]["date"] == "2026-05-01"
assert counts[0]["by_type"]["slo_violation"] == 2
# =============================================================================
# 5. service 層 compliance_rate 純函式測試(不經 HTTP
# =============================================================================
class TestComplianceRateCalculation:
"""直接測試 service 邏輯,不經 Router。"""
def test_formula_normal(self):
"""1 - 2/10 = 0.8"""
rate = round(1.0 - 2 / 10, 4)
assert rate == pytest.approx(0.8)
def test_formula_zero_total(self):
"""total=0 → 1.0"""
total = 0
rate = 1.0 if total == 0 else round(1.0 - 0 / total, 4)
assert rate == pytest.approx(1.0)
def test_formula_all_unresolved(self):
"""1 - 5/5 = 0.0"""
rate = round(1.0 - 5 / 5, 4)
assert rate == pytest.approx(0.0)