164 lines
5.8 KiB
Python
164 lines
5.8 KiB
Python
"""進階特徵工程:從資料庫抽取多維比賽特徵。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from math import radians, sin, cos, asin, sqrt
|
|
from typing import Iterable
|
|
|
|
from sqlalchemy import and_, desc, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from ..db.models import Match, Team
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class MatchFeatureVector:
|
|
rest_days_advantage: float
|
|
travel_distance_km: float
|
|
recent_5_xg_diff: float
|
|
elo_rating_diff: float
|
|
|
|
|
|
def _haversine_km(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
|
|
"""Haversine 地球大圓距離(公里)。"""
|
|
|
|
R = 6371.0
|
|
dlat = radians(lat2 - lat1)
|
|
dlon = radians(lon2 - lon1)
|
|
a = sin(dlat / 2) ** 2 + cos(radians(lat1)) * cos(radians(lat2)) * sin(dlon / 2) ** 2
|
|
return 2 * R * asin(min(1.0, sqrt(a)))
|
|
|
|
|
|
class MatchFeatureExtractor:
|
|
"""抽取並生成賽前特徵。"""
|
|
|
|
def __init__(
|
|
self,
|
|
session_factory,
|
|
*,
|
|
team_locations: dict[str, tuple[float, float]] | None = None,
|
|
) -> None:
|
|
self.session_factory = session_factory
|
|
# 可選:{team_id: (lat, lon)},若缺資料則 fallback 為 0 距離。
|
|
self.team_locations = team_locations or {}
|
|
|
|
async def _previous_match(self, session: AsyncSession, team_id: str, match_time: datetime) -> Match | None:
|
|
stmt = (
|
|
select(Match)
|
|
.where(
|
|
and_(
|
|
(Match.home_team_id == team_id) | (Match.away_team_id == team_id),
|
|
Match.match_time_utc < match_time,
|
|
Match.home_xg.is_not(None),
|
|
Match.away_xg.is_not(None),
|
|
),
|
|
)
|
|
.order_by(desc(Match.match_time_utc))
|
|
.limit(1)
|
|
)
|
|
result = await session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def _recent_xg_series(self, session: AsyncSession, team_id: str, as_of_match_id: str, count: int = 5) -> list[float]:
|
|
stmt = (
|
|
select(Match)
|
|
.where(
|
|
(Match.home_team_id == team_id) | (Match.away_team_id == team_id),
|
|
Match.home_xg.is_not(None),
|
|
Match.away_xg.is_not(None),
|
|
Match.id != as_of_match_id,
|
|
)
|
|
.order_by(desc(Match.match_time_utc))
|
|
.limit(count)
|
|
)
|
|
result = await session.execute(stmt)
|
|
rows = result.scalars().all()
|
|
out: list[float] = []
|
|
|
|
for row in rows:
|
|
home_xg = float(row.home_xg or 0.0)
|
|
away_xg = float(row.away_xg or 0.0)
|
|
out.append(home_xg)
|
|
out.append(away_xg)
|
|
|
|
return out[:count]
|
|
|
|
async def extract_features(self, match_id: str) -> MatchFeatureVector:
|
|
"""產生四個關鍵特徵。
|
|
|
|
1) rest_days_advantage
|
|
2) travel_distance_km
|
|
3) recent_5_xg_diff
|
|
4) elo_rating_diff
|
|
"""
|
|
|
|
async with self.session_factory() as session: # type: ignore[assignment]
|
|
current_match = await session.get(Match, match_id)
|
|
if current_match is None:
|
|
raise ValueError(f'找不到 match_id={match_id}')
|
|
|
|
home_team = await session.get(Team, current_match.home_team_id)
|
|
away_team = await session.get(Team, current_match.away_team_id)
|
|
if home_team is None or away_team is None:
|
|
raise ValueError('比賽球隊資料不完整')
|
|
|
|
home_prev = await self._previous_match(session, home_team.id, current_match.match_time_utc)
|
|
away_prev = await self._previous_match(session, away_team.id, current_match.match_time_utc)
|
|
|
|
rest_home = (
|
|
(current_match.match_time_utc - home_prev.match_time_utc).days
|
|
if home_prev is not None
|
|
else 0
|
|
)
|
|
rest_away = (
|
|
(current_match.match_time_utc - away_prev.match_time_utc).days
|
|
if away_prev is not None
|
|
else 0
|
|
)
|
|
|
|
travel_distance = self._distance_between_teams(home_team.id, away_team.id)
|
|
|
|
home_xg = await self._recent_xg_series(session, home_team.id, current_match.id)
|
|
away_xg = await self._recent_xg_series(session, away_team.id, current_match.id)
|
|
recent_diff = sum(home_xg[:5]) / max(len(home_xg[:5]) or 1, 1) - sum(away_xg[:5]) / max(
|
|
len(away_xg[:5]) or 1,
|
|
1,
|
|
)
|
|
|
|
home_elo = float(home_team.current_elo_rating or 1500)
|
|
away_elo = float(away_team.current_elo_rating or 1500)
|
|
|
|
return MatchFeatureVector(
|
|
rest_days_advantage=float(rest_home - rest_away),
|
|
travel_distance_km=float(travel_distance),
|
|
recent_5_xg_diff=float(recent_diff),
|
|
elo_rating_diff=float(home_elo - away_elo),
|
|
)
|
|
|
|
def _distance_between_teams(self, home_team_id: str, away_team_id: str) -> float:
|
|
home_loc = self.team_locations.get(home_team_id)
|
|
away_loc = self.team_locations.get(away_team_id)
|
|
|
|
if home_loc is None or away_loc is None:
|
|
return 0.0
|
|
|
|
return float(_haversine_km(home_loc[0], home_loc[1], away_loc[0], away_loc[1]))
|
|
|
|
@staticmethod
|
|
def to_model_payload(features: MatchFeatureVector, columns: Iterable[str] | None = None) -> dict:
|
|
"""輸出可直接餵進 XGBoost 的特徵字典。"""
|
|
|
|
payload = {
|
|
'rest_days_advantage': features.rest_days_advantage,
|
|
'travel_distance_km': features.travel_distance_km,
|
|
'recent_5_xg_diff': features.recent_5_xg_diff,
|
|
'elo_rating_diff': features.elo_rating_diff,
|
|
}
|
|
|
|
if columns is None:
|
|
return payload
|
|
cols = list(columns)
|
|
return {c: float(payload[c]) for c in cols if c in payload}
|