174 lines
5.7 KiB
Python
174 lines
5.7 KiB
Python
import requests
|
|
|
|
|
|
class FakeResponse:
|
|
def __init__(self, status_code, payload):
|
|
self.status_code = status_code
|
|
self._payload = payload
|
|
|
|
def raise_for_status(self):
|
|
if self.status_code >= 400:
|
|
error = requests.HTTPError(f"{self.status_code} error")
|
|
error.response = self
|
|
raise error
|
|
|
|
def json(self):
|
|
return self._payload
|
|
|
|
|
|
def test_elephant_service_falls_back_when_primary_model_is_unavailable(monkeypatch):
|
|
from services import elephant_service as module
|
|
|
|
calls = []
|
|
|
|
def fake_post(_url, json, headers, timeout):
|
|
calls.append(json["model"])
|
|
if json["model"] == "nvidia/unavailable":
|
|
return FakeResponse(404, {"detail": "function not found"})
|
|
return FakeResponse(
|
|
200,
|
|
{
|
|
"choices": [{"message": {"content": "OK"}}],
|
|
"usage": {"prompt_tokens": 3, "completion_tokens": 2},
|
|
},
|
|
)
|
|
|
|
monkeypatch.setattr(module, "ELEPHANT_FALLBACK_MODELS", ["nvidia/available"])
|
|
monkeypatch.setattr(module.requests, "post", fake_post)
|
|
|
|
service = module.ElephantService(api_key="test-key", model="nvidia/unavailable")
|
|
result = service.generate("hello")
|
|
|
|
assert result.success is True
|
|
assert result.model == "nvidia/available"
|
|
assert result.content == "OK"
|
|
assert calls == ["nvidia/unavailable", "nvidia/available"]
|
|
|
|
|
|
def test_elephant_service_falls_back_when_primary_model_times_out(monkeypatch):
|
|
from services import elephant_service as module
|
|
|
|
calls = []
|
|
|
|
def fake_post(_url, json, headers, timeout):
|
|
calls.append(json["model"])
|
|
if json["model"] == "nvidia/slow":
|
|
raise requests.Timeout("read timed out")
|
|
return FakeResponse(
|
|
200,
|
|
{
|
|
"choices": [{"message": {"content": "Fallback OK"}}],
|
|
"usage": {"prompt_tokens": 4, "completion_tokens": 3},
|
|
},
|
|
)
|
|
|
|
monkeypatch.setattr(module, "ELEPHANT_FALLBACK_MODELS", ["nvidia/available"])
|
|
monkeypatch.setattr(module.requests, "post", fake_post)
|
|
|
|
service = module.ElephantService(api_key="test-key", model="nvidia/slow")
|
|
result = service.generate("hello", timeout=3)
|
|
|
|
assert result.success is True
|
|
assert result.model == "nvidia/available"
|
|
assert result.content == "Fallback OK"
|
|
assert calls == ["nvidia/slow", "nvidia/available"]
|
|
|
|
|
|
def test_elephant_service_falls_back_when_primary_model_connection_fails(monkeypatch):
|
|
from services import elephant_service as module
|
|
|
|
calls = []
|
|
|
|
def fake_post(_url, json, headers, timeout):
|
|
calls.append(json["model"])
|
|
if json["model"] == "nvidia/disconnected":
|
|
raise requests.ConnectionError("connection reset")
|
|
return FakeResponse(
|
|
200,
|
|
{
|
|
"choices": [{"message": {"content": "Connected fallback"}}],
|
|
"usage": {},
|
|
},
|
|
)
|
|
|
|
monkeypatch.setattr(module, "ELEPHANT_FALLBACK_MODELS", ["nvidia/available"])
|
|
monkeypatch.setattr(module.requests, "post", fake_post)
|
|
|
|
service = module.ElephantService(api_key="test-key", model="nvidia/disconnected")
|
|
result = service.generate("hello")
|
|
|
|
assert result.success is True
|
|
assert result.model == "nvidia/available"
|
|
assert result.content == "Connected fallback"
|
|
assert calls == ["nvidia/disconnected", "nvidia/available"]
|
|
|
|
|
|
def test_elephant_service_falls_back_on_transient_http_status(monkeypatch):
|
|
from services import elephant_service as module
|
|
|
|
calls = []
|
|
|
|
def fake_post(_url, json, headers, timeout):
|
|
calls.append(json["model"])
|
|
if json["model"] == "nvidia/overloaded":
|
|
return FakeResponse(503, {"detail": "temporarily unavailable"})
|
|
return FakeResponse(
|
|
200,
|
|
{
|
|
"choices": [{"message": {"content": "Recovered"}}],
|
|
"usage": {},
|
|
},
|
|
)
|
|
|
|
monkeypatch.setattr(module, "ELEPHANT_FALLBACK_MODELS", ["nvidia/available"])
|
|
monkeypatch.setattr(module.requests, "post", fake_post)
|
|
|
|
service = module.ElephantService(api_key="test-key", model="nvidia/overloaded")
|
|
result = service.generate("hello")
|
|
|
|
assert result.success is True
|
|
assert result.model == "nvidia/available"
|
|
assert result.content == "Recovered"
|
|
assert calls == ["nvidia/overloaded", "nvidia/available"]
|
|
|
|
|
|
def test_elephant_service_does_not_fallback_on_non_transient_client_error(monkeypatch):
|
|
from services import elephant_service as module
|
|
|
|
calls = []
|
|
|
|
def fake_post(_url, json, headers, timeout):
|
|
calls.append(json["model"])
|
|
return FakeResponse(400, {"detail": "bad request"})
|
|
|
|
monkeypatch.setattr(module, "ELEPHANT_FALLBACK_MODELS", ["nvidia/available"])
|
|
monkeypatch.setattr(module.requests, "post", fake_post)
|
|
|
|
service = module.ElephantService(api_key="test-key", model="nvidia/bad-request")
|
|
result = service.generate("hello")
|
|
|
|
assert result.success is False
|
|
assert result.model == "nvidia/bad-request"
|
|
assert calls == ["nvidia/bad-request"]
|
|
|
|
|
|
def test_elephant_service_uses_reasoning_content_when_content_is_empty(monkeypatch):
|
|
from services import elephant_service as module
|
|
|
|
def fake_post(_url, json, headers, timeout):
|
|
return FakeResponse(
|
|
200,
|
|
{
|
|
"choices": [{"message": {"content": None, "reasoning_content": "thinking"}}],
|
|
"usage": {},
|
|
},
|
|
)
|
|
|
|
monkeypatch.setattr(module.requests, "post", fake_post)
|
|
|
|
service = module.ElephantService(api_key="test-key", model="nvidia/available")
|
|
result = service.generate("hello")
|
|
|
|
assert result.success is True
|
|
assert result.content == "thinking"
|