89 lines
2.7 KiB
Python
89 lines
2.7 KiB
Python
from services.ollama_service import OllamaService
|
|
|
|
|
|
class FakeResponse:
|
|
def __init__(self, status_code, payload=None, text=""):
|
|
self.status_code = status_code
|
|
self._payload = payload or {}
|
|
self.text = text
|
|
|
|
def json(self):
|
|
return self._payload
|
|
|
|
|
|
def test_generate_embedding_uses_current_embed_endpoint(monkeypatch):
|
|
calls = []
|
|
|
|
def fake_post(url, json, timeout):
|
|
calls.append((url, json, timeout))
|
|
return FakeResponse(200, {"embeddings": [[0.1, 0.2, 0.3]]})
|
|
|
|
monkeypatch.setattr("services.ollama_service.requests.post", fake_post)
|
|
|
|
vec = OllamaService().generate_embedding("hello", model="bge-m3:latest", host="http://ollama", timeout=7)
|
|
|
|
assert vec == [0.1, 0.2, 0.3]
|
|
assert calls == [
|
|
(
|
|
"http://ollama/api/embed",
|
|
{"model": "bge-m3:latest", "input": "hello", "keep_alive": "1m"},
|
|
7,
|
|
),
|
|
]
|
|
|
|
|
|
def test_generate_embedding_falls_back_to_legacy_embeddings_endpoint(monkeypatch):
|
|
calls = []
|
|
|
|
def fake_post(url, json, timeout):
|
|
calls.append((url, json, timeout))
|
|
if url.endswith("/api/embed"):
|
|
return FakeResponse(404, text="not found")
|
|
return FakeResponse(200, {"embedding": [0.4, 0.5]})
|
|
|
|
monkeypatch.setattr("services.ollama_service.requests.post", fake_post)
|
|
|
|
vec = OllamaService().generate_embedding("hello", model="bge-m3:latest", host="http://ollama/", timeout=9)
|
|
|
|
assert vec == [0.4, 0.5]
|
|
assert calls == [
|
|
(
|
|
"http://ollama/api/embed",
|
|
{"model": "bge-m3:latest", "input": "hello", "keep_alive": "1m"},
|
|
9,
|
|
),
|
|
("http://ollama/api/embeddings", {"model": "bge-m3:latest", "prompt": "hello"}, 9),
|
|
]
|
|
|
|
|
|
def test_extract_embedding_accepts_flat_embeddings_shape():
|
|
assert OllamaService._extract_embedding({"embeddings": [0.1, 0.2]}) == [0.1, 0.2]
|
|
|
|
|
|
def test_generate_embedding_caps_timeout_and_clips_input(monkeypatch):
|
|
calls = []
|
|
|
|
def fake_post(url, json, timeout):
|
|
calls.append((url, json, timeout))
|
|
return FakeResponse(200, {"embeddings": [[0.1, 0.2, 0.3]]})
|
|
|
|
monkeypatch.setattr("services.ollama_service.EMBED_MAX_TIMEOUT", 3)
|
|
monkeypatch.setattr("services.ollama_service.EMBED_MAX_CHARS", 5)
|
|
monkeypatch.setattr("services.ollama_service.requests.post", fake_post)
|
|
|
|
vec = OllamaService().generate_embedding(
|
|
"hello world",
|
|
model="bge-m3:latest",
|
|
host="http://ollama",
|
|
timeout=45,
|
|
)
|
|
|
|
assert vec == [0.1, 0.2, 0.3]
|
|
assert calls == [
|
|
(
|
|
"http://ollama/api/embed",
|
|
{"model": "bge-m3:latest", "input": "hello", "keep_alive": "1m"},
|
|
3,
|
|
),
|
|
]
|