from __future__ import annotations import json from typing import Any import pytest from src.services import ai_router as ai_router_module from src.services.ai_providers.interfaces import AIResult from src.services.ai_router import AIProviderRegistry, AIRouterExecutor class _FakeRedis: def __init__(self, cached_provider: str) -> None: self.cached_provider = cached_provider self.set_calls: list[tuple[str, str, int | None]] = [] async def get(self, key: str) -> str: return json.dumps({ "response": '{"provider":"stale"}', "provider": self.cached_provider, }) async def set(self, key: str, value: str, ex: int | None = None) -> None: self.set_calls.append((key, value, ex)) class _FakeProvider: name = "ollama_gcp_a" privacy_level = "local" is_enabled = True capabilities = {"rca", "chat"} def __init__(self) -> None: self.calls = 0 async def analyze(self, prompt: str, context: dict[str, Any] | None = None) -> AIResult: self.calls += 1 return AIResult( raw_response='{"provider":"fresh_ollama"}', success=True, provider=self.name, ) class _FailingLocalProvider: privacy_level = "local" is_enabled = True capabilities = {"rca", "chat"} def __init__(self, name: str) -> None: self.name = name self.calls = 0 async def analyze(self, prompt: str, context: dict[str, Any] | None = None) -> AIResult: self.calls += 1 return AIResult(raw_response="", success=False, provider=self.name, error="forced failure") class _CloudProvider: name = "gemini" privacy_level = "cloud" is_enabled = True capabilities = {"rca", "chat"} def __init__(self) -> None: self.calls = 0 async def analyze(self, prompt: str, context: dict[str, Any] | None = None) -> AIResult: self.calls += 1 return AIResult(raw_response='{"provider":"gemini"}', success=True, provider=self.name) @pytest.mark.asyncio async def test_executor_skips_cached_cloud_provider_when_ollama_lane_is_required( monkeypatch: pytest.MonkeyPatch, ) -> None: fake_redis = _FakeRedis(cached_provider="gemini") fake_provider = _FakeProvider() registry = AIProviderRegistry() registry.register(fake_provider) monkeypatch.setattr(ai_router_module._settings, "MOCK_MODE", False) monkeypatch.setattr("src.core.redis_client.get_redis", lambda: fake_redis) result = await AIRouterExecutor(registry).execute( prompt="diagnose alert", provider_order=["ollama_gcp_a", "ollama_gcp_b", "ollama_local", "gemini"], context={"intent_hint": "diagnose", "alert_type": "HostHighCpuLoad"}, ) assert result.provider == "ollama_gcp_a" assert result.raw_response == '{"provider":"fresh_ollama"}' assert fake_provider.calls == 1 assert fake_redis.set_calls @pytest.mark.asyncio async def test_executor_allows_cached_ollama_provider_for_ollama_lane( monkeypatch: pytest.MonkeyPatch, ) -> None: fake_redis = _FakeRedis(cached_provider="ollama") fake_provider = _FakeProvider() registry = AIProviderRegistry() registry.register(fake_provider) monkeypatch.setattr(ai_router_module._settings, "MOCK_MODE", False) monkeypatch.setattr("src.core.redis_client.get_redis", lambda: fake_redis) result = await AIRouterExecutor(registry).execute( prompt="diagnose alert", provider_order=["ollama_gcp_a", "ollama_gcp_b", "ollama_local"], context={"intent_hint": "diagnose", "alert_type": "HostHighCpuLoad"}, ) assert result.provider == "ollama" assert result.from_cache is True assert fake_provider.calls == 0 @pytest.mark.asyncio async def test_alert_cloud_backup_waits_for_real_ollama_local_attempt( monkeypatch: pytest.MonkeyPatch, ) -> None: fake_redis = _FakeRedis(cached_provider="none") gcp_a = _FailingLocalProvider("ollama_gcp_a") gemini = _CloudProvider() registry = AIProviderRegistry() registry.register(gcp_a) registry.register(gemini) monkeypatch.setattr(ai_router_module._settings, "MOCK_MODE", False) monkeypatch.setattr("src.core.redis_client.get_redis", lambda: fake_redis) result = await AIRouterExecutor(registry).execute( prompt="diagnose alert", provider_order=["ollama_gcp_a", "gemini"], context={ "intent_hint": "diagnose", "alert_type": "HostHighCpuLoad", "alert_requires_ollama_before_cloud": True, }, ) assert result.success is False assert gcp_a.calls == 1 assert gemini.calls == 0 @pytest.mark.asyncio async def test_alert_cloud_backup_allowed_after_ollama_local_attempt( monkeypatch: pytest.MonkeyPatch, ) -> None: fake_redis = _FakeRedis(cached_provider="none") gcp_a = _FailingLocalProvider("ollama_gcp_a") local = _FailingLocalProvider("ollama_local") gemini = _CloudProvider() registry = AIProviderRegistry() registry.register(gcp_a) registry.register(local) registry.register(gemini) monkeypatch.setattr(ai_router_module._settings, "MOCK_MODE", False) monkeypatch.setattr("src.core.redis_client.get_redis", lambda: fake_redis) result = await AIRouterExecutor(registry).execute( prompt="diagnose alert", provider_order=["ollama_gcp_a", "ollama_local", "gemini"], context={ "intent_hint": "diagnose", "alert_type": "HostHighCpuLoad", "alert_requires_ollama_before_cloud": True, }, ) assert result.success is True assert result.provider == "gemini" assert gcp_a.calls == 1 assert local.calls == 1 assert gemini.calls == 1