Files
awoooi/apps/api/tests/test_ai_router_cache_provider_policy.py
Your Name ee5e3bc94f
Some checks failed
Code Review / ai-code-review (push) Successful in 27s
CD Pipeline / tests (push) Successful in 5m17s
CD Pipeline / build-and-deploy (push) Failing after 5m35s
CD Pipeline / post-deploy-checks (push) Has been skipped
fix(openclaw): gate alert cloud fallback behind flag
2026-05-05 20:54:47 +08:00

91 lines
2.9 KiB
Python

from __future__ import annotations
import json
from typing import Any
import pytest
from src.services import ai_router as ai_router_module
from src.services.ai_providers.interfaces import AIResult
from src.services.ai_router import AIProviderRegistry, AIRouterExecutor
class _FakeRedis:
def __init__(self, cached_provider: str) -> None:
self.cached_provider = cached_provider
self.set_calls: list[tuple[str, str, int | None]] = []
async def get(self, key: str) -> str:
return json.dumps({
"response": '{"provider":"stale"}',
"provider": self.cached_provider,
})
async def set(self, key: str, value: str, ex: int | None = None) -> None:
self.set_calls.append((key, value, ex))
class _FakeProvider:
name = "ollama_gcp_a"
privacy_level = "local"
is_enabled = True
capabilities = {"rca", "chat"}
def __init__(self) -> None:
self.calls = 0
async def analyze(self, prompt: str, context: dict[str, Any] | None = None) -> AIResult:
self.calls += 1
return AIResult(
raw_response='{"provider":"fresh_ollama"}',
success=True,
provider=self.name,
)
@pytest.mark.asyncio
async def test_executor_skips_cached_cloud_provider_when_ollama_lane_is_required(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_redis = _FakeRedis(cached_provider="gemini")
fake_provider = _FakeProvider()
registry = AIProviderRegistry()
registry.register(fake_provider)
monkeypatch.setattr(ai_router_module._settings, "MOCK_MODE", False)
monkeypatch.setattr("src.core.redis_client.get_redis", lambda: fake_redis)
result = await AIRouterExecutor(registry).execute(
prompt="diagnose alert",
provider_order=["ollama_gcp_a", "ollama_gcp_b", "ollama_local", "gemini"],
context={"intent_hint": "diagnose", "alert_type": "HostHighCpuLoad"},
)
assert result.provider == "ollama_gcp_a"
assert result.raw_response == '{"provider":"fresh_ollama"}'
assert fake_provider.calls == 1
assert fake_redis.set_calls
@pytest.mark.asyncio
async def test_executor_allows_cached_ollama_provider_for_ollama_lane(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_redis = _FakeRedis(cached_provider="ollama")
fake_provider = _FakeProvider()
registry = AIProviderRegistry()
registry.register(fake_provider)
monkeypatch.setattr(ai_router_module._settings, "MOCK_MODE", False)
monkeypatch.setattr("src.core.redis_client.get_redis", lambda: fake_redis)
result = await AIRouterExecutor(registry).execute(
prompt="diagnose alert",
provider_order=["ollama_gcp_a", "ollama_gcp_b", "ollama_local"],
context={"intent_hint": "diagnose", "alert_type": "HostHighCpuLoad"},
)
assert result.provider == "ollama"
assert result.from_cache is True
assert fake_provider.calls == 0