""" LLM Output Schema Validator ============================= AwoooP Phase 3.3: LLM 輸出 → schema 驗證 → retry 機制(ADR-112) 2026-05-04 ogt + Claude Sonnet 4.6 設計原則: - LLM 輸出必須通過 Pydantic schema 驗證才能到達 channel adapter - 驗證失敗 → 自動 retry(最多 3 次,含 retry prompt) - 3 次全部失敗 → 拋出 SchemaValidationError(E-SCHEMA-001) - 支援六合約家族 + 自訂 Pydantic model 位置:介於 LLM response 和 channel adapter 之間 呼叫方:任何需要結構化 LLM 輸出的 service(playbook_generator, decision_manager 等) """ from __future__ import annotations import json import re from typing import Any, TypeVar import structlog from pydantic import BaseModel, ValidationError logger = structlog.get_logger(__name__) T = TypeVar("T", bound=BaseModel) _MAX_RETRIES = 3 _JSON_EXTRACT_RE = re.compile(r"```(?:json)?\s*(\{[\s\S]*?\})\s*```|(\{[\s\S]*\})", re.DOTALL) # ───────────────────────────────────────────────────────────────────────────── # 錯誤定義 # ───────────────────────────────────────────────────────────────────────────── class SchemaValidationError(Exception): """LLM 輸出連續 3 次 schema 驗證失敗""" error_code: str = "E-SCHEMA-001" def __init__(self, model_name: str, attempts: int, last_error: str) -> None: self.model_name = model_name self.attempts = attempts self.last_error = last_error super().__init__( f"[E-SCHEMA-001] LLM 輸出 {attempts} 次驗證失敗 " f"(model={model_name}): {last_error}" ) # ───────────────────────────────────────────────────────────────────────────── # JSON 萃取(容錯解析) # ───────────────────────────────────────────────────────────────────────────── def extract_json_from_llm_output(raw: str) -> dict[str, Any] | None: """ 從 LLM 原始輸出中萃取 JSON。 策略: 1. 直接 json.loads(最常見:LLM 直接回傳 JSON) 2. 從 ```json ... ``` 程式碼區塊萃取 3. 找第一個 { ... } 區塊嘗試解析 """ raw = raw.strip() # 策略 1:直接解析 try: obj = json.loads(raw) if isinstance(obj, dict): return obj except json.JSONDecodeError: pass # 策略 2 + 3:正則萃取 for match in _JSON_EXTRACT_RE.finditer(raw): candidate = match.group(1) or match.group(2) if candidate: try: obj = json.loads(candidate) if isinstance(obj, dict): return obj except json.JSONDecodeError: continue return None # ───────────────────────────────────────────────────────────────────────────── # Retry prompt builder # ───────────────────────────────────────────────────────────────────────────── def build_retry_prompt( original_prompt: str, failed_output: str, validation_error: str, model_name: str, attempt: int, ) -> str: """ 建立包含錯誤回饋的 retry prompt。 讓 LLM 知道上次輸出哪裡出錯,引導修正。 """ return ( f"{original_prompt}\n\n" f"---\n" f"[SCHEMA VALIDATION RETRY {attempt}/{_MAX_RETRIES}]\n" f"上次回應未通過結構驗證({model_name}),請修正以下問題後重新回應:\n\n" f"驗證錯誤:\n{validation_error}\n\n" f"上次回應(供參考):\n{failed_output[:500]}...\n" f"---\n\n" f"請只回傳符合格式的 JSON 物件,不要包含任何額外說明。" ) # ───────────────────────────────────────────────────────────────────────────── # Core validator # ───────────────────────────────────────────────────────────────────────────── async def validate_llm_output( *, raw_output: str, model_cls: type[T], llm_caller: Any, # Callable[[str], Awaitable[str]] — 供 retry 使用 original_prompt: str, context: dict[str, Any] | None = None, ) -> T: """ 驗證 LLM 輸出是否符合 Pydantic model。 Args: raw_output: LLM 第一次回傳的原始字串 model_cls: 目標 Pydantic model class llm_caller: async callable(prompt: str) -> str,用於 retry original_prompt: 原始 prompt(retry 時附加錯誤回饋) context: 額外 logging context Returns: 驗證成功的 model instance Raises: SchemaValidationError: 連續 3 次失敗後拋出 """ model_name = model_cls.__name__ ctx = context or {} current_output = raw_output last_error = "" for attempt in range(1, _MAX_RETRIES + 1): # 1. 萃取 JSON parsed = extract_json_from_llm_output(current_output) if parsed is None: last_error = "無法從 LLM 輸出中萃取 JSON 物件" logger.warning( "schema_validator_no_json", model_name=model_name, attempt=attempt, output_preview=current_output[:200], **ctx, ) else: # 2. Pydantic 驗證 try: instance = model_cls.model_validate(parsed) logger.info( "schema_validator_passed", model_name=model_name, attempt=attempt, **ctx, ) return instance except ValidationError as exc: last_error = exc.json(indent=None) logger.warning( "schema_validator_failed", model_name=model_name, attempt=attempt, error=last_error[:500], **ctx, ) # 3. Retry(如果不是最後一次) if attempt < _MAX_RETRIES: retry_prompt = build_retry_prompt( original_prompt=original_prompt, failed_output=current_output, validation_error=last_error, model_name=model_name, attempt=attempt, ) try: current_output = await llm_caller(retry_prompt) except Exception as exc: logger.warning( "schema_validator_llm_retry_failed", model_name=model_name, attempt=attempt, error=str(exc), **ctx, ) # LLM 呼叫本身失敗,保留上次 output,繼續嘗試(或直接結束) break # 3 次全失敗 logger.error( "schema_validator_exhausted", model_name=model_name, total_attempts=_MAX_RETRIES, last_error=last_error[:500], **ctx, ) raise SchemaValidationError(model_name, _MAX_RETRIES, last_error) # ───────────────────────────────────────────────────────────────────────────── # 便利方法:從 contract family 名稱驗證(不需知道具體 model class) # ───────────────────────────────────────────────────────────────────────────── async def validate_llm_output_by_family( *, raw_output: str, contract_family: str, llm_caller: Any, original_prompt: str, context: dict[str, Any] | None = None, ) -> BaseModel: """ 依 contract_family 自動選擇 model class 並驗證。 適合 generic pipeline 呼叫(不知道具體 model)。 """ from src.models.awooop_contracts import CONTRACT_FAMILY_MODELS, VALID_CONTRACT_FAMILIES model_cls = CONTRACT_FAMILY_MODELS.get(contract_family) if model_cls is None: raise ValueError( f"未知 contract_family: {contract_family!r}," f"合法值:{sorted(VALID_CONTRACT_FAMILIES)}" ) return await validate_llm_output( raw_output=raw_output, model_cls=model_cls, llm_caller=llm_caller, original_prompt=original_prompt, context=context, ) # ───────────────────────────────────────────────────────────────────────────── # 同步版本(非 LLM retry,只做一次驗證)— 供測試和非 LLM 路徑使用 # ───────────────────────────────────────────────────────────────────────────── def validate_once(raw: str | dict[str, Any], model_cls: type[T]) -> T: """ 單次驗證,不做 retry。 適合:已知格式正確的內部資料、測試 fixture 驗證。 """ if isinstance(raw, str): parsed = extract_json_from_llm_output(raw) if parsed is None: raise SchemaValidationError(model_cls.__name__, 1, "無法萃取 JSON") return model_cls.model_validate(parsed) return model_cls.model_validate(raw)