Files
ewoooc/tests/test_ollama_embedding.py
OoO 353e565e52
All checks were successful
CD Pipeline / deploy (push) Successful in 1m4s
V10.417 protect embedding fallback routing
2026-05-24 14:53:43 +08:00

89 lines
2.7 KiB
Python

from services.ollama_service import OllamaService
class FakeResponse:
def __init__(self, status_code, payload=None, text=""):
self.status_code = status_code
self._payload = payload or {}
self.text = text
def json(self):
return self._payload
def test_generate_embedding_uses_current_embed_endpoint(monkeypatch):
calls = []
def fake_post(url, json, timeout):
calls.append((url, json, timeout))
return FakeResponse(200, {"embeddings": [[0.1, 0.2, 0.3]]})
monkeypatch.setattr("services.ollama_service.requests.post", fake_post)
vec = OllamaService().generate_embedding("hello", model="bge-m3:latest", host="http://ollama", timeout=7)
assert vec == [0.1, 0.2, 0.3]
assert calls == [
(
"http://ollama/api/embed",
{"model": "bge-m3:latest", "input": "hello", "keep_alive": "1m"},
7,
),
]
def test_generate_embedding_falls_back_to_legacy_embeddings_endpoint(monkeypatch):
calls = []
def fake_post(url, json, timeout):
calls.append((url, json, timeout))
if url.endswith("/api/embed"):
return FakeResponse(404, text="not found")
return FakeResponse(200, {"embedding": [0.4, 0.5]})
monkeypatch.setattr("services.ollama_service.requests.post", fake_post)
vec = OllamaService().generate_embedding("hello", model="bge-m3:latest", host="http://ollama/", timeout=9)
assert vec == [0.4, 0.5]
assert calls == [
(
"http://ollama/api/embed",
{"model": "bge-m3:latest", "input": "hello", "keep_alive": "1m"},
9,
),
("http://ollama/api/embeddings", {"model": "bge-m3:latest", "prompt": "hello"}, 9),
]
def test_extract_embedding_accepts_flat_embeddings_shape():
assert OllamaService._extract_embedding({"embeddings": [0.1, 0.2]}) == [0.1, 0.2]
def test_generate_embedding_caps_timeout_and_clips_input(monkeypatch):
calls = []
def fake_post(url, json, timeout):
calls.append((url, json, timeout))
return FakeResponse(200, {"embeddings": [[0.1, 0.2, 0.3]]})
monkeypatch.setattr("services.ollama_service.EMBED_MAX_TIMEOUT", 3)
monkeypatch.setattr("services.ollama_service.EMBED_MAX_CHARS", 5)
monkeypatch.setattr("services.ollama_service.requests.post", fake_post)
vec = OllamaService().generate_embedding(
"hello world",
model="bge-m3:latest",
host="http://ollama",
timeout=45,
)
assert vec == [0.1, 0.2, 0.3]
assert calls == [
(
"http://ollama/api/embed",
{"model": "bge-m3:latest", "input": "hello", "keep_alive": "1m"},
3,
),
]