Files
awoooi/scripts/ops/awooop_rls_preflight.py
Your Name b7af597459
All checks were successful
Code Review / ai-code-review (push) Successful in 10s
chore(rls): 套用 tool registry canary wave1.1
2026-05-12 21:15:14 +08:00

389 lines
15 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Read-only AwoooP RLS preflight.
This script is designed to run inside the production API pod. It uses the
pod-local DATABASE_URL and never prints the URL or credentials.
"""
from __future__ import annotations
import argparse
import asyncio
import json
import os
import sys
from dataclasses import asdict, dataclass
from typing import Any
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine
TARGET_TABLES = [
"incidents",
"knowledge_entries",
"playbooks",
"audit_logs",
"budget_ledger",
"awooop_projects",
"awooop_contracts",
"awooop_contract_revisions",
"awooop_published_contracts",
"awooop_run_state",
"awooop_run_event",
"awooop_cost_ledger",
"awooop_mcp_tool_registry",
"awooop_mcp_grants",
"awooop_mcp_credential_refs",
"awooop_mcp_gateway_audit",
"awooop_conversation_event",
"awooop_outbound_message",
]
REQUIRED_ROLES = [
"awooop_app",
"awooop_platform_admin",
"awooop_migration",
]
@dataclass
class Check:
name: str
status: str
detail: str
def add(checks: list[Check], name: str, status: str, detail: str) -> None:
checks.append(Check(name=name, status=status, detail=detail))
async def scalar(conn: Any, sql: str, params: dict[str, Any] | None = None) -> Any:
return await conn.scalar(text(sql), params or {})
async def rows(conn: Any, sql: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
result = await conn.execute(text(sql), params or {})
return [dict(row._mapping) for row in result.fetchall()]
async def collect(exact_counts: bool) -> tuple[list[Check], dict[str, Any]]:
database_url = os.environ.get("DATABASE_URL")
if not database_url:
return [Check("database_url", "BLOCKED", "DATABASE_URL is not set in this environment")], {}
engine = create_async_engine(database_url, pool_pre_ping=True)
checks: list[Check] = []
evidence: dict[str, Any] = {}
async with engine.connect() as conn:
current_role = await rows(
conn,
"""
SELECT
current_user AS current_user,
session_user AS session_user,
r.rolsuper AS current_user_superuser,
r.rolcreaterole AS current_user_createrole,
r.rolcreatedb AS current_user_createdb,
r.rolbypassrls AS current_user_bypassrls
FROM pg_roles r
WHERE r.rolname = current_user
""",
)
evidence["current_role"] = current_role[0] if current_role else {}
role = evidence["current_role"]
if role.get("current_user_superuser") or role.get("current_user_bypassrls"):
add(
checks,
"current_role_rls_enforced",
"BLOCKED",
f"current_user={role.get('current_user')} can bypass RLS",
)
else:
add(
checks,
"current_role_rls_enforced",
"PASS",
f"current_user={role.get('current_user')} is subject to RLS",
)
before = await scalar(conn, "SELECT current_setting('app.project_id', TRUE)")
await scalar(conn, "SELECT set_config('app.project_id', :pid, TRUE)", {"pid": "awoooi"})
after = await scalar(conn, "SELECT current_setting('app.project_id', TRUE)")
evidence["project_context_probe"] = {"before": before, "after": after}
if after == "awoooi":
add(checks, "project_context_set_config", "PASS", "set_config app.project_id works")
else:
add(checks, "project_context_set_config", "BLOCKED", f"expected awoooi, got {after!r}")
roles = await rows(
conn,
"""
WITH required_roles(rolname) AS (
SELECT jsonb_array_elements_text(CAST(:roles_json AS jsonb))
)
SELECT
rr.rolname,
r.rolsuper,
r.rolcreaterole,
r.rolbypassrls,
r.oid IS NOT NULL AS exists,
CASE
WHEN r.oid IS NULL THEN FALSE
ELSE pg_has_role(current_user, rr.rolname, 'member')
END AS current_user_is_member
FROM required_roles rr
LEFT JOIN pg_roles r ON r.rolname = rr.rolname
ORDER BY rr.rolname
""",
{"roles_json": json.dumps(REQUIRED_ROLES)},
)
evidence["required_roles"] = roles
present_roles = {row["rolname"] for row in roles if row["exists"]}
missing_roles = [role_name for role_name in REQUIRED_ROLES if role_name not in present_roles]
if missing_roles:
add(checks, "required_roles", "BLOCKED", f"missing roles: {', '.join(missing_roles)}")
else:
add(checks, "required_roles", "PASS", "all required RLS roles exist")
if not role.get("current_user_superuser") and not role.get("current_user_createrole") and missing_roles:
add(
checks,
"role_bootstrap_authority",
"WARN",
"current API DB user cannot create missing roles; bootstrap requires postgres/CREATEROLE",
)
elif missing_roles:
add(checks, "role_bootstrap_authority", "PASS", "current DB user can create roles")
app_role = next((row for row in roles if row["rolname"] == "awooop_app" and row["exists"]), None)
if app_role is None:
add(checks, "app_role_membership", "WARN", "awooop_app role missing; membership not evaluated")
elif app_role["current_user_is_member"]:
add(checks, "app_role_membership", "PASS", "current API DB user is member of awooop_app")
else:
add(
checks,
"app_role_membership",
"BLOCKED",
"current API DB user is not a member of awooop_app; policies FOR awooop_app would not apply",
)
table_rows = await rows(
conn,
"""
WITH target(relname) AS (
SELECT jsonb_array_elements_text(CAST(:tables_json AS jsonb))
),
rels AS (
SELECT
t.relname,
c.oid,
c.relrowsecurity,
c.relforcerowsecurity,
pg_get_userbyid(c.relowner) AS table_owner,
COALESCE(c.reltuples, 0)::bigint AS estimated_rows
FROM target t
LEFT JOIN pg_class c
ON c.relname = t.relname
AND c.relkind IN ('r', 'p')
AND c.relnamespace = 'public'::regnamespace
),
project_columns AS (
SELECT table_name, TRUE AS has_project_id
FROM information_schema.columns
WHERE table_schema = 'public'
AND column_name = 'project_id'
AND table_name IN (SELECT relname FROM target)
),
policy_stats AS (
SELECT
p.polrelid,
COUNT(*) AS policy_count,
BOOL_OR(
COALESCE(pg_get_expr(p.polqual, p.polrelid), '') ILIKE '%current_setting(''app.project_id'', true) IS NULL%'
OR COALESCE(pg_get_expr(p.polwithcheck, p.polrelid), '') ILIKE '%current_setting(''app.project_id'', true) IS NULL%'
) AS has_null_fail_open_policy,
BOOL_OR(
COALESCE(pg_get_expr(p.polqual, p.polrelid), '') ILIKE '%current_setting(''app.project_id'', true) = ''''%'
OR COALESCE(pg_get_expr(p.polwithcheck, p.polrelid), '') ILIKE '%current_setting(''app.project_id'', true) = ''''%'
) AS has_empty_string_fail_open_policy
FROM pg_policy p
GROUP BY p.polrelid
)
SELECT
r.relname AS table_name,
r.oid IS NOT NULL AS exists,
COALESCE(pc.has_project_id, FALSE) AS has_project_id,
COALESCE(r.relrowsecurity, FALSE) AS rls_enabled,
COALESCE(r.relforcerowsecurity, FALSE) AS rls_forced,
COALESCE(ps.policy_count, 0) AS policy_count,
COALESCE(ps.has_null_fail_open_policy, FALSE) AS has_null_fail_open_policy,
COALESCE(ps.has_empty_string_fail_open_policy, FALSE) AS has_empty_string_fail_open_policy,
r.table_owner,
r.estimated_rows
FROM rels r
LEFT JOIN project_columns pc ON pc.table_name = r.relname
LEFT JOIN policy_stats ps ON ps.polrelid = r.oid
ORDER BY r.relname
""",
{"tables_json": json.dumps(TARGET_TABLES)},
)
evidence["tables"] = table_rows
existing = [row for row in table_rows if row["exists"]]
missing_project_id = [row["table_name"] for row in existing if not row["has_project_id"]]
if missing_project_id:
add(checks, "project_id_columns", "BLOCKED", f"missing project_id: {', '.join(missing_project_id)}")
else:
add(checks, "project_id_columns", "PASS", "all existing target tables have project_id")
rls_missing = [
row["table_name"]
for row in existing
if not row["rls_enabled"] or not row["rls_forced"] or row["policy_count"] == 0
]
if rls_missing:
add(
checks,
"rls_enabled_forced_policy",
"BLOCKED",
f"RLS not fully enabled/forced/policied: {', '.join(rls_missing)}",
)
else:
add(checks, "rls_enabled_forced_policy", "PASS", "all existing target tables have forced RLS policy")
fail_open = [
row["table_name"]
for row in existing
if row["has_null_fail_open_policy"] or row["has_empty_string_fail_open_policy"]
]
if fail_open:
add(checks, "fail_open_policies", "BLOCKED", f"fail-open policy expressions: {', '.join(fail_open)}")
else:
add(checks, "fail_open_policies", "PASS", "no fail-open policy expressions detected")
if exact_counts:
exact_rows: list[dict[str, Any]] = []
for row in existing:
if not row["has_project_id"]:
continue
quoted = '"' + row["table_name"].replace('"', '""') + '"'
count_row = await rows(
conn,
f"""
SELECT
:table_name AS table_name,
CAST(:rls_filtered AS boolean) AS rls_filtered,
current_setting('app.project_id', TRUE) AS project_context,
COUNT(*) AS total_rows,
COUNT(*) FILTER (WHERE project_id IS NULL) AS null_project_id_rows
FROM {quoted}
""",
{
"table_name": row["table_name"],
"rls_filtered": bool(row["rls_enabled"]),
},
)
exact_rows.extend(count_row)
evidence["exact_counts"] = exact_rows
null_tables = [row["table_name"] for row in exact_rows if int(row["null_project_id_rows"]) > 0]
rls_filtered_tables = [row["table_name"] for row in exact_rows if row.get("rls_filtered")]
if null_tables:
add(checks, "project_id_backfill", "BLOCKED", f"NULL project_id remains: {', '.join(null_tables)}")
else:
add(checks, "project_id_backfill", "PASS", "no NULL project_id rows in counted tables")
if rls_filtered_tables:
add(
checks,
"exact_counts_scope",
"WARN",
"counts for RLS-enabled tables are tenant-visible only; use operator role for global counts",
)
else:
add(checks, "project_id_backfill", "WARN", "exact counts skipped; rerun with --exact-counts before enabling RLS")
await engine.dispose()
return checks, evidence
def print_human(checks: list[Check], evidence: dict[str, Any]) -> None:
blocked = sum(1 for check in checks if check.status == "BLOCKED")
warn = sum(1 for check in checks if check.status == "WARN")
passed = sum(1 for check in checks if check.status == "PASS")
print(f"AwoooP RLS preflight: PASS={passed} WARN={warn} BLOCKED={blocked}")
for check in checks:
print(f"{check.status:<7} {check.name}: {check.detail}")
role = evidence.get("current_role") or {}
if role:
print(
"role "
f"current_user={role.get('current_user')} "
f"session_user={role.get('session_user')} "
f"superuser={role.get('current_user_superuser')} "
f"createrole={role.get('current_user_createrole')} "
f"bypassrls={role.get('current_user_bypassrls')}"
)
for row in evidence.get("tables", []):
print(
"table "
f"{row['table_name']} "
f"exists={row['exists']} "
f"project_id={row['has_project_id']} "
f"rls={row['rls_enabled']} "
f"force={row['rls_forced']} "
f"policies={row['policy_count']} "
f"fail_open_null={row['has_null_fail_open_policy']} "
f"fail_open_empty={row['has_empty_string_fail_open_policy']} "
f"owner={row['table_owner']} "
f"estimated_rows={row['estimated_rows']}"
)
for row in evidence.get("exact_counts", []):
scope = "rls_filtered" if row.get("rls_filtered") else "global_visible"
print(
"count "
f"{row['table_name']} "
f"scope={scope} "
f"project_context={row.get('project_context')} "
f"total_rows={row['total_rows']} "
f"null_project_id_rows={row['null_project_id_rows']}"
)
async def main() -> int:
parser = argparse.ArgumentParser(description="Run read-only AwoooP RLS preflight checks.")
parser.add_argument("--exact-counts", action="store_true", help="Run exact COUNT(*) checks for project_id backfill.")
parser.add_argument("--json", action="store_true", help="Print JSON instead of human-readable output.")
args = parser.parse_args()
checks, evidence = await collect(exact_counts=args.exact_counts)
blocked = any(check.status == "BLOCKED" for check in checks)
if args.json:
print(
json.dumps(
{"checks": [asdict(check) for check in checks], "evidence": evidence},
ensure_ascii=False,
default=str,
)
)
else:
print_human(checks, evidence)
return 2 if blocked else 0
if __name__ == "__main__":
try:
raise SystemExit(asyncio.run(main()))
except KeyboardInterrupt:
raise SystemExit(130)
except Exception as exc:
print(f"BLOCKED preflight_exception: {exc}", file=sys.stderr)
raise SystemExit(2)