fix(rls): 收斂 API DB access context
All checks were successful
Code Review / ai-code-review (push) Successful in 21s
CD Pipeline / tests (push) Successful in 1m20s
CD Pipeline / build-and-deploy (push) Successful in 4m15s
CD Pipeline / post-deploy-checks (push) Successful in 1m58s

This commit is contained in:
Your Name
2026-05-12 19:55:13 +08:00
parent 33c0577e93
commit ff30c61c4c
16 changed files with 327 additions and 55 deletions

View File

@@ -17,6 +17,7 @@ PostgreSQL 事務管理器,確保多表操作原子性。
from typing import Any
import structlog
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
logger = structlog.get_logger(__name__)
@@ -49,14 +50,20 @@ class UnitOfWork:
- Redis 操作失敗時必須手動呼叫 rollback()
"""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]):
def __init__(
self,
session_factory: async_sessionmaker[AsyncSession],
project_id: str | None = None,
):
"""
初始化 UnitOfWork
Args:
session_factory: SQLAlchemy async session factory
project_id: RLS project context. None means contextvar/default awoooi.
"""
self._session_factory = session_factory
self._project_id = project_id
self._session: AsyncSession | None = None
self._committed = False
@@ -74,9 +81,18 @@ class UnitOfWork:
async def __aenter__(self) -> "UnitOfWork":
"""進入事務"""
from src.core.context import get_current_project_id
self._session = self._session_factory()
effective_pid = (
self._project_id if self._project_id is not None else get_current_project_id()
)
await self._session.execute(
text("SELECT set_config('app.project_id', :pid, TRUE)"),
{"pid": effective_pid},
)
self._committed = False
logger.debug("uow_started")
logger.debug("uow_started", project_id=effective_pid)
return self
async def __aexit__(

View File

@@ -106,10 +106,13 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
factory = get_session_factory()
async with factory() as session:
try:
from src.core.context import get_current_project_id
# AwoooP Phase 2.3 (2026-05-04 ogt): SET LOCAL app.project_id 讓 RLS Policy 生效
# 預設 'awoooi',多租戶路由將在 middleware 注入實際 project_id
# 預設 'awoooi',多租戶路由將透過 contextvar 注入實際 project_id
await session.execute(
text("SELECT set_config('app.project_id', 'awoooi', TRUE)")
text("SELECT set_config('app.project_id', :pid, TRUE)"),
{"pid": get_current_project_id()},
)
yield session
await session.commit()

View File

@@ -28,7 +28,7 @@ from datetime import timedelta
import structlog
from sqlalchemy import select, update
from src.db.base import get_session_factory
from src.db.base import get_db_context
from src.db.models import AiGovernanceEvent, KnowledgeEntryRecord
from src.utils.timezone import now_taipei
@@ -129,7 +129,7 @@ class KbRotCleaner:
rot_reasons: dict[str, list[str]] = {}
total = 0
async with get_session_factory()() as session:
async with get_db_context() as session:
# 只掃 active 狀態(非 archived
q = await session.execute(
select(KnowledgeEntryRecord).where(
@@ -193,7 +193,7 @@ class KbRotCleaner:
if not result.stale_ids:
return
async with get_session_factory()() as session:
async with get_db_context() as session:
# 逐條更新(避免 bulk update 覆蓋 tags JSONB
q = await session.execute(
select(KnowledgeEntryRecord).where(
@@ -220,7 +220,7 @@ class KbRotCleaner:
async def _save_event(self, result: RotScanResult) -> None:
"""寫 kb_stale 事件到 ai_governance_events。"""
try:
async with get_session_factory()() as session:
async with get_db_context() as session:
event = AiGovernanceEvent(
event_type="kb_stale",
details=result.to_dict(),

View File

@@ -33,7 +33,7 @@ from datetime import timedelta
import structlog
from sqlalchemy import and_, select, update
from src.db.base import get_session_factory
from src.db.base import get_db_context
from src.db.models import KnowledgeEntryRecord
from src.models.knowledge import EntryStatus
from src.utils.timezone import now_taipei
@@ -112,8 +112,7 @@ class KnowledgeDecayJob:
cutoff = now_taipei() - timedelta(days=DECAY_AGE_DAYS)
decayable_statuses = [EntryStatus.DRAFT.value, EntryStatus.REVIEW.value]
session_factory = get_session_factory()
async with session_factory() as db:
async with get_db_context() as db:
# 查30 天未引用view_count=0且 updated_at < cutoff 的 draft/review 條目
stmt = select(KnowledgeEntryRecord).where(
and_(

View File

@@ -29,7 +29,7 @@ from datetime import timedelta
import structlog
from sqlalchemy import and_, select
from src.db.base import get_session_factory
from src.db.base import get_db_context
from src.db.models import AgentSession, AiGovernanceEvent, AutoRepairExecution, IncidentEvidence
from src.utils.timezone import now_taipei
@@ -109,9 +109,7 @@ class OfflineReplayService:
async def _run_replay(self) -> OfflineReplayReport:
cutoff = now_taipei() - timedelta(days=REPLAY_LOOKBACK_DAYS)
session_factory = get_session_factory()
async with session_factory() as db:
async with get_db_context() as db:
# 1. 取最近 N 個有 AgentSession(coordinator) 的 Incident
stmt = (
select(AgentSession.incident_id)
@@ -137,7 +135,7 @@ class OfflineReplayService:
)
results: list[IncidentReplayResult] = []
async with session_factory() as db:
async with get_db_context() as db:
for incident_id in incident_ids:
r = await self._replay_one(db, incident_id)
results.append(r)

View File

@@ -842,14 +842,13 @@ class AIRouter:
空 dict 代表無資料或查詢失敗caller 應降級為忽略)。
"""
try:
from src.db.base import get_session_factory
from src.db.base import get_db_context
from src.repositories.aider_event_repository import AiderEventRepository
except ImportError:
return {}
try:
sf = get_session_factory()
async with sf() as sess:
async with get_db_context() as sess:
repo_obj = AiderEventRepository(sess)
stats = await repo_obj.model_stats_since(days=days)
except Exception:

View File

@@ -28,7 +28,7 @@ from datetime import timedelta
import structlog
from sqlalchemy import func, select, text
from src.db.base import get_session_factory
from src.db.base import get_db_context
from src.db.models import AiGovernanceEvent, AutoRepairExecution, ApprovalRecord
from src.utils.timezone import now_taipei
@@ -127,7 +127,7 @@ class AiSloCalculator:
try:
since = now_taipei() - timedelta(days=SLO_WINDOW_DAYS)
async with get_session_factory()() as session:
async with get_db_context() as session:
slo1 = await self._calc_auto_success_rate(session, since)
slo2 = await self._calc_human_override_rate(session, since)
slo3 = await self._calc_false_neg_rate(session, since)
@@ -210,7 +210,7 @@ class AiSloCalculator:
只在 any_violated=True 時呼叫。不管舊違反是否解決。
"""
try:
async with get_session_factory()() as session:
async with get_db_context() as session:
event = AiGovernanceEvent(
event_type="slo_violation",
details=report.to_dict(),

View File

@@ -1933,14 +1933,14 @@ class DecisionManager:
try:
from src.core.feature_flags import aiops_flags as _p6_flags
if _p6_flags.is_sub_flag_enabled("AIOPS_P6_SELF_DEMOTION"):
from src.db.base import get_session_factory as _p6_sf
from src.db.base import get_db_context as _p6_db_context
from src.db.models import AiGovernanceEvent as _GovernanceEvent
from sqlalchemy import select as _p6_select, func as _p6_func
from datetime import timedelta as _p6_td
_now = __import__("src.utils.timezone", fromlist=["now_taipei"]).now_taipei()
async with _p6_sf()() as _p6_sess:
async with _p6_db_context() as _p6_sess:
# 過去 7 天有幾筆未解決的 slo_violation
_viol_7d_q = await _p6_sess.execute(
_p6_select(_p6_func.count()).where(
@@ -1980,8 +1980,8 @@ class DecisionManager:
)
# 記錄保守模式事件
try:
from src.db.base import get_session_factory as _p6_sf2
async with _p6_sf2()() as _s2:
from src.db.base import get_db_context as _p6_db_context2
async with _p6_db_context2() as _s2:
_s2.add(_GovernanceEvent(
event_type="conservative_mode",
details={
@@ -2021,8 +2021,8 @@ class DecisionManager:
_push_decision_to_telegram(incident, token.proposal_data)
)
try:
from src.db.base import get_session_factory as _p6_sf3
async with _p6_sf3()() as _s3:
from src.db.base import get_db_context as _p6_db_context3
async with _p6_db_context3() as _s3:
_s3.add(_GovernanceEvent(
event_type="self_demotion",
details={

View File

@@ -424,11 +424,10 @@ class DynamicBaselineService:
async def _pg_upsert_baseline(self, state: BaselineState, promql: str, lookback_hours: int) -> None:
"""寫入 DynamicBaselineRecord 到 PostgreSQLINSERT不更新舊記錄"""
try:
from src.db.base import get_session_factory
from src.db.base import get_db_context
from src.db.models import DynamicBaselineRecord
factory = get_session_factory()
async with factory() as session:
async with get_db_context() as session:
record = DynamicBaselineRecord(
metric_name=state.metric_name,
mean=state.mean,
@@ -449,11 +448,10 @@ class DynamicBaselineService:
try:
from sqlalchemy import select
from src.db.base import get_session_factory
from src.db.base import get_db_context
from src.db.models import DynamicBaselineRecord
factory = get_session_factory()
async with factory() as session:
async with get_db_context() as session:
stmt = (
select(DynamicBaselineRecord)
.where(DynamicBaselineRecord.metric_name == metric_name)

View File

@@ -52,7 +52,7 @@ from pathlib import Path
import structlog
from sqlalchemy import and_, select, text as sql_text
from src.db.base import get_session_factory
from src.db.base import get_db_context
from src.db.models import AgentSession, AutoRepairExecution, IncidentEvidence
from src.utils.timezone import now_taipei
@@ -107,9 +107,7 @@ class FineTuneExporter:
async def _run_export(self) -> tuple[str | None, int]:
cutoff = now_taipei() - timedelta(days=EXPORT_LOOKBACK_DAYS)
session_factory = get_session_factory()
async with session_factory() as db:
async with get_db_context() as db:
# 1. 取得成功驗證的 EvidenceSnapshot有 evidence_summary + verification_result='success'
stmt = select(IncidentEvidence).where(
and_(
@@ -153,7 +151,7 @@ class FineTuneExporter:
with open(output_path, 'rb') as _f:
_checksum = hashlib.sha256(_f.read()).hexdigest()
_ids = [str(ev.id) for ev in evidences]
async with session_factory() as _db:
async with get_db_context() as _db:
await _db.execute(
sql_text("""
INSERT INTO finetune_exports (

View File

@@ -296,12 +296,11 @@ class LogAnomalyDetector:
"""
try:
from sqlalchemy.dialects.postgresql import insert as pg_insert
from src.db.base import get_session_factory
from src.db.base import get_db_context
from src.db.models import LogClusterRecord
from src.utils.timezone import now_taipei
factory = get_session_factory()
async with factory() as session:
async with get_db_context() as session:
stmt = pg_insert(LogClusterRecord).values(
cluster_id=cluster.cluster_id,
template=cluster.template,

View File

@@ -39,7 +39,7 @@ from dataclasses import dataclass
import structlog
from sqlalchemy import func, select
from src.db.base import get_session_factory
from src.db.base import get_db_context
from src.db.models import AiGovernanceEvent, PlaybookRecord
from src.utils.timezone import now_taipei
@@ -115,7 +115,7 @@ class TrustDriftDetector:
TrustDistribution樣本不足時 drift_detected=False
"""
try:
async with get_session_factory()() as session:
async with get_db_context() as session:
# 只計算 approved 狀態的 Playbook
total_q = await session.execute(
select(func.count()).where(
@@ -215,7 +215,7 @@ class TrustDriftDetector:
async def save_drift_event(self, dist: TrustDistribution) -> None:
"""將信任度漂移事件寫入 ai_governance_events。"""
try:
async with get_session_factory()() as session:
async with get_db_context() as session:
event = AiGovernanceEvent(
event_type="trust_drift",
details={

View File

@@ -21,7 +21,7 @@ from typing import Any
import structlog
from src.core.redis_client import get_redis, get_worker_redis, init_worker_redis_pool
from src.db.base import get_session_factory
from src.db.base import get_db_context
from src.models.aider import AiderEventIn
from src.repositories.aider_event_repository import AiderEventRepository
from src.services.aider_event_service import build_signal_data, should_create_incident
@@ -123,7 +123,7 @@ class AiderEventProcessor:
self, stream_key: str, msg_id: Any, data: dict, _session_factory=None
) -> None:
"""處理單筆 messageparse → (maybe) incident → DB write → ACK。
_session_factory: 可注入測試用 factory,預設使用 get_session_factory()
_session_factory: 可注入測試用 factoryproduction 預設使用 get_db_context() 設定 RLS context
"""
try:
raw = data.get(b"payload") or data.get("payload")
@@ -151,14 +151,22 @@ class AiderEventProcessor:
# 不中斷 — 即使 incident 失敗event 仍要持久化
try:
session_factory = _session_factory or get_session_factory()
async with session_factory() as session:
repo = AiderEventRepository(session)
await repo.insert(
session_id=ev.session_id, ts=ev.ts, type_=ev.type,
host=ev.host, payload=ev.payload, incident_id=incident_id,
)
await session.commit()
if _session_factory is None:
async with get_db_context() as session:
repo = AiderEventRepository(session)
await repo.insert(
session_id=ev.session_id, ts=ev.ts, type_=ev.type,
host=ev.host, payload=ev.payload, incident_id=incident_id,
)
else:
session_factory = _session_factory
async with session_factory() as session:
repo = AiderEventRepository(session)
await repo.insert(
session_id=ev.session_id, ts=ev.ts, type_=ev.type,
host=ev.host, payload=ev.payload, incident_id=incident_id,
)
await session.commit()
except Exception:
logger.exception("aider_processor_db_write_failed",
session_id=ev.session_id)

View File

@@ -1,3 +1,45 @@
## 2026-05-12 | RLS Access Path Audit 收斂
**背景**RLS role bootstrap 已完成後,下一個 gate 是確認 API runtime DB access 都會設定 `app.project_id`;否則一旦 fail-closed policy 上線,直接 session factory 入口會讀不到資料或寫入失敗。
**runtime 修補**
- `get_db()`
-`get_db_context()` 對齊,改讀 `src.core.context.get_current_project_id()` 並以 bind parameter 設定 `app.project_id`
- `UnitOfWork`
- `__aenter__` 會讀 `src.core.context.get_current_project_id()`,並執行 `SELECT set_config('app.project_id', :pid, TRUE)`
- `IncidentApprovalService` 繼續注入 session factory但經由 `UnitOfWork` 進入 RLS-safe path。
- 將 production runtime 直接 `get_session_factory()` call sites 改為 `get_db_context()`
- `apps/api/src/jobs/kb_rot_cleaner.py`
- `apps/api/src/jobs/knowledge_decay_job.py`
- `apps/api/src/jobs/offline_replay_service.py`
- `apps/api/src/services/ai_router.py`
- `apps/api/src/services/ai_slo_calculator.py`
- `apps/api/src/services/decision_manager.py`
- `apps/api/src/services/dynamic_baseline_service.py`
- `apps/api/src/services/finetune_exporter.py`
- `apps/api/src/services/log_anomaly_detector.py`
- `apps/api/src/services/trust_drift_detector.py`
- `apps/api/src/workers/aider_event_processor.py`
- `aider_event_processor` 的 production path 改走 `get_db_context()`;測試注入 `_session_factory` 時仍保留測試隔離。
**新增 audit gate**
- `scripts/ops/awooop-rls-access-audit.py`
- static 掃描 `apps/api/src` runtime 中的 `get_session_factory()``create_async_engine()``asyncpg.connect()``settings.DATABASE_URL`
- 只允許 engine owner、health `SELECT 1`、sanitized log、UnitOfWork injection 等明確例外allowlist 以 path/rule/text pattern 判斷,避免行號漂移造成誤報。
- exit `2` 表示還有 runtime blocker。
- `docs/runbooks/AWOOOP-RLS-ACCESS-AUDIT.md` 記錄 gate 與例外。
**驗證**
- `python3 scripts/ops/awooop-rls-access-audit.py --show-allowed``BLOCKED=0 ALLOW=10`
- `python3 -m py_compile` 對修改過的 runtime 檔與 audit script → passed。
- `scripts/ops/awooop-rls-preflight.sh --exact-counts` → 仍為 `PASS=7 WARN=0 BLOCKED=1`;唯一 blocker 仍是尚未啟用 RLS policy符合預期。
- Production health `/api/v1/health` → 200 healthy。
- 嘗試跑 `python3 -m pytest ...`,但本機 `/usr/bin/python3``pytest`,且 repo 內未找到可用 venv本輪未安裝依賴改以 compile/static/live smoke 驗證。
**下一步**
- 針對 manual scripts (`apps/api/scripts/`、top-level `scripts/`) 補 operator review policy它們不是 API runtime但 RLS policy 上線後若直接拿 `DATABASE_URL` 操作 tenant tables仍需明確 `SET LOCAL app.project_id` 或用 migration/operator role。
- 產出第一批 staged policy enablement SQL先從空表 / 低流量 AwoooP tables canary不從 incidents / knowledge_entries 開始。
## 2026-05-12 | RLS Role Bootstrap 已套用
**背景**:上一輪已新增 `scripts/ops/awooop-rls-role-bootstrap.sql`,但尚未執行;使用者批准後,本輪只執行 role bootstrap不啟用 RLS policy。

View File

@@ -0,0 +1,49 @@
# AwoooP RLS Access Path Audit
> Purpose: verify API runtime DB access paths are ready for fail-closed RLS.
Before enabling RLS policies, runtime database access must set
`app.project_id`. The approved paths are:
- FastAPI dependency `get_db()`.
- Background/service context `get_db_context()`.
- `UnitOfWork`, which now sets `app.project_id` on entry.
Both `get_db()` and `get_db_context()` derive `app.project_id` from
`src.core.context.get_current_project_id()` unless the caller passes an explicit
project id to `get_db_context()`.
Run:
```bash
python3 scripts/ops/awooop-rls-access-audit.py
```
To include accepted exceptions:
```bash
python3 scripts/ops/awooop-rls-access-audit.py --show-allowed
```
## 2026-05-12 Result
After fixing direct `get_session_factory()` runtime call sites:
```text
AwoooP RLS access audit: BLOCKED=0 ALLOW=10
```
Accepted exceptions:
- `apps/api/src/db/base.py`: owns the shared engine/session factory and sets
`app.project_id` in `get_db()` / `get_db_context()`.
- `apps/api/src/routes/health.py`: raw `asyncpg` health check only runs
`SELECT 1`, not tenant table queries.
- `apps/api/src/main.py` and `apps/api/src/workers/signal_worker.py`: only log a
sanitized DB host suffix.
- `apps/api/src/services/incident_approval_service.py`: injects
`UnitOfWork`; `UnitOfWork` now sets `app.project_id`.
Manual scripts under `apps/api/scripts/` and top-level `scripts/` are not API
runtime. They still need operator review before being used against production
after RLS policy enablement.

View File

@@ -0,0 +1,163 @@
#!/usr/bin/env python3
"""Static RLS access-path audit for AWOOOI API runtime code.
The goal is narrow: find production runtime DB access that may bypass
get_db()/get_db_context() and therefore miss SET LOCAL app.project_id.
It is intentionally conservative and read-only.
"""
from __future__ import annotations
import argparse
import re
from dataclasses import dataclass
from pathlib import Path
ROOT = Path(__file__).resolve().parents[2]
SRC_ROOT = ROOT / "apps/api/src"
@dataclass(frozen=True)
class Finding:
severity: str
path: Path
line: int
rule: str
text: str
reason: str
@dataclass(frozen=True)
class AllowRule:
path: str
rule: str
text_pattern: re.Pattern[str]
reason: str
RULES: list[tuple[str, re.Pattern[str]]] = [
("session_factory", re.compile(r"\bget_session_factory\s*\(")),
("create_async_engine", re.compile(r"\bcreate_async_engine\s*\(")),
("asyncpg_connect", re.compile(r"\basyncpg\.connect\s*\(")),
("settings_database_url", re.compile(r"\bsettings\.DATABASE_URL\b")),
("env_database_url", re.compile(r"os\.environ(?:\.get)?\([\"']DATABASE_URL[\"']|os\.environ\[[\"']DATABASE_URL[\"']")),
]
ALLOW_RULES: tuple[AllowRule, ...] = (
AllowRule(
"apps/api/src/db/base.py",
"settings_database_url",
re.compile(r"\bdatabase_url\s*=\s*settings\.DATABASE_URL\b"),
"DB engine owner reads DATABASE_URL and sets RLS context in get_db/get_db_context.",
),
AllowRule(
"apps/api/src/db/base.py",
"create_async_engine",
re.compile(r"\b_engine\s*=\s*create_async_engine\("),
"DB engine owner creates the shared async engine.",
),
AllowRule(
"apps/api/src/db/base.py",
"session_factory",
re.compile(r"\bdef\s+get_session_factory\("),
"Factory definition, not a call-site bypass.",
),
AllowRule(
"apps/api/src/db/base.py",
"session_factory",
re.compile(r"\bfactory\s*=\s*get_session_factory\(\)"),
"get_db/get_db_context wrap factory and set app.project_id.",
),
AllowRule(
"apps/api/src/routes/health.py",
"settings_database_url",
re.compile(r"\bdb_url\s*=\s*settings\.DATABASE_URL\.replace\("),
"Health check parses DATABASE_URL for SELECT 1 only.",
),
AllowRule(
"apps/api/src/routes/health.py",
"asyncpg_connect",
re.compile(r"\basyncpg\.connect\(db_url\)"),
"Health check raw asyncpg SELECT 1 does not read tenant tables.",
),
AllowRule(
"apps/api/src/main.py",
"settings_database_url",
re.compile(r"\bdb_url\s*=\s*settings\.DATABASE_URL\b"),
"Startup logs sanitized DB host suffix after init_db.",
),
AllowRule(
"apps/api/src/workers/signal_worker.py",
"settings_database_url",
re.compile(r"\bdatabase_url=settings\.DATABASE_URL\.split\(\"@\"\)\[-1\]"),
"Structured log uses redacted DATABASE_URL suffix only.",
),
AllowRule(
"apps/api/src/services/incident_approval_service.py",
"session_factory",
re.compile(r"\bsession_factory=get_session_factory\(\),"),
"IncidentApprovalService injects UnitOfWork; UnitOfWork now sets app.project_id.",
),
)
def classify(path: Path, rule: str, line_text: str) -> tuple[str, str]:
rel = path.relative_to(ROOT).as_posix()
for allow in ALLOW_RULES:
if allow.path == rel and allow.rule == rule and allow.text_pattern.search(line_text):
return "ALLOW", allow.reason
return "BLOCKED", "Runtime DB access must set app.project_id through get_db/get_db_context or UnitOfWork."
def scan() -> list[Finding]:
findings: list[Finding] = []
for path in sorted(SRC_ROOT.rglob("*.py")):
try:
lines = path.read_text(encoding="utf-8").splitlines()
except UnicodeDecodeError:
lines = path.read_text(errors="replace").splitlines()
for idx, line in enumerate(lines, start=1):
for rule, pattern in RULES:
if not pattern.search(line):
continue
severity, reason = classify(path, rule, line)
findings.append(
Finding(
severity=severity,
path=path.relative_to(ROOT),
line=idx,
rule=rule,
text=line.strip(),
reason=reason,
)
)
return findings
def main() -> int:
parser = argparse.ArgumentParser(description="Audit API runtime DB access paths for RLS readiness.")
parser.add_argument("--show-allowed", action="store_true", help="Print allowed findings too.")
args = parser.parse_args()
findings = scan()
blocked = [item for item in findings if item.severity == "BLOCKED"]
allowed = [item for item in findings if item.severity == "ALLOW"]
print(f"AwoooP RLS access audit: BLOCKED={len(blocked)} ALLOW={len(allowed)}")
for item in blocked:
print(f"{item.severity} {item.path}:{item.line} [{item.rule}] {item.text}")
print(f" reason: {item.reason}")
if args.show_allowed:
for item in allowed:
print(f"{item.severity} {item.path}:{item.line} [{item.rule}] {item.text}")
print(f" reason: {item.reason}")
return 2 if blocked else 0
if __name__ == "__main__":
raise SystemExit(main())