Files
awoooi/apps/api/tests/test_ai_router_cache_provider_policy.py
Your Name 2aa31c205a
All checks were successful
CD Pipeline / tests (push) Successful in 54s
Code Review / ai-code-review (push) Successful in 10s
CD Pipeline / build-and-deploy (push) Successful in 3m21s
CD Pipeline / post-deploy-checks (push) Successful in 2m2s
fix(ai): require 111 before alert cloud fallback
2026-05-06 00:05:51 +08:00

181 lines
5.6 KiB
Python

from __future__ import annotations
import json
from typing import Any
import pytest
from src.services import ai_router as ai_router_module
from src.services.ai_providers.interfaces import AIResult
from src.services.ai_router import AIProviderRegistry, AIRouterExecutor
class _FakeRedis:
def __init__(self, cached_provider: str) -> None:
self.cached_provider = cached_provider
self.set_calls: list[tuple[str, str, int | None]] = []
async def get(self, key: str) -> str:
return json.dumps({
"response": '{"provider":"stale"}',
"provider": self.cached_provider,
})
async def set(self, key: str, value: str, ex: int | None = None) -> None:
self.set_calls.append((key, value, ex))
class _FakeProvider:
name = "ollama_gcp_a"
privacy_level = "local"
is_enabled = True
capabilities = {"rca", "chat"}
def __init__(self) -> None:
self.calls = 0
async def analyze(self, prompt: str, context: dict[str, Any] | None = None) -> AIResult:
self.calls += 1
return AIResult(
raw_response='{"provider":"fresh_ollama"}',
success=True,
provider=self.name,
)
class _FailingLocalProvider:
privacy_level = "local"
is_enabled = True
capabilities = {"rca", "chat"}
def __init__(self, name: str) -> None:
self.name = name
self.calls = 0
async def analyze(self, prompt: str, context: dict[str, Any] | None = None) -> AIResult:
self.calls += 1
return AIResult(raw_response="", success=False, provider=self.name, error="forced failure")
class _CloudProvider:
name = "gemini"
privacy_level = "cloud"
is_enabled = True
capabilities = {"rca", "chat"}
def __init__(self) -> None:
self.calls = 0
async def analyze(self, prompt: str, context: dict[str, Any] | None = None) -> AIResult:
self.calls += 1
return AIResult(raw_response='{"provider":"gemini"}', success=True, provider=self.name)
@pytest.mark.asyncio
async def test_executor_skips_cached_cloud_provider_when_ollama_lane_is_required(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_redis = _FakeRedis(cached_provider="gemini")
fake_provider = _FakeProvider()
registry = AIProviderRegistry()
registry.register(fake_provider)
monkeypatch.setattr(ai_router_module._settings, "MOCK_MODE", False)
monkeypatch.setattr("src.core.redis_client.get_redis", lambda: fake_redis)
result = await AIRouterExecutor(registry).execute(
prompt="diagnose alert",
provider_order=["ollama_gcp_a", "ollama_gcp_b", "ollama_local", "gemini"],
context={"intent_hint": "diagnose", "alert_type": "HostHighCpuLoad"},
)
assert result.provider == "ollama_gcp_a"
assert result.raw_response == '{"provider":"fresh_ollama"}'
assert fake_provider.calls == 1
assert fake_redis.set_calls
@pytest.mark.asyncio
async def test_executor_allows_cached_ollama_provider_for_ollama_lane(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_redis = _FakeRedis(cached_provider="ollama")
fake_provider = _FakeProvider()
registry = AIProviderRegistry()
registry.register(fake_provider)
monkeypatch.setattr(ai_router_module._settings, "MOCK_MODE", False)
monkeypatch.setattr("src.core.redis_client.get_redis", lambda: fake_redis)
result = await AIRouterExecutor(registry).execute(
prompt="diagnose alert",
provider_order=["ollama_gcp_a", "ollama_gcp_b", "ollama_local"],
context={"intent_hint": "diagnose", "alert_type": "HostHighCpuLoad"},
)
assert result.provider == "ollama"
assert result.from_cache is True
assert fake_provider.calls == 0
@pytest.mark.asyncio
async def test_alert_cloud_backup_waits_for_real_ollama_local_attempt(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_redis = _FakeRedis(cached_provider="none")
gcp_a = _FailingLocalProvider("ollama_gcp_a")
gemini = _CloudProvider()
registry = AIProviderRegistry()
registry.register(gcp_a)
registry.register(gemini)
monkeypatch.setattr(ai_router_module._settings, "MOCK_MODE", False)
monkeypatch.setattr("src.core.redis_client.get_redis", lambda: fake_redis)
result = await AIRouterExecutor(registry).execute(
prompt="diagnose alert",
provider_order=["ollama_gcp_a", "gemini"],
context={
"intent_hint": "diagnose",
"alert_type": "HostHighCpuLoad",
"alert_requires_ollama_before_cloud": True,
},
)
assert result.success is False
assert gcp_a.calls == 1
assert gemini.calls == 0
@pytest.mark.asyncio
async def test_alert_cloud_backup_allowed_after_ollama_local_attempt(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_redis = _FakeRedis(cached_provider="none")
gcp_a = _FailingLocalProvider("ollama_gcp_a")
local = _FailingLocalProvider("ollama_local")
gemini = _CloudProvider()
registry = AIProviderRegistry()
registry.register(gcp_a)
registry.register(local)
registry.register(gemini)
monkeypatch.setattr(ai_router_module._settings, "MOCK_MODE", False)
monkeypatch.setattr("src.core.redis_client.get_redis", lambda: fake_redis)
result = await AIRouterExecutor(registry).execute(
prompt="diagnose alert",
provider_order=["ollama_gcp_a", "ollama_local", "gemini"],
context={
"intent_hint": "diagnose",
"alert_type": "HostHighCpuLoad",
"alert_requires_ollama_before_cloud": True,
},
)
assert result.success is True
assert result.provider == "gemini"
assert gcp_a.calls == 1
assert local.calls == 1
assert gemini.calls == 1