fix(awooop): initialize mcp runtime for signal worker
This commit is contained in:
@@ -448,6 +448,7 @@ async def _add_observation_timeline(
|
||||
async def _run_pre_decision_investigation(incident: "Incident") -> int | None:
|
||||
started = time.monotonic()
|
||||
try:
|
||||
await _ensure_signal_observation_mcp_runtime()
|
||||
from src.services.pre_decision_investigator import get_pre_decision_investigator
|
||||
|
||||
await asyncio.wait_for(
|
||||
@@ -471,6 +472,33 @@ async def _run_pre_decision_investigation(incident: "Incident") -> int | None:
|
||||
return int((time.monotonic() - started) * 1000)
|
||||
|
||||
|
||||
async def _ensure_signal_observation_mcp_runtime() -> None:
|
||||
"""Make one-off signal observation runs use the same MCP runtime as API startup."""
|
||||
try:
|
||||
from src.plugins.mcp.providers import register_all_providers
|
||||
from src.plugins.mcp.registry import get_provider_registry
|
||||
from src.services.mcp_tool_registry import get_mcp_tool_registry, init_mcp_tool_registry
|
||||
|
||||
provider_registry = get_provider_registry()
|
||||
if len(provider_registry) == 0:
|
||||
register_all_providers()
|
||||
logger.info("signal_observation_mcp_providers_registered")
|
||||
|
||||
tool_registry = get_mcp_tool_registry()
|
||||
if tool_registry.tool_count == 0:
|
||||
await init_mcp_tool_registry()
|
||||
logger.info(
|
||||
"signal_observation_mcp_tool_registry_initialized",
|
||||
providers=tool_registry.provider_count,
|
||||
tools=tool_registry.tool_count,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"signal_observation_mcp_runtime_init_failed",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
|
||||
async def record_signal_worker_observation(
|
||||
incident: "Incident",
|
||||
signal_data: dict[str, Any],
|
||||
|
||||
@@ -592,6 +592,16 @@ async def _main() -> None:
|
||||
message="Episodic Memory (DB) will be unavailable - incidents won't persist",
|
||||
)
|
||||
|
||||
try:
|
||||
from src.plugins.mcp.providers import register_all_providers
|
||||
from src.services.mcp_tool_registry import init_mcp_tool_registry
|
||||
|
||||
register_all_providers()
|
||||
await init_mcp_tool_registry()
|
||||
logger.info("signal_worker_mcp_runtime_initialized")
|
||||
except Exception as e:
|
||||
logger.warning("signal_worker_mcp_runtime_init_failed", error=str(e))
|
||||
|
||||
# Write health files for K8s probes
|
||||
await _write_health_files()
|
||||
|
||||
|
||||
@@ -337,6 +337,36 @@ class TestSuggestTools:
|
||||
registry = MCPToolRegistry()
|
||||
assert registry.suggest_tools() == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_host_error_log_flood_gets_host_observability_tools(self):
|
||||
registry = MCPToolRegistry()
|
||||
ssh_provider = _StubProvider(
|
||||
"ssh_host",
|
||||
["ssh_diagnose", "ssh_get_top_processes", "ssh_get_container_logs"],
|
||||
)
|
||||
signoz_provider = _StubProvider("signoz", ["query_logs"])
|
||||
prometheus_provider = _StubProvider("prometheus", ["prometheus_query"])
|
||||
await registry.register_provider(ssh_provider)
|
||||
await registry.register_provider(signoz_provider)
|
||||
await registry.register_provider(prometheus_provider)
|
||||
|
||||
tools = registry.suggest_tools(
|
||||
alertname="HostErrorLogFlood",
|
||||
incident_labels={
|
||||
"target": "ollama",
|
||||
"sensor_host": "ollama",
|
||||
"sensor_ip": "192.168.0.188",
|
||||
"host": "192.168.0.188",
|
||||
},
|
||||
max_tools=10,
|
||||
)
|
||||
names = [reg.tool.name for reg in tools]
|
||||
|
||||
assert "ssh_diagnose" in names
|
||||
assert "ssh_get_top_processes" in names
|
||||
assert "query_logs" in names
|
||||
assert "prometheus_query" in names
|
||||
|
||||
def test_get_all_tools_returns_all(self):
|
||||
registry = MCPToolRegistry()
|
||||
provider = _StubProvider("test", [])
|
||||
|
||||
Reference in New Issue
Block a user