389 lines
15 KiB
Python
Executable File
389 lines
15 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Read-only AwoooP RLS preflight.
|
|
|
|
This script is designed to run inside the production API pod. It uses the
|
|
pod-local DATABASE_URL and never prints the URL or credentials.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
from dataclasses import asdict, dataclass
|
|
from typing import Any
|
|
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
|
|
|
|
TARGET_TABLES = [
|
|
"incidents",
|
|
"knowledge_entries",
|
|
"playbooks",
|
|
"audit_logs",
|
|
"budget_ledger",
|
|
"awooop_projects",
|
|
"awooop_contracts",
|
|
"awooop_contract_revisions",
|
|
"awooop_published_contracts",
|
|
"awooop_run_state",
|
|
"awooop_run_event",
|
|
"awooop_cost_ledger",
|
|
"awooop_mcp_tool_registry",
|
|
"awooop_mcp_grants",
|
|
"awooop_mcp_credential_refs",
|
|
"awooop_mcp_gateway_audit",
|
|
"awooop_conversation_event",
|
|
"awooop_outbound_message",
|
|
]
|
|
|
|
REQUIRED_ROLES = [
|
|
"awooop_app",
|
|
"awooop_platform_admin",
|
|
"awooop_migration",
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class Check:
|
|
name: str
|
|
status: str
|
|
detail: str
|
|
|
|
|
|
def add(checks: list[Check], name: str, status: str, detail: str) -> None:
|
|
checks.append(Check(name=name, status=status, detail=detail))
|
|
|
|
|
|
async def scalar(conn: Any, sql: str, params: dict[str, Any] | None = None) -> Any:
|
|
return await conn.scalar(text(sql), params or {})
|
|
|
|
|
|
async def rows(conn: Any, sql: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
|
|
result = await conn.execute(text(sql), params or {})
|
|
return [dict(row._mapping) for row in result.fetchall()]
|
|
|
|
|
|
async def collect(exact_counts: bool) -> tuple[list[Check], dict[str, Any]]:
|
|
database_url = os.environ.get("DATABASE_URL")
|
|
if not database_url:
|
|
return [Check("database_url", "BLOCKED", "DATABASE_URL is not set in this environment")], {}
|
|
|
|
engine = create_async_engine(database_url, pool_pre_ping=True)
|
|
checks: list[Check] = []
|
|
evidence: dict[str, Any] = {}
|
|
|
|
async with engine.connect() as conn:
|
|
current_role = await rows(
|
|
conn,
|
|
"""
|
|
SELECT
|
|
current_user AS current_user,
|
|
session_user AS session_user,
|
|
r.rolsuper AS current_user_superuser,
|
|
r.rolcreaterole AS current_user_createrole,
|
|
r.rolcreatedb AS current_user_createdb,
|
|
r.rolbypassrls AS current_user_bypassrls
|
|
FROM pg_roles r
|
|
WHERE r.rolname = current_user
|
|
""",
|
|
)
|
|
evidence["current_role"] = current_role[0] if current_role else {}
|
|
role = evidence["current_role"]
|
|
if role.get("current_user_superuser") or role.get("current_user_bypassrls"):
|
|
add(
|
|
checks,
|
|
"current_role_rls_enforced",
|
|
"BLOCKED",
|
|
f"current_user={role.get('current_user')} can bypass RLS",
|
|
)
|
|
else:
|
|
add(
|
|
checks,
|
|
"current_role_rls_enforced",
|
|
"PASS",
|
|
f"current_user={role.get('current_user')} is subject to RLS",
|
|
)
|
|
|
|
before = await scalar(conn, "SELECT current_setting('app.project_id', TRUE)")
|
|
await scalar(conn, "SELECT set_config('app.project_id', :pid, TRUE)", {"pid": "awoooi"})
|
|
after = await scalar(conn, "SELECT current_setting('app.project_id', TRUE)")
|
|
evidence["project_context_probe"] = {"before": before, "after": after}
|
|
if after == "awoooi":
|
|
add(checks, "project_context_set_config", "PASS", "set_config app.project_id works")
|
|
else:
|
|
add(checks, "project_context_set_config", "BLOCKED", f"expected awoooi, got {after!r}")
|
|
|
|
roles = await rows(
|
|
conn,
|
|
"""
|
|
WITH required_roles(rolname) AS (
|
|
SELECT jsonb_array_elements_text(CAST(:roles_json AS jsonb))
|
|
)
|
|
SELECT
|
|
rr.rolname,
|
|
r.rolsuper,
|
|
r.rolcreaterole,
|
|
r.rolbypassrls,
|
|
r.oid IS NOT NULL AS exists,
|
|
CASE
|
|
WHEN r.oid IS NULL THEN FALSE
|
|
ELSE pg_has_role(current_user, rr.rolname, 'member')
|
|
END AS current_user_is_member
|
|
FROM required_roles rr
|
|
LEFT JOIN pg_roles r ON r.rolname = rr.rolname
|
|
ORDER BY rr.rolname
|
|
""",
|
|
{"roles_json": json.dumps(REQUIRED_ROLES)},
|
|
)
|
|
evidence["required_roles"] = roles
|
|
present_roles = {row["rolname"] for row in roles if row["exists"]}
|
|
missing_roles = [role_name for role_name in REQUIRED_ROLES if role_name not in present_roles]
|
|
if missing_roles:
|
|
add(checks, "required_roles", "BLOCKED", f"missing roles: {', '.join(missing_roles)}")
|
|
else:
|
|
add(checks, "required_roles", "PASS", "all required RLS roles exist")
|
|
|
|
if not role.get("current_user_superuser") and not role.get("current_user_createrole") and missing_roles:
|
|
add(
|
|
checks,
|
|
"role_bootstrap_authority",
|
|
"WARN",
|
|
"current API DB user cannot create missing roles; bootstrap requires postgres/CREATEROLE",
|
|
)
|
|
elif missing_roles:
|
|
add(checks, "role_bootstrap_authority", "PASS", "current DB user can create roles")
|
|
|
|
app_role = next((row for row in roles if row["rolname"] == "awooop_app" and row["exists"]), None)
|
|
if app_role is None:
|
|
add(checks, "app_role_membership", "WARN", "awooop_app role missing; membership not evaluated")
|
|
elif app_role["current_user_is_member"]:
|
|
add(checks, "app_role_membership", "PASS", "current API DB user is member of awooop_app")
|
|
else:
|
|
add(
|
|
checks,
|
|
"app_role_membership",
|
|
"BLOCKED",
|
|
"current API DB user is not a member of awooop_app; policies FOR awooop_app would not apply",
|
|
)
|
|
|
|
table_rows = await rows(
|
|
conn,
|
|
"""
|
|
WITH target(relname) AS (
|
|
SELECT jsonb_array_elements_text(CAST(:tables_json AS jsonb))
|
|
),
|
|
rels AS (
|
|
SELECT
|
|
t.relname,
|
|
c.oid,
|
|
c.relrowsecurity,
|
|
c.relforcerowsecurity,
|
|
pg_get_userbyid(c.relowner) AS table_owner,
|
|
COALESCE(c.reltuples, 0)::bigint AS estimated_rows
|
|
FROM target t
|
|
LEFT JOIN pg_class c
|
|
ON c.relname = t.relname
|
|
AND c.relkind IN ('r', 'p')
|
|
AND c.relnamespace = 'public'::regnamespace
|
|
),
|
|
project_columns AS (
|
|
SELECT table_name, TRUE AS has_project_id
|
|
FROM information_schema.columns
|
|
WHERE table_schema = 'public'
|
|
AND column_name = 'project_id'
|
|
AND table_name IN (SELECT relname FROM target)
|
|
),
|
|
policy_stats AS (
|
|
SELECT
|
|
p.polrelid,
|
|
COUNT(*) AS policy_count,
|
|
BOOL_OR(
|
|
COALESCE(pg_get_expr(p.polqual, p.polrelid), '') ILIKE '%current_setting(''app.project_id'', true) IS NULL%'
|
|
OR COALESCE(pg_get_expr(p.polwithcheck, p.polrelid), '') ILIKE '%current_setting(''app.project_id'', true) IS NULL%'
|
|
) AS has_null_fail_open_policy,
|
|
BOOL_OR(
|
|
COALESCE(pg_get_expr(p.polqual, p.polrelid), '') ILIKE '%current_setting(''app.project_id'', true) = ''''%'
|
|
OR COALESCE(pg_get_expr(p.polwithcheck, p.polrelid), '') ILIKE '%current_setting(''app.project_id'', true) = ''''%'
|
|
) AS has_empty_string_fail_open_policy
|
|
FROM pg_policy p
|
|
GROUP BY p.polrelid
|
|
)
|
|
SELECT
|
|
r.relname AS table_name,
|
|
r.oid IS NOT NULL AS exists,
|
|
COALESCE(pc.has_project_id, FALSE) AS has_project_id,
|
|
COALESCE(r.relrowsecurity, FALSE) AS rls_enabled,
|
|
COALESCE(r.relforcerowsecurity, FALSE) AS rls_forced,
|
|
COALESCE(ps.policy_count, 0) AS policy_count,
|
|
COALESCE(ps.has_null_fail_open_policy, FALSE) AS has_null_fail_open_policy,
|
|
COALESCE(ps.has_empty_string_fail_open_policy, FALSE) AS has_empty_string_fail_open_policy,
|
|
r.table_owner,
|
|
r.estimated_rows
|
|
FROM rels r
|
|
LEFT JOIN project_columns pc ON pc.table_name = r.relname
|
|
LEFT JOIN policy_stats ps ON ps.polrelid = r.oid
|
|
ORDER BY r.relname
|
|
""",
|
|
{"tables_json": json.dumps(TARGET_TABLES)},
|
|
)
|
|
evidence["tables"] = table_rows
|
|
|
|
existing = [row for row in table_rows if row["exists"]]
|
|
missing_project_id = [row["table_name"] for row in existing if not row["has_project_id"]]
|
|
if missing_project_id:
|
|
add(checks, "project_id_columns", "BLOCKED", f"missing project_id: {', '.join(missing_project_id)}")
|
|
else:
|
|
add(checks, "project_id_columns", "PASS", "all existing target tables have project_id")
|
|
|
|
rls_missing = [
|
|
row["table_name"]
|
|
for row in existing
|
|
if not row["rls_enabled"] or not row["rls_forced"] or row["policy_count"] == 0
|
|
]
|
|
if rls_missing:
|
|
add(
|
|
checks,
|
|
"rls_enabled_forced_policy",
|
|
"BLOCKED",
|
|
f"RLS not fully enabled/forced/policied: {', '.join(rls_missing)}",
|
|
)
|
|
else:
|
|
add(checks, "rls_enabled_forced_policy", "PASS", "all existing target tables have forced RLS policy")
|
|
|
|
fail_open = [
|
|
row["table_name"]
|
|
for row in existing
|
|
if row["has_null_fail_open_policy"] or row["has_empty_string_fail_open_policy"]
|
|
]
|
|
if fail_open:
|
|
add(checks, "fail_open_policies", "BLOCKED", f"fail-open policy expressions: {', '.join(fail_open)}")
|
|
else:
|
|
add(checks, "fail_open_policies", "PASS", "no fail-open policy expressions detected")
|
|
|
|
if exact_counts:
|
|
exact_rows: list[dict[str, Any]] = []
|
|
for row in existing:
|
|
if not row["has_project_id"]:
|
|
continue
|
|
quoted = '"' + row["table_name"].replace('"', '""') + '"'
|
|
count_row = await rows(
|
|
conn,
|
|
f"""
|
|
SELECT
|
|
:table_name AS table_name,
|
|
CAST(:rls_filtered AS boolean) AS rls_filtered,
|
|
current_setting('app.project_id', TRUE) AS project_context,
|
|
COUNT(*) AS total_rows,
|
|
COUNT(*) FILTER (WHERE project_id IS NULL) AS null_project_id_rows
|
|
FROM {quoted}
|
|
""",
|
|
{
|
|
"table_name": row["table_name"],
|
|
"rls_filtered": bool(row["rls_enabled"]),
|
|
},
|
|
)
|
|
exact_rows.extend(count_row)
|
|
evidence["exact_counts"] = exact_rows
|
|
null_tables = [row["table_name"] for row in exact_rows if int(row["null_project_id_rows"]) > 0]
|
|
rls_filtered_tables = [row["table_name"] for row in exact_rows if row.get("rls_filtered")]
|
|
if null_tables:
|
|
add(checks, "project_id_backfill", "BLOCKED", f"NULL project_id remains: {', '.join(null_tables)}")
|
|
else:
|
|
add(checks, "project_id_backfill", "PASS", "no NULL project_id rows in counted tables")
|
|
if rls_filtered_tables:
|
|
add(
|
|
checks,
|
|
"exact_counts_scope",
|
|
"WARN",
|
|
"counts for RLS-enabled tables are tenant-visible only; use operator role for global counts",
|
|
)
|
|
else:
|
|
add(checks, "project_id_backfill", "WARN", "exact counts skipped; rerun with --exact-counts before enabling RLS")
|
|
|
|
await engine.dispose()
|
|
return checks, evidence
|
|
|
|
|
|
def print_human(checks: list[Check], evidence: dict[str, Any]) -> None:
|
|
blocked = sum(1 for check in checks if check.status == "BLOCKED")
|
|
warn = sum(1 for check in checks if check.status == "WARN")
|
|
passed = sum(1 for check in checks if check.status == "PASS")
|
|
print(f"AwoooP RLS preflight: PASS={passed} WARN={warn} BLOCKED={blocked}")
|
|
for check in checks:
|
|
print(f"{check.status:<7} {check.name}: {check.detail}")
|
|
|
|
role = evidence.get("current_role") or {}
|
|
if role:
|
|
print(
|
|
"role "
|
|
f"current_user={role.get('current_user')} "
|
|
f"session_user={role.get('session_user')} "
|
|
f"superuser={role.get('current_user_superuser')} "
|
|
f"createrole={role.get('current_user_createrole')} "
|
|
f"bypassrls={role.get('current_user_bypassrls')}"
|
|
)
|
|
|
|
for row in evidence.get("tables", []):
|
|
print(
|
|
"table "
|
|
f"{row['table_name']} "
|
|
f"exists={row['exists']} "
|
|
f"project_id={row['has_project_id']} "
|
|
f"rls={row['rls_enabled']} "
|
|
f"force={row['rls_forced']} "
|
|
f"policies={row['policy_count']} "
|
|
f"fail_open_null={row['has_null_fail_open_policy']} "
|
|
f"fail_open_empty={row['has_empty_string_fail_open_policy']} "
|
|
f"owner={row['table_owner']} "
|
|
f"estimated_rows={row['estimated_rows']}"
|
|
)
|
|
|
|
for row in evidence.get("exact_counts", []):
|
|
scope = "rls_filtered" if row.get("rls_filtered") else "global_visible"
|
|
print(
|
|
"count "
|
|
f"{row['table_name']} "
|
|
f"scope={scope} "
|
|
f"project_context={row.get('project_context')} "
|
|
f"total_rows={row['total_rows']} "
|
|
f"null_project_id_rows={row['null_project_id_rows']}"
|
|
)
|
|
|
|
|
|
async def main() -> int:
|
|
parser = argparse.ArgumentParser(description="Run read-only AwoooP RLS preflight checks.")
|
|
parser.add_argument("--exact-counts", action="store_true", help="Run exact COUNT(*) checks for project_id backfill.")
|
|
parser.add_argument("--json", action="store_true", help="Print JSON instead of human-readable output.")
|
|
args = parser.parse_args()
|
|
|
|
checks, evidence = await collect(exact_counts=args.exact_counts)
|
|
blocked = any(check.status == "BLOCKED" for check in checks)
|
|
|
|
if args.json:
|
|
print(
|
|
json.dumps(
|
|
{"checks": [asdict(check) for check in checks], "evidence": evidence},
|
|
ensure_ascii=False,
|
|
default=str,
|
|
)
|
|
)
|
|
else:
|
|
print_human(checks, evidence)
|
|
|
|
return 2 if blocked else 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
raise SystemExit(asyncio.run(main()))
|
|
except KeyboardInterrupt:
|
|
raise SystemExit(130)
|
|
except Exception as exc:
|
|
print(f"BLOCKED preflight_exception: {exc}", file=sys.stderr)
|
|
raise SystemExit(2)
|