fix(rls): 收斂 API DB access context
This commit is contained in:
@@ -17,6 +17,7 @@ PostgreSQL 事務管理器,確保多表操作原子性。
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
@@ -49,14 +50,20 @@ class UnitOfWork:
|
||||
- Redis 操作失敗時必須手動呼叫 rollback()
|
||||
"""
|
||||
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]):
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
project_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
初始化 UnitOfWork
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy async session factory
|
||||
project_id: RLS project context. None means contextvar/default awoooi.
|
||||
"""
|
||||
self._session_factory = session_factory
|
||||
self._project_id = project_id
|
||||
self._session: AsyncSession | None = None
|
||||
self._committed = False
|
||||
|
||||
@@ -74,9 +81,18 @@ class UnitOfWork:
|
||||
|
||||
async def __aenter__(self) -> "UnitOfWork":
|
||||
"""進入事務"""
|
||||
from src.core.context import get_current_project_id
|
||||
|
||||
self._session = self._session_factory()
|
||||
effective_pid = (
|
||||
self._project_id if self._project_id is not None else get_current_project_id()
|
||||
)
|
||||
await self._session.execute(
|
||||
text("SELECT set_config('app.project_id', :pid, TRUE)"),
|
||||
{"pid": effective_pid},
|
||||
)
|
||||
self._committed = False
|
||||
logger.debug("uow_started")
|
||||
logger.debug("uow_started", project_id=effective_pid)
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
|
||||
@@ -106,10 +106,13 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
factory = get_session_factory()
|
||||
async with factory() as session:
|
||||
try:
|
||||
from src.core.context import get_current_project_id
|
||||
|
||||
# AwoooP Phase 2.3 (2026-05-04 ogt): SET LOCAL app.project_id 讓 RLS Policy 生效
|
||||
# 預設 'awoooi',多租戶路由將在 middleware 注入實際 project_id
|
||||
# 預設 'awoooi',多租戶路由將透過 contextvar 注入實際 project_id
|
||||
await session.execute(
|
||||
text("SELECT set_config('app.project_id', 'awoooi', TRUE)")
|
||||
text("SELECT set_config('app.project_id', :pid, TRUE)"),
|
||||
{"pid": get_current_project_id()},
|
||||
)
|
||||
yield session
|
||||
await session.commit()
|
||||
|
||||
@@ -28,7 +28,7 @@ from datetime import timedelta
|
||||
import structlog
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from src.db.base import get_session_factory
|
||||
from src.db.base import get_db_context
|
||||
from src.db.models import AiGovernanceEvent, KnowledgeEntryRecord
|
||||
from src.utils.timezone import now_taipei
|
||||
|
||||
@@ -129,7 +129,7 @@ class KbRotCleaner:
|
||||
rot_reasons: dict[str, list[str]] = {}
|
||||
total = 0
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
async with get_db_context() as session:
|
||||
# 只掃 active 狀態(非 archived)
|
||||
q = await session.execute(
|
||||
select(KnowledgeEntryRecord).where(
|
||||
@@ -193,7 +193,7 @@ class KbRotCleaner:
|
||||
if not result.stale_ids:
|
||||
return
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
async with get_db_context() as session:
|
||||
# 逐條更新(避免 bulk update 覆蓋 tags JSONB)
|
||||
q = await session.execute(
|
||||
select(KnowledgeEntryRecord).where(
|
||||
@@ -220,7 +220,7 @@ class KbRotCleaner:
|
||||
async def _save_event(self, result: RotScanResult) -> None:
|
||||
"""寫 kb_stale 事件到 ai_governance_events。"""
|
||||
try:
|
||||
async with get_session_factory()() as session:
|
||||
async with get_db_context() as session:
|
||||
event = AiGovernanceEvent(
|
||||
event_type="kb_stale",
|
||||
details=result.to_dict(),
|
||||
|
||||
@@ -33,7 +33,7 @@ from datetime import timedelta
|
||||
import structlog
|
||||
from sqlalchemy import and_, select, update
|
||||
|
||||
from src.db.base import get_session_factory
|
||||
from src.db.base import get_db_context
|
||||
from src.db.models import KnowledgeEntryRecord
|
||||
from src.models.knowledge import EntryStatus
|
||||
from src.utils.timezone import now_taipei
|
||||
@@ -112,8 +112,7 @@ class KnowledgeDecayJob:
|
||||
cutoff = now_taipei() - timedelta(days=DECAY_AGE_DAYS)
|
||||
decayable_statuses = [EntryStatus.DRAFT.value, EntryStatus.REVIEW.value]
|
||||
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as db:
|
||||
async with get_db_context() as db:
|
||||
# 查:30 天未引用(view_count=0)且 updated_at < cutoff 的 draft/review 條目
|
||||
stmt = select(KnowledgeEntryRecord).where(
|
||||
and_(
|
||||
|
||||
@@ -29,7 +29,7 @@ from datetime import timedelta
|
||||
import structlog
|
||||
from sqlalchemy import and_, select
|
||||
|
||||
from src.db.base import get_session_factory
|
||||
from src.db.base import get_db_context
|
||||
from src.db.models import AgentSession, AiGovernanceEvent, AutoRepairExecution, IncidentEvidence
|
||||
from src.utils.timezone import now_taipei
|
||||
|
||||
@@ -109,9 +109,7 @@ class OfflineReplayService:
|
||||
|
||||
async def _run_replay(self) -> OfflineReplayReport:
|
||||
cutoff = now_taipei() - timedelta(days=REPLAY_LOOKBACK_DAYS)
|
||||
session_factory = get_session_factory()
|
||||
|
||||
async with session_factory() as db:
|
||||
async with get_db_context() as db:
|
||||
# 1. 取最近 N 個有 AgentSession(coordinator) 的 Incident
|
||||
stmt = (
|
||||
select(AgentSession.incident_id)
|
||||
@@ -137,7 +135,7 @@ class OfflineReplayService:
|
||||
)
|
||||
|
||||
results: list[IncidentReplayResult] = []
|
||||
async with session_factory() as db:
|
||||
async with get_db_context() as db:
|
||||
for incident_id in incident_ids:
|
||||
r = await self._replay_one(db, incident_id)
|
||||
results.append(r)
|
||||
|
||||
@@ -842,14 +842,13 @@ class AIRouter:
|
||||
空 dict 代表無資料或查詢失敗(caller 應降級為忽略)。
|
||||
"""
|
||||
try:
|
||||
from src.db.base import get_session_factory
|
||||
from src.db.base import get_db_context
|
||||
from src.repositories.aider_event_repository import AiderEventRepository
|
||||
except ImportError:
|
||||
return {}
|
||||
|
||||
try:
|
||||
sf = get_session_factory()
|
||||
async with sf() as sess:
|
||||
async with get_db_context() as sess:
|
||||
repo_obj = AiderEventRepository(sess)
|
||||
stats = await repo_obj.model_stats_since(days=days)
|
||||
except Exception:
|
||||
|
||||
@@ -28,7 +28,7 @@ from datetime import timedelta
|
||||
import structlog
|
||||
from sqlalchemy import func, select, text
|
||||
|
||||
from src.db.base import get_session_factory
|
||||
from src.db.base import get_db_context
|
||||
from src.db.models import AiGovernanceEvent, AutoRepairExecution, ApprovalRecord
|
||||
from src.utils.timezone import now_taipei
|
||||
|
||||
@@ -127,7 +127,7 @@ class AiSloCalculator:
|
||||
try:
|
||||
since = now_taipei() - timedelta(days=SLO_WINDOW_DAYS)
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
async with get_db_context() as session:
|
||||
slo1 = await self._calc_auto_success_rate(session, since)
|
||||
slo2 = await self._calc_human_override_rate(session, since)
|
||||
slo3 = await self._calc_false_neg_rate(session, since)
|
||||
@@ -210,7 +210,7 @@ class AiSloCalculator:
|
||||
只在 any_violated=True 時呼叫。不管舊違反是否解決。
|
||||
"""
|
||||
try:
|
||||
async with get_session_factory()() as session:
|
||||
async with get_db_context() as session:
|
||||
event = AiGovernanceEvent(
|
||||
event_type="slo_violation",
|
||||
details=report.to_dict(),
|
||||
|
||||
@@ -1933,14 +1933,14 @@ class DecisionManager:
|
||||
try:
|
||||
from src.core.feature_flags import aiops_flags as _p6_flags
|
||||
if _p6_flags.is_sub_flag_enabled("AIOPS_P6_SELF_DEMOTION"):
|
||||
from src.db.base import get_session_factory as _p6_sf
|
||||
from src.db.base import get_db_context as _p6_db_context
|
||||
from src.db.models import AiGovernanceEvent as _GovernanceEvent
|
||||
from sqlalchemy import select as _p6_select, func as _p6_func
|
||||
from datetime import timedelta as _p6_td
|
||||
|
||||
_now = __import__("src.utils.timezone", fromlist=["now_taipei"]).now_taipei()
|
||||
|
||||
async with _p6_sf()() as _p6_sess:
|
||||
async with _p6_db_context() as _p6_sess:
|
||||
# 過去 7 天有幾筆未解決的 slo_violation?
|
||||
_viol_7d_q = await _p6_sess.execute(
|
||||
_p6_select(_p6_func.count()).where(
|
||||
@@ -1980,8 +1980,8 @@ class DecisionManager:
|
||||
)
|
||||
# 記錄保守模式事件
|
||||
try:
|
||||
from src.db.base import get_session_factory as _p6_sf2
|
||||
async with _p6_sf2()() as _s2:
|
||||
from src.db.base import get_db_context as _p6_db_context2
|
||||
async with _p6_db_context2() as _s2:
|
||||
_s2.add(_GovernanceEvent(
|
||||
event_type="conservative_mode",
|
||||
details={
|
||||
@@ -2021,8 +2021,8 @@ class DecisionManager:
|
||||
_push_decision_to_telegram(incident, token.proposal_data)
|
||||
)
|
||||
try:
|
||||
from src.db.base import get_session_factory as _p6_sf3
|
||||
async with _p6_sf3()() as _s3:
|
||||
from src.db.base import get_db_context as _p6_db_context3
|
||||
async with _p6_db_context3() as _s3:
|
||||
_s3.add(_GovernanceEvent(
|
||||
event_type="self_demotion",
|
||||
details={
|
||||
|
||||
@@ -424,11 +424,10 @@ class DynamicBaselineService:
|
||||
async def _pg_upsert_baseline(self, state: BaselineState, promql: str, lookback_hours: int) -> None:
|
||||
"""寫入 DynamicBaselineRecord 到 PostgreSQL(INSERT,不更新舊記錄)"""
|
||||
try:
|
||||
from src.db.base import get_session_factory
|
||||
from src.db.base import get_db_context
|
||||
from src.db.models import DynamicBaselineRecord
|
||||
|
||||
factory = get_session_factory()
|
||||
async with factory() as session:
|
||||
async with get_db_context() as session:
|
||||
record = DynamicBaselineRecord(
|
||||
metric_name=state.metric_name,
|
||||
mean=state.mean,
|
||||
@@ -449,11 +448,10 @@ class DynamicBaselineService:
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.db.base import get_session_factory
|
||||
from src.db.base import get_db_context
|
||||
from src.db.models import DynamicBaselineRecord
|
||||
|
||||
factory = get_session_factory()
|
||||
async with factory() as session:
|
||||
async with get_db_context() as session:
|
||||
stmt = (
|
||||
select(DynamicBaselineRecord)
|
||||
.where(DynamicBaselineRecord.metric_name == metric_name)
|
||||
|
||||
@@ -52,7 +52,7 @@ from pathlib import Path
|
||||
import structlog
|
||||
from sqlalchemy import and_, select, text as sql_text
|
||||
|
||||
from src.db.base import get_session_factory
|
||||
from src.db.base import get_db_context
|
||||
from src.db.models import AgentSession, AutoRepairExecution, IncidentEvidence
|
||||
from src.utils.timezone import now_taipei
|
||||
|
||||
@@ -107,9 +107,7 @@ class FineTuneExporter:
|
||||
|
||||
async def _run_export(self) -> tuple[str | None, int]:
|
||||
cutoff = now_taipei() - timedelta(days=EXPORT_LOOKBACK_DAYS)
|
||||
session_factory = get_session_factory()
|
||||
|
||||
async with session_factory() as db:
|
||||
async with get_db_context() as db:
|
||||
# 1. 取得成功驗證的 EvidenceSnapshot(有 evidence_summary + verification_result='success')
|
||||
stmt = select(IncidentEvidence).where(
|
||||
and_(
|
||||
@@ -153,7 +151,7 @@ class FineTuneExporter:
|
||||
with open(output_path, 'rb') as _f:
|
||||
_checksum = hashlib.sha256(_f.read()).hexdigest()
|
||||
_ids = [str(ev.id) for ev in evidences]
|
||||
async with session_factory() as _db:
|
||||
async with get_db_context() as _db:
|
||||
await _db.execute(
|
||||
sql_text("""
|
||||
INSERT INTO finetune_exports (
|
||||
|
||||
@@ -296,12 +296,11 @@ class LogAnomalyDetector:
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from src.db.base import get_session_factory
|
||||
from src.db.base import get_db_context
|
||||
from src.db.models import LogClusterRecord
|
||||
from src.utils.timezone import now_taipei
|
||||
|
||||
factory = get_session_factory()
|
||||
async with factory() as session:
|
||||
async with get_db_context() as session:
|
||||
stmt = pg_insert(LogClusterRecord).values(
|
||||
cluster_id=cluster.cluster_id,
|
||||
template=cluster.template,
|
||||
|
||||
@@ -39,7 +39,7 @@ from dataclasses import dataclass
|
||||
import structlog
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from src.db.base import get_session_factory
|
||||
from src.db.base import get_db_context
|
||||
from src.db.models import AiGovernanceEvent, PlaybookRecord
|
||||
from src.utils.timezone import now_taipei
|
||||
|
||||
@@ -115,7 +115,7 @@ class TrustDriftDetector:
|
||||
TrustDistribution(樣本不足時 drift_detected=False)
|
||||
"""
|
||||
try:
|
||||
async with get_session_factory()() as session:
|
||||
async with get_db_context() as session:
|
||||
# 只計算 approved 狀態的 Playbook
|
||||
total_q = await session.execute(
|
||||
select(func.count()).where(
|
||||
@@ -215,7 +215,7 @@ class TrustDriftDetector:
|
||||
async def save_drift_event(self, dist: TrustDistribution) -> None:
|
||||
"""將信任度漂移事件寫入 ai_governance_events。"""
|
||||
try:
|
||||
async with get_session_factory()() as session:
|
||||
async with get_db_context() as session:
|
||||
event = AiGovernanceEvent(
|
||||
event_type="trust_drift",
|
||||
details={
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Any
|
||||
import structlog
|
||||
|
||||
from src.core.redis_client import get_redis, get_worker_redis, init_worker_redis_pool
|
||||
from src.db.base import get_session_factory
|
||||
from src.db.base import get_db_context
|
||||
from src.models.aider import AiderEventIn
|
||||
from src.repositories.aider_event_repository import AiderEventRepository
|
||||
from src.services.aider_event_service import build_signal_data, should_create_incident
|
||||
@@ -123,7 +123,7 @@ class AiderEventProcessor:
|
||||
self, stream_key: str, msg_id: Any, data: dict, _session_factory=None
|
||||
) -> None:
|
||||
"""處理單筆 message:parse → (maybe) incident → DB write → ACK。
|
||||
_session_factory: 可注入測試用 factory,預設使用 get_session_factory()。
|
||||
_session_factory: 可注入測試用 factory;production 預設使用 get_db_context() 設定 RLS context。
|
||||
"""
|
||||
try:
|
||||
raw = data.get(b"payload") or data.get("payload")
|
||||
@@ -151,14 +151,22 @@ class AiderEventProcessor:
|
||||
# 不中斷 — 即使 incident 失敗,event 仍要持久化
|
||||
|
||||
try:
|
||||
session_factory = _session_factory or get_session_factory()
|
||||
async with session_factory() as session:
|
||||
repo = AiderEventRepository(session)
|
||||
await repo.insert(
|
||||
session_id=ev.session_id, ts=ev.ts, type_=ev.type,
|
||||
host=ev.host, payload=ev.payload, incident_id=incident_id,
|
||||
)
|
||||
await session.commit()
|
||||
if _session_factory is None:
|
||||
async with get_db_context() as session:
|
||||
repo = AiderEventRepository(session)
|
||||
await repo.insert(
|
||||
session_id=ev.session_id, ts=ev.ts, type_=ev.type,
|
||||
host=ev.host, payload=ev.payload, incident_id=incident_id,
|
||||
)
|
||||
else:
|
||||
session_factory = _session_factory
|
||||
async with session_factory() as session:
|
||||
repo = AiderEventRepository(session)
|
||||
await repo.insert(
|
||||
session_id=ev.session_id, ts=ev.ts, type_=ev.type,
|
||||
host=ev.host, payload=ev.payload, incident_id=incident_id,
|
||||
)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
logger.exception("aider_processor_db_write_failed",
|
||||
session_id=ev.session_id)
|
||||
|
||||
@@ -1,3 +1,45 @@
|
||||
## 2026-05-12 | RLS Access Path Audit 收斂
|
||||
|
||||
**背景**:RLS role bootstrap 已完成後,下一個 gate 是確認 API runtime DB access 都會設定 `app.project_id`;否則一旦 fail-closed policy 上線,直接 session factory 入口會讀不到資料或寫入失敗。
|
||||
|
||||
**runtime 修補**:
|
||||
- `get_db()`:
|
||||
- 和 `get_db_context()` 對齊,改讀 `src.core.context.get_current_project_id()` 並以 bind parameter 設定 `app.project_id`。
|
||||
- `UnitOfWork`:
|
||||
- `__aenter__` 會讀 `src.core.context.get_current_project_id()`,並執行 `SELECT set_config('app.project_id', :pid, TRUE)`。
|
||||
- `IncidentApprovalService` 繼續注入 session factory,但經由 `UnitOfWork` 進入 RLS-safe path。
|
||||
- 將 production runtime 直接 `get_session_factory()` call sites 改為 `get_db_context()`:
|
||||
- `apps/api/src/jobs/kb_rot_cleaner.py`
|
||||
- `apps/api/src/jobs/knowledge_decay_job.py`
|
||||
- `apps/api/src/jobs/offline_replay_service.py`
|
||||
- `apps/api/src/services/ai_router.py`
|
||||
- `apps/api/src/services/ai_slo_calculator.py`
|
||||
- `apps/api/src/services/decision_manager.py`
|
||||
- `apps/api/src/services/dynamic_baseline_service.py`
|
||||
- `apps/api/src/services/finetune_exporter.py`
|
||||
- `apps/api/src/services/log_anomaly_detector.py`
|
||||
- `apps/api/src/services/trust_drift_detector.py`
|
||||
- `apps/api/src/workers/aider_event_processor.py`
|
||||
- `aider_event_processor` 的 production path 改走 `get_db_context()`;測試注入 `_session_factory` 時仍保留測試隔離。
|
||||
|
||||
**新增 audit gate**:
|
||||
- `scripts/ops/awooop-rls-access-audit.py`:
|
||||
- static 掃描 `apps/api/src` runtime 中的 `get_session_factory()`、`create_async_engine()`、`asyncpg.connect()`、`settings.DATABASE_URL`。
|
||||
- 只允許 engine owner、health `SELECT 1`、sanitized log、UnitOfWork injection 等明確例外;allowlist 以 path/rule/text pattern 判斷,避免行號漂移造成誤報。
|
||||
- exit `2` 表示還有 runtime blocker。
|
||||
- `docs/runbooks/AWOOOP-RLS-ACCESS-AUDIT.md` 記錄 gate 與例外。
|
||||
|
||||
**驗證**:
|
||||
- `python3 scripts/ops/awooop-rls-access-audit.py --show-allowed` → `BLOCKED=0 ALLOW=10`。
|
||||
- `python3 -m py_compile` 對修改過的 runtime 檔與 audit script → passed。
|
||||
- `scripts/ops/awooop-rls-preflight.sh --exact-counts` → 仍為 `PASS=7 WARN=0 BLOCKED=1`;唯一 blocker 仍是尚未啟用 RLS policy,符合預期。
|
||||
- Production health `/api/v1/health` → 200 healthy。
|
||||
- 嘗試跑 `python3 -m pytest ...`,但本機 `/usr/bin/python3` 無 `pytest`,且 repo 內未找到可用 venv;本輪未安裝依賴,改以 compile/static/live smoke 驗證。
|
||||
|
||||
**下一步**:
|
||||
- 針對 manual scripts (`apps/api/scripts/`、top-level `scripts/`) 補 operator review policy;它們不是 API runtime,但 RLS policy 上線後若直接拿 `DATABASE_URL` 操作 tenant tables,仍需明確 `SET LOCAL app.project_id` 或用 migration/operator role。
|
||||
- 產出第一批 staged policy enablement SQL,先從空表 / 低流量 AwoooP tables canary,不從 incidents / knowledge_entries 開始。
|
||||
|
||||
## 2026-05-12 | RLS Role Bootstrap 已套用
|
||||
|
||||
**背景**:上一輪已新增 `scripts/ops/awooop-rls-role-bootstrap.sql`,但尚未執行;使用者批准後,本輪只執行 role bootstrap,不啟用 RLS policy。
|
||||
|
||||
49
docs/runbooks/AWOOOP-RLS-ACCESS-AUDIT.md
Normal file
49
docs/runbooks/AWOOOP-RLS-ACCESS-AUDIT.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# AwoooP RLS Access Path Audit
|
||||
|
||||
> Purpose: verify API runtime DB access paths are ready for fail-closed RLS.
|
||||
|
||||
Before enabling RLS policies, runtime database access must set
|
||||
`app.project_id`. The approved paths are:
|
||||
|
||||
- FastAPI dependency `get_db()`.
|
||||
- Background/service context `get_db_context()`.
|
||||
- `UnitOfWork`, which now sets `app.project_id` on entry.
|
||||
|
||||
Both `get_db()` and `get_db_context()` derive `app.project_id` from
|
||||
`src.core.context.get_current_project_id()` unless the caller passes an explicit
|
||||
project id to `get_db_context()`.
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
python3 scripts/ops/awooop-rls-access-audit.py
|
||||
```
|
||||
|
||||
To include accepted exceptions:
|
||||
|
||||
```bash
|
||||
python3 scripts/ops/awooop-rls-access-audit.py --show-allowed
|
||||
```
|
||||
|
||||
## 2026-05-12 Result
|
||||
|
||||
After fixing direct `get_session_factory()` runtime call sites:
|
||||
|
||||
```text
|
||||
AwoooP RLS access audit: BLOCKED=0 ALLOW=10
|
||||
```
|
||||
|
||||
Accepted exceptions:
|
||||
|
||||
- `apps/api/src/db/base.py`: owns the shared engine/session factory and sets
|
||||
`app.project_id` in `get_db()` / `get_db_context()`.
|
||||
- `apps/api/src/routes/health.py`: raw `asyncpg` health check only runs
|
||||
`SELECT 1`, not tenant table queries.
|
||||
- `apps/api/src/main.py` and `apps/api/src/workers/signal_worker.py`: only log a
|
||||
sanitized DB host suffix.
|
||||
- `apps/api/src/services/incident_approval_service.py`: injects
|
||||
`UnitOfWork`; `UnitOfWork` now sets `app.project_id`.
|
||||
|
||||
Manual scripts under `apps/api/scripts/` and top-level `scripts/` are not API
|
||||
runtime. They still need operator review before being used against production
|
||||
after RLS policy enablement.
|
||||
163
scripts/ops/awooop-rls-access-audit.py
Executable file
163
scripts/ops/awooop-rls-access-audit.py
Executable file
@@ -0,0 +1,163 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Static RLS access-path audit for AWOOOI API runtime code.
|
||||
|
||||
The goal is narrow: find production runtime DB access that may bypass
|
||||
get_db()/get_db_context() and therefore miss SET LOCAL app.project_id.
|
||||
It is intentionally conservative and read-only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
SRC_ROOT = ROOT / "apps/api/src"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Finding:
|
||||
severity: str
|
||||
path: Path
|
||||
line: int
|
||||
rule: str
|
||||
text: str
|
||||
reason: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AllowRule:
|
||||
path: str
|
||||
rule: str
|
||||
text_pattern: re.Pattern[str]
|
||||
reason: str
|
||||
|
||||
|
||||
RULES: list[tuple[str, re.Pattern[str]]] = [
|
||||
("session_factory", re.compile(r"\bget_session_factory\s*\(")),
|
||||
("create_async_engine", re.compile(r"\bcreate_async_engine\s*\(")),
|
||||
("asyncpg_connect", re.compile(r"\basyncpg\.connect\s*\(")),
|
||||
("settings_database_url", re.compile(r"\bsettings\.DATABASE_URL\b")),
|
||||
("env_database_url", re.compile(r"os\.environ(?:\.get)?\([\"']DATABASE_URL[\"']|os\.environ\[[\"']DATABASE_URL[\"']")),
|
||||
]
|
||||
|
||||
|
||||
ALLOW_RULES: tuple[AllowRule, ...] = (
|
||||
AllowRule(
|
||||
"apps/api/src/db/base.py",
|
||||
"settings_database_url",
|
||||
re.compile(r"\bdatabase_url\s*=\s*settings\.DATABASE_URL\b"),
|
||||
"DB engine owner reads DATABASE_URL and sets RLS context in get_db/get_db_context.",
|
||||
),
|
||||
AllowRule(
|
||||
"apps/api/src/db/base.py",
|
||||
"create_async_engine",
|
||||
re.compile(r"\b_engine\s*=\s*create_async_engine\("),
|
||||
"DB engine owner creates the shared async engine.",
|
||||
),
|
||||
AllowRule(
|
||||
"apps/api/src/db/base.py",
|
||||
"session_factory",
|
||||
re.compile(r"\bdef\s+get_session_factory\("),
|
||||
"Factory definition, not a call-site bypass.",
|
||||
),
|
||||
AllowRule(
|
||||
"apps/api/src/db/base.py",
|
||||
"session_factory",
|
||||
re.compile(r"\bfactory\s*=\s*get_session_factory\(\)"),
|
||||
"get_db/get_db_context wrap factory and set app.project_id.",
|
||||
),
|
||||
AllowRule(
|
||||
"apps/api/src/routes/health.py",
|
||||
"settings_database_url",
|
||||
re.compile(r"\bdb_url\s*=\s*settings\.DATABASE_URL\.replace\("),
|
||||
"Health check parses DATABASE_URL for SELECT 1 only.",
|
||||
),
|
||||
AllowRule(
|
||||
"apps/api/src/routes/health.py",
|
||||
"asyncpg_connect",
|
||||
re.compile(r"\basyncpg\.connect\(db_url\)"),
|
||||
"Health check raw asyncpg SELECT 1 does not read tenant tables.",
|
||||
),
|
||||
AllowRule(
|
||||
"apps/api/src/main.py",
|
||||
"settings_database_url",
|
||||
re.compile(r"\bdb_url\s*=\s*settings\.DATABASE_URL\b"),
|
||||
"Startup logs sanitized DB host suffix after init_db.",
|
||||
),
|
||||
AllowRule(
|
||||
"apps/api/src/workers/signal_worker.py",
|
||||
"settings_database_url",
|
||||
re.compile(r"\bdatabase_url=settings\.DATABASE_URL\.split\(\"@\"\)\[-1\]"),
|
||||
"Structured log uses redacted DATABASE_URL suffix only.",
|
||||
),
|
||||
AllowRule(
|
||||
"apps/api/src/services/incident_approval_service.py",
|
||||
"session_factory",
|
||||
re.compile(r"\bsession_factory=get_session_factory\(\),"),
|
||||
"IncidentApprovalService injects UnitOfWork; UnitOfWork now sets app.project_id.",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def classify(path: Path, rule: str, line_text: str) -> tuple[str, str]:
|
||||
rel = path.relative_to(ROOT).as_posix()
|
||||
for allow in ALLOW_RULES:
|
||||
if allow.path == rel and allow.rule == rule and allow.text_pattern.search(line_text):
|
||||
return "ALLOW", allow.reason
|
||||
return "BLOCKED", "Runtime DB access must set app.project_id through get_db/get_db_context or UnitOfWork."
|
||||
|
||||
|
||||
def scan() -> list[Finding]:
|
||||
findings: list[Finding] = []
|
||||
for path in sorted(SRC_ROOT.rglob("*.py")):
|
||||
try:
|
||||
lines = path.read_text(encoding="utf-8").splitlines()
|
||||
except UnicodeDecodeError:
|
||||
lines = path.read_text(errors="replace").splitlines()
|
||||
|
||||
for idx, line in enumerate(lines, start=1):
|
||||
for rule, pattern in RULES:
|
||||
if not pattern.search(line):
|
||||
continue
|
||||
severity, reason = classify(path, rule, line)
|
||||
findings.append(
|
||||
Finding(
|
||||
severity=severity,
|
||||
path=path.relative_to(ROOT),
|
||||
line=idx,
|
||||
rule=rule,
|
||||
text=line.strip(),
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
return findings
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Audit API runtime DB access paths for RLS readiness.")
|
||||
parser.add_argument("--show-allowed", action="store_true", help="Print allowed findings too.")
|
||||
args = parser.parse_args()
|
||||
|
||||
findings = scan()
|
||||
blocked = [item for item in findings if item.severity == "BLOCKED"]
|
||||
allowed = [item for item in findings if item.severity == "ALLOW"]
|
||||
|
||||
print(f"AwoooP RLS access audit: BLOCKED={len(blocked)} ALLOW={len(allowed)}")
|
||||
for item in blocked:
|
||||
print(f"{item.severity} {item.path}:{item.line} [{item.rule}] {item.text}")
|
||||
print(f" reason: {item.reason}")
|
||||
|
||||
if args.show_allowed:
|
||||
for item in allowed:
|
||||
print(f"{item.severity} {item.path}:{item.line} [{item.rule}] {item.text}")
|
||||
print(f" reason: {item.reason}")
|
||||
|
||||
return 2 if blocked else 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user