""" Enterprise-Grade SSE (Server-Sent Events) Module ================================================= Production-ready SSE implementation with: - EventPublisher: Pub/Sub pattern for broadcasting events - Client disconnect detection via asyncio.CancelledError - Automatic resource cleanup on disconnect - Heartbeat mechanism to detect stale connections - Backpressure handling with bounded queues ADR-004: SSE 串流企業級實作模式 (Buffer + AbortController + Zustand) """ import asyncio import json import uuid from collections.abc import AsyncGenerator, Callable from dataclasses import dataclass, field from datetime import UTC, datetime from enum import Enum from typing import Any from src.core.logging import get_logger logger = get_logger("awoooi.sse") # ============================================================================= # Constants # ============================================================================= HEARTBEAT_INTERVAL = 15.0 # seconds CLIENT_QUEUE_SIZE = 100 # max queued events per client CLEANUP_INTERVAL = 30.0 # seconds between cleanup runs # ============================================================================= # Event Types # ============================================================================= class EventType(str, Enum): """ Standard SSE event types Phase 19: 新增 TERMINAL_* 事件類型 @see ADR-031 Omni-Terminal SSE Architecture """ # Standard events CONNECTED = "connected" HEARTBEAT = "heartbeat" HOST_UPDATE = "host_update" ALERT = "alert" APPROVAL = "approval" AI_THINKING = "ai_thinking" METRIC_UPDATE = "metric_update" DISCONNECTED = "disconnected" ERROR = "error" # Phase 19: Terminal events TERMINAL_THOUGHT = "terminal_thought" TERMINAL_TOOL_CALL = "terminal_tool_call" TERMINAL_RENDER_UI = "terminal_render_ui" TERMINAL_ACTION_REQUEST = "terminal_action_request" TERMINAL_ACTION_RESULT = "terminal_action_result" TERMINAL_COMPLETE = "terminal_complete" TERMINAL_ERROR = "terminal_error" TERMINAL_HEARTBEAT = "terminal_heartbeat" @dataclass class SSEEvent: """SSE Event structure""" type: EventType data: dict[str, Any] id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) timestamp: datetime = field(default_factory=lambda: datetime.now(UTC)) retry: int | None = None # Client retry interval in ms def to_sse_format(self) -> str: """Convert to SSE wire format""" lines = [] if self.id: lines.append(f"id: {self.id}") lines.append(f"event: {self.type.value}") # Add timestamp to data payload = { **self.data, "timestamp": self.timestamp.isoformat(), "event_id": self.id, } lines.append(f"data: {json.dumps(payload, ensure_ascii=False)}") if self.retry is not None: lines.append(f"retry: {self.retry}") return "\n".join(lines) + "\n\n" # ============================================================================= # Client Connection # ============================================================================= @dataclass class SSEClient: """ Individual SSE client connection Tracks: - Unique client ID - Event queue (bounded to prevent memory bloat) - Connection state - Last activity timestamp """ id: str = field(default_factory=lambda: str(uuid.uuid4())) queue: asyncio.Queue = field(default_factory=lambda: asyncio.Queue(maxsize=CLIENT_QUEUE_SIZE)) connected_at: datetime = field(default_factory=lambda: datetime.now(UTC)) last_activity: datetime = field(default_factory=lambda: datetime.now(UTC)) is_active: bool = True metadata: dict[str, Any] = field(default_factory=dict) def touch(self) -> None: """Update last activity timestamp""" self.last_activity = datetime.now(UTC) async def send(self, event: SSEEvent) -> bool: """ Send event to client queue Returns False if queue is full (backpressure) """ if not self.is_active: return False try: self.queue.put_nowait(event) self.touch() return True except asyncio.QueueFull: logger.warning( "sse_client_queue_full", client_id=self.id, queue_size=self.queue.qsize(), ) return False def disconnect(self) -> None: """Mark client as disconnected""" self.is_active = False # ============================================================================= # Event Publisher (Pub/Sub Pattern) # ============================================================================= class EventPublisher: """ Enterprise-grade SSE Event Publisher Features: - Pub/Sub pattern for event broadcasting - Automatic client disconnect detection - Resource cleanup on disconnect - Heartbeat mechanism - Topic-based subscriptions Usage: publisher = EventPublisher() # Subscribe a client client = await publisher.subscribe() # Publish events await publisher.publish(SSEEvent(type=EventType.ALERT, data={...})) # Client generator for streaming async for event in publisher.stream(client): yield event.to_sse_format() """ def __init__(self) -> None: self._clients: dict[str, SSEClient] = {} self._topics: dict[str, set[str]] = {} # topic -> client_ids self._lock = asyncio.Lock() self._heartbeat_task: asyncio.Task | None = None self._cleanup_task: asyncio.Task | None = None self._running = False self._on_disconnect_callbacks: list[Callable[[str], None]] = [] async def start(self) -> None: """Start background tasks""" if self._running: return self._running = True self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) self._cleanup_task = asyncio.create_task(self._cleanup_loop()) logger.info("sse_publisher_started") async def stop(self) -> None: """Stop background tasks and disconnect all clients""" self._running = False if self._heartbeat_task: self._heartbeat_task.cancel() try: await self._heartbeat_task except asyncio.CancelledError: pass if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass # Disconnect all clients async with self._lock: for client in self._clients.values(): client.disconnect() self._clients.clear() self._topics.clear() logger.info("sse_publisher_stopped") async def subscribe( self, topics: list[str] | None = None, metadata: dict[str, Any] | None = None, ) -> SSEClient: """ Subscribe a new client Args: topics: Optional list of topics to subscribe to metadata: Optional client metadata (user_id, etc.) Returns: SSEClient instance """ client = SSEClient(metadata=metadata or {}) async with self._lock: self._clients[client.id] = client # Subscribe to topics if topics: for topic in topics: if topic not in self._topics: self._topics[topic] = set() self._topics[topic].add(client.id) logger.info( "sse_client_connected", client_id=client.id, topics=topics, total_clients=len(self._clients), ) # Send connected event await client.send(SSEEvent( type=EventType.CONNECTED, data={ "client_id": client.id, "message": "SSE connection established", }, )) return client async def unsubscribe(self, client_id: str) -> None: """ Unsubscribe and cleanup a client Called automatically on disconnect or manually. """ async with self._lock: if client_id not in self._clients: return client = self._clients.pop(client_id) client.disconnect() # Remove from all topics for topic_clients in self._topics.values(): topic_clients.discard(client_id) # Call disconnect callbacks for callback in self._on_disconnect_callbacks: try: callback(client_id) except Exception as e: logger.error("sse_disconnect_callback_error", error=str(e)) logger.info( "sse_client_disconnected", client_id=client_id, total_clients=len(self._clients), ) def on_disconnect(self, callback: Callable[[str], None]) -> None: """Register a disconnect callback""" self._on_disconnect_callbacks.append(callback) async def publish( self, event: SSEEvent, topic: str | None = None, client_ids: list[str] | None = None, ) -> int: """ Publish event to clients Args: event: SSE event to publish topic: Optional topic to publish to client_ids: Optional specific client IDs Returns: Number of clients event was sent to """ sent_count = 0 async with self._lock: # Determine target clients if client_ids: target_ids = set(client_ids) & set(self._clients.keys()) elif topic and topic in self._topics: target_ids = self._topics[topic] else: target_ids = set(self._clients.keys()) # Send to all targets for client_id in target_ids: client = self._clients.get(client_id) if client and await client.send(event): sent_count += 1 if sent_count > 0: logger.debug( "sse_event_published", event_type=event.type.value, sent_count=sent_count, topic=topic, ) return sent_count async def stream(self, client: SSEClient) -> AsyncGenerator[str, None]: """ Stream events to a client This is the main generator for SSE responses. Handles: - Event delivery from queue - Client disconnect detection - Automatic cleanup Usage: async for data in publisher.stream(client): yield data """ try: while client.is_active: try: # Wait for event with timeout (allows disconnect detection) event = await asyncio.wait_for( client.queue.get(), timeout=HEARTBEAT_INTERVAL + 5, ) yield event.to_sse_format() except TimeoutError: # No event received, but connection might still be alive # Heartbeat will be sent by background task continue except asyncio.CancelledError: # Client disconnected (browser closed, network error, etc.) logger.info("sse_client_cancelled", client_id=client.id) raise except Exception as e: logger.error( "sse_stream_error", client_id=client.id, error=str(e), ) finally: # Cleanup: Always unsubscribe on exit await self.unsubscribe(client.id) async def _heartbeat_loop(self) -> None: """Background task: Send periodic heartbeats""" while self._running: try: await asyncio.sleep(HEARTBEAT_INTERVAL) heartbeat = SSEEvent( type=EventType.HEARTBEAT, data={"clients": len(self._clients)}, ) async with self._lock: for client in self._clients.values(): await client.send(heartbeat) except asyncio.CancelledError: break except Exception as e: logger.error("sse_heartbeat_error", error=str(e)) async def _cleanup_loop(self) -> None: """Background task: Cleanup stale connections""" while self._running: try: await asyncio.sleep(CLEANUP_INTERVAL) now = datetime.now(UTC) stale_threshold = HEARTBEAT_INTERVAL * 3 # 45 seconds async with self._lock: stale_clients = [ client_id for client_id, client in self._clients.items() if (now - client.last_activity).total_seconds() > stale_threshold and not client.is_active ] for client_id in stale_clients: await self.unsubscribe(client_id) logger.info("sse_stale_client_removed", client_id=client_id) except asyncio.CancelledError: break except Exception as e: logger.error("sse_cleanup_error", error=str(e)) @property def client_count(self) -> int: """Get current client count""" return len(self._clients) @property def is_running(self) -> bool: """Check if publisher is running""" return self._running # ============================================================================= # Global Publisher Instance # ============================================================================= # Singleton publisher for the application publisher = EventPublisher() async def get_publisher() -> EventPublisher: """ Get the global publisher instance Ensures publisher is started before returning. """ if not publisher.is_running: await publisher.start() return publisher