371 lines
14 KiB
Python
371 lines
14 KiB
Python
# apps/api/tests/test_p2_db_fixes.py
|
||
# 2026-04-26 P2-DB-Fix by Claude — db-expert P0 三修 驗收測試
|
||
"""
|
||
P0.1 / P0.2 / P0.3 三修驗收測試
|
||
================================
|
||
|
||
測試分類:unit(全部 mock DB,無真實 PG 依賴)
|
||
|
||
覆蓋:
|
||
P0.1 — test_governance_agent_writes_to_pg
|
||
GovernanceAgent._alert() 呼叫時,AiGovernanceEvent INSERT 被執行
|
||
P0.2 — test_consensus_engine_persists_to_pg
|
||
ConsensusEngine._save_consensus() 寫入 N+1 行到 agent_sessions
|
||
P0.3 — migration SQL syntax check(pyparsing-free,用 re 驗證關鍵字)
|
||
— approval_db.update_decision_fusion 呼叫正確欄位
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import re
|
||
from pathlib import Path
|
||
from typing import Any
|
||
from unittest.mock import ANY, AsyncMock, MagicMock, call, patch
|
||
|
||
import pytest
|
||
|
||
|
||
# =============================================================================
|
||
# P0.1 — GovernanceAgent._alert() 寫入 ai_governance_events
|
||
# =============================================================================
|
||
|
||
|
||
class TestGovernanceAgentWritesToPg:
|
||
"""P0.1: _alert() 必須在 logger + Telegram 前先寫 PG"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pg_insert_called_on_alert(self):
|
||
"""_alert() 被呼叫 → AiGovernanceEvent INSERT 觸發(PG 寫入優先)"""
|
||
from src.services.governance_agent import GovernanceAgent
|
||
|
||
alerter = AsyncMock()
|
||
alerter.alert_governance = AsyncMock()
|
||
agent = GovernanceAgent(alerter=alerter)
|
||
|
||
mock_db = AsyncMock()
|
||
mock_db.execute = AsyncMock()
|
||
mock_db.commit = AsyncMock()
|
||
|
||
with patch("src.services.governance_agent.get_db_context") as mock_ctx:
|
||
mock_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db)
|
||
mock_ctx.return_value.__aexit__ = AsyncMock(return_value=False)
|
||
|
||
await agent._alert("llm_hallucination", {"rate": 0.15, "failed": 15})
|
||
|
||
# PG 寫入必須觸發
|
||
mock_db.execute.assert_called_once()
|
||
mock_db.commit.assert_called_once()
|
||
|
||
# Telegram 告警也要觸發(既有行為不破壞)
|
||
alerter.alert_governance.assert_called_once_with(
|
||
"llm_hallucination", {"rate": 0.15, "failed": 15}
|
||
)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pg_failure_does_not_block_telegram(self):
|
||
"""PG 寫入失敗 → 不阻斷 Telegram 告警(ADR-085 保底設計)"""
|
||
from src.services.governance_agent import GovernanceAgent
|
||
|
||
alerter = AsyncMock()
|
||
alerter.alert_governance = AsyncMock()
|
||
agent = GovernanceAgent(alerter=alerter)
|
||
|
||
mock_db = AsyncMock()
|
||
mock_db.execute = AsyncMock(side_effect=RuntimeError("PG down"))
|
||
|
||
with patch("src.services.governance_agent.get_db_context") as mock_ctx:
|
||
mock_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db)
|
||
mock_ctx.return_value.__aexit__ = AsyncMock(return_value=False)
|
||
|
||
# 不應拋例外
|
||
await agent._alert("execution_blast_radius", {"rate": 0.25})
|
||
|
||
# Telegram 仍然被呼叫
|
||
alerter.alert_governance.assert_called_once()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pg_insert_uses_correct_event_type(self):
|
||
"""INSERT 時 event_type 欄位必須與 _alert() 入參一致"""
|
||
from src.services.governance_agent import GovernanceAgent
|
||
|
||
alerter = AsyncMock()
|
||
alerter.alert_governance = AsyncMock()
|
||
agent = GovernanceAgent(alerter=alerter)
|
||
|
||
captured_stmt = {}
|
||
|
||
async def capture_execute(stmt):
|
||
captured_stmt["stmt"] = stmt
|
||
|
||
mock_db = AsyncMock()
|
||
mock_db.execute = capture_execute
|
||
mock_db.commit = AsyncMock()
|
||
|
||
with patch("src.services.governance_agent.get_db_context") as mock_ctx:
|
||
mock_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db)
|
||
mock_ctx.return_value.__aexit__ = AsyncMock(return_value=False)
|
||
|
||
await agent._alert("trust_drift", {"drifted_count": 3})
|
||
|
||
# INSERT 語句必須被捕捉到(不是 None)
|
||
assert captured_stmt.get("stmt") is not None
|
||
|
||
|
||
# =============================================================================
|
||
# P0.2 — ConsensusEngine._save_consensus() 寫入 agent_sessions
|
||
# =============================================================================
|
||
|
||
|
||
class TestConsensusEnginePersistsToPg:
|
||
"""P0.2: _save_consensus() 必須同時寫 Redis 和 PG(N opinions + 1 coordinator)"""
|
||
|
||
def _make_result(self, n_opinions: int = 3) -> Any:
|
||
"""建立 ConsensusResult mock"""
|
||
from src.services.consensus_engine import (
|
||
AgentOpinion,
|
||
AgentType,
|
||
ConsensusResult,
|
||
)
|
||
from datetime import datetime, timezone
|
||
|
||
opinions = []
|
||
agent_types = [AgentType.SRE, AgentType.SECURITY, AgentType.COST, AgentType.PERFORMANCE]
|
||
for i in range(n_opinions):
|
||
opinions.append(
|
||
AgentOpinion(
|
||
agent_type=agent_types[i % len(agent_types)],
|
||
action=f"action_{i}",
|
||
reasoning=f"reasoning_{i}",
|
||
confidence=0.8,
|
||
risk_assessment="medium",
|
||
)
|
||
)
|
||
|
||
return ConsensusResult(
|
||
consensus_id="CS-TEST-001",
|
||
incident_id="INC-TEST-001",
|
||
opinions=opinions,
|
||
consensus_score=0.75,
|
||
recommended_action="restart service",
|
||
final_reasoning="consensus reached",
|
||
risk_level="medium",
|
||
)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pg_insert_called_with_n_plus_1_rows(self):
|
||
"""3 opinions → INSERT 4 行(3 agent + 1 coordinator)"""
|
||
from src.services.consensus_engine import ConsensusEngine
|
||
|
||
result = self._make_result(n_opinions=3)
|
||
engine = ConsensusEngine()
|
||
|
||
mock_redis = AsyncMock()
|
||
mock_redis.set = AsyncMock()
|
||
|
||
mock_db = AsyncMock()
|
||
mock_db.execute = AsyncMock()
|
||
mock_db.commit = AsyncMock()
|
||
|
||
# lazy import 從 src.db.base 取,patch 目標必須是來源模組
|
||
with patch("src.services.consensus_engine.get_redis", return_value=mock_redis):
|
||
with patch("src.db.base.get_db_context") as mock_ctx:
|
||
mock_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db)
|
||
mock_ctx.return_value.__aexit__ = AsyncMock(return_value=False)
|
||
|
||
await engine._save_consensus(result)
|
||
|
||
# Redis Phase A 雙寫(新 namespace + legacy 熱快取)
|
||
mock_redis.set.assert_has_calls(
|
||
[
|
||
call("__platform__:consensus:CS-TEST-001", ANY, ex=3600),
|
||
call("consensus:CS-TEST-001", ANY, ex=3600),
|
||
]
|
||
)
|
||
|
||
# PG 寫(永久記錄)
|
||
mock_db.execute.assert_called_once()
|
||
mock_db.commit.assert_called_once()
|
||
|
||
# 驗證傳入 execute 的 rows 數量 = opinions + 1 coordinator
|
||
call_args = mock_db.execute.call_args
|
||
assert call_args is not None
|
||
rows_arg = call_args[0][1] if len(call_args[0]) > 1 else call_args[1].get("rows")
|
||
if rows_arg is not None:
|
||
assert len(rows_arg) == 4 # 3 opinions + 1 coordinator
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_coordinator_row_has_correct_vote(self):
|
||
"""coordinator 行:consensus_score >= 0.6 → vote='approve'"""
|
||
from src.services.consensus_engine import ConsensusEngine
|
||
|
||
result = self._make_result(n_opinions=2)
|
||
# consensus_score=0.75 >= 0.6 → approve
|
||
engine = ConsensusEngine()
|
||
|
||
captured_rows: list[dict] = []
|
||
|
||
mock_redis = AsyncMock()
|
||
mock_redis.set = AsyncMock()
|
||
|
||
async def capture_execute(_stmt, rows=None):
|
||
if rows:
|
||
captured_rows.extend(rows)
|
||
|
||
mock_db = AsyncMock()
|
||
mock_db.execute = capture_execute
|
||
mock_db.commit = AsyncMock()
|
||
|
||
with patch("src.services.consensus_engine.get_redis", return_value=mock_redis):
|
||
with patch("src.db.base.get_db_context") as mock_ctx:
|
||
mock_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db)
|
||
mock_ctx.return_value.__aexit__ = AsyncMock(return_value=False)
|
||
|
||
await engine._save_consensus(result)
|
||
|
||
# 找 coordinator 行
|
||
coordinator_rows = [r for r in captured_rows if r.get("agent_role") == "coordinator"]
|
||
if coordinator_rows:
|
||
assert coordinator_rows[0]["vote"] == "approve"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pg_failure_does_not_block_redis(self):
|
||
"""PG 寫入失敗 → Redis 仍完成(ADR-085 保底)"""
|
||
from src.services.consensus_engine import ConsensusEngine
|
||
|
||
result = self._make_result(n_opinions=2)
|
||
engine = ConsensusEngine()
|
||
|
||
mock_redis = AsyncMock()
|
||
mock_redis.set = AsyncMock()
|
||
|
||
mock_db = AsyncMock()
|
||
mock_db.execute = AsyncMock(side_effect=RuntimeError("PG down"))
|
||
|
||
with patch("src.services.consensus_engine.get_redis", return_value=mock_redis):
|
||
with patch("src.db.base.get_db_context") as mock_ctx:
|
||
mock_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db)
|
||
mock_ctx.return_value.__aexit__ = AsyncMock(return_value=False)
|
||
|
||
# 不應拋例外
|
||
await engine._save_consensus(result)
|
||
|
||
# Redis Phase A 雙寫已完成(在 PG 嘗試之前)
|
||
mock_redis.set.assert_has_calls(
|
||
[
|
||
call("__platform__:consensus:CS-TEST-001", ANY, ex=3600),
|
||
call("consensus:CS-TEST-001", ANY, ex=3600),
|
||
]
|
||
)
|
||
|
||
|
||
# =============================================================================
|
||
# P0.3 — Migration SQL syntax smoke test
|
||
# =============================================================================
|
||
|
||
|
||
class TestMigrationSqlSyntax:
|
||
"""P0.3: migration SQL 必須包含必要關鍵字,格式合法"""
|
||
|
||
def _read_sql(self, filename: str) -> str:
|
||
path = Path(__file__).parent.parent / "migrations" / filename
|
||
return path.read_text()
|
||
|
||
def test_migration_contains_required_statements(self):
|
||
"""p2_decision_fusion_columns.sql 必須包含 ALTER TABLE + 3 欄位 + 2 index"""
|
||
sql = self._read_sql("p2_decision_fusion_columns.sql")
|
||
|
||
assert "ALTER TABLE approval_records" in sql
|
||
assert "composite_score" in sql
|
||
assert "complexity_tier" in sql
|
||
assert "decision_fusion_details" in sql
|
||
assert "chk_complexity_tier" in sql
|
||
assert "ix_approval_composite_score" in sql
|
||
assert "ix_approval_complexity_tier" in sql
|
||
assert "CONCURRENTLY" in sql
|
||
|
||
def test_rollback_contains_drop_statements(self):
|
||
"""p2_decision_fusion_columns_rollback.sql 必須包含 DROP COLUMN + DROP INDEX"""
|
||
sql = self._read_sql("p2_decision_fusion_columns_rollback.sql")
|
||
|
||
assert "DROP COLUMN" in sql
|
||
assert "composite_score" in sql
|
||
assert "complexity_tier" in sql
|
||
assert "decision_fusion_details" in sql
|
||
assert "DROP INDEX" in sql
|
||
assert "ix_approval_composite_score" in sql
|
||
assert "ix_approval_complexity_tier" in sql
|
||
|
||
def test_migration_has_transaction_boundary(self):
|
||
"""migration SQL 必須有 BEGIN/COMMIT 包住 DDL"""
|
||
sql = self._read_sql("p2_decision_fusion_columns.sql")
|
||
assert re.search(r"\bBEGIN\b", sql)
|
||
assert re.search(r"\bCOMMIT\b", sql)
|
||
|
||
def test_check_constraint_values_match_orm(self):
|
||
"""CHECK constraint 的合法值必須與 ORM complexity_tier String(16) 一致"""
|
||
sql = self._read_sql("p2_decision_fusion_columns.sql")
|
||
# 四個 tier 都要出現在 CHECK constraint 中
|
||
for tier in ("low", "medium", "high", "critical"):
|
||
assert tier in sql, f"Missing tier '{tier}' in CHECK constraint"
|
||
|
||
|
||
# =============================================================================
|
||
# P0.3 — approval_db.update_decision_fusion 方法驗收
|
||
# =============================================================================
|
||
|
||
|
||
class TestApprovalDbUpdateDecisionFusion:
|
||
"""P0.3: update_decision_fusion 必須以 incident_id + PENDING status 為條件更新"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_called_with_correct_values(self):
|
||
"""update_decision_fusion 呼叫 → UPDATE approval_records 含正確欄位"""
|
||
from src.services.approval_db import ApprovalDBService
|
||
|
||
mock_result = MagicMock()
|
||
mock_result.rowcount = 1
|
||
|
||
mock_db = AsyncMock()
|
||
mock_db.execute = AsyncMock(return_value=mock_result)
|
||
mock_db.commit = AsyncMock() # get_db_context autocommit
|
||
|
||
with patch("src.services.approval_db.get_db_context") as mock_ctx:
|
||
mock_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db)
|
||
mock_ctx.return_value.__aexit__ = AsyncMock(return_value=False)
|
||
|
||
svc = ApprovalDBService()
|
||
rowcount = await svc.update_decision_fusion(
|
||
incident_id="INC-20260426-001",
|
||
composite_score=0.82,
|
||
complexity_tier="medium",
|
||
fusion_details={"composite": 0.82, "openclaw": 0.85},
|
||
)
|
||
|
||
assert rowcount == 1
|
||
mock_db.execute.assert_called_once()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_returns_zero_when_no_pending(self):
|
||
"""找不到 PENDING approval → rowcount=0(不拋例外)"""
|
||
from src.services.approval_db import ApprovalDBService
|
||
|
||
mock_result = MagicMock()
|
||
mock_result.rowcount = 0
|
||
|
||
mock_db = AsyncMock()
|
||
mock_db.execute = AsyncMock(return_value=mock_result)
|
||
|
||
with patch("src.services.approval_db.get_db_context") as mock_ctx:
|
||
mock_ctx.return_value.__aenter__ = AsyncMock(return_value=mock_db)
|
||
mock_ctx.return_value.__aexit__ = AsyncMock(return_value=False)
|
||
|
||
svc = ApprovalDBService()
|
||
rowcount = await svc.update_decision_fusion(
|
||
incident_id="INC-NONEXISTENT",
|
||
composite_score=0.5,
|
||
complexity_tier="low",
|
||
fusion_details={},
|
||
)
|
||
|
||
assert rowcount == 0
|