181 lines
5.6 KiB
Python
181 lines
5.6 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
from src.services import ai_router as ai_router_module
|
|
from src.services.ai_providers.interfaces import AIResult
|
|
from src.services.ai_router import AIProviderRegistry, AIRouterExecutor
|
|
|
|
|
|
class _FakeRedis:
|
|
def __init__(self, cached_provider: str) -> None:
|
|
self.cached_provider = cached_provider
|
|
self.set_calls: list[tuple[str, str, int | None]] = []
|
|
|
|
async def get(self, key: str) -> str:
|
|
return json.dumps({
|
|
"response": '{"provider":"stale"}',
|
|
"provider": self.cached_provider,
|
|
})
|
|
|
|
async def set(self, key: str, value: str, ex: int | None = None) -> None:
|
|
self.set_calls.append((key, value, ex))
|
|
|
|
|
|
class _FakeProvider:
|
|
name = "ollama_gcp_a"
|
|
privacy_level = "local"
|
|
is_enabled = True
|
|
capabilities = {"rca", "chat"}
|
|
|
|
def __init__(self) -> None:
|
|
self.calls = 0
|
|
|
|
async def analyze(self, prompt: str, context: dict[str, Any] | None = None) -> AIResult:
|
|
self.calls += 1
|
|
return AIResult(
|
|
raw_response='{"provider":"fresh_ollama"}',
|
|
success=True,
|
|
provider=self.name,
|
|
)
|
|
|
|
|
|
class _FailingLocalProvider:
|
|
privacy_level = "local"
|
|
is_enabled = True
|
|
capabilities = {"rca", "chat"}
|
|
|
|
def __init__(self, name: str) -> None:
|
|
self.name = name
|
|
self.calls = 0
|
|
|
|
async def analyze(self, prompt: str, context: dict[str, Any] | None = None) -> AIResult:
|
|
self.calls += 1
|
|
return AIResult(raw_response="", success=False, provider=self.name, error="forced failure")
|
|
|
|
|
|
class _CloudProvider:
|
|
name = "gemini"
|
|
privacy_level = "cloud"
|
|
is_enabled = True
|
|
capabilities = {"rca", "chat"}
|
|
|
|
def __init__(self) -> None:
|
|
self.calls = 0
|
|
|
|
async def analyze(self, prompt: str, context: dict[str, Any] | None = None) -> AIResult:
|
|
self.calls += 1
|
|
return AIResult(raw_response='{"provider":"gemini"}', success=True, provider=self.name)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_executor_skips_cached_cloud_provider_when_ollama_lane_is_required(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
fake_redis = _FakeRedis(cached_provider="gemini")
|
|
fake_provider = _FakeProvider()
|
|
registry = AIProviderRegistry()
|
|
registry.register(fake_provider)
|
|
|
|
monkeypatch.setattr(ai_router_module._settings, "MOCK_MODE", False)
|
|
monkeypatch.setattr("src.core.redis_client.get_redis", lambda: fake_redis)
|
|
|
|
result = await AIRouterExecutor(registry).execute(
|
|
prompt="diagnose alert",
|
|
provider_order=["ollama_gcp_a", "ollama_gcp_b", "ollama_local", "gemini"],
|
|
context={"intent_hint": "diagnose", "alert_type": "HostHighCpuLoad"},
|
|
)
|
|
|
|
assert result.provider == "ollama_gcp_a"
|
|
assert result.raw_response == '{"provider":"fresh_ollama"}'
|
|
assert fake_provider.calls == 1
|
|
assert fake_redis.set_calls
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_executor_allows_cached_ollama_provider_for_ollama_lane(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
fake_redis = _FakeRedis(cached_provider="ollama")
|
|
fake_provider = _FakeProvider()
|
|
registry = AIProviderRegistry()
|
|
registry.register(fake_provider)
|
|
|
|
monkeypatch.setattr(ai_router_module._settings, "MOCK_MODE", False)
|
|
monkeypatch.setattr("src.core.redis_client.get_redis", lambda: fake_redis)
|
|
|
|
result = await AIRouterExecutor(registry).execute(
|
|
prompt="diagnose alert",
|
|
provider_order=["ollama_gcp_a", "ollama_gcp_b", "ollama_local"],
|
|
context={"intent_hint": "diagnose", "alert_type": "HostHighCpuLoad"},
|
|
)
|
|
|
|
assert result.provider == "ollama"
|
|
assert result.from_cache is True
|
|
assert fake_provider.calls == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_alert_cloud_backup_waits_for_real_ollama_local_attempt(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
fake_redis = _FakeRedis(cached_provider="none")
|
|
gcp_a = _FailingLocalProvider("ollama_gcp_a")
|
|
gemini = _CloudProvider()
|
|
registry = AIProviderRegistry()
|
|
registry.register(gcp_a)
|
|
registry.register(gemini)
|
|
|
|
monkeypatch.setattr(ai_router_module._settings, "MOCK_MODE", False)
|
|
monkeypatch.setattr("src.core.redis_client.get_redis", lambda: fake_redis)
|
|
|
|
result = await AIRouterExecutor(registry).execute(
|
|
prompt="diagnose alert",
|
|
provider_order=["ollama_gcp_a", "gemini"],
|
|
context={
|
|
"intent_hint": "diagnose",
|
|
"alert_type": "HostHighCpuLoad",
|
|
"alert_requires_ollama_before_cloud": True,
|
|
},
|
|
)
|
|
|
|
assert result.success is False
|
|
assert gcp_a.calls == 1
|
|
assert gemini.calls == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_alert_cloud_backup_allowed_after_ollama_local_attempt(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
fake_redis = _FakeRedis(cached_provider="none")
|
|
gcp_a = _FailingLocalProvider("ollama_gcp_a")
|
|
local = _FailingLocalProvider("ollama_local")
|
|
gemini = _CloudProvider()
|
|
registry = AIProviderRegistry()
|
|
registry.register(gcp_a)
|
|
registry.register(local)
|
|
registry.register(gemini)
|
|
|
|
monkeypatch.setattr(ai_router_module._settings, "MOCK_MODE", False)
|
|
monkeypatch.setattr("src.core.redis_client.get_redis", lambda: fake_redis)
|
|
|
|
result = await AIRouterExecutor(registry).execute(
|
|
prompt="diagnose alert",
|
|
provider_order=["ollama_gcp_a", "ollama_local", "gemini"],
|
|
context={
|
|
"intent_hint": "diagnose",
|
|
"alert_type": "HostHighCpuLoad",
|
|
"alert_requires_ollama_before_cloud": True,
|
|
},
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.provider == "gemini"
|
|
assert gcp_a.calls == 1
|
|
assert local.calls == 1
|
|
assert gemini.calls == 1
|