Files
awoooi/apps/api/src/services/run_state_machine.py
Your Name bb1995f349
Some checks failed
CD Pipeline / tests (push) Failing after 1m48s
CD Pipeline / build-and-deploy (push) Has been skipped
CD Pipeline / post-deploy-checks (push) Has been skipped
Code Review / ai-code-review (push) Has been cancelled
fix(awooop): use naive utc for run lease timestamps
2026-05-05 13:53:07 +08:00

311 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Run State Machine
==================
AwoooP Phase 4: Run FSM 轉換規則 + Worker LeaseADR-114/ADR-119
2026-05-04 ogt + Claude Sonnet 4.6
狀態機:
PENDING → RUNNINGworker 取得 lease
RUNNING → WAITING_TOOL等待 tool call 完成)
RUNNING → WAITING_APPROVAL等待人工審核
RUNNING → COMPLETED / FAILED / CANCELLED
WAITING_TOOL → RUNNINGtool call 完成)
WAITING_TOOL → FAILEDtool call 失敗 + 超過 max_attempts
WAITING_APPROVAL → RUNNING核准
WAITING_APPROVAL → CANCELLED拒絕/超時)
* → TIMEOUTlease_until 過期且超過 max_attempts
SKIP LOCKED
Worker 以 SELECT ... FOR UPDATE SKIP LOCKED 取單,防 double-pickup。
Lease TTL = 60 秒Heartbeat 每 15 秒更新。
Stale run reaper
每分鐘掃描 lease_until < NOW() 的 running run
attempt_count < max_attempts → 重設 PENDING
attempt_count >= max_attempts → 標記 FAILED(E-RUN-002)
"""
from __future__ import annotations
import socket
import uuid
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING
import structlog
from sqlalchemy import select, text, update
from src.db.awooop_models import AwoooPRunState
from src.db.base import get_db_context
if TYPE_CHECKING:
from uuid import UUID
logger = structlog.get_logger(__name__)
# Worker lease TTL
LEASE_TTL_SECONDS = 60
HEARTBEAT_INTERVAL_SECONDS = 15
STALE_REAPER_INTERVAL_SECONDS = 60
# 有效的 FSM 轉換表
# key: from_state, value: set of valid to_states
_VALID_TRANSITIONS: dict[str, frozenset[str]] = {
"pending": frozenset({"running", "cancelled"}),
"running": frozenset({"waiting_tool", "waiting_approval", "completed", "failed", "cancelled", "timeout"}),
"waiting_tool": frozenset({"running", "failed", "cancelled"}),
"waiting_approval": frozenset({"running", "cancelled", "timeout"}),
"completed": frozenset(), # terminal
"failed": frozenset(), # terminal
"cancelled": frozenset(), # terminal
"timeout": frozenset(), # terminal
}
TERMINAL_STATES = frozenset({"completed", "failed", "cancelled", "timeout"})
_WORKER_ID = f"{socket.gethostname()}:{uuid.uuid4().hex[:8]}"
def _utc_now_naive() -> datetime:
"""Return UTC now matching AwoooP timestamp-without-timezone columns."""
return datetime.now(UTC).replace(tzinfo=None)
# ─────────────────────────────────────────────────────────────────────────────
# FSM 驗證
# ─────────────────────────────────────────────────────────────────────────────
class InvalidStateTransitionError(Exception):
def __init__(self, from_state: str, to_state: str) -> None:
self.from_state = from_state
self.to_state = to_state
super().__init__(f"非法 FSM 轉換: {from_state!r}{to_state!r}")
def validate_transition(from_state: str, to_state: str) -> None:
"""驗證 FSM 轉換是否合法,非法則拋出 InvalidStateTransitionError"""
valid_targets = _VALID_TRANSITIONS.get(from_state, frozenset())
if to_state not in valid_targets:
raise InvalidStateTransitionError(from_state, to_state)
# ─────────────────────────────────────────────────────────────────────────────
# Worker LeaseSKIP LOCKED
# ─────────────────────────────────────────────────────────────────────────────
async def acquire_pending_run(
project_id: str,
worker_id: str = _WORKER_ID,
) -> AwoooPRunState | None:
"""
以 SKIP LOCKED 取得一筆 PENDING run並設定 lease。
同時只有一個 worker 可取得同一筆 runPostgreSQL SKIP LOCKED 保證)。
Returns None 表示目前沒有待處理的 run。
"""
now = _utc_now_naive()
lease_until = now + timedelta(seconds=LEASE_TTL_SECONDS)
async with get_db_context(project_id) as db:
# SKIP LOCKED其他 worker 已鎖定的 row 直接跳過
result = await db.execute(
text("""
SELECT run_id FROM awooop_run_state
WHERE project_id = :project_id
AND state = 'pending'
AND (lease_until IS NULL OR lease_until < NOW())
ORDER BY created_at ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
"""),
{"project_id": project_id},
)
row = result.fetchone()
if row is None:
return None
run_id = row[0]
# 更新 lease + 轉為 RUNNING
await db.execute(
update(AwoooPRunState)
.where(
AwoooPRunState.run_id == run_id,
AwoooPRunState.project_id == project_id,
)
.values(
state="running",
lease_until=lease_until,
heartbeat_at=now,
worker_id=worker_id,
started_at=now,
attempt_count=AwoooPRunState.attempt_count + 1,
)
)
# 重新讀取完整 record
result2 = await db.execute(
select(AwoooPRunState).where(AwoooPRunState.run_id == run_id)
)
run = result2.scalar_one()
logger.info(
"run_lease_acquired",
run_id=str(run_id),
project_id=project_id,
worker_id=worker_id,
attempt_count=run.attempt_count,
)
return run
async def heartbeat(run_id: UUID, project_id: str) -> None:
"""更新 run 的 heartbeat + 延長 lease TTL"""
now = _utc_now_naive()
new_lease = now + timedelta(seconds=LEASE_TTL_SECONDS)
async with get_db_context(project_id) as db:
await db.execute(
update(AwoooPRunState)
.where(
AwoooPRunState.run_id == run_id,
AwoooPRunState.state == "running",
)
.values(
heartbeat_at=now,
lease_until=new_lease,
)
)
async def transition(
run_id: UUID,
project_id: str,
to_state: str,
*,
error_code: str | None = None,
error_detail: str | None = None,
output_sha256: str | None = None,
cost_usd_delta: float = 0.0,
step_count_delta: int = 0,
) -> None:
"""
執行 FSM 狀態轉換(含驗證)。
先從 DB 讀取 current state驗證轉換合法性再 UPDATE。
terminal state 同時寫入 completed_at。
"""
async with get_db_context(project_id) as db:
result = await db.execute(
select(AwoooPRunState.state).where(
AwoooPRunState.run_id == run_id,
AwoooPRunState.project_id == project_id,
)
)
row = result.fetchone()
if row is None:
raise ValueError(f"run {run_id} 不存在或無 RLS 權限")
from_state = row[0]
validate_transition(from_state, to_state)
values: dict = {"state": to_state}
if error_code:
values["error_code"] = error_code
if error_detail:
values["error_detail"] = error_detail
if output_sha256:
values["output_sha256"] = output_sha256
if cost_usd_delta:
values["cost_usd"] = AwoooPRunState.cost_usd + cost_usd_delta
if step_count_delta:
values["step_count"] = AwoooPRunState.step_count + step_count_delta
if to_state in TERMINAL_STATES:
values["completed_at"] = _utc_now_naive()
values["lease_until"] = None
values["worker_id"] = None
await db.execute(
update(AwoooPRunState)
.where(AwoooPRunState.run_id == run_id)
.values(**values)
)
logger.info(
"run_state_transition",
run_id=str(run_id),
from_state=from_state,
to_state=to_state,
error_code=error_code,
)
# ─────────────────────────────────────────────────────────────────────────────
# Stale Run Reaper
# ─────────────────────────────────────────────────────────────────────────────
async def reap_stale_runs(project_id: str) -> int:
"""
掃描 lease_until < NOW() 的 RUNNING run。
- attempt_count < max_attempts → 重設 PENDINGretry
- attempt_count >= max_attempts → FAILED(E-RUN-002)
Returns: 處理的 stale run 數
"""
now = _utc_now_naive()
reaped = 0
async with get_db_context(project_id) as db:
# 找所有 stale RUNNING runs
result = await db.execute(
select(AwoooPRunState).where(
AwoooPRunState.project_id == project_id,
AwoooPRunState.state == "running",
AwoooPRunState.lease_until < now,
)
)
stale_runs = list(result.scalars().all())
for run in stale_runs:
if run.attempt_count < run.max_attempts:
# Retry重設為 PENDING
await db.execute(
update(AwoooPRunState)
.where(AwoooPRunState.run_id == run.run_id)
.values(
state="pending",
lease_until=None,
worker_id=None,
heartbeat_at=None,
)
)
logger.warning(
"stale_run_requeued",
run_id=str(run.run_id),
attempt_count=run.attempt_count,
max_attempts=run.max_attempts,
)
else:
# 超過最大重試次數 → FAILED
await db.execute(
update(AwoooPRunState)
.where(AwoooPRunState.run_id == run.run_id)
.values(
state="failed",
error_code="E-RUN-002",
error_detail=f"max_attempts={run.max_attempts} 超過stale run 已廢棄",
completed_at=now,
lease_until=None,
worker_id=None,
)
)
logger.error(
"stale_run_failed",
run_id=str(run.run_id),
attempt_count=run.attempt_count,
)
reaped += 1
if reaped:
logger.info("stale_run_reaper_done", project_id=project_id, reaped=reaped)
return reaped