Files
aiagent/backend/tests/test_tool_registry.py

294 lines
10 KiB
Python
Raw Normal View History

"""
工具注册表单元测试
"""
from __future__ import annotations
import json
import pytest
from unittest.mock import patch, AsyncMock
from app.services.tool_registry import ToolRegistry, _CODE_SAFE_GLOBALS
@pytest.fixture
def registry():
r = ToolRegistry()
return r
class TestToolRegistryBuiltin:
"""内置工具注册与查询"""
@pytest.mark.unit
def test_register_and_get(self, registry):
def my_tool(**kwargs):
return {"result": kwargs.get("x", 0) + kwargs.get("y", 0)}
schema = {
"type": "function",
"function": {
"name": "add",
"description": "加法",
"parameters": {"type": "object", "properties": {"x": {"type": "number"}, "y": {"type": "number"}}},
},
}
registry.register_builtin_tool("add", my_tool, schema)
assert registry.get_tool_schema("add") == schema
assert registry.get_tool_function("add") == my_tool
assert registry.builtin_tool_count() == 1
assert "add" in registry.builtin_tool_names()
@pytest.mark.unit
def test_get_missing_tool(self, registry):
assert registry.get_tool_schema("nonexistent") is None
assert registry.get_tool_function("nonexistent") is None
@pytest.mark.unit
def test_get_all_schemas(self, registry):
schema1 = {"type": "function", "function": {"name": "tool1"}}
schema2 = {"type": "function", "function": {"name": "tool2"}}
registry.register_builtin_tool("tool1", lambda: None, schema1)
registry.register_builtin_tool("tool2", lambda: None, schema2)
assert len(registry.get_all_tool_schemas()) == 2
@pytest.mark.unit
def test_get_tools_by_names(self, registry):
registry.register_builtin_tool("a", lambda: None, {"function": {"name": "a"}})
registry.register_builtin_tool("b", lambda: None, {"function": {"name": "b"}})
tools = registry.get_tools_by_names(["a", "b", "c"])
assert len(tools) == 2
assert tools[0]["function"]["name"] == "a"
@pytest.mark.unit
def test_sync_function_execution(self, registry):
def sync_func(x=0, y=0):
return x + y
registry.register_builtin_tool("add", sync_func, {"function": {"name": "add"}})
result = registry._run_function(sync_func, "add", {"x": 3, "y": 4})
import asyncio
result = asyncio.run(registry._run_function(sync_func, "add", {"x": 3, "y": 4}))
parsed = json.loads(result)
assert parsed == 7
@pytest.mark.usefixtures("registry")
class TestToolRegistryHTTP:
"""HTTP 工具执行测试mock httpx"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_http(self, registry):
config = {
"url": "https://api.example.com/data?q={query}",
"method": "GET",
"headers": {"Authorization": "Bearer token"},
"timeout": 10,
"_type": "http",
}
registry._custom_tool_configs["test_http"] = config
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.text = '{"result": "ok"}'
mock_request.return_value = mock_response
result = await registry.execute_tool("test_http", {"query": "hello"})
parsed = json.loads(result)
assert parsed["status_code"] == 200
# 验证 URL 模板替换
called_url = mock_request.call_args[0][1]
assert "hello" in called_url
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_http_post_with_body(self, registry):
config = {
"url": "https://api.example.com/submit",
"method": "POST",
"headers": {},
"body_template": {"name": "{name}", "age": "{age}"},
"timeout": 10,
"_type": "http",
}
registry._custom_tool_configs["test_post"] = config
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
mock_response = AsyncMock()
mock_response.status_code = 201
mock_response.text = '{"id": 1}'
mock_request.return_value = mock_response
result = await registry.execute_tool("test_post", {"name": "Alice", "age": 30})
parsed = json.loads(result)
assert parsed["status_code"] == 201
# Verify POST method
assert mock_request.call_args[0][0] == "POST"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_http_no_url(self, registry):
config = {"_type": "http"}
registry._custom_tool_configs["bad_http"] = config
result = await registry.execute_tool("bad_http", {})
assert "error" in result
class TestToolRegistryCode:
"""代码沙箱执行测试"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_code_simple(self, registry):
config = {
"source": "def run(args):\n return {'sum': args['a'] + args['b']}",
"_type": "code",
}
registry._custom_tool_configs["calc"] = config
result = await registry.execute_tool("calc", {"a": 10, "b": 20})
parsed = json.loads(result)
assert parsed["sum"] == 30
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_code_text_stats(self, registry):
config = {
"source": "def run(args):\n text = args.get('text', '')\n return {'len': len(text), 'words': len(text.split())}",
"_type": "code",
}
registry._custom_tool_configs["stats"] = config
result = await registry.execute_tool("stats", {"text": "hello world test"})
parsed = json.loads(result)
assert parsed["len"] == 16
assert parsed["words"] == 3
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_source_missing(self, registry):
config = {"_type": "code"} # no source
registry._custom_tool_configs["no_source"] = config
result = await registry.execute_tool("no_source", {})
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_no_run_function(self, registry):
config = {"source": "x = 1", "_type": "code"}
registry._custom_tool_configs["no_run"] = config
result = await registry.execute_tool("no_run", {})
assert "run" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_runtime_error(self, registry):
config = {
"source": "def run(args):\n raise ValueError('test error')",
"_type": "code",
}
registry._custom_tool_configs["err"] = config
result = await registry.execute_tool("err", {})
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_sandbox_restriction(self, registry):
"""验证 __builtins__ 被禁用"""
config = {
"source": "def run(args):\n import os\n return os.name",
"_type": "code",
}
registry._custom_tool_configs["unsafe"] = config
result = await registry.execute_tool("unsafe", {})
# import 在沙箱中应该失败
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_sandbox_no_file_access(self, registry):
"""验证无法访问文件系统"""
config = {
"source": "def run(args):\n open('/etc/passwd')\n return 'ok'",
"_type": "code",
}
registry._custom_tool_configs["file_access"] = config
result = await registry.execute_tool("file_access", {})
assert "error" in result
class TestToolRegistryTestHelpers:
"""测试工具(不保存到 DB"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_test_code_tool_success(self, registry):
source = "def run(args):\n return {'result': args['x'] * 2}"
result = await registry.test_code_tool(source, {"x": 5})
assert result["success"] is True
assert result["result"]["result"] == 10
@pytest.mark.unit
@pytest.mark.asyncio
async def test_test_code_tool_compile_error(self, registry):
source = "def run(args):\n invalid syntax{{{"
result = await registry.test_code_tool(source, {})
assert result["success"] is False
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_test_http_tool(self, registry):
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.text = '{"ip": "8.8.8.8"}'
mock_request.return_value = mock_response
from app.services.tool_registry import _CODE_SAFE_GLOBALS
result = await registry.test_http_tool(
url="https://httpbin.org/get?ip={ip}",
method="GET",
headers={},
body=None,
args={"ip": "8.8.8.8"},
timeout=5,
)
assert result["success"] is True
assert result["status_code"] == 200
class TestToolRegistryExecute:
"""execute_tool 整体流程"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_builtin(self, registry):
def hello(**kwargs):
return f"Hello, {kwargs.get('name', 'world')}!"
registry.register_builtin_tool("hello", hello, {"function": {"name": "hello"}})
result = await registry.execute_tool("hello", {"name": "Test"})
assert "Hello, Test!" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_nonexistent(self, registry):
result = await registry.execute_tool("no_such_tool", {})
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_unsupported_type(self, registry):
config = {"_type": "unsupported"}
registry._custom_tool_configs["weird"] = config
result = await registry.execute_tool("weird", {})
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_workflow_not_supported(self, registry):
config = {"_type": "workflow"}
registry._custom_tool_configs["wf"] = config
result = await registry.execute_tool("wf", {})
assert "暂不支持" in result