Files
ewoooc/tests/test_ai_call_logger.py
OoO 12c8c7e94d
All checks were successful
CD Pipeline / deploy (push) Successful in 1m6s
V10.538 對齊 ai_calls ollama_other provider
2026-06-01 02:37:12 +08:00

485 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
tests/test_ai_call_logger.py
ai_call_logger 單元測試 (Operation Ollama-First v5.0 — Phase 1)
測試紀律 (對應 phase1 spec):
- context manager 正常路徑status='ok'
- context manager 例外路徑status='error',例外仍 re-raise
- decorator 正常路徑 + auto token extract
- DB 失敗時主流程不爆
- cost 計算正確gemini-2.5-flash / 未知 model fallback / NIM 免費)
- 環境開關 AI_CALL_LOGGING_ENABLED=false 時跳過寫入
- kill-switch 連續失敗 ≥ 10 次降級
- PII 保護set_prompt_hash 只存前 12 碼
"""
import os
import sys
import time
import builtins
import logging
import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 隔離 import避免被 ai_call_logger 內部 lazy import 的 database.manager 拖到
import services.ai_call_logger as logger_mod
from services.ai_call_logger import (
COST_TABLE,
_calc_cost,
_CallState,
_is_logging_enabled,
_normalize_provider,
_reset_kill_switch,
log_ai_call,
logged_ai_call,
)
# ─────────────────────────────────────────────────────────────────────────────
# Fixtures
# ─────────────────────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def reset_state(monkeypatch):
"""每個測試前重置 kill-switch 並 stub 掉真實 DB 寫入。"""
_reset_kill_switch()
# stub _write_to_db把寫入內容收集到 list避免真連 DB
captured = []
def fake_write(state):
captured.append({
'caller': state.caller,
'provider': state.provider,
'model': state.model,
'input_tokens': state.input_tokens,
'output_tokens': state.output_tokens,
'duration_ms': state.duration_ms,
'status': state.status,
'fallback_to': state.fallback_to,
'cost_usd': _calc_cost(state.model, state.input_tokens, state.output_tokens),
'cache_hit': state.cache_hit,
'rag_hit': state.rag_hit,
'request_id': state.request_id,
'error': state.error,
'meta': dict(state.meta),
})
monkeypatch.setattr(logger_mod, '_write_to_db', fake_write)
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true')
# 把 captured 暴露給測試使用
yield captured
def _wait_for_async(captured, n=1, timeout=2.0):
"""等待 daemon thread 寫完。"""
deadline = time.time() + timeout
while time.time() < deadline:
if len(captured) >= n:
return True
time.sleep(0.01)
return False
# ─────────────────────────────────────────────────────────────────────────────
# context manager 測試
# ─────────────────────────────────────────────────────────────────────────────
def test_context_manager_happy_path(reset_state):
captured = reset_state
with log_ai_call('hermes_analyst', 'gcp_ollama', 'hermes3:latest') as ctx:
ctx.set_tokens(input=120, output=80)
ctx.set_cache_hit(False)
assert _wait_for_async(captured, 1), "async write 未完成"
assert len(captured) == 1
rec = captured[0]
assert rec['caller'] == 'hermes_analyst'
assert rec['provider'] == 'gcp_ollama'
assert rec['model'] == 'hermes3:latest'
assert rec['input_tokens'] == 120
assert rec['output_tokens'] == 80
assert rec['status'] == 'ok'
assert rec['error'] is None
assert rec['duration_ms'] is not None and rec['duration_ms'] >= 0
def test_context_manager_exception_path(reset_state):
captured = reset_state
with pytest.raises(ValueError, match="boom"):
with log_ai_call('nemotron_dispatch', 'nim', 'meta/llama-3.1-8b-instruct'):
raise ValueError("boom")
assert _wait_for_async(captured, 1)
rec = captured[0]
assert rec['status'] == 'error'
assert rec['error'] is not None
assert 'ValueError' in rec['error']
assert 'boom' in rec['error']
def test_context_manager_logs_registry_import_failure(monkeypatch, caplog):
"""caller registry 匯入失敗時不阻擋 LLM 遙測,但要留下診斷 log。"""
real_import = builtins.__import__
def _import_with_missing_registry(name, globals=None, locals=None, fromlist=(), level=0):
if name == "services.llm_caller_registry":
raise ImportError("registry unavailable")
return real_import(name, globals, locals, fromlist, level)
monkeypatch.setattr(builtins, "__import__", _import_with_missing_registry)
caplog.set_level(logging.WARNING, logger="services.ai_call_logger")
with log_ai_call('hermes_analyst', 'gcp_ollama', 'hermes3:latest') as ctx:
ctx.set_tokens(input=10, output=5)
assert "caller registry import failed" in caplog.text
def test_context_manager_explicit_fallback(reset_state):
captured = reset_state
with log_ai_call('openclaw_qa', 'gemini', 'gemini-2.5-flash') as ctx:
ctx.fallback_to_caller('openclaw_bot_nim')
assert _wait_for_async(captured, 1)
rec = captured[0]
assert rec['status'] == 'fallback'
assert rec['fallback_to'] == 'openclaw_bot_nim'
def test_context_manager_set_error_without_raise(reset_state):
"""caller 主動 set_error 但不 raise例如 LLM 回 success=false"""
captured = reset_state
with log_ai_call('sales_copy', 'gcp_ollama', 'llama3.1:8b') as ctx:
ctx.set_error('timeout after 30s')
ctx.set_tokens(input=50, output=0)
assert _wait_for_async(captured, 1)
rec = captured[0]
assert rec['status'] == 'error'
assert 'timeout' in rec['error']
def test_context_manager_marks_rag_hit(reset_state):
captured = reset_state
with log_ai_call('openclaw_qa', 'gcp_ollama', 'qwen3:14b') as ctx:
ctx.set_rag_hit(True)
assert _wait_for_async(captured, 1)
rec = captured[0]
assert rec['status'] == 'ok'
assert rec['rag_hit'] is True
def test_context_manager_can_update_actual_provider_after_retry(reset_state):
captured = reset_state
with log_ai_call('code_review_openclaw', 'gcp_ollama', 'qwen3:14b') as ctx:
ctx.set_provider('ollama_secondary')
assert _wait_for_async(captured, 1)
rec = captured[0]
assert rec['provider'] == 'ollama_secondary'
def test_provider_normalization_keeps_ai_calls_check_safe(reset_state):
assert _normalize_provider('') == 'ollama_other'
assert _normalize_provider('unknown') == 'ollama_other'
assert _normalize_provider('ollama_other') == 'ollama_other'
assert _normalize_provider('gemini-2.5-flash') == 'gemini'
assert _normalize_provider('anthropic') == 'claude'
assert _normalize_provider('nvidia/nemotron') == 'nim'
captured = reset_state
with log_ai_call('hermes_analyst', 'unknown', 'hermes3:latest'):
pass
assert _wait_for_async(captured, 1)
assert captured[0]['provider'] == 'ollama_other'
# ─────────────────────────────────────────────────────────────────────────────
# decorator 測試
# ─────────────────────────────────────────────────────────────────────────────
def test_decorator_happy_path(reset_state):
captured = reset_state
@logged_ai_call(caller='trend_match', provider='gcp_ollama', model='llama3.1:8b')
def fake_call(prompt: str):
return {'response': 'ok', 'eval_count': 42, 'prompt_eval_count': 100}
out = fake_call("hello")
assert out['response'] == 'ok'
assert _wait_for_async(captured, 1)
rec = captured[0]
assert rec['caller'] == 'trend_match'
assert rec['model'] == 'llama3.1:8b'
assert rec['input_tokens'] == 100
assert rec['output_tokens'] == 42
assert rec['status'] == 'ok'
def test_decorator_with_model_extractor(reset_state):
captured = reset_state
@logged_ai_call(
caller='ppt_gemini',
provider='gemini',
model_extractor=lambda args, kw: kw.get('model', 'gemini-2.0-flash'),
)
def fake_call(*, model: str, prompt: str):
return {'usage': {'prompt_tokens': 200, 'completion_tokens': 50}}
fake_call(model='gemini-2.5-flash', prompt='x')
assert _wait_for_async(captured, 1)
rec = captured[0]
assert rec['model'] == 'gemini-2.5-flash'
assert rec['input_tokens'] == 200
assert rec['output_tokens'] == 50
def test_decorator_exception_does_reraise(reset_state):
captured = reset_state
@logged_ai_call(caller='code_review_hermes', provider='gcp_ollama', model='hermes3:latest')
def fake_call():
raise RuntimeError("net down")
with pytest.raises(RuntimeError, match="net down"):
fake_call()
assert _wait_for_async(captured, 1)
assert captured[0]['status'] == 'error'
# ─────────────────────────────────────────────────────────────────────────────
# DB 失敗不爆主流程
# ─────────────────────────────────────────────────────────────────────────────
def test_db_failure_does_not_break_main_flow(monkeypatch, caplog):
"""驗證 _write_to_db 實際碰到 DB 失敗時,例外不會冒到主流程。
直接同步呼叫真實 _write_to_db已含 try/except不開 thread避免噪音。
"""
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true')
# 把 daemon thread 換成同步呼叫,讓我們直接觀察 _write_to_db 行為
class SyncThread:
def __init__(self, target=None, args=(), kwargs=None, **_):
self._target = target
self._args = args
self._kwargs = kwargs or {}
def start(self):
self._target(*self._args, **self._kwargs)
monkeypatch.setattr(logger_mod.threading, 'Thread', SyncThread)
# autouse fixture 已 stub _write_to_db這裡覆寫成「真實會失敗的版本」
def real_write_that_fails(state):
try:
raise ImportError("simulated DB unavailable")
except Exception as e:
logger_mod._record_failure()
logger_mod.logger.warning(
"[AICallLogger] write failed (caller=%s provider=%s): %s",
state.caller, state.provider, e,
)
monkeypatch.setattr(logger_mod, '_write_to_db', real_write_that_fails)
# 主流程不應 raise。
with caplog.at_level('WARNING'):
with log_ai_call('hermes_intent', 'gcp_ollama', 'hermes3:latest') as ctx:
ctx.set_tokens(input=10, output=5)
# 至少有一條 [AICallLogger] write failed warningcaller 已 catch
assert any('write failed' in r.message for r in caplog.records), \
"預期 _write_to_db 失敗時 log warning"
def test_async_dispatch_failure_swallowed(monkeypatch):
"""模擬 thread.start() 失敗(極端 case主流程也不能爆。"""
class BadThread:
def __init__(self, *a, **kw):
raise OSError("can't fork")
monkeypatch.setattr(logger_mod.threading, 'Thread', BadThread)
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true')
# 不應 raise
with log_ai_call('x', 'y', 'z'):
pass
# ─────────────────────────────────────────────────────────────────────────────
# cost 計算
# ─────────────────────────────────────────────────────────────────────────────
def test_calc_cost_gemini_flash():
"""gemini-2.5-flash 1M in + 100K out = $0.075 + $0.030 = $0.105"""
cost = _calc_cost('gemini-2.5-flash', 1_000_000, 100_000)
assert cost == pytest.approx(0.105, rel=1e-6)
def test_calc_cost_claude_opus():
"""claude-opus-4-7 1K in + 1K out = $0.015 + $0.075 = $0.090 / 1000 = $0.00009"""
cost = _calc_cost('claude-opus-4-7', 1000, 1000)
expected = (1000 * 15.0 + 1000 * 75.0) / 1_000_000
assert cost == pytest.approx(expected, rel=1e-6)
def test_calc_cost_ollama_zero():
assert _calc_cost('hermes3:latest', 100_000, 100_000) == 0.0
assert _calc_cost('llama3.1:8b', 999_999, 999_999) == 0.0
def test_calc_cost_unknown_model_returns_zero(caplog):
with caplog.at_level('WARNING'):
cost = _calc_cost('totally-fake-model-xyz', 1_000_000, 1_000_000)
assert cost == 0.0
assert any('unknown model cost' in r.message for r in caplog.records)
def test_calc_cost_nim_prefix_silent_zero(caplog):
"""nvidia/* meta/* deepseek-* 不應觸發 unknown warning。"""
with caplog.at_level('WARNING'):
cost = _calc_cost('nvidia/some-future-model', 1_000_000, 1_000_000)
assert cost == 0.0
assert not any('unknown model cost' in r.message for r in caplog.records)
def test_calc_cost_negative_or_none_safe():
assert _calc_cost('gemini-2.5-flash', None, None) == 0.0
assert _calc_cost('', 100, 100) == 0.0
assert _calc_cost('gemini-2.5-flash', -1, -5) == 0.0
# ─────────────────────────────────────────────────────────────────────────────
# 環境開關
# ─────────────────────────────────────────────────────────────────────────────
def test_logging_disabled_skips_write(monkeypatch):
captured = []
def fake_write(state):
captured.append(state)
monkeypatch.setattr(logger_mod, '_write_to_db', fake_write)
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'false')
with log_ai_call('sales_copy', 'gcp_ollama', 'llama3.1:8b') as ctx:
ctx.set_tokens(input=10, output=10)
time.sleep(0.05)
assert len(captured) == 0, "AI_CALL_LOGGING_ENABLED=false 時不應寫入"
def test_logging_enabled_default_true(monkeypatch):
monkeypatch.delenv('AI_CALL_LOGGING_ENABLED', raising=False)
assert _is_logging_enabled() is True
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', '0')
assert _is_logging_enabled() is False
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'OFF')
assert _is_logging_enabled() is False
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true')
assert _is_logging_enabled() is True
# ─────────────────────────────────────────────────────────────────────────────
# Kill-switch
# ─────────────────────────────────────────────────────────────────────────────
def test_kill_switch_after_consecutive_failures(monkeypatch, caplog):
"""連續失敗 >= 10 次後降級為 logger.info。"""
_reset_kill_switch()
# 真實 _write_to_db 會 catch 例外然後 _record_failure這裡直接模擬
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true')
# 強制觸發 10 次失敗
for _ in range(10):
logger_mod._record_failure()
assert logger_mod._is_killed() is True
# 之後再 _async_write 應該不會啟動新 thread看是否走 logger.info 分支)
captured_threads = []
class TrackingThread:
def __init__(self, *a, **kw):
captured_threads.append(kw.get('target'))
def start(self):
pass
monkeypatch.setattr(logger_mod.threading, 'Thread', TrackingThread)
with log_ai_call('x', 'y', 'z'):
pass
time.sleep(0.05)
assert len(captured_threads) == 0, "kill-switch 啟動後不應再開新 thread"
def test_record_success_resets_failure_counter():
_reset_kill_switch()
for _ in range(5):
logger_mod._record_failure()
assert logger_mod._failure_state['count'] == 5
logger_mod._record_success()
assert logger_mod._failure_state['count'] == 0
# ─────────────────────────────────────────────────────────────────────────────
# PII 保護
# ─────────────────────────────────────────────────────────────────────────────
def test_set_prompt_hash_truncates_to_12():
state = _CallState('a', 'b', 'c', None, {})
state.set_prompt_hash('Hello world some sensitive PII content here')
assert 'prompt_hash' in state.meta
assert len(state.meta['prompt_hash']) == 12
# 確認不是原文
assert 'Hello' not in state.meta['prompt_hash']
def test_meta_does_not_leak_raw_prompt_into_call_state():
"""log_ai_call 介面不接受原始 prompt 欄位(只能透過 set_prompt_hash 進去)。"""
with log_ai_call('x', 'y', 'z', meta={'temperature': 0.3}) as ctx:
ctx.set_prompt_hash("super secret user prompt 123")
assert 'prompt_hash' in ctx.meta
assert ctx.meta['temperature'] == 0.3
# meta 中不應有 'prompt' key除非 caller 自己加)
assert 'prompt' not in ctx.meta
# ─────────────────────────────────────────────────────────────────────────────
# 雜項cost table 鍵值完整性
# ─────────────────────────────────────────────────────────────────────────────
def test_cost_table_contains_critical_models():
"""phase0 audit 列舉的關鍵模型必須在表內。"""
critical = [
'gemini-2.5-flash',
'gemini-2.0-flash',
'meta/llama-3.1-8b-instruct',
'hermes3:latest',
'qwen2.5-coder:7b',
'llama3.1:8b',
'bge-m3:latest',
]
for m in critical:
assert m in COST_TABLE, f"COST_TABLE missing {m}"