#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ tests/test_ai_call_logger.py ai_call_logger 單元測試 (Operation Ollama-First v5.0 — Phase 1) 測試紀律 (對應 phase1 spec): - context manager 正常路徑(status='ok') - context manager 例外路徑(status='error',例外仍 re-raise) - decorator 正常路徑 + auto token extract - DB 失敗時主流程不爆 - cost 計算正確(gemini-2.5-flash / 未知 model fallback / NIM 免費) - 環境開關 AI_CALL_LOGGING_ENABLED=false 時跳過寫入 - kill-switch 連續失敗 ≥ 10 次降級 - PII 保護:set_prompt_hash 只存前 12 碼 """ import os import sys import time import builtins import logging import pytest sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 隔離 import:避免被 ai_call_logger 內部 lazy import 的 database.manager 拖到 import services.ai_call_logger as logger_mod from services.ai_call_logger import ( COST_TABLE, _calc_cost, _CallState, _is_logging_enabled, _normalize_provider, _reset_kill_switch, log_ai_call, logged_ai_call, ) # ───────────────────────────────────────────────────────────────────────────── # Fixtures # ───────────────────────────────────────────────────────────────────────────── @pytest.fixture(autouse=True) def reset_state(monkeypatch): """每個測試前重置 kill-switch 並 stub 掉真實 DB 寫入。""" _reset_kill_switch() # stub _write_to_db:把寫入內容收集到 list(避免真連 DB) captured = [] def fake_write(state): captured.append({ 'caller': state.caller, 'provider': state.provider, 'model': state.model, 'input_tokens': state.input_tokens, 'output_tokens': state.output_tokens, 'duration_ms': state.duration_ms, 'status': state.status, 'fallback_to': state.fallback_to, 'cost_usd': _calc_cost(state.model, state.input_tokens, state.output_tokens), 'cache_hit': state.cache_hit, 'rag_hit': state.rag_hit, 'request_id': state.request_id, 'error': state.error, 'meta': dict(state.meta), }) monkeypatch.setattr(logger_mod, '_write_to_db', fake_write) monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true') # 把 captured 暴露給測試使用 yield captured def _wait_for_async(captured, n=1, timeout=2.0): """等待 daemon thread 寫完。""" deadline = time.time() + timeout while time.time() < deadline: if len(captured) >= n: return True time.sleep(0.01) return False # ───────────────────────────────────────────────────────────────────────────── # context manager 測試 # ───────────────────────────────────────────────────────────────────────────── def test_context_manager_happy_path(reset_state): captured = reset_state with log_ai_call('hermes_analyst', 'gcp_ollama', 'hermes3:latest') as ctx: ctx.set_tokens(input=120, output=80) ctx.set_cache_hit(False) assert _wait_for_async(captured, 1), "async write 未完成" assert len(captured) == 1 rec = captured[0] assert rec['caller'] == 'hermes_analyst' assert rec['provider'] == 'gcp_ollama' assert rec['model'] == 'hermes3:latest' assert rec['input_tokens'] == 120 assert rec['output_tokens'] == 80 assert rec['status'] == 'ok' assert rec['error'] is None assert rec['duration_ms'] is not None and rec['duration_ms'] >= 0 def test_context_manager_exception_path(reset_state): captured = reset_state with pytest.raises(ValueError, match="boom"): with log_ai_call('nemotron_dispatch', 'nim', 'meta/llama-3.1-8b-instruct'): raise ValueError("boom") assert _wait_for_async(captured, 1) rec = captured[0] assert rec['status'] == 'error' assert rec['error'] is not None assert 'ValueError' in rec['error'] assert 'boom' in rec['error'] def test_context_manager_logs_registry_import_failure(monkeypatch, caplog): """caller registry 匯入失敗時不阻擋 LLM 遙測,但要留下診斷 log。""" real_import = builtins.__import__ def _import_with_missing_registry(name, globals=None, locals=None, fromlist=(), level=0): if name == "services.llm_caller_registry": raise ImportError("registry unavailable") return real_import(name, globals, locals, fromlist, level) monkeypatch.setattr(builtins, "__import__", _import_with_missing_registry) caplog.set_level(logging.WARNING, logger="services.ai_call_logger") with log_ai_call('hermes_analyst', 'gcp_ollama', 'hermes3:latest') as ctx: ctx.set_tokens(input=10, output=5) assert "caller registry import failed" in caplog.text def test_context_manager_explicit_fallback(reset_state): captured = reset_state with log_ai_call('openclaw_qa', 'gemini', 'gemini-2.5-flash') as ctx: ctx.fallback_to_caller('openclaw_bot_nim') assert _wait_for_async(captured, 1) rec = captured[0] assert rec['status'] == 'fallback' assert rec['fallback_to'] == 'openclaw_bot_nim' def test_context_manager_set_error_without_raise(reset_state): """caller 主動 set_error 但不 raise(例如 LLM 回 success=false)""" captured = reset_state with log_ai_call('sales_copy', 'gcp_ollama', 'llama3.1:8b') as ctx: ctx.set_error('timeout after 30s') ctx.set_tokens(input=50, output=0) assert _wait_for_async(captured, 1) rec = captured[0] assert rec['status'] == 'error' assert 'timeout' in rec['error'] def test_context_manager_marks_rag_hit(reset_state): captured = reset_state with log_ai_call('openclaw_qa', 'gcp_ollama', 'qwen3:14b') as ctx: ctx.set_rag_hit(True) assert _wait_for_async(captured, 1) rec = captured[0] assert rec['status'] == 'ok' assert rec['rag_hit'] is True def test_context_manager_can_update_actual_provider_after_retry(reset_state): captured = reset_state with log_ai_call('code_review_openclaw', 'gcp_ollama', 'qwen3:14b') as ctx: ctx.set_provider('ollama_secondary') assert _wait_for_async(captured, 1) rec = captured[0] assert rec['provider'] == 'ollama_secondary' def test_provider_normalization_keeps_ai_calls_check_safe(reset_state): assert _normalize_provider('') == 'ollama_other' assert _normalize_provider('unknown') == 'ollama_other' assert _normalize_provider('ollama_other') == 'ollama_other' assert _normalize_provider('gemini-2.5-flash') == 'gemini' assert _normalize_provider('anthropic') == 'claude' assert _normalize_provider('nvidia/nemotron') == 'nim' captured = reset_state with log_ai_call('hermes_analyst', 'unknown', 'hermes3:latest'): pass assert _wait_for_async(captured, 1) assert captured[0]['provider'] == 'ollama_other' # ───────────────────────────────────────────────────────────────────────────── # decorator 測試 # ───────────────────────────────────────────────────────────────────────────── def test_decorator_happy_path(reset_state): captured = reset_state @logged_ai_call(caller='trend_match', provider='gcp_ollama', model='llama3.1:8b') def fake_call(prompt: str): return {'response': 'ok', 'eval_count': 42, 'prompt_eval_count': 100} out = fake_call("hello") assert out['response'] == 'ok' assert _wait_for_async(captured, 1) rec = captured[0] assert rec['caller'] == 'trend_match' assert rec['model'] == 'llama3.1:8b' assert rec['input_tokens'] == 100 assert rec['output_tokens'] == 42 assert rec['status'] == 'ok' def test_decorator_with_model_extractor(reset_state): captured = reset_state @logged_ai_call( caller='ppt_gemini', provider='gemini', model_extractor=lambda args, kw: kw.get('model', 'gemini-2.0-flash'), ) def fake_call(*, model: str, prompt: str): return {'usage': {'prompt_tokens': 200, 'completion_tokens': 50}} fake_call(model='gemini-2.5-flash', prompt='x') assert _wait_for_async(captured, 1) rec = captured[0] assert rec['model'] == 'gemini-2.5-flash' assert rec['input_tokens'] == 200 assert rec['output_tokens'] == 50 def test_decorator_exception_does_reraise(reset_state): captured = reset_state @logged_ai_call(caller='code_review_hermes', provider='gcp_ollama', model='hermes3:latest') def fake_call(): raise RuntimeError("net down") with pytest.raises(RuntimeError, match="net down"): fake_call() assert _wait_for_async(captured, 1) assert captured[0]['status'] == 'error' # ───────────────────────────────────────────────────────────────────────────── # DB 失敗不爆主流程 # ───────────────────────────────────────────────────────────────────────────── def test_db_failure_does_not_break_main_flow(monkeypatch, caplog): """驗證 _write_to_db 實際碰到 DB 失敗時,例外不會冒到主流程。 直接同步呼叫真實 _write_to_db(已含 try/except);不開 thread,避免噪音。 """ monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true') # 把 daemon thread 換成同步呼叫,讓我們直接觀察 _write_to_db 行為 class SyncThread: def __init__(self, target=None, args=(), kwargs=None, **_): self._target = target self._args = args self._kwargs = kwargs or {} def start(self): self._target(*self._args, **self._kwargs) monkeypatch.setattr(logger_mod.threading, 'Thread', SyncThread) # autouse fixture 已 stub _write_to_db;這裡覆寫成「真實會失敗的版本」 def real_write_that_fails(state): try: raise ImportError("simulated DB unavailable") except Exception as e: logger_mod._record_failure() logger_mod.logger.warning( "[AICallLogger] write failed (caller=%s provider=%s): %s", state.caller, state.provider, e, ) monkeypatch.setattr(logger_mod, '_write_to_db', real_write_that_fails) # 主流程不應 raise。 with caplog.at_level('WARNING'): with log_ai_call('hermes_intent', 'gcp_ollama', 'hermes3:latest') as ctx: ctx.set_tokens(input=10, output=5) # 至少有一條 [AICallLogger] write failed warning(caller 已 catch) assert any('write failed' in r.message for r in caplog.records), \ "預期 _write_to_db 失敗時 log warning" def test_async_dispatch_failure_swallowed(monkeypatch): """模擬 thread.start() 失敗(極端 case),主流程也不能爆。""" class BadThread: def __init__(self, *a, **kw): raise OSError("can't fork") monkeypatch.setattr(logger_mod.threading, 'Thread', BadThread) monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true') # 不應 raise with log_ai_call('x', 'y', 'z'): pass # ───────────────────────────────────────────────────────────────────────────── # cost 計算 # ───────────────────────────────────────────────────────────────────────────── def test_calc_cost_gemini_flash(): """gemini-2.5-flash 1M in + 100K out = $0.075 + $0.030 = $0.105""" cost = _calc_cost('gemini-2.5-flash', 1_000_000, 100_000) assert cost == pytest.approx(0.105, rel=1e-6) def test_calc_cost_claude_opus(): """claude-opus-4-7 1K in + 1K out = $0.015 + $0.075 = $0.090 / 1000 = $0.00009""" cost = _calc_cost('claude-opus-4-7', 1000, 1000) expected = (1000 * 15.0 + 1000 * 75.0) / 1_000_000 assert cost == pytest.approx(expected, rel=1e-6) def test_calc_cost_ollama_zero(): assert _calc_cost('hermes3:latest', 100_000, 100_000) == 0.0 assert _calc_cost('llama3.1:8b', 999_999, 999_999) == 0.0 def test_calc_cost_unknown_model_returns_zero(caplog): with caplog.at_level('WARNING'): cost = _calc_cost('totally-fake-model-xyz', 1_000_000, 1_000_000) assert cost == 0.0 assert any('unknown model cost' in r.message for r in caplog.records) def test_calc_cost_nim_prefix_silent_zero(caplog): """nvidia/* meta/* deepseek-* 不應觸發 unknown warning。""" with caplog.at_level('WARNING'): cost = _calc_cost('nvidia/some-future-model', 1_000_000, 1_000_000) assert cost == 0.0 assert not any('unknown model cost' in r.message for r in caplog.records) def test_calc_cost_negative_or_none_safe(): assert _calc_cost('gemini-2.5-flash', None, None) == 0.0 assert _calc_cost('', 100, 100) == 0.0 assert _calc_cost('gemini-2.5-flash', -1, -5) == 0.0 # ───────────────────────────────────────────────────────────────────────────── # 環境開關 # ───────────────────────────────────────────────────────────────────────────── def test_logging_disabled_skips_write(monkeypatch): captured = [] def fake_write(state): captured.append(state) monkeypatch.setattr(logger_mod, '_write_to_db', fake_write) monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'false') with log_ai_call('sales_copy', 'gcp_ollama', 'llama3.1:8b') as ctx: ctx.set_tokens(input=10, output=10) time.sleep(0.05) assert len(captured) == 0, "AI_CALL_LOGGING_ENABLED=false 時不應寫入" def test_logging_enabled_default_true(monkeypatch): monkeypatch.delenv('AI_CALL_LOGGING_ENABLED', raising=False) assert _is_logging_enabled() is True monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', '0') assert _is_logging_enabled() is False monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'OFF') assert _is_logging_enabled() is False monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true') assert _is_logging_enabled() is True # ───────────────────────────────────────────────────────────────────────────── # Kill-switch # ───────────────────────────────────────────────────────────────────────────── def test_kill_switch_after_consecutive_failures(monkeypatch, caplog): """連續失敗 >= 10 次後降級為 logger.info。""" _reset_kill_switch() # 真實 _write_to_db 會 catch 例外然後 _record_failure;這裡直接模擬 monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true') # 強制觸發 10 次失敗 for _ in range(10): logger_mod._record_failure() assert logger_mod._is_killed() is True # 之後再 _async_write 應該不會啟動新 thread(看是否走 logger.info 分支) captured_threads = [] class TrackingThread: def __init__(self, *a, **kw): captured_threads.append(kw.get('target')) def start(self): pass monkeypatch.setattr(logger_mod.threading, 'Thread', TrackingThread) with log_ai_call('x', 'y', 'z'): pass time.sleep(0.05) assert len(captured_threads) == 0, "kill-switch 啟動後不應再開新 thread" def test_record_success_resets_failure_counter(): _reset_kill_switch() for _ in range(5): logger_mod._record_failure() assert logger_mod._failure_state['count'] == 5 logger_mod._record_success() assert logger_mod._failure_state['count'] == 0 # ───────────────────────────────────────────────────────────────────────────── # PII 保護 # ───────────────────────────────────────────────────────────────────────────── def test_set_prompt_hash_truncates_to_12(): state = _CallState('a', 'b', 'c', None, {}) state.set_prompt_hash('Hello world some sensitive PII content here') assert 'prompt_hash' in state.meta assert len(state.meta['prompt_hash']) == 12 # 確認不是原文 assert 'Hello' not in state.meta['prompt_hash'] def test_meta_does_not_leak_raw_prompt_into_call_state(): """log_ai_call 介面不接受原始 prompt 欄位(只能透過 set_prompt_hash 進去)。""" with log_ai_call('x', 'y', 'z', meta={'temperature': 0.3}) as ctx: ctx.set_prompt_hash("super secret user prompt 123") assert 'prompt_hash' in ctx.meta assert ctx.meta['temperature'] == 0.3 # meta 中不應有 'prompt' key(除非 caller 自己加) assert 'prompt' not in ctx.meta # ───────────────────────────────────────────────────────────────────────────── # 雜項:cost table 鍵值完整性 # ───────────────────────────────────────────────────────────────────────────── def test_cost_table_contains_critical_models(): """phase0 audit 列舉的關鍵模型必須在表內。""" critical = [ 'gemini-2.5-flash', 'gemini-2.0-flash', 'meta/llama-3.1-8b-instruct', 'hermes3:latest', 'qwen2.5-coder:7b', 'llama3.1:8b', 'bge-m3:latest', ] for m in critical: assert m in COST_TABLE, f"COST_TABLE missing {m}"