497 lines
17 KiB
Python
497 lines
17 KiB
Python
"""
|
||
AI Rate Limiter - Gemini API 用量閥值控制
|
||
=========================================
|
||
|
||
防止 API 用量暴衝,超過閥值自動降級回 Ollama。
|
||
|
||
功能:
|
||
- 每分鐘請求限制 (RPM)
|
||
- 每日請求限制
|
||
- 每日 Token 限制
|
||
- 🔴 累積成本限制 ($5 USD) - 2026-03-29 ogt 新增
|
||
- 超限自動降級 + Telegram 告警
|
||
|
||
版本: v1.1
|
||
建立日期: 2026-03-26 21:00 (台北時區)
|
||
更新日期: 2026-03-29 22:45 (台北時區)
|
||
建立者: Claude Code
|
||
"""
|
||
|
||
import structlog
|
||
|
||
logger = structlog.get_logger(__name__)
|
||
|
||
|
||
# =============================================================================
|
||
# Configuration - 閥值設定
|
||
# =============================================================================
|
||
|
||
RATE_LIMITS = {
|
||
"gemini": {
|
||
"rpm": 10, # 每分鐘請求數
|
||
"daily_requests": 500, # 每日請求數
|
||
"daily_tokens": 100_000, # 每日 Token 數
|
||
},
|
||
"claude": {
|
||
"rpm": 5,
|
||
"daily_requests": 200,
|
||
"daily_tokens": 50_000,
|
||
},
|
||
# 2026-03-31 ogt: NVIDIA NIM 免費版無每日限制!
|
||
# 只保留 RPM 限制 (併發控制) + 極大的 daily 上限 (監控用)
|
||
# 2026-04-03 ogt: I3 — "nvidia" → "openclaw_nemo" 對齊 AIProviderEnum (Phase 24)
|
||
"openclaw_nemo": {
|
||
"rpm": 10, # 每分鐘請求數 (放寬到 10)
|
||
"daily_requests": 99999, # 🔴 免費版無限制!設大數避免誤觸
|
||
"daily_tokens": 9999999, # 免費版無限制
|
||
},
|
||
}
|
||
|
||
# =============================================================================
|
||
# 2026-03-29 ogt: 累積成本限制 (統帥要求)
|
||
# =============================================================================
|
||
|
||
COST_LIMITS = {
|
||
"gemini": {
|
||
"total_cost_usd": 5.0, # 🔴 總成本上限 $5 USD,超過自動停用
|
||
"alert_threshold_usd": 4.0, # 警告閾值 $4 USD
|
||
},
|
||
"claude": {
|
||
"total_cost_usd": 10.0,
|
||
"alert_threshold_usd": 8.0,
|
||
},
|
||
# 2026-03-29 ogt: ADR-036 Nemotron (免費 Tier,設定低限制作為監控)
|
||
# 2026-03-31 ogt: 修復 $0.00 >= $0.00 永遠 True 的 Bug,改用大數值表示無限制
|
||
# 2026-04-03 ogt: I3 — "nvidia" → "openclaw_nemo" 對齊 AIProviderEnum (Phase 24)
|
||
"openclaw_nemo": {
|
||
"total_cost_usd": 999999.0, # 免費 Tier 無成本限制
|
||
"alert_threshold_usd": 0.0, # 不發送成本告警
|
||
},
|
||
}
|
||
|
||
# Gemini 1.5 Flash 定價 (per token)
|
||
GEMINI_PRICING = {
|
||
"input_per_token": 0.000000075, # $0.075 / 1M tokens
|
||
"output_per_token": 0.0000003, # $0.30 / 1M tokens
|
||
}
|
||
|
||
# Redis Keys
|
||
REDIS_KEY_PREFIX = "ai_rate:"
|
||
RPM_KEY = f"{REDIS_KEY_PREFIX}rpm:{{provider}}"
|
||
DAILY_REQ_KEY = f"{REDIS_KEY_PREFIX}daily_req:{{provider}}:{{date}}"
|
||
DAILY_TOKEN_KEY = f"{REDIS_KEY_PREFIX}daily_token:{{provider}}:{{date}}"
|
||
# 2026-03-29 ogt: 累積成本 Key (不過期,手動重置)
|
||
TOTAL_COST_KEY = f"{REDIS_KEY_PREFIX}total_cost:{{provider}}"
|
||
COST_ALERT_SENT_KEY = f"{REDIS_KEY_PREFIX}cost_alert_sent:{{provider}}"
|
||
|
||
|
||
# =============================================================================
|
||
# Rate Limiter
|
||
# =============================================================================
|
||
|
||
|
||
class AIRateLimiter:
|
||
"""
|
||
AI API 用量限制器
|
||
|
||
使用 Redis 計數器追蹤用量,超限時返回降級建議。
|
||
|
||
Usage:
|
||
limiter = AIRateLimiter()
|
||
allowed, reason = await limiter.check_and_increment("gemini")
|
||
if not allowed:
|
||
# 降級到 Ollama
|
||
provider = "ollama"
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
self._redis = None
|
||
|
||
async def _get_redis(self):
|
||
"""Lazy load Redis"""
|
||
if self._redis is None:
|
||
from src.core.redis_client import get_redis
|
||
self._redis = get_redis()
|
||
return self._redis
|
||
|
||
def _get_today(self) -> str:
|
||
"""取得今日日期 (台北時區)"""
|
||
from src.utils.timezone import now_taipei
|
||
return now_taipei().strftime("%Y-%m-%d")
|
||
|
||
async def check_and_increment(
|
||
self,
|
||
provider: str,
|
||
tokens: int = 0,
|
||
) -> tuple[bool, str | None]:
|
||
"""
|
||
檢查並遞增計數器
|
||
|
||
Args:
|
||
provider: AI 提供者 (gemini, claude)
|
||
tokens: 本次使用的 token 數 (事後更新用)
|
||
|
||
Returns:
|
||
tuple[bool, str | None]: (是否允許, 拒絕原因)
|
||
"""
|
||
if provider not in RATE_LIMITS:
|
||
return True, None # 無限制的 provider (如 ollama)
|
||
|
||
limits = RATE_LIMITS[provider]
|
||
r = await self._get_redis()
|
||
today = self._get_today()
|
||
|
||
# 0. 🔴 2026-03-29 ogt: 檢查累積成本 (最高優先級)
|
||
if provider in COST_LIMITS:
|
||
cost_limit = COST_LIMITS[provider]["total_cost_usd"]
|
||
total_cost_key = TOTAL_COST_KEY.format(provider=provider)
|
||
current_cost = await r.get(total_cost_key)
|
||
current_cost = float(current_cost) if current_cost else 0.0
|
||
|
||
if current_cost >= cost_limit:
|
||
logger.error(
|
||
"ai_cost_limit_exceeded_blocking",
|
||
provider=provider,
|
||
current_cost=f"${current_cost:.4f}",
|
||
limit=f"${cost_limit:.2f}",
|
||
action="AUTO_SWITCH_TO_OLLAMA",
|
||
)
|
||
# 發送告警 (只發一次)
|
||
await self._send_cost_alert(provider, current_cost, cost_limit)
|
||
return False, f"🔴 成本超限! ${current_cost:.2f} >= ${cost_limit:.2f},已自動切換到 Ollama"
|
||
|
||
# 1. 檢查 RPM
|
||
rpm_key = RPM_KEY.format(provider=provider)
|
||
current_rpm = await r.get(rpm_key)
|
||
current_rpm = int(current_rpm) if current_rpm else 0
|
||
|
||
if current_rpm >= limits["rpm"]:
|
||
logger.warning(
|
||
"ai_rate_limit_rpm",
|
||
provider=provider,
|
||
current=current_rpm,
|
||
limit=limits["rpm"],
|
||
)
|
||
return False, f"RPM limit exceeded ({current_rpm}/{limits['rpm']})"
|
||
|
||
# 2. 檢查每日請求數
|
||
daily_req_key = DAILY_REQ_KEY.format(provider=provider, date=today)
|
||
current_daily = await r.get(daily_req_key)
|
||
current_daily = int(current_daily) if current_daily else 0
|
||
|
||
if current_daily >= limits["daily_requests"]:
|
||
logger.warning(
|
||
"ai_rate_limit_daily",
|
||
provider=provider,
|
||
current=current_daily,
|
||
limit=limits["daily_requests"],
|
||
)
|
||
return False, f"Daily request limit exceeded ({current_daily}/{limits['daily_requests']})"
|
||
|
||
# 3. 檢查每日 Token (如果有追蹤)
|
||
daily_token_key = DAILY_TOKEN_KEY.format(provider=provider, date=today)
|
||
current_tokens = await r.get(daily_token_key)
|
||
current_tokens = int(current_tokens) if current_tokens else 0
|
||
|
||
if current_tokens >= limits["daily_tokens"]:
|
||
logger.warning(
|
||
"ai_rate_limit_tokens",
|
||
provider=provider,
|
||
current=current_tokens,
|
||
limit=limits["daily_tokens"],
|
||
)
|
||
return False, f"Daily token limit exceeded ({current_tokens}/{limits['daily_tokens']})"
|
||
|
||
# 4. 遞增計數器
|
||
pipe = r.pipeline()
|
||
|
||
# RPM: 60 秒過期
|
||
pipe.incr(rpm_key)
|
||
pipe.expire(rpm_key, 60)
|
||
|
||
# Daily requests: 明天過期
|
||
pipe.incr(daily_req_key)
|
||
pipe.expire(daily_req_key, 86400)
|
||
|
||
# Daily tokens
|
||
if tokens > 0:
|
||
pipe.incrby(daily_token_key, tokens)
|
||
pipe.expire(daily_token_key, 86400)
|
||
|
||
await pipe.execute()
|
||
|
||
logger.debug(
|
||
"ai_rate_check_passed",
|
||
provider=provider,
|
||
rpm=current_rpm + 1,
|
||
daily=current_daily + 1,
|
||
)
|
||
|
||
return True, None
|
||
|
||
async def record_cost(self, provider: str, cost_usd: float) -> None:
|
||
"""
|
||
2026-03-29 ogt: 記錄累積成本
|
||
|
||
Args:
|
||
provider: AI 提供者
|
||
cost_usd: 本次成本 (USD)
|
||
"""
|
||
if provider not in COST_LIMITS or cost_usd <= 0:
|
||
return
|
||
|
||
r = await self._get_redis()
|
||
total_cost_key = TOTAL_COST_KEY.format(provider=provider)
|
||
|
||
# 使用 INCRBYFLOAT 原子操作
|
||
new_total = await r.incrbyfloat(total_cost_key, cost_usd)
|
||
|
||
logger.info(
|
||
"ai_cost_recorded",
|
||
provider=provider,
|
||
cost_usd=f"${cost_usd:.6f}",
|
||
total_cost=f"${new_total:.4f}",
|
||
)
|
||
|
||
# 檢查是否需要發送警告 (接近上限)
|
||
alert_threshold = COST_LIMITS[provider]["alert_threshold_usd"]
|
||
if new_total >= alert_threshold:
|
||
await self._send_cost_warning(provider, new_total, alert_threshold)
|
||
|
||
async def _send_cost_alert(self, provider: str, current_cost: float, limit: float) -> None:
|
||
"""
|
||
2026-03-29 ogt: 發送成本超限告警到 Telegram (只發一次)
|
||
"""
|
||
r = await self._get_redis()
|
||
alert_sent_key = COST_ALERT_SENT_KEY.format(provider=provider)
|
||
|
||
# 檢查是否已發送
|
||
if await r.get(alert_sent_key):
|
||
return
|
||
|
||
# 標記已發送 (24小時後可重新發送)
|
||
await r.set(alert_sent_key, "1", ex=86400)
|
||
|
||
try:
|
||
from src.core.config import settings
|
||
from src.services.telegram_gateway import get_telegram_gateway
|
||
|
||
target_chat_id = settings.SRE_GROUP_CHAT_ID or settings.OPENCLAW_TG_CHAT_ID
|
||
if not settings.OPENCLAW_TG_BOT_TOKEN or not target_chat_id:
|
||
logger.warning("telegram_not_configured_for_cost_alert")
|
||
return
|
||
|
||
message = (
|
||
f"🚨🚨🚨 <b>AI 成本超限警報</b> 🚨🚨🚨\n\n"
|
||
f"Provider: <code>{provider.upper()}</code>\n"
|
||
f"累積成本: <b>${current_cost:.2f}</b>\n"
|
||
f"上限: <b>${limit:.2f}</b>\n\n"
|
||
f"⚡ <b>已自動切換到 Ollama</b>\n\n"
|
||
f"如需恢復 {provider.upper()},請執行:\n"
|
||
f"<code>redis-cli DEL ai_rate:total_cost:{provider}</code>"
|
||
)
|
||
|
||
gateway = get_telegram_gateway()
|
||
await gateway._send_request(
|
||
"sendMessage",
|
||
{
|
||
"chat_id": target_chat_id,
|
||
"text": message,
|
||
"parse_mode": "HTML",
|
||
},
|
||
)
|
||
|
||
logger.error(
|
||
"ai_cost_alert_sent",
|
||
provider=provider,
|
||
current_cost=f"${current_cost:.2f}",
|
||
limit=f"${limit:.2f}",
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error("ai_cost_alert_failed", error=str(e))
|
||
|
||
async def _send_cost_warning(self, provider: str, current_cost: float, threshold: float) -> None:
|
||
"""
|
||
2026-03-29 ogt: 發送成本接近上限警告
|
||
"""
|
||
r = await self._get_redis()
|
||
warning_key = f"{REDIS_KEY_PREFIX}cost_warning_sent:{provider}"
|
||
|
||
# 每小時只發一次警告
|
||
if await r.get(warning_key):
|
||
return
|
||
|
||
await r.set(warning_key, "1", ex=3600)
|
||
|
||
try:
|
||
from src.core.config import settings
|
||
from src.services.telegram_gateway import get_telegram_gateway
|
||
|
||
target_chat_id = settings.SRE_GROUP_CHAT_ID or settings.OPENCLAW_TG_CHAT_ID
|
||
if not settings.OPENCLAW_TG_BOT_TOKEN or not target_chat_id:
|
||
return
|
||
|
||
limit = COST_LIMITS[provider]["total_cost_usd"]
|
||
remaining = limit - current_cost
|
||
|
||
message = (
|
||
f"⚠️ <b>AI 成本警告</b>\n\n"
|
||
f"Provider: <code>{provider.upper()}</code>\n"
|
||
f"累積成本: <b>${current_cost:.2f}</b> / ${limit:.2f}\n"
|
||
f"剩餘額度: <b>${remaining:.2f}</b>\n\n"
|
||
f"接近上限,請注意監控!"
|
||
)
|
||
|
||
gateway = get_telegram_gateway()
|
||
await gateway._send_request(
|
||
"sendMessage",
|
||
{
|
||
"chat_id": target_chat_id,
|
||
"text": message,
|
||
"parse_mode": "HTML",
|
||
},
|
||
)
|
||
|
||
logger.warning(
|
||
"ai_cost_warning_sent",
|
||
provider=provider,
|
||
current_cost=f"${current_cost:.2f}",
|
||
threshold=f"${threshold:.2f}",
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.warning("ai_cost_warning_failed", error=str(e))
|
||
|
||
async def record_tokens(self, provider: str, tokens: int) -> None:
|
||
"""
|
||
記錄 Token 用量 (回應後呼叫)
|
||
|
||
Args:
|
||
provider: AI 提供者
|
||
tokens: 使用的 token 數
|
||
"""
|
||
if provider not in RATE_LIMITS or tokens <= 0:
|
||
return
|
||
|
||
r = await self._get_redis()
|
||
today = self._get_today()
|
||
daily_token_key = DAILY_TOKEN_KEY.format(provider=provider, date=today)
|
||
|
||
await r.incrby(daily_token_key, tokens)
|
||
await r.expire(daily_token_key, 86400)
|
||
|
||
logger.debug(
|
||
"ai_tokens_recorded",
|
||
provider=provider,
|
||
tokens=tokens,
|
||
)
|
||
|
||
async def get_usage_stats(self, provider: str) -> dict:
|
||
"""
|
||
取得用量統計 (含成本)
|
||
|
||
Args:
|
||
provider: AI 提供者
|
||
|
||
Returns:
|
||
dict: 用量統計
|
||
"""
|
||
if provider not in RATE_LIMITS:
|
||
return {"provider": provider, "limited": False}
|
||
|
||
limits = RATE_LIMITS[provider]
|
||
r = await self._get_redis()
|
||
today = self._get_today()
|
||
|
||
rpm_key = RPM_KEY.format(provider=provider)
|
||
daily_req_key = DAILY_REQ_KEY.format(provider=provider, date=today)
|
||
daily_token_key = DAILY_TOKEN_KEY.format(provider=provider, date=today)
|
||
total_cost_key = TOTAL_COST_KEY.format(provider=provider)
|
||
|
||
current_rpm = await r.get(rpm_key)
|
||
current_daily = await r.get(daily_req_key)
|
||
current_tokens = await r.get(daily_token_key)
|
||
current_cost = await r.get(total_cost_key)
|
||
|
||
# 2026-03-29 ogt: 加入成本資訊
|
||
cost_info = {}
|
||
if provider in COST_LIMITS:
|
||
cost_limit = COST_LIMITS[provider]
|
||
current_cost_float = float(current_cost) if current_cost else 0.0
|
||
cost_info = {
|
||
"total_cost_usd": {
|
||
"current": round(current_cost_float, 4),
|
||
"limit": cost_limit["total_cost_usd"],
|
||
"remaining": round(cost_limit["total_cost_usd"] - current_cost_float, 4),
|
||
"alert_threshold": cost_limit["alert_threshold_usd"],
|
||
},
|
||
"cost_exceeded": current_cost_float >= cost_limit["total_cost_usd"],
|
||
}
|
||
|
||
return {
|
||
"provider": provider,
|
||
"date": today,
|
||
"rpm": {
|
||
"current": int(current_rpm) if current_rpm else 0,
|
||
"limit": limits["rpm"],
|
||
},
|
||
"daily_requests": {
|
||
"current": int(current_daily) if current_daily else 0,
|
||
"limit": limits["daily_requests"],
|
||
},
|
||
"daily_tokens": {
|
||
"current": int(current_tokens) if current_tokens else 0,
|
||
"limit": limits["daily_tokens"],
|
||
},
|
||
**cost_info,
|
||
}
|
||
|
||
async def reset_cost(self, provider: str) -> None:
|
||
"""
|
||
2026-03-29 ogt: 重置累積成本 (統帥授權後使用)
|
||
|
||
Args:
|
||
provider: AI 提供者
|
||
"""
|
||
r = await self._get_redis()
|
||
total_cost_key = TOTAL_COST_KEY.format(provider=provider)
|
||
alert_sent_key = COST_ALERT_SENT_KEY.format(provider=provider)
|
||
|
||
await r.delete(total_cost_key, alert_sent_key)
|
||
logger.info("ai_cost_reset", provider=provider)
|
||
|
||
async def reset_limits(self, provider: str) -> None:
|
||
"""
|
||
重置限制 (緊急用)
|
||
|
||
Args:
|
||
provider: AI 提供者
|
||
"""
|
||
r = await self._get_redis()
|
||
today = self._get_today()
|
||
|
||
keys = [
|
||
RPM_KEY.format(provider=provider),
|
||
DAILY_REQ_KEY.format(provider=provider, date=today),
|
||
DAILY_TOKEN_KEY.format(provider=provider, date=today),
|
||
]
|
||
|
||
await r.delete(*keys)
|
||
logger.info("ai_rate_limits_reset", provider=provider)
|
||
|
||
|
||
# =============================================================================
|
||
# Singleton
|
||
# =============================================================================
|
||
|
||
_rate_limiter: AIRateLimiter | None = None
|
||
|
||
|
||
def get_ai_rate_limiter() -> AIRateLimiter:
|
||
"""取得 Rate Limiter 單例"""
|
||
global _rate_limiter
|
||
if _rate_limiter is None:
|
||
_rate_limiter = AIRateLimiter()
|
||
return _rate_limiter
|