"""進階特徵工程:從資料庫抽取多維比賽特徵。""" from __future__ import annotations from dataclasses import dataclass from datetime import datetime from math import radians, sin, cos, asin, sqrt from typing import Iterable from sqlalchemy import and_, desc, select from sqlalchemy.ext.asyncio import AsyncSession from ..db.models import Match, Team @dataclass(frozen=True) class MatchFeatureVector: rest_days_advantage: float travel_distance_km: float recent_5_xg_diff: float elo_rating_diff: float def _haversine_km(lat1: float, lon1: float, lat2: float, lon2: float) -> float: """Haversine 地球大圓距離(公里)。""" R = 6371.0 dlat = radians(lat2 - lat1) dlon = radians(lon2 - lon1) a = sin(dlat / 2) ** 2 + cos(radians(lat1)) * cos(radians(lat2)) * sin(dlon / 2) ** 2 return 2 * R * asin(min(1.0, sqrt(a))) class MatchFeatureExtractor: """抽取並生成賽前特徵。""" def __init__( self, session_factory, *, team_locations: dict[str, tuple[float, float]] | None = None, ) -> None: self.session_factory = session_factory # 可選:{team_id: (lat, lon)},若缺資料則 fallback 為 0 距離。 self.team_locations = team_locations or {} async def _previous_match(self, session: AsyncSession, team_id: str, match_time: datetime) -> Match | None: stmt = ( select(Match) .where( and_( (Match.home_team_id == team_id) | (Match.away_team_id == team_id), Match.match_time_utc < match_time, Match.home_xg.is_not(None), Match.away_xg.is_not(None), ), ) .order_by(desc(Match.match_time_utc)) .limit(1) ) result = await session.execute(stmt) return result.scalar_one_or_none() async def _recent_xg_series(self, session: AsyncSession, team_id: str, as_of_match_id: str, count: int = 5) -> list[float]: stmt = ( select(Match) .where( (Match.home_team_id == team_id) | (Match.away_team_id == team_id), Match.home_xg.is_not(None), Match.away_xg.is_not(None), Match.id != as_of_match_id, ) .order_by(desc(Match.match_time_utc)) .limit(count) ) result = await session.execute(stmt) rows = result.scalars().all() out: list[float] = [] for row in rows: home_xg = float(row.home_xg or 0.0) away_xg = float(row.away_xg or 0.0) out.append(home_xg) out.append(away_xg) return out[:count] async def extract_features(self, match_id: str) -> MatchFeatureVector: """產生四個關鍵特徵。 1) rest_days_advantage 2) travel_distance_km 3) recent_5_xg_diff 4) elo_rating_diff """ async with self.session_factory() as session: # type: ignore[assignment] current_match = await session.get(Match, match_id) if current_match is None: raise ValueError(f'找不到 match_id={match_id}') home_team = await session.get(Team, current_match.home_team_id) away_team = await session.get(Team, current_match.away_team_id) if home_team is None or away_team is None: raise ValueError('比賽球隊資料不完整') home_prev = await self._previous_match(session, home_team.id, current_match.match_time_utc) away_prev = await self._previous_match(session, away_team.id, current_match.match_time_utc) rest_home = ( (current_match.match_time_utc - home_prev.match_time_utc).days if home_prev is not None else 0 ) rest_away = ( (current_match.match_time_utc - away_prev.match_time_utc).days if away_prev is not None else 0 ) travel_distance = self._distance_between_teams(home_team.id, away_team.id) home_xg = await self._recent_xg_series(session, home_team.id, current_match.id) away_xg = await self._recent_xg_series(session, away_team.id, current_match.id) recent_diff = sum(home_xg[:5]) / max(len(home_xg[:5]) or 1, 1) - sum(away_xg[:5]) / max( len(away_xg[:5]) or 1, 1, ) home_elo = float(home_team.current_elo_rating or 1500) away_elo = float(away_team.current_elo_rating or 1500) return MatchFeatureVector( rest_days_advantage=float(rest_home - rest_away), travel_distance_km=float(travel_distance), recent_5_xg_diff=float(recent_diff), elo_rating_diff=float(home_elo - away_elo), ) def _distance_between_teams(self, home_team_id: str, away_team_id: str) -> float: home_loc = self.team_locations.get(home_team_id) away_loc = self.team_locations.get(away_team_id) if home_loc is None or away_loc is None: return 0.0 return float(_haversine_km(home_loc[0], home_loc[1], away_loc[0], away_loc[1])) @staticmethod def to_model_payload(features: MatchFeatureVector, columns: Iterable[str] | None = None) -> dict: """輸出可直接餵進 XGBoost 的特徵字典。""" payload = { 'rest_days_advantage': features.rest_days_advantage, 'travel_distance_km': features.travel_distance_km, 'recent_5_xg_diff': features.recent_5_xg_diff, 'elo_rating_diff': features.elo_rating_diff, } if columns is None: return payload cols = list(columns) return {c: float(payload[c]) for c in cols if c in payload}