311 lines
11 KiB
Python
311 lines
11 KiB
Python
"""
|
||
Run State Machine
|
||
==================
|
||
AwoooP Phase 4: Run FSM 轉換規則 + Worker Lease(ADR-114/ADR-119)
|
||
2026-05-04 ogt + Claude Sonnet 4.6
|
||
|
||
狀態機:
|
||
PENDING → RUNNING(worker 取得 lease)
|
||
RUNNING → WAITING_TOOL(等待 tool call 完成)
|
||
RUNNING → WAITING_APPROVAL(等待人工審核)
|
||
RUNNING → COMPLETED / FAILED / CANCELLED
|
||
WAITING_TOOL → RUNNING(tool call 完成)
|
||
WAITING_TOOL → FAILED(tool call 失敗 + 超過 max_attempts)
|
||
WAITING_APPROVAL → RUNNING(核准)
|
||
WAITING_APPROVAL → CANCELLED(拒絕/超時)
|
||
* → TIMEOUT(lease_until 過期且超過 max_attempts)
|
||
|
||
SKIP LOCKED:
|
||
Worker 以 SELECT ... FOR UPDATE SKIP LOCKED 取單,防 double-pickup。
|
||
Lease TTL = 60 秒;Heartbeat 每 15 秒更新。
|
||
|
||
Stale run reaper:
|
||
每分鐘掃描 lease_until < NOW() 的 running run:
|
||
attempt_count < max_attempts → 重設 PENDING
|
||
attempt_count >= max_attempts → 標記 FAILED(E-RUN-002)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import socket
|
||
import uuid
|
||
from datetime import UTC, datetime, timedelta
|
||
from typing import TYPE_CHECKING
|
||
|
||
import structlog
|
||
from sqlalchemy import select, text, update
|
||
|
||
from src.db.awooop_models import AwoooPRunState
|
||
from src.db.base import get_db_context
|
||
|
||
if TYPE_CHECKING:
|
||
from uuid import UUID
|
||
|
||
logger = structlog.get_logger(__name__)
|
||
|
||
# Worker lease TTL(秒)
|
||
LEASE_TTL_SECONDS = 60
|
||
HEARTBEAT_INTERVAL_SECONDS = 15
|
||
STALE_REAPER_INTERVAL_SECONDS = 60
|
||
|
||
# 有效的 FSM 轉換表
|
||
# key: from_state, value: set of valid to_states
|
||
_VALID_TRANSITIONS: dict[str, frozenset[str]] = {
|
||
"pending": frozenset({"running", "cancelled"}),
|
||
"running": frozenset({"waiting_tool", "waiting_approval", "completed", "failed", "cancelled", "timeout"}),
|
||
"waiting_tool": frozenset({"running", "failed", "cancelled"}),
|
||
"waiting_approval": frozenset({"running", "cancelled", "timeout"}),
|
||
"completed": frozenset(), # terminal
|
||
"failed": frozenset(), # terminal
|
||
"cancelled": frozenset(), # terminal
|
||
"timeout": frozenset(), # terminal
|
||
}
|
||
|
||
TERMINAL_STATES = frozenset({"completed", "failed", "cancelled", "timeout"})
|
||
|
||
_WORKER_ID = f"{socket.gethostname()}:{uuid.uuid4().hex[:8]}"
|
||
|
||
|
||
def _utc_now_naive() -> datetime:
|
||
"""Return UTC now matching AwoooP timestamp-without-timezone columns."""
|
||
return datetime.now(UTC).replace(tzinfo=None)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# FSM 驗證
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
class InvalidStateTransitionError(Exception):
|
||
def __init__(self, from_state: str, to_state: str) -> None:
|
||
self.from_state = from_state
|
||
self.to_state = to_state
|
||
super().__init__(f"非法 FSM 轉換: {from_state!r} → {to_state!r}")
|
||
|
||
|
||
def validate_transition(from_state: str, to_state: str) -> None:
|
||
"""驗證 FSM 轉換是否合法,非法則拋出 InvalidStateTransitionError"""
|
||
valid_targets = _VALID_TRANSITIONS.get(from_state, frozenset())
|
||
if to_state not in valid_targets:
|
||
raise InvalidStateTransitionError(from_state, to_state)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Worker Lease(SKIP LOCKED)
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
async def acquire_pending_run(
|
||
project_id: str,
|
||
worker_id: str = _WORKER_ID,
|
||
) -> AwoooPRunState | None:
|
||
"""
|
||
以 SKIP LOCKED 取得一筆 PENDING run,並設定 lease。
|
||
|
||
同時只有一個 worker 可取得同一筆 run(PostgreSQL SKIP LOCKED 保證)。
|
||
Returns None 表示目前沒有待處理的 run。
|
||
"""
|
||
now = _utc_now_naive()
|
||
lease_until = now + timedelta(seconds=LEASE_TTL_SECONDS)
|
||
|
||
async with get_db_context(project_id) as db:
|
||
# SKIP LOCKED:其他 worker 已鎖定的 row 直接跳過
|
||
result = await db.execute(
|
||
text("""
|
||
SELECT run_id FROM awooop_run_state
|
||
WHERE project_id = :project_id
|
||
AND state = 'pending'
|
||
AND (lease_until IS NULL OR lease_until < NOW())
|
||
ORDER BY created_at ASC
|
||
LIMIT 1
|
||
FOR UPDATE SKIP LOCKED
|
||
"""),
|
||
{"project_id": project_id},
|
||
)
|
||
row = result.fetchone()
|
||
if row is None:
|
||
return None
|
||
|
||
run_id = row[0]
|
||
|
||
# 更新 lease + 轉為 RUNNING
|
||
await db.execute(
|
||
update(AwoooPRunState)
|
||
.where(
|
||
AwoooPRunState.run_id == run_id,
|
||
AwoooPRunState.project_id == project_id,
|
||
)
|
||
.values(
|
||
state="running",
|
||
lease_until=lease_until,
|
||
heartbeat_at=now,
|
||
worker_id=worker_id,
|
||
started_at=now,
|
||
attempt_count=AwoooPRunState.attempt_count + 1,
|
||
)
|
||
)
|
||
|
||
# 重新讀取完整 record
|
||
result2 = await db.execute(
|
||
select(AwoooPRunState).where(AwoooPRunState.run_id == run_id)
|
||
)
|
||
run = result2.scalar_one()
|
||
|
||
logger.info(
|
||
"run_lease_acquired",
|
||
run_id=str(run_id),
|
||
project_id=project_id,
|
||
worker_id=worker_id,
|
||
attempt_count=run.attempt_count,
|
||
)
|
||
return run
|
||
|
||
|
||
async def heartbeat(run_id: UUID, project_id: str) -> None:
|
||
"""更新 run 的 heartbeat + 延長 lease TTL"""
|
||
now = _utc_now_naive()
|
||
new_lease = now + timedelta(seconds=LEASE_TTL_SECONDS)
|
||
async with get_db_context(project_id) as db:
|
||
await db.execute(
|
||
update(AwoooPRunState)
|
||
.where(
|
||
AwoooPRunState.run_id == run_id,
|
||
AwoooPRunState.state == "running",
|
||
)
|
||
.values(
|
||
heartbeat_at=now,
|
||
lease_until=new_lease,
|
||
)
|
||
)
|
||
|
||
|
||
async def transition(
|
||
run_id: UUID,
|
||
project_id: str,
|
||
to_state: str,
|
||
*,
|
||
error_code: str | None = None,
|
||
error_detail: str | None = None,
|
||
output_sha256: str | None = None,
|
||
cost_usd_delta: float = 0.0,
|
||
step_count_delta: int = 0,
|
||
) -> None:
|
||
"""
|
||
執行 FSM 狀態轉換(含驗證)。
|
||
|
||
先從 DB 讀取 current state,驗證轉換合法性,再 UPDATE。
|
||
terminal state 同時寫入 completed_at。
|
||
"""
|
||
async with get_db_context(project_id) as db:
|
||
result = await db.execute(
|
||
select(AwoooPRunState.state).where(
|
||
AwoooPRunState.run_id == run_id,
|
||
AwoooPRunState.project_id == project_id,
|
||
)
|
||
)
|
||
row = result.fetchone()
|
||
if row is None:
|
||
raise ValueError(f"run {run_id} 不存在或無 RLS 權限")
|
||
|
||
from_state = row[0]
|
||
validate_transition(from_state, to_state)
|
||
|
||
values: dict = {"state": to_state}
|
||
if error_code:
|
||
values["error_code"] = error_code
|
||
if error_detail:
|
||
values["error_detail"] = error_detail
|
||
if output_sha256:
|
||
values["output_sha256"] = output_sha256
|
||
if cost_usd_delta:
|
||
values["cost_usd"] = AwoooPRunState.cost_usd + cost_usd_delta
|
||
if step_count_delta:
|
||
values["step_count"] = AwoooPRunState.step_count + step_count_delta
|
||
if to_state in TERMINAL_STATES:
|
||
values["completed_at"] = _utc_now_naive()
|
||
values["lease_until"] = None
|
||
values["worker_id"] = None
|
||
|
||
await db.execute(
|
||
update(AwoooPRunState)
|
||
.where(AwoooPRunState.run_id == run_id)
|
||
.values(**values)
|
||
)
|
||
|
||
logger.info(
|
||
"run_state_transition",
|
||
run_id=str(run_id),
|
||
from_state=from_state,
|
||
to_state=to_state,
|
||
error_code=error_code,
|
||
)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Stale Run Reaper
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
async def reap_stale_runs(project_id: str) -> int:
|
||
"""
|
||
掃描 lease_until < NOW() 的 RUNNING run。
|
||
- attempt_count < max_attempts → 重設 PENDING(retry)
|
||
- attempt_count >= max_attempts → FAILED(E-RUN-002)
|
||
|
||
Returns: 處理的 stale run 數
|
||
"""
|
||
now = _utc_now_naive()
|
||
reaped = 0
|
||
|
||
async with get_db_context(project_id) as db:
|
||
# 找所有 stale RUNNING runs
|
||
result = await db.execute(
|
||
select(AwoooPRunState).where(
|
||
AwoooPRunState.project_id == project_id,
|
||
AwoooPRunState.state == "running",
|
||
AwoooPRunState.lease_until < now,
|
||
)
|
||
)
|
||
stale_runs = list(result.scalars().all())
|
||
|
||
for run in stale_runs:
|
||
if run.attempt_count < run.max_attempts:
|
||
# Retry:重設為 PENDING
|
||
await db.execute(
|
||
update(AwoooPRunState)
|
||
.where(AwoooPRunState.run_id == run.run_id)
|
||
.values(
|
||
state="pending",
|
||
lease_until=None,
|
||
worker_id=None,
|
||
heartbeat_at=None,
|
||
)
|
||
)
|
||
logger.warning(
|
||
"stale_run_requeued",
|
||
run_id=str(run.run_id),
|
||
attempt_count=run.attempt_count,
|
||
max_attempts=run.max_attempts,
|
||
)
|
||
else:
|
||
# 超過最大重試次數 → FAILED
|
||
await db.execute(
|
||
update(AwoooPRunState)
|
||
.where(AwoooPRunState.run_id == run.run_id)
|
||
.values(
|
||
state="failed",
|
||
error_code="E-RUN-002",
|
||
error_detail=f"max_attempts={run.max_attempts} 超過,stale run 已廢棄",
|
||
completed_at=now,
|
||
lease_until=None,
|
||
worker_id=None,
|
||
)
|
||
)
|
||
logger.error(
|
||
"stale_run_failed",
|
||
run_id=str(run.run_id),
|
||
attempt_count=run.attempt_count,
|
||
)
|
||
reaped += 1
|
||
|
||
if reaped:
|
||
logger.info("stale_run_reaper_done", project_id=project_id, reaped=reaped)
|
||
return reaped
|