468 lines
18 KiB
Python
468 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
tests/test_ai_call_logger.py
|
||
ai_call_logger 單元測試 (Operation Ollama-First v5.0 — Phase 1)
|
||
|
||
測試紀律 (對應 phase1 spec):
|
||
- context manager 正常路徑(status='ok')
|
||
- context manager 例外路徑(status='error',例外仍 re-raise)
|
||
- decorator 正常路徑 + auto token extract
|
||
- DB 失敗時主流程不爆
|
||
- cost 計算正確(gemini-2.5-flash / 未知 model fallback / NIM 免費)
|
||
- 環境開關 AI_CALL_LOGGING_ENABLED=false 時跳過寫入
|
||
- kill-switch 連續失敗 ≥ 10 次降級
|
||
- PII 保護:set_prompt_hash 只存前 12 碼
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import time
|
||
import builtins
|
||
import logging
|
||
|
||
import pytest
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
# 隔離 import:避免被 ai_call_logger 內部 lazy import 的 database.manager 拖到
|
||
import services.ai_call_logger as logger_mod
|
||
from services.ai_call_logger import (
|
||
COST_TABLE,
|
||
_calc_cost,
|
||
_CallState,
|
||
_is_logging_enabled,
|
||
_reset_kill_switch,
|
||
log_ai_call,
|
||
logged_ai_call,
|
||
)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Fixtures
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
@pytest.fixture(autouse=True)
|
||
def reset_state(monkeypatch):
|
||
"""每個測試前重置 kill-switch 並 stub 掉真實 DB 寫入。"""
|
||
_reset_kill_switch()
|
||
|
||
# stub _write_to_db:把寫入內容收集到 list(避免真連 DB)
|
||
captured = []
|
||
|
||
def fake_write(state):
|
||
captured.append({
|
||
'caller': state.caller,
|
||
'provider': state.provider,
|
||
'model': state.model,
|
||
'input_tokens': state.input_tokens,
|
||
'output_tokens': state.output_tokens,
|
||
'duration_ms': state.duration_ms,
|
||
'status': state.status,
|
||
'fallback_to': state.fallback_to,
|
||
'cost_usd': _calc_cost(state.model, state.input_tokens, state.output_tokens),
|
||
'cache_hit': state.cache_hit,
|
||
'rag_hit': state.rag_hit,
|
||
'request_id': state.request_id,
|
||
'error': state.error,
|
||
'meta': dict(state.meta),
|
||
})
|
||
|
||
monkeypatch.setattr(logger_mod, '_write_to_db', fake_write)
|
||
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true')
|
||
|
||
# 把 captured 暴露給測試使用
|
||
yield captured
|
||
|
||
|
||
def _wait_for_async(captured, n=1, timeout=2.0):
|
||
"""等待 daemon thread 寫完。"""
|
||
deadline = time.time() + timeout
|
||
while time.time() < deadline:
|
||
if len(captured) >= n:
|
||
return True
|
||
time.sleep(0.01)
|
||
return False
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# context manager 測試
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
def test_context_manager_happy_path(reset_state):
|
||
captured = reset_state
|
||
with log_ai_call('hermes_analyst', 'gcp_ollama', 'hermes3:latest') as ctx:
|
||
ctx.set_tokens(input=120, output=80)
|
||
ctx.set_cache_hit(False)
|
||
|
||
assert _wait_for_async(captured, 1), "async write 未完成"
|
||
assert len(captured) == 1
|
||
rec = captured[0]
|
||
assert rec['caller'] == 'hermes_analyst'
|
||
assert rec['provider'] == 'gcp_ollama'
|
||
assert rec['model'] == 'hermes3:latest'
|
||
assert rec['input_tokens'] == 120
|
||
assert rec['output_tokens'] == 80
|
||
assert rec['status'] == 'ok'
|
||
assert rec['error'] is None
|
||
assert rec['duration_ms'] is not None and rec['duration_ms'] >= 0
|
||
|
||
|
||
def test_context_manager_exception_path(reset_state):
|
||
captured = reset_state
|
||
with pytest.raises(ValueError, match="boom"):
|
||
with log_ai_call('nemotron_dispatch', 'nim', 'meta/llama-3.1-8b-instruct'):
|
||
raise ValueError("boom")
|
||
|
||
assert _wait_for_async(captured, 1)
|
||
rec = captured[0]
|
||
assert rec['status'] == 'error'
|
||
assert rec['error'] is not None
|
||
assert 'ValueError' in rec['error']
|
||
assert 'boom' in rec['error']
|
||
|
||
|
||
def test_context_manager_logs_registry_import_failure(monkeypatch, caplog):
|
||
"""caller registry 匯入失敗時不阻擋 LLM 遙測,但要留下診斷 log。"""
|
||
real_import = builtins.__import__
|
||
|
||
def _import_with_missing_registry(name, globals=None, locals=None, fromlist=(), level=0):
|
||
if name == "services.llm_caller_registry":
|
||
raise ImportError("registry unavailable")
|
||
return real_import(name, globals, locals, fromlist, level)
|
||
|
||
monkeypatch.setattr(builtins, "__import__", _import_with_missing_registry)
|
||
caplog.set_level(logging.WARNING, logger="services.ai_call_logger")
|
||
|
||
with log_ai_call('hermes_analyst', 'gcp_ollama', 'hermes3:latest') as ctx:
|
||
ctx.set_tokens(input=10, output=5)
|
||
|
||
assert "caller registry import failed" in caplog.text
|
||
|
||
|
||
def test_context_manager_explicit_fallback(reset_state):
|
||
captured = reset_state
|
||
with log_ai_call('openclaw_qa', 'gemini', 'gemini-2.5-flash') as ctx:
|
||
ctx.fallback_to_caller('openclaw_bot_nim')
|
||
|
||
assert _wait_for_async(captured, 1)
|
||
rec = captured[0]
|
||
assert rec['status'] == 'fallback'
|
||
assert rec['fallback_to'] == 'openclaw_bot_nim'
|
||
|
||
|
||
def test_context_manager_set_error_without_raise(reset_state):
|
||
"""caller 主動 set_error 但不 raise(例如 LLM 回 success=false)"""
|
||
captured = reset_state
|
||
with log_ai_call('sales_copy', 'gcp_ollama', 'llama3.1:8b') as ctx:
|
||
ctx.set_error('timeout after 30s')
|
||
ctx.set_tokens(input=50, output=0)
|
||
|
||
assert _wait_for_async(captured, 1)
|
||
rec = captured[0]
|
||
assert rec['status'] == 'error'
|
||
assert 'timeout' in rec['error']
|
||
|
||
|
||
def test_context_manager_marks_rag_hit(reset_state):
|
||
captured = reset_state
|
||
with log_ai_call('openclaw_qa', 'gcp_ollama', 'qwen3:14b') as ctx:
|
||
ctx.set_rag_hit(True)
|
||
|
||
assert _wait_for_async(captured, 1)
|
||
rec = captured[0]
|
||
assert rec['status'] == 'ok'
|
||
assert rec['rag_hit'] is True
|
||
|
||
|
||
def test_context_manager_can_update_actual_provider_after_retry(reset_state):
|
||
captured = reset_state
|
||
with log_ai_call('code_review_openclaw', 'gcp_ollama', 'qwen3:14b') as ctx:
|
||
ctx.set_provider('ollama_secondary')
|
||
|
||
assert _wait_for_async(captured, 1)
|
||
rec = captured[0]
|
||
assert rec['provider'] == 'ollama_secondary'
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# decorator 測試
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
def test_decorator_happy_path(reset_state):
|
||
captured = reset_state
|
||
|
||
@logged_ai_call(caller='trend_match', provider='gcp_ollama', model='llama3.1:8b')
|
||
def fake_call(prompt: str):
|
||
return {'response': 'ok', 'eval_count': 42, 'prompt_eval_count': 100}
|
||
|
||
out = fake_call("hello")
|
||
assert out['response'] == 'ok'
|
||
|
||
assert _wait_for_async(captured, 1)
|
||
rec = captured[0]
|
||
assert rec['caller'] == 'trend_match'
|
||
assert rec['model'] == 'llama3.1:8b'
|
||
assert rec['input_tokens'] == 100
|
||
assert rec['output_tokens'] == 42
|
||
assert rec['status'] == 'ok'
|
||
|
||
|
||
def test_decorator_with_model_extractor(reset_state):
|
||
captured = reset_state
|
||
|
||
@logged_ai_call(
|
||
caller='ppt_gemini',
|
||
provider='gemini',
|
||
model_extractor=lambda args, kw: kw.get('model', 'gemini-2.0-flash'),
|
||
)
|
||
def fake_call(*, model: str, prompt: str):
|
||
return {'usage': {'prompt_tokens': 200, 'completion_tokens': 50}}
|
||
|
||
fake_call(model='gemini-2.5-flash', prompt='x')
|
||
|
||
assert _wait_for_async(captured, 1)
|
||
rec = captured[0]
|
||
assert rec['model'] == 'gemini-2.5-flash'
|
||
assert rec['input_tokens'] == 200
|
||
assert rec['output_tokens'] == 50
|
||
|
||
|
||
def test_decorator_exception_does_reraise(reset_state):
|
||
captured = reset_state
|
||
|
||
@logged_ai_call(caller='code_review_hermes', provider='gcp_ollama', model='hermes3:latest')
|
||
def fake_call():
|
||
raise RuntimeError("net down")
|
||
|
||
with pytest.raises(RuntimeError, match="net down"):
|
||
fake_call()
|
||
|
||
assert _wait_for_async(captured, 1)
|
||
assert captured[0]['status'] == 'error'
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# DB 失敗不爆主流程
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
def test_db_failure_does_not_break_main_flow(monkeypatch, caplog):
|
||
"""驗證 _write_to_db 實際碰到 DB 失敗時,例外不會冒到主流程。
|
||
|
||
直接同步呼叫真實 _write_to_db(已含 try/except);不開 thread,避免噪音。
|
||
"""
|
||
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true')
|
||
|
||
# 把 daemon thread 換成同步呼叫,讓我們直接觀察 _write_to_db 行為
|
||
class SyncThread:
|
||
def __init__(self, target=None, args=(), kwargs=None, **_):
|
||
self._target = target
|
||
self._args = args
|
||
self._kwargs = kwargs or {}
|
||
|
||
def start(self):
|
||
self._target(*self._args, **self._kwargs)
|
||
|
||
monkeypatch.setattr(logger_mod.threading, 'Thread', SyncThread)
|
||
|
||
# autouse fixture 已 stub _write_to_db;這裡覆寫成「真實會失敗的版本」
|
||
def real_write_that_fails(state):
|
||
try:
|
||
raise ImportError("simulated DB unavailable")
|
||
except Exception as e:
|
||
logger_mod._record_failure()
|
||
logger_mod.logger.warning(
|
||
"[AICallLogger] write failed (caller=%s provider=%s): %s",
|
||
state.caller, state.provider, e,
|
||
)
|
||
|
||
monkeypatch.setattr(logger_mod, '_write_to_db', real_write_that_fails)
|
||
|
||
# 主流程不應 raise。
|
||
with caplog.at_level('WARNING'):
|
||
with log_ai_call('hermes_intent', 'gcp_ollama', 'hermes3:latest') as ctx:
|
||
ctx.set_tokens(input=10, output=5)
|
||
|
||
# 至少有一條 [AICallLogger] write failed warning(caller 已 catch)
|
||
assert any('write failed' in r.message for r in caplog.records), \
|
||
"預期 _write_to_db 失敗時 log warning"
|
||
|
||
|
||
def test_async_dispatch_failure_swallowed(monkeypatch):
|
||
"""模擬 thread.start() 失敗(極端 case),主流程也不能爆。"""
|
||
|
||
class BadThread:
|
||
def __init__(self, *a, **kw):
|
||
raise OSError("can't fork")
|
||
|
||
monkeypatch.setattr(logger_mod.threading, 'Thread', BadThread)
|
||
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true')
|
||
|
||
# 不應 raise
|
||
with log_ai_call('x', 'y', 'z'):
|
||
pass
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# cost 計算
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
def test_calc_cost_gemini_flash():
|
||
"""gemini-2.5-flash 1M in + 100K out = $0.075 + $0.030 = $0.105"""
|
||
cost = _calc_cost('gemini-2.5-flash', 1_000_000, 100_000)
|
||
assert cost == pytest.approx(0.105, rel=1e-6)
|
||
|
||
|
||
def test_calc_cost_claude_opus():
|
||
"""claude-opus-4-7 1K in + 1K out = $0.015 + $0.075 = $0.090 / 1000 = $0.00009"""
|
||
cost = _calc_cost('claude-opus-4-7', 1000, 1000)
|
||
expected = (1000 * 15.0 + 1000 * 75.0) / 1_000_000
|
||
assert cost == pytest.approx(expected, rel=1e-6)
|
||
|
||
|
||
def test_calc_cost_ollama_zero():
|
||
assert _calc_cost('hermes3:latest', 100_000, 100_000) == 0.0
|
||
assert _calc_cost('llama3.1:8b', 999_999, 999_999) == 0.0
|
||
|
||
|
||
def test_calc_cost_unknown_model_returns_zero(caplog):
|
||
with caplog.at_level('WARNING'):
|
||
cost = _calc_cost('totally-fake-model-xyz', 1_000_000, 1_000_000)
|
||
assert cost == 0.0
|
||
assert any('unknown model cost' in r.message for r in caplog.records)
|
||
|
||
|
||
def test_calc_cost_nim_prefix_silent_zero(caplog):
|
||
"""nvidia/* meta/* deepseek-* 不應觸發 unknown warning。"""
|
||
with caplog.at_level('WARNING'):
|
||
cost = _calc_cost('nvidia/some-future-model', 1_000_000, 1_000_000)
|
||
assert cost == 0.0
|
||
assert not any('unknown model cost' in r.message for r in caplog.records)
|
||
|
||
|
||
def test_calc_cost_negative_or_none_safe():
|
||
assert _calc_cost('gemini-2.5-flash', None, None) == 0.0
|
||
assert _calc_cost('', 100, 100) == 0.0
|
||
assert _calc_cost('gemini-2.5-flash', -1, -5) == 0.0
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# 環境開關
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
def test_logging_disabled_skips_write(monkeypatch):
|
||
captured = []
|
||
|
||
def fake_write(state):
|
||
captured.append(state)
|
||
|
||
monkeypatch.setattr(logger_mod, '_write_to_db', fake_write)
|
||
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'false')
|
||
|
||
with log_ai_call('sales_copy', 'gcp_ollama', 'llama3.1:8b') as ctx:
|
||
ctx.set_tokens(input=10, output=10)
|
||
|
||
time.sleep(0.05)
|
||
assert len(captured) == 0, "AI_CALL_LOGGING_ENABLED=false 時不應寫入"
|
||
|
||
|
||
def test_logging_enabled_default_true(monkeypatch):
|
||
monkeypatch.delenv('AI_CALL_LOGGING_ENABLED', raising=False)
|
||
assert _is_logging_enabled() is True
|
||
|
||
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', '0')
|
||
assert _is_logging_enabled() is False
|
||
|
||
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'OFF')
|
||
assert _is_logging_enabled() is False
|
||
|
||
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true')
|
||
assert _is_logging_enabled() is True
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Kill-switch
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
def test_kill_switch_after_consecutive_failures(monkeypatch, caplog):
|
||
"""連續失敗 >= 10 次後降級為 logger.info。"""
|
||
_reset_kill_switch()
|
||
|
||
# 真實 _write_to_db 會 catch 例外然後 _record_failure;這裡直接模擬
|
||
monkeypatch.setenv('AI_CALL_LOGGING_ENABLED', 'true')
|
||
|
||
# 強制觸發 10 次失敗
|
||
for _ in range(10):
|
||
logger_mod._record_failure()
|
||
|
||
assert logger_mod._is_killed() is True
|
||
|
||
# 之後再 _async_write 應該不會啟動新 thread(看是否走 logger.info 分支)
|
||
captured_threads = []
|
||
|
||
class TrackingThread:
|
||
def __init__(self, *a, **kw):
|
||
captured_threads.append(kw.get('target'))
|
||
|
||
def start(self):
|
||
pass
|
||
|
||
monkeypatch.setattr(logger_mod.threading, 'Thread', TrackingThread)
|
||
|
||
with log_ai_call('x', 'y', 'z'):
|
||
pass
|
||
|
||
time.sleep(0.05)
|
||
assert len(captured_threads) == 0, "kill-switch 啟動後不應再開新 thread"
|
||
|
||
|
||
def test_record_success_resets_failure_counter():
|
||
_reset_kill_switch()
|
||
for _ in range(5):
|
||
logger_mod._record_failure()
|
||
assert logger_mod._failure_state['count'] == 5
|
||
logger_mod._record_success()
|
||
assert logger_mod._failure_state['count'] == 0
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# PII 保護
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
def test_set_prompt_hash_truncates_to_12():
|
||
state = _CallState('a', 'b', 'c', None, {})
|
||
state.set_prompt_hash('Hello world some sensitive PII content here')
|
||
assert 'prompt_hash' in state.meta
|
||
assert len(state.meta['prompt_hash']) == 12
|
||
# 確認不是原文
|
||
assert 'Hello' not in state.meta['prompt_hash']
|
||
|
||
|
||
def test_meta_does_not_leak_raw_prompt_into_call_state():
|
||
"""log_ai_call 介面不接受原始 prompt 欄位(只能透過 set_prompt_hash 進去)。"""
|
||
with log_ai_call('x', 'y', 'z', meta={'temperature': 0.3}) as ctx:
|
||
ctx.set_prompt_hash("super secret user prompt 123")
|
||
assert 'prompt_hash' in ctx.meta
|
||
assert ctx.meta['temperature'] == 0.3
|
||
# meta 中不應有 'prompt' key(除非 caller 自己加)
|
||
assert 'prompt' not in ctx.meta
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# 雜項:cost table 鍵值完整性
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
def test_cost_table_contains_critical_models():
|
||
"""phase0 audit 列舉的關鍵模型必須在表內。"""
|
||
critical = [
|
||
'gemini-2.5-flash',
|
||
'gemini-2.0-flash',
|
||
'meta/llama-3.1-8b-instruct',
|
||
'hermes3:latest',
|
||
'qwen2.5-coder:7b',
|
||
'llama3.1:8b',
|
||
'bge-m3:latest',
|
||
]
|
||
for m in critical:
|
||
assert m in COST_TABLE, f"COST_TABLE missing {m}"
|