diff --git a/apps/api/src/services/channel_event_dossier_service.py b/apps/api/src/services/channel_event_dossier_service.py index 75fedcd7..0c101870 100644 --- a/apps/api/src/services/channel_event_dossier_service.py +++ b/apps/api/src/services/channel_event_dossier_service.py @@ -84,9 +84,21 @@ async def fetch_channel_event_dossier( effective_project_id = project_id or "awoooi" safe_limit = max(1, min(limit, _MAX_DOSSIER_EVENTS)) + where_clauses = ["project_id = :project_id"] + params: dict[str, Any] = { + "project_id": effective_project_id, + "limit": safe_limit, + } + if run_id is not None: + where_clauses.append("run_id = CAST(:run_id AS uuid)") + params["run_id"] = str(run_id) + if provider_event_id: + where_clauses.append("provider_event_id = :provider_event_id") + params["provider_event_id"] = provider_event_id + async with get_db_context(effective_project_id) as db: result = await db.execute( - text(""" + text(f""" SELECT event_id, project_id, @@ -101,18 +113,11 @@ async def fetch_channel_event_dossier( provider_ts, received_at FROM awooop_conversation_event - WHERE project_id = :project_id - AND (:run_id IS NULL OR run_id = :run_id) - AND (:provider_event_id IS NULL OR provider_event_id = :provider_event_id) + WHERE {" AND ".join(where_clauses)} ORDER BY received_at ASC LIMIT :limit """), - { - "project_id": effective_project_id, - "run_id": run_id, - "provider_event_id": provider_event_id, - "limit": safe_limit, - }, + params, ) rows = [dict(row) for row in result.mappings().all()] diff --git a/apps/api/tests/test_channel_event_dossier_service.py b/apps/api/tests/test_channel_event_dossier_service.py index e98ccad3..dee22bab 100644 --- a/apps/api/tests/test_channel_event_dossier_service.py +++ b/apps/api/tests/test_channel_event_dossier_service.py @@ -2,7 +2,9 @@ from __future__ import annotations import pytest from fastapi import HTTPException +from uuid import UUID +from src.services import channel_event_dossier_service from src.services.channel_event_dossier_service import ( build_dossier_event, fetch_channel_event_dossier, @@ -63,3 +65,52 @@ async def test_fetch_channel_event_dossier_requires_source() -> None: ) assert exc_info.value.status_code == 422 + + +@pytest.mark.asyncio +async def test_fetch_channel_event_dossier_uses_typed_run_filter(monkeypatch) -> None: + captured: dict[str, object] = {} + + class FakeMappings: + def all(self) -> list[dict[str, object]]: + return [] + + class FakeResult: + def mappings(self) -> FakeMappings: + return FakeMappings() + + class FakeDb: + async def execute(self, statement, params): # noqa: ANN001 + captured["sql"] = str(statement) + captured["params"] = params + return FakeResult() + + class FakeContext: + async def __aenter__(self) -> FakeDb: + return FakeDb() + + async def __aexit__(self, exc_type, exc, tb) -> None: # noqa: ANN001 + return None + + monkeypatch.setattr( + channel_event_dossier_service, + "get_db_context", + lambda _project_id: FakeContext(), + ) + + run_id = UUID("0a4c365f-609e-5441-bc29-4c7ebc3603b6") + result = await fetch_channel_event_dossier( + project_id="awoooi", + run_id=run_id, + provider_event_id=None, + limit=20, + ) + + assert result["total"] == 0 + assert "run_id = CAST(:run_id AS uuid)" in str(captured["sql"]) + assert ":run_id IS NULL" not in str(captured["sql"]) + assert captured["params"] == { + "project_id": "awoooi", + "run_id": str(run_id), + "limit": 20, + }