From ff30c61c4c634effb073de253264597ff6d98d78 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 12 May 2026 19:55:13 +0800 Subject: [PATCH] =?UTF-8?q?fix(rls):=20=E6=94=B6=E6=96=82=20API=20DB=20acc?= =?UTF-8?q?ess=20context?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/api/src/core/unit_of_work.py | 20 ++- apps/api/src/db/base.py | 7 +- apps/api/src/jobs/kb_rot_cleaner.py | 8 +- apps/api/src/jobs/knowledge_decay_job.py | 5 +- apps/api/src/jobs/offline_replay_service.py | 8 +- apps/api/src/services/ai_router.py | 5 +- apps/api/src/services/ai_slo_calculator.py | 6 +- apps/api/src/services/decision_manager.py | 12 +- .../src/services/dynamic_baseline_service.py | 10 +- apps/api/src/services/finetune_exporter.py | 8 +- apps/api/src/services/log_anomaly_detector.py | 5 +- apps/api/src/services/trust_drift_detector.py | 6 +- apps/api/src/workers/aider_event_processor.py | 28 +-- docs/LOGBOOK.md | 42 +++++ docs/runbooks/AWOOOP-RLS-ACCESS-AUDIT.md | 49 ++++++ scripts/ops/awooop-rls-access-audit.py | 163 ++++++++++++++++++ 16 files changed, 327 insertions(+), 55 deletions(-) create mode 100644 docs/runbooks/AWOOOP-RLS-ACCESS-AUDIT.md create mode 100755 scripts/ops/awooop-rls-access-audit.py diff --git a/apps/api/src/core/unit_of_work.py b/apps/api/src/core/unit_of_work.py index 4674db3a..a192375e 100644 --- a/apps/api/src/core/unit_of_work.py +++ b/apps/api/src/core/unit_of_work.py @@ -17,6 +17,7 @@ PostgreSQL 事務管理器,確保多表操作原子性。 from typing import Any import structlog +from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker logger = structlog.get_logger(__name__) @@ -49,14 +50,20 @@ class UnitOfWork: - Redis 操作失敗時必須手動呼叫 rollback() """ - def __init__(self, session_factory: async_sessionmaker[AsyncSession]): + def __init__( + self, + session_factory: async_sessionmaker[AsyncSession], + project_id: str | None = None, + ): """ 初始化 UnitOfWork Args: session_factory: SQLAlchemy async session factory + project_id: RLS project context. None means contextvar/default awoooi. """ self._session_factory = session_factory + self._project_id = project_id self._session: AsyncSession | None = None self._committed = False @@ -74,9 +81,18 @@ class UnitOfWork: async def __aenter__(self) -> "UnitOfWork": """進入事務""" + from src.core.context import get_current_project_id + self._session = self._session_factory() + effective_pid = ( + self._project_id if self._project_id is not None else get_current_project_id() + ) + await self._session.execute( + text("SELECT set_config('app.project_id', :pid, TRUE)"), + {"pid": effective_pid}, + ) self._committed = False - logger.debug("uow_started") + logger.debug("uow_started", project_id=effective_pid) return self async def __aexit__( diff --git a/apps/api/src/db/base.py b/apps/api/src/db/base.py index b02b7660..108aeab0 100644 --- a/apps/api/src/db/base.py +++ b/apps/api/src/db/base.py @@ -106,10 +106,13 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]: factory = get_session_factory() async with factory() as session: try: + from src.core.context import get_current_project_id + # AwoooP Phase 2.3 (2026-05-04 ogt): SET LOCAL app.project_id 讓 RLS Policy 生效 - # 預設 'awoooi',多租戶路由將在 middleware 注入實際 project_id + # 預設 'awoooi',多租戶路由將透過 contextvar 注入實際 project_id await session.execute( - text("SELECT set_config('app.project_id', 'awoooi', TRUE)") + text("SELECT set_config('app.project_id', :pid, TRUE)"), + {"pid": get_current_project_id()}, ) yield session await session.commit() diff --git a/apps/api/src/jobs/kb_rot_cleaner.py b/apps/api/src/jobs/kb_rot_cleaner.py index 349e56f9..c1f85bed 100644 --- a/apps/api/src/jobs/kb_rot_cleaner.py +++ b/apps/api/src/jobs/kb_rot_cleaner.py @@ -28,7 +28,7 @@ from datetime import timedelta import structlog from sqlalchemy import select, update -from src.db.base import get_session_factory +from src.db.base import get_db_context from src.db.models import AiGovernanceEvent, KnowledgeEntryRecord from src.utils.timezone import now_taipei @@ -129,7 +129,7 @@ class KbRotCleaner: rot_reasons: dict[str, list[str]] = {} total = 0 - async with get_session_factory()() as session: + async with get_db_context() as session: # 只掃 active 狀態(非 archived) q = await session.execute( select(KnowledgeEntryRecord).where( @@ -193,7 +193,7 @@ class KbRotCleaner: if not result.stale_ids: return - async with get_session_factory()() as session: + async with get_db_context() as session: # 逐條更新(避免 bulk update 覆蓋 tags JSONB) q = await session.execute( select(KnowledgeEntryRecord).where( @@ -220,7 +220,7 @@ class KbRotCleaner: async def _save_event(self, result: RotScanResult) -> None: """寫 kb_stale 事件到 ai_governance_events。""" try: - async with get_session_factory()() as session: + async with get_db_context() as session: event = AiGovernanceEvent( event_type="kb_stale", details=result.to_dict(), diff --git a/apps/api/src/jobs/knowledge_decay_job.py b/apps/api/src/jobs/knowledge_decay_job.py index f002074e..a3b9cfc1 100644 --- a/apps/api/src/jobs/knowledge_decay_job.py +++ b/apps/api/src/jobs/knowledge_decay_job.py @@ -33,7 +33,7 @@ from datetime import timedelta import structlog from sqlalchemy import and_, select, update -from src.db.base import get_session_factory +from src.db.base import get_db_context from src.db.models import KnowledgeEntryRecord from src.models.knowledge import EntryStatus from src.utils.timezone import now_taipei @@ -112,8 +112,7 @@ class KnowledgeDecayJob: cutoff = now_taipei() - timedelta(days=DECAY_AGE_DAYS) decayable_statuses = [EntryStatus.DRAFT.value, EntryStatus.REVIEW.value] - session_factory = get_session_factory() - async with session_factory() as db: + async with get_db_context() as db: # 查:30 天未引用(view_count=0)且 updated_at < cutoff 的 draft/review 條目 stmt = select(KnowledgeEntryRecord).where( and_( diff --git a/apps/api/src/jobs/offline_replay_service.py b/apps/api/src/jobs/offline_replay_service.py index 700234dc..127581e6 100644 --- a/apps/api/src/jobs/offline_replay_service.py +++ b/apps/api/src/jobs/offline_replay_service.py @@ -29,7 +29,7 @@ from datetime import timedelta import structlog from sqlalchemy import and_, select -from src.db.base import get_session_factory +from src.db.base import get_db_context from src.db.models import AgentSession, AiGovernanceEvent, AutoRepairExecution, IncidentEvidence from src.utils.timezone import now_taipei @@ -109,9 +109,7 @@ class OfflineReplayService: async def _run_replay(self) -> OfflineReplayReport: cutoff = now_taipei() - timedelta(days=REPLAY_LOOKBACK_DAYS) - session_factory = get_session_factory() - - async with session_factory() as db: + async with get_db_context() as db: # 1. 取最近 N 個有 AgentSession(coordinator) 的 Incident stmt = ( select(AgentSession.incident_id) @@ -137,7 +135,7 @@ class OfflineReplayService: ) results: list[IncidentReplayResult] = [] - async with session_factory() as db: + async with get_db_context() as db: for incident_id in incident_ids: r = await self._replay_one(db, incident_id) results.append(r) diff --git a/apps/api/src/services/ai_router.py b/apps/api/src/services/ai_router.py index ce9594cb..67b9a15c 100644 --- a/apps/api/src/services/ai_router.py +++ b/apps/api/src/services/ai_router.py @@ -842,14 +842,13 @@ class AIRouter: 空 dict 代表無資料或查詢失敗(caller 應降級為忽略)。 """ try: - from src.db.base import get_session_factory + from src.db.base import get_db_context from src.repositories.aider_event_repository import AiderEventRepository except ImportError: return {} try: - sf = get_session_factory() - async with sf() as sess: + async with get_db_context() as sess: repo_obj = AiderEventRepository(sess) stats = await repo_obj.model_stats_since(days=days) except Exception: diff --git a/apps/api/src/services/ai_slo_calculator.py b/apps/api/src/services/ai_slo_calculator.py index 84b39e40..aab4c5c6 100644 --- a/apps/api/src/services/ai_slo_calculator.py +++ b/apps/api/src/services/ai_slo_calculator.py @@ -28,7 +28,7 @@ from datetime import timedelta import structlog from sqlalchemy import func, select, text -from src.db.base import get_session_factory +from src.db.base import get_db_context from src.db.models import AiGovernanceEvent, AutoRepairExecution, ApprovalRecord from src.utils.timezone import now_taipei @@ -127,7 +127,7 @@ class AiSloCalculator: try: since = now_taipei() - timedelta(days=SLO_WINDOW_DAYS) - async with get_session_factory()() as session: + async with get_db_context() as session: slo1 = await self._calc_auto_success_rate(session, since) slo2 = await self._calc_human_override_rate(session, since) slo3 = await self._calc_false_neg_rate(session, since) @@ -210,7 +210,7 @@ class AiSloCalculator: 只在 any_violated=True 時呼叫。不管舊違反是否解決。 """ try: - async with get_session_factory()() as session: + async with get_db_context() as session: event = AiGovernanceEvent( event_type="slo_violation", details=report.to_dict(), diff --git a/apps/api/src/services/decision_manager.py b/apps/api/src/services/decision_manager.py index acc0d60d..c77b0af9 100644 --- a/apps/api/src/services/decision_manager.py +++ b/apps/api/src/services/decision_manager.py @@ -1933,14 +1933,14 @@ class DecisionManager: try: from src.core.feature_flags import aiops_flags as _p6_flags if _p6_flags.is_sub_flag_enabled("AIOPS_P6_SELF_DEMOTION"): - from src.db.base import get_session_factory as _p6_sf + from src.db.base import get_db_context as _p6_db_context from src.db.models import AiGovernanceEvent as _GovernanceEvent from sqlalchemy import select as _p6_select, func as _p6_func from datetime import timedelta as _p6_td _now = __import__("src.utils.timezone", fromlist=["now_taipei"]).now_taipei() - async with _p6_sf()() as _p6_sess: + async with _p6_db_context() as _p6_sess: # 過去 7 天有幾筆未解決的 slo_violation? _viol_7d_q = await _p6_sess.execute( _p6_select(_p6_func.count()).where( @@ -1980,8 +1980,8 @@ class DecisionManager: ) # 記錄保守模式事件 try: - from src.db.base import get_session_factory as _p6_sf2 - async with _p6_sf2()() as _s2: + from src.db.base import get_db_context as _p6_db_context2 + async with _p6_db_context2() as _s2: _s2.add(_GovernanceEvent( event_type="conservative_mode", details={ @@ -2021,8 +2021,8 @@ class DecisionManager: _push_decision_to_telegram(incident, token.proposal_data) ) try: - from src.db.base import get_session_factory as _p6_sf3 - async with _p6_sf3()() as _s3: + from src.db.base import get_db_context as _p6_db_context3 + async with _p6_db_context3() as _s3: _s3.add(_GovernanceEvent( event_type="self_demotion", details={ diff --git a/apps/api/src/services/dynamic_baseline_service.py b/apps/api/src/services/dynamic_baseline_service.py index 4cb110df..012572a2 100644 --- a/apps/api/src/services/dynamic_baseline_service.py +++ b/apps/api/src/services/dynamic_baseline_service.py @@ -424,11 +424,10 @@ class DynamicBaselineService: async def _pg_upsert_baseline(self, state: BaselineState, promql: str, lookback_hours: int) -> None: """寫入 DynamicBaselineRecord 到 PostgreSQL(INSERT,不更新舊記錄)""" try: - from src.db.base import get_session_factory + from src.db.base import get_db_context from src.db.models import DynamicBaselineRecord - factory = get_session_factory() - async with factory() as session: + async with get_db_context() as session: record = DynamicBaselineRecord( metric_name=state.metric_name, mean=state.mean, @@ -449,11 +448,10 @@ class DynamicBaselineService: try: from sqlalchemy import select - from src.db.base import get_session_factory + from src.db.base import get_db_context from src.db.models import DynamicBaselineRecord - factory = get_session_factory() - async with factory() as session: + async with get_db_context() as session: stmt = ( select(DynamicBaselineRecord) .where(DynamicBaselineRecord.metric_name == metric_name) diff --git a/apps/api/src/services/finetune_exporter.py b/apps/api/src/services/finetune_exporter.py index 79db14ca..cb4afe71 100644 --- a/apps/api/src/services/finetune_exporter.py +++ b/apps/api/src/services/finetune_exporter.py @@ -52,7 +52,7 @@ from pathlib import Path import structlog from sqlalchemy import and_, select, text as sql_text -from src.db.base import get_session_factory +from src.db.base import get_db_context from src.db.models import AgentSession, AutoRepairExecution, IncidentEvidence from src.utils.timezone import now_taipei @@ -107,9 +107,7 @@ class FineTuneExporter: async def _run_export(self) -> tuple[str | None, int]: cutoff = now_taipei() - timedelta(days=EXPORT_LOOKBACK_DAYS) - session_factory = get_session_factory() - - async with session_factory() as db: + async with get_db_context() as db: # 1. 取得成功驗證的 EvidenceSnapshot(有 evidence_summary + verification_result='success') stmt = select(IncidentEvidence).where( and_( @@ -153,7 +151,7 @@ class FineTuneExporter: with open(output_path, 'rb') as _f: _checksum = hashlib.sha256(_f.read()).hexdigest() _ids = [str(ev.id) for ev in evidences] - async with session_factory() as _db: + async with get_db_context() as _db: await _db.execute( sql_text(""" INSERT INTO finetune_exports ( diff --git a/apps/api/src/services/log_anomaly_detector.py b/apps/api/src/services/log_anomaly_detector.py index bdbdd71c..f6365df1 100644 --- a/apps/api/src/services/log_anomaly_detector.py +++ b/apps/api/src/services/log_anomaly_detector.py @@ -296,12 +296,11 @@ class LogAnomalyDetector: """ try: from sqlalchemy.dialects.postgresql import insert as pg_insert - from src.db.base import get_session_factory + from src.db.base import get_db_context from src.db.models import LogClusterRecord from src.utils.timezone import now_taipei - factory = get_session_factory() - async with factory() as session: + async with get_db_context() as session: stmt = pg_insert(LogClusterRecord).values( cluster_id=cluster.cluster_id, template=cluster.template, diff --git a/apps/api/src/services/trust_drift_detector.py b/apps/api/src/services/trust_drift_detector.py index 4ad54fea..19a15c68 100644 --- a/apps/api/src/services/trust_drift_detector.py +++ b/apps/api/src/services/trust_drift_detector.py @@ -39,7 +39,7 @@ from dataclasses import dataclass import structlog from sqlalchemy import func, select -from src.db.base import get_session_factory +from src.db.base import get_db_context from src.db.models import AiGovernanceEvent, PlaybookRecord from src.utils.timezone import now_taipei @@ -115,7 +115,7 @@ class TrustDriftDetector: TrustDistribution(樣本不足時 drift_detected=False) """ try: - async with get_session_factory()() as session: + async with get_db_context() as session: # 只計算 approved 狀態的 Playbook total_q = await session.execute( select(func.count()).where( @@ -215,7 +215,7 @@ class TrustDriftDetector: async def save_drift_event(self, dist: TrustDistribution) -> None: """將信任度漂移事件寫入 ai_governance_events。""" try: - async with get_session_factory()() as session: + async with get_db_context() as session: event = AiGovernanceEvent( event_type="trust_drift", details={ diff --git a/apps/api/src/workers/aider_event_processor.py b/apps/api/src/workers/aider_event_processor.py index 6754755c..c4b95240 100644 --- a/apps/api/src/workers/aider_event_processor.py +++ b/apps/api/src/workers/aider_event_processor.py @@ -21,7 +21,7 @@ from typing import Any import structlog from src.core.redis_client import get_redis, get_worker_redis, init_worker_redis_pool -from src.db.base import get_session_factory +from src.db.base import get_db_context from src.models.aider import AiderEventIn from src.repositories.aider_event_repository import AiderEventRepository from src.services.aider_event_service import build_signal_data, should_create_incident @@ -123,7 +123,7 @@ class AiderEventProcessor: self, stream_key: str, msg_id: Any, data: dict, _session_factory=None ) -> None: """處理單筆 message:parse → (maybe) incident → DB write → ACK。 - _session_factory: 可注入測試用 factory,預設使用 get_session_factory()。 + _session_factory: 可注入測試用 factory;production 預設使用 get_db_context() 設定 RLS context。 """ try: raw = data.get(b"payload") or data.get("payload") @@ -151,14 +151,22 @@ class AiderEventProcessor: # 不中斷 — 即使 incident 失敗,event 仍要持久化 try: - session_factory = _session_factory or get_session_factory() - async with session_factory() as session: - repo = AiderEventRepository(session) - await repo.insert( - session_id=ev.session_id, ts=ev.ts, type_=ev.type, - host=ev.host, payload=ev.payload, incident_id=incident_id, - ) - await session.commit() + if _session_factory is None: + async with get_db_context() as session: + repo = AiderEventRepository(session) + await repo.insert( + session_id=ev.session_id, ts=ev.ts, type_=ev.type, + host=ev.host, payload=ev.payload, incident_id=incident_id, + ) + else: + session_factory = _session_factory + async with session_factory() as session: + repo = AiderEventRepository(session) + await repo.insert( + session_id=ev.session_id, ts=ev.ts, type_=ev.type, + host=ev.host, payload=ev.payload, incident_id=incident_id, + ) + await session.commit() except Exception: logger.exception("aider_processor_db_write_failed", session_id=ev.session_id) diff --git a/docs/LOGBOOK.md b/docs/LOGBOOK.md index aab98997..2134d451 100644 --- a/docs/LOGBOOK.md +++ b/docs/LOGBOOK.md @@ -1,3 +1,45 @@ +## 2026-05-12 | RLS Access Path Audit 收斂 + +**背景**:RLS role bootstrap 已完成後,下一個 gate 是確認 API runtime DB access 都會設定 `app.project_id`;否則一旦 fail-closed policy 上線,直接 session factory 入口會讀不到資料或寫入失敗。 + +**runtime 修補**: +- `get_db()`: + - 和 `get_db_context()` 對齊,改讀 `src.core.context.get_current_project_id()` 並以 bind parameter 設定 `app.project_id`。 +- `UnitOfWork`: + - `__aenter__` 會讀 `src.core.context.get_current_project_id()`,並執行 `SELECT set_config('app.project_id', :pid, TRUE)`。 + - `IncidentApprovalService` 繼續注入 session factory,但經由 `UnitOfWork` 進入 RLS-safe path。 +- 將 production runtime 直接 `get_session_factory()` call sites 改為 `get_db_context()`: + - `apps/api/src/jobs/kb_rot_cleaner.py` + - `apps/api/src/jobs/knowledge_decay_job.py` + - `apps/api/src/jobs/offline_replay_service.py` + - `apps/api/src/services/ai_router.py` + - `apps/api/src/services/ai_slo_calculator.py` + - `apps/api/src/services/decision_manager.py` + - `apps/api/src/services/dynamic_baseline_service.py` + - `apps/api/src/services/finetune_exporter.py` + - `apps/api/src/services/log_anomaly_detector.py` + - `apps/api/src/services/trust_drift_detector.py` + - `apps/api/src/workers/aider_event_processor.py` +- `aider_event_processor` 的 production path 改走 `get_db_context()`;測試注入 `_session_factory` 時仍保留測試隔離。 + +**新增 audit gate**: +- `scripts/ops/awooop-rls-access-audit.py`: + - static 掃描 `apps/api/src` runtime 中的 `get_session_factory()`、`create_async_engine()`、`asyncpg.connect()`、`settings.DATABASE_URL`。 + - 只允許 engine owner、health `SELECT 1`、sanitized log、UnitOfWork injection 等明確例外;allowlist 以 path/rule/text pattern 判斷,避免行號漂移造成誤報。 + - exit `2` 表示還有 runtime blocker。 +- `docs/runbooks/AWOOOP-RLS-ACCESS-AUDIT.md` 記錄 gate 與例外。 + +**驗證**: +- `python3 scripts/ops/awooop-rls-access-audit.py --show-allowed` → `BLOCKED=0 ALLOW=10`。 +- `python3 -m py_compile` 對修改過的 runtime 檔與 audit script → passed。 +- `scripts/ops/awooop-rls-preflight.sh --exact-counts` → 仍為 `PASS=7 WARN=0 BLOCKED=1`;唯一 blocker 仍是尚未啟用 RLS policy,符合預期。 +- Production health `/api/v1/health` → 200 healthy。 +- 嘗試跑 `python3 -m pytest ...`,但本機 `/usr/bin/python3` 無 `pytest`,且 repo 內未找到可用 venv;本輪未安裝依賴,改以 compile/static/live smoke 驗證。 + +**下一步**: +- 針對 manual scripts (`apps/api/scripts/`、top-level `scripts/`) 補 operator review policy;它們不是 API runtime,但 RLS policy 上線後若直接拿 `DATABASE_URL` 操作 tenant tables,仍需明確 `SET LOCAL app.project_id` 或用 migration/operator role。 +- 產出第一批 staged policy enablement SQL,先從空表 / 低流量 AwoooP tables canary,不從 incidents / knowledge_entries 開始。 + ## 2026-05-12 | RLS Role Bootstrap 已套用 **背景**:上一輪已新增 `scripts/ops/awooop-rls-role-bootstrap.sql`,但尚未執行;使用者批准後,本輪只執行 role bootstrap,不啟用 RLS policy。 diff --git a/docs/runbooks/AWOOOP-RLS-ACCESS-AUDIT.md b/docs/runbooks/AWOOOP-RLS-ACCESS-AUDIT.md new file mode 100644 index 00000000..52a9d4f1 --- /dev/null +++ b/docs/runbooks/AWOOOP-RLS-ACCESS-AUDIT.md @@ -0,0 +1,49 @@ +# AwoooP RLS Access Path Audit + +> Purpose: verify API runtime DB access paths are ready for fail-closed RLS. + +Before enabling RLS policies, runtime database access must set +`app.project_id`. The approved paths are: + +- FastAPI dependency `get_db()`. +- Background/service context `get_db_context()`. +- `UnitOfWork`, which now sets `app.project_id` on entry. + +Both `get_db()` and `get_db_context()` derive `app.project_id` from +`src.core.context.get_current_project_id()` unless the caller passes an explicit +project id to `get_db_context()`. + +Run: + +```bash +python3 scripts/ops/awooop-rls-access-audit.py +``` + +To include accepted exceptions: + +```bash +python3 scripts/ops/awooop-rls-access-audit.py --show-allowed +``` + +## 2026-05-12 Result + +After fixing direct `get_session_factory()` runtime call sites: + +```text +AwoooP RLS access audit: BLOCKED=0 ALLOW=10 +``` + +Accepted exceptions: + +- `apps/api/src/db/base.py`: owns the shared engine/session factory and sets + `app.project_id` in `get_db()` / `get_db_context()`. +- `apps/api/src/routes/health.py`: raw `asyncpg` health check only runs + `SELECT 1`, not tenant table queries. +- `apps/api/src/main.py` and `apps/api/src/workers/signal_worker.py`: only log a + sanitized DB host suffix. +- `apps/api/src/services/incident_approval_service.py`: injects + `UnitOfWork`; `UnitOfWork` now sets `app.project_id`. + +Manual scripts under `apps/api/scripts/` and top-level `scripts/` are not API +runtime. They still need operator review before being used against production +after RLS policy enablement. diff --git a/scripts/ops/awooop-rls-access-audit.py b/scripts/ops/awooop-rls-access-audit.py new file mode 100755 index 00000000..61e359f2 --- /dev/null +++ b/scripts/ops/awooop-rls-access-audit.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +"""Static RLS access-path audit for AWOOOI API runtime code. + +The goal is narrow: find production runtime DB access that may bypass +get_db()/get_db_context() and therefore miss SET LOCAL app.project_id. +It is intentionally conservative and read-only. +""" + +from __future__ import annotations + +import argparse +import re +from dataclasses import dataclass +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[2] +SRC_ROOT = ROOT / "apps/api/src" + + +@dataclass(frozen=True) +class Finding: + severity: str + path: Path + line: int + rule: str + text: str + reason: str + + +@dataclass(frozen=True) +class AllowRule: + path: str + rule: str + text_pattern: re.Pattern[str] + reason: str + + +RULES: list[tuple[str, re.Pattern[str]]] = [ + ("session_factory", re.compile(r"\bget_session_factory\s*\(")), + ("create_async_engine", re.compile(r"\bcreate_async_engine\s*\(")), + ("asyncpg_connect", re.compile(r"\basyncpg\.connect\s*\(")), + ("settings_database_url", re.compile(r"\bsettings\.DATABASE_URL\b")), + ("env_database_url", re.compile(r"os\.environ(?:\.get)?\([\"']DATABASE_URL[\"']|os\.environ\[[\"']DATABASE_URL[\"']")), +] + + +ALLOW_RULES: tuple[AllowRule, ...] = ( + AllowRule( + "apps/api/src/db/base.py", + "settings_database_url", + re.compile(r"\bdatabase_url\s*=\s*settings\.DATABASE_URL\b"), + "DB engine owner reads DATABASE_URL and sets RLS context in get_db/get_db_context.", + ), + AllowRule( + "apps/api/src/db/base.py", + "create_async_engine", + re.compile(r"\b_engine\s*=\s*create_async_engine\("), + "DB engine owner creates the shared async engine.", + ), + AllowRule( + "apps/api/src/db/base.py", + "session_factory", + re.compile(r"\bdef\s+get_session_factory\("), + "Factory definition, not a call-site bypass.", + ), + AllowRule( + "apps/api/src/db/base.py", + "session_factory", + re.compile(r"\bfactory\s*=\s*get_session_factory\(\)"), + "get_db/get_db_context wrap factory and set app.project_id.", + ), + AllowRule( + "apps/api/src/routes/health.py", + "settings_database_url", + re.compile(r"\bdb_url\s*=\s*settings\.DATABASE_URL\.replace\("), + "Health check parses DATABASE_URL for SELECT 1 only.", + ), + AllowRule( + "apps/api/src/routes/health.py", + "asyncpg_connect", + re.compile(r"\basyncpg\.connect\(db_url\)"), + "Health check raw asyncpg SELECT 1 does not read tenant tables.", + ), + AllowRule( + "apps/api/src/main.py", + "settings_database_url", + re.compile(r"\bdb_url\s*=\s*settings\.DATABASE_URL\b"), + "Startup logs sanitized DB host suffix after init_db.", + ), + AllowRule( + "apps/api/src/workers/signal_worker.py", + "settings_database_url", + re.compile(r"\bdatabase_url=settings\.DATABASE_URL\.split\(\"@\"\)\[-1\]"), + "Structured log uses redacted DATABASE_URL suffix only.", + ), + AllowRule( + "apps/api/src/services/incident_approval_service.py", + "session_factory", + re.compile(r"\bsession_factory=get_session_factory\(\),"), + "IncidentApprovalService injects UnitOfWork; UnitOfWork now sets app.project_id.", + ), +) + + +def classify(path: Path, rule: str, line_text: str) -> tuple[str, str]: + rel = path.relative_to(ROOT).as_posix() + for allow in ALLOW_RULES: + if allow.path == rel and allow.rule == rule and allow.text_pattern.search(line_text): + return "ALLOW", allow.reason + return "BLOCKED", "Runtime DB access must set app.project_id through get_db/get_db_context or UnitOfWork." + + +def scan() -> list[Finding]: + findings: list[Finding] = [] + for path in sorted(SRC_ROOT.rglob("*.py")): + try: + lines = path.read_text(encoding="utf-8").splitlines() + except UnicodeDecodeError: + lines = path.read_text(errors="replace").splitlines() + + for idx, line in enumerate(lines, start=1): + for rule, pattern in RULES: + if not pattern.search(line): + continue + severity, reason = classify(path, rule, line) + findings.append( + Finding( + severity=severity, + path=path.relative_to(ROOT), + line=idx, + rule=rule, + text=line.strip(), + reason=reason, + ) + ) + return findings + + +def main() -> int: + parser = argparse.ArgumentParser(description="Audit API runtime DB access paths for RLS readiness.") + parser.add_argument("--show-allowed", action="store_true", help="Print allowed findings too.") + args = parser.parse_args() + + findings = scan() + blocked = [item for item in findings if item.severity == "BLOCKED"] + allowed = [item for item in findings if item.severity == "ALLOW"] + + print(f"AwoooP RLS access audit: BLOCKED={len(blocked)} ALLOW={len(allowed)}") + for item in blocked: + print(f"{item.severity} {item.path}:{item.line} [{item.rule}] {item.text}") + print(f" reason: {item.reason}") + + if args.show_allowed: + for item in allowed: + print(f"{item.severity} {item.path}:{item.line} [{item.rule}] {item.text}") + print(f" reason: {item.reason}") + + return 2 if blocked else 0 + + +if __name__ == "__main__": + raise SystemExit(main())