165 lines
5.5 KiB
Python
165 lines
5.5 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
from src.services.ai_providers import ollama as ollama_module
|
|
from src.services.ai_providers.ollama import OllamaGcpBProvider, OllamaProvider
|
|
|
|
|
|
class _FakeRegistry:
|
|
def get_model(self, provider: str, use_case: str) -> str:
|
|
return "qwen3:14b"
|
|
|
|
def get_provider_options(self, provider: str) -> dict[str, Any]:
|
|
return {"num_predict": 32, "temperature": 0.1, "top_p": 0.9}
|
|
|
|
|
|
class _FakeResponse:
|
|
status_code = 200
|
|
|
|
def raise_for_status(self) -> None:
|
|
return None
|
|
|
|
def json(self) -> dict[str, Any]:
|
|
return {
|
|
"response": '{"summary":"ok"}',
|
|
"eval_count": 4,
|
|
"prompt_eval_count": 3,
|
|
}
|
|
|
|
|
|
class _FakeClient:
|
|
def __init__(self) -> None:
|
|
self.posted_urls: list[str] = []
|
|
self.posted_payloads: list[dict[str, Any]] = []
|
|
self.checked_urls: list[str] = []
|
|
|
|
async def post(self, url: str, **kwargs: Any) -> _FakeResponse:
|
|
self.posted_urls.append(url)
|
|
self.posted_payloads.append(kwargs.get("json", {}))
|
|
return _FakeResponse()
|
|
|
|
async def get(self, url: str, **kwargs: Any) -> _FakeResponse:
|
|
self.checked_urls.append(url)
|
|
return _FakeResponse()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ollama_gcp_b_analyze_uses_secondary_url(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(ollama_module, "get_model_registry", lambda: _FakeRegistry())
|
|
monkeypatch.setattr(ollama_module.settings, "OLLAMA_URL", "http://primary:11435")
|
|
monkeypatch.setattr(
|
|
ollama_module.settings,
|
|
"OLLAMA_SECONDARY_URL",
|
|
"http://secondary:11436",
|
|
)
|
|
monkeypatch.setattr(ollama_module.settings, "ALERT_OLLAMA_MODEL", "qwen3:14b")
|
|
|
|
client = _FakeClient()
|
|
provider = OllamaGcpBProvider()
|
|
|
|
async def _get_client() -> _FakeClient:
|
|
return client
|
|
|
|
monkeypatch.setattr(provider, "_get_client", _get_client)
|
|
|
|
result = await provider.analyze("diagnose", context={"task_type": "diagnose"})
|
|
|
|
assert result.success is True
|
|
assert result.provider == "ollama_gcp_b"
|
|
assert client.posted_urls == ["http://secondary:11436/api/generate"]
|
|
assert client.posted_payloads[0]["model"] == "qwen3:14b"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ollama_gcp_a_allows_heavy_diagnose_model(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
monkeypatch.setattr(ollama_module, "get_model_registry", lambda: _FakeRegistry())
|
|
monkeypatch.setattr(ollama_module.settings, "OLLAMA_URL", "http://primary:11435")
|
|
monkeypatch.setattr(ollama_module.settings, "OLLAMA_SECONDARY_URL", "http://secondary:11436")
|
|
monkeypatch.setattr(ollama_module.settings, "ALERT_OLLAMA_MODEL", "qwen3:14b")
|
|
|
|
client = _FakeClient()
|
|
provider = OllamaProvider()
|
|
|
|
async def _get_client() -> _FakeClient:
|
|
return client
|
|
|
|
monkeypatch.setattr(provider, "_get_client", _get_client)
|
|
|
|
result = await provider.analyze("diagnose", context={"task_type": "diagnose"})
|
|
|
|
assert result.success is True
|
|
assert client.posted_urls == ["http://primary:11435/api/generate"]
|
|
assert client.posted_payloads[0]["model"] == "qwen3:14b"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ollama_gcp_a_coerces_non_diagnosis_heavy_model_to_health_model(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
monkeypatch.setattr(ollama_module, "get_model_registry", lambda: _FakeRegistry())
|
|
monkeypatch.setattr(ollama_module.settings, "OLLAMA_URL", "http://primary:11435")
|
|
monkeypatch.setattr(ollama_module.settings, "OLLAMA_SECONDARY_URL", "http://secondary:11436")
|
|
monkeypatch.setattr(ollama_module.settings, "OLLAMA_HEALTH_CHECK_MODEL", "gemma3:4b")
|
|
|
|
client = _FakeClient()
|
|
provider = OllamaProvider()
|
|
|
|
async def _get_client() -> _FakeClient:
|
|
return client
|
|
|
|
monkeypatch.setattr(provider, "_get_client", _get_client)
|
|
|
|
result = await provider.analyze("background summary", context={"task_type": "summary"})
|
|
|
|
assert result.success is True
|
|
assert client.posted_urls == ["http://primary:11435/api/generate"]
|
|
assert client.posted_payloads[0]["model"] == "gemma3:4b"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ollama_gcp_a_can_explicitly_allow_heavy_model(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
monkeypatch.setattr(ollama_module, "get_model_registry", lambda: _FakeRegistry())
|
|
monkeypatch.setattr(ollama_module.settings, "OLLAMA_URL", "http://primary:11435")
|
|
monkeypatch.setattr(ollama_module.settings, "OLLAMA_SECONDARY_URL", "http://secondary:11436")
|
|
monkeypatch.setattr(ollama_module.settings, "ALERT_OLLAMA_MODEL", "qwen3:14b")
|
|
|
|
client = _FakeClient()
|
|
provider = OllamaProvider()
|
|
|
|
async def _get_client() -> _FakeClient:
|
|
return client
|
|
|
|
monkeypatch.setattr(provider, "_get_client", _get_client)
|
|
|
|
result = await provider.analyze(
|
|
"deep diagnose",
|
|
context={"task_type": "diagnose", "allow_gcp_heavy_model": True},
|
|
)
|
|
|
|
assert result.success is True
|
|
assert client.posted_payloads[0]["model"] == "qwen3:14b"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ollama_base_provider_health_uses_endpoint_hook(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
client = _FakeClient()
|
|
provider = OllamaProvider()
|
|
monkeypatch.setattr(provider, "_endpoint_url", lambda: "http://primary:11435")
|
|
|
|
async def _get_client() -> _FakeClient:
|
|
return client
|
|
|
|
monkeypatch.setattr(provider, "_get_client", _get_client)
|
|
|
|
assert await provider.health_check() is True
|
|
assert client.checked_urls == ["http://primary:11435/api/tags"]
|