106 lines
3.2 KiB
Python
106 lines
3.2 KiB
Python
import logging
|
|
|
|
import pytest
|
|
|
|
import src.plugins.mcp.providers.ssh_provider as ssh_provider_module
|
|
from src.plugins.mcp.providers.ssh_provider import (
|
|
SSHProvider,
|
|
_normalize_ssh_host,
|
|
_quiet_asyncssh_info_logs,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ssh_diagnose_is_registered_read_only_tool():
|
|
provider = SSHProvider()
|
|
tool_names = {tool.name for tool in await provider.list_tools()}
|
|
command = provider._build_command("ssh_diagnose", {})
|
|
|
|
assert "ssh_diagnose" in tool_names
|
|
assert "CPU TOP" in command
|
|
assert "df -h" in command
|
|
|
|
|
|
def test_ssh_diagnose_can_include_container_read_only_context():
|
|
provider = SSHProvider()
|
|
|
|
command = provider._build_command(
|
|
"ssh_diagnose",
|
|
{"container_name": "sentry-self-hosted-clickhouse-1"},
|
|
)
|
|
|
|
assert "docker stats --no-stream sentry-self-hosted-clickhouse-1" in command
|
|
assert "docker inspect sentry-self-hosted-clickhouse-1" in command
|
|
|
|
|
|
def test_ssh_provider_uses_ollama_user_for_188():
|
|
provider = SSHProvider()
|
|
|
|
assert provider._ssh_user_for_host("192.168.0.188") == "ollama"
|
|
assert provider._ssh_user_for_host("192.168.0.110") == "wooo"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"raw,expected",
|
|
[
|
|
("192.168.0.110:9100", "192.168.0.110"),
|
|
("110:9100", "192.168.0.110"),
|
|
("wooo", "192.168.0.110"),
|
|
("wooo:9100", "192.168.0.110"),
|
|
("188", "192.168.0.188"),
|
|
("wooo@192.168.0.110", "192.168.0.110"),
|
|
("ssh://wooo@192.168.0.110:22", "192.168.0.110"),
|
|
("192.168.0.188", "192.168.0.188"),
|
|
],
|
|
)
|
|
def test_normalize_ssh_host_strips_exporter_ports_and_users(raw, expected):
|
|
assert _normalize_ssh_host(raw) == expected
|
|
|
|
|
|
def test_quiet_asyncssh_info_logs_sets_asyncssh_to_warning(monkeypatch):
|
|
monkeypatch.setattr(ssh_provider_module, "_asyncssh_logger_configured", False)
|
|
asyncssh_logger = logging.getLogger("asyncssh")
|
|
previous_level = asyncssh_logger.level
|
|
asyncssh_logger.setLevel(logging.INFO)
|
|
|
|
try:
|
|
_quiet_asyncssh_info_logs()
|
|
|
|
assert asyncssh_logger.level == logging.WARNING
|
|
assert ssh_provider_module._asyncssh_logger_configured is True
|
|
finally:
|
|
asyncssh_logger.setLevel(previous_level)
|
|
ssh_provider_module._asyncssh_logger_configured = False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ssh_execute_normalizes_host_before_allowed_check(monkeypatch):
|
|
provider = SSHProvider()
|
|
captured = {}
|
|
|
|
async def fake_ssh_exec(host, cmd, timeout, username=None):
|
|
captured["host"] = host
|
|
captured["timeout"] = timeout
|
|
captured["username"] = username
|
|
return "ok", ""
|
|
|
|
monkeypatch.setattr(provider, "_allowed_hosts", lambda: ["192.168.0.110"])
|
|
monkeypatch.setattr(provider, "_ssh_exec", fake_ssh_exec)
|
|
|
|
result = await provider.execute("ssh_diagnose", {"host": "192.168.0.110:9100"})
|
|
|
|
assert result.success is True
|
|
assert captured["host"] == "192.168.0.110"
|
|
assert isinstance(captured["timeout"], int)
|
|
|
|
|
|
def test_ssh_container_status_accepts_legacy_container_name_alias():
|
|
provider = SSHProvider()
|
|
|
|
command = provider._build_command(
|
|
"ssh_get_container_status",
|
|
{"container_name": "awoooi-api"},
|
|
)
|
|
|
|
assert command == "docker ps -a --filter name=awoooi-api"
|