Some checks failed
CD Pipeline / tests (push) Successful in 1m59s
Code Review / ai-code-review (push) Successful in 28s
run-migration / migrate (push) Failing after 24s
CD Pipeline / post-deploy-checks (push) Has been cancelled
CD Pipeline / build-and-deploy (push) Has been cancelled
120 lines
3.8 KiB
Python
120 lines
3.8 KiB
Python
import pytest
|
|
|
|
from src.plugins.mcp.interfaces import MCPTool, MCPToolProvider, MCPToolResult
|
|
from src.plugins.mcp.registry import AuditedMCPToolProvider
|
|
from src.services.ai_providers.agent_loop import AgentToolExecutor
|
|
from src.services.ai_providers.permissions import filter_tools_for_agent, is_tool_allowed
|
|
from src.services.ai_providers.tool_schema import (
|
|
anthropic_tool_schema,
|
|
openai_tool_schema,
|
|
to_provider_tool_name,
|
|
tool_by_provider_name,
|
|
)
|
|
|
|
|
|
class FakeProvider(MCPToolProvider):
|
|
def __init__(self, name="kubernetes"):
|
|
self.calls = []
|
|
self._name = name
|
|
|
|
@property
|
|
def name(self):
|
|
return self._name
|
|
|
|
async def list_tools(self):
|
|
return []
|
|
|
|
async def execute(self, tool_name, parameters):
|
|
self.calls.append((tool_name, parameters))
|
|
return MCPToolResult(success=True, execution_id="exec-1", output={"ok": True})
|
|
|
|
|
|
def _tool(server: str, name: str) -> MCPTool:
|
|
return MCPTool(
|
|
name=name,
|
|
description=f"{server}.{name}",
|
|
input_schema={"type": "object", "properties": {}},
|
|
server_name=server,
|
|
)
|
|
|
|
|
|
def test_agent_tool_permissions_are_role_scoped():
|
|
k8s_get = _tool("kubernetes", "kubectl_get")
|
|
k8s_restart = _tool("kubernetes", "kubectl_restart")
|
|
prom_query = _tool("prometheus", "prometheus_query")
|
|
db_tool = _tool("database", "write_km_entry")
|
|
|
|
assert is_tool_allowed(k8s_restart, "openclaw") is True
|
|
assert is_tool_allowed(k8s_get, "nemotron") is True
|
|
assert is_tool_allowed(k8s_restart, "nemotron") is False
|
|
assert is_tool_allowed(prom_query, "hermes") is True
|
|
assert is_tool_allowed(db_tool, "hermes") is True
|
|
assert is_tool_allowed(k8s_restart, "elephant_alpha") is False
|
|
|
|
filtered = filter_tools_for_agent([k8s_get, k8s_restart, prom_query], "nemotron")
|
|
assert [tool.name for tool in filtered] == ["kubectl_get", "prometheus_query"]
|
|
|
|
|
|
def test_tool_schema_round_trips_provider_safe_names():
|
|
tool = _tool("kubernetes", "kubectl_get")
|
|
|
|
safe_name = to_provider_tool_name(tool)
|
|
|
|
assert safe_name == "kubernetes__kubectl_get"
|
|
assert anthropic_tool_schema(tool)["name"] == safe_name
|
|
assert openai_tool_schema(tool)["function"]["name"] == safe_name
|
|
assert tool_by_provider_name([tool], safe_name) is tool
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_audited_provider_strips_internal_audit_context(monkeypatch):
|
|
audit_calls = []
|
|
|
|
async def fake_record_mcp_call(**kwargs):
|
|
audit_calls.append(kwargs)
|
|
|
|
monkeypatch.setattr(
|
|
"src.services.mcp_audit_service.record_mcp_call",
|
|
fake_record_mcp_call,
|
|
)
|
|
|
|
provider = FakeProvider()
|
|
audited = AuditedMCPToolProvider(provider)
|
|
|
|
result = await audited.execute(
|
|
"kubectl_get",
|
|
{
|
|
"resource": "pods",
|
|
"_mcp_audit": {
|
|
"agent_role": "openclaw",
|
|
"session_id": "session-1",
|
|
"incident_id": "INC-1",
|
|
"flywheel_node": "sense",
|
|
},
|
|
},
|
|
)
|
|
|
|
assert result.success is True
|
|
assert provider.calls == [("kubectl_get", {"resource": "pods"})]
|
|
assert audit_calls[0]["agent_role"] == "openclaw"
|
|
assert audit_calls[0]["session_id"] == "session-1"
|
|
assert audit_calls[0]["incident_id"] == "INC-1"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_tool_executor_blocks_disallowed_tool():
|
|
restart_tool = _tool("kubernetes", "kubectl_restart")
|
|
provider = FakeProvider()
|
|
executor = AgentToolExecutor(
|
|
available_tools=[restart_tool],
|
|
providers={"kubernetes": provider},
|
|
agent_role="nemotron",
|
|
incident_id="INC-1",
|
|
)
|
|
|
|
result = await executor.execute("kubernetes__kubectl_restart", {"deployment": "api"})
|
|
|
|
assert result.success is False
|
|
assert "not allowed" in (result.error or "")
|
|
assert provider.calls == []
|