294 lines
10 KiB
Python
294 lines
10 KiB
Python
|
|
"""
|
|||
|
|
工具注册表单元测试
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import pytest
|
|||
|
|
from unittest.mock import patch, AsyncMock
|
|||
|
|
|
|||
|
|
from app.services.tool_registry import ToolRegistry, _CODE_SAFE_GLOBALS
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def registry():
|
|||
|
|
r = ToolRegistry()
|
|||
|
|
return r
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestToolRegistryBuiltin:
|
|||
|
|
"""内置工具注册与查询"""
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
def test_register_and_get(self, registry):
|
|||
|
|
def my_tool(**kwargs):
|
|||
|
|
return {"result": kwargs.get("x", 0) + kwargs.get("y", 0)}
|
|||
|
|
|
|||
|
|
schema = {
|
|||
|
|
"type": "function",
|
|||
|
|
"function": {
|
|||
|
|
"name": "add",
|
|||
|
|
"description": "加法",
|
|||
|
|
"parameters": {"type": "object", "properties": {"x": {"type": "number"}, "y": {"type": "number"}}},
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
registry.register_builtin_tool("add", my_tool, schema)
|
|||
|
|
assert registry.get_tool_schema("add") == schema
|
|||
|
|
assert registry.get_tool_function("add") == my_tool
|
|||
|
|
assert registry.builtin_tool_count() == 1
|
|||
|
|
assert "add" in registry.builtin_tool_names()
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
def test_get_missing_tool(self, registry):
|
|||
|
|
assert registry.get_tool_schema("nonexistent") is None
|
|||
|
|
assert registry.get_tool_function("nonexistent") is None
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
def test_get_all_schemas(self, registry):
|
|||
|
|
schema1 = {"type": "function", "function": {"name": "tool1"}}
|
|||
|
|
schema2 = {"type": "function", "function": {"name": "tool2"}}
|
|||
|
|
registry.register_builtin_tool("tool1", lambda: None, schema1)
|
|||
|
|
registry.register_builtin_tool("tool2", lambda: None, schema2)
|
|||
|
|
assert len(registry.get_all_tool_schemas()) == 2
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
def test_get_tools_by_names(self, registry):
|
|||
|
|
registry.register_builtin_tool("a", lambda: None, {"function": {"name": "a"}})
|
|||
|
|
registry.register_builtin_tool("b", lambda: None, {"function": {"name": "b"}})
|
|||
|
|
tools = registry.get_tools_by_names(["a", "b", "c"])
|
|||
|
|
assert len(tools) == 2
|
|||
|
|
assert tools[0]["function"]["name"] == "a"
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
def test_sync_function_execution(self, registry):
|
|||
|
|
def sync_func(x=0, y=0):
|
|||
|
|
return x + y
|
|||
|
|
|
|||
|
|
registry.register_builtin_tool("add", sync_func, {"function": {"name": "add"}})
|
|||
|
|
result = registry._run_function(sync_func, "add", {"x": 3, "y": 4})
|
|||
|
|
import asyncio
|
|||
|
|
result = asyncio.run(registry._run_function(sync_func, "add", {"x": 3, "y": 4}))
|
|||
|
|
parsed = json.loads(result)
|
|||
|
|
assert parsed == 7
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.usefixtures("registry")
|
|||
|
|
class TestToolRegistryHTTP:
|
|||
|
|
"""HTTP 工具执行测试(mock httpx)"""
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_execute_http(self, registry):
|
|||
|
|
config = {
|
|||
|
|
"url": "https://api.example.com/data?q={query}",
|
|||
|
|
"method": "GET",
|
|||
|
|
"headers": {"Authorization": "Bearer token"},
|
|||
|
|
"timeout": 10,
|
|||
|
|
"_type": "http",
|
|||
|
|
}
|
|||
|
|
registry._custom_tool_configs["test_http"] = config
|
|||
|
|
|
|||
|
|
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
|
|||
|
|
mock_response = AsyncMock()
|
|||
|
|
mock_response.status_code = 200
|
|||
|
|
mock_response.text = '{"result": "ok"}'
|
|||
|
|
mock_request.return_value = mock_response
|
|||
|
|
|
|||
|
|
result = await registry.execute_tool("test_http", {"query": "hello"})
|
|||
|
|
parsed = json.loads(result)
|
|||
|
|
assert parsed["status_code"] == 200
|
|||
|
|
# 验证 URL 模板替换
|
|||
|
|
called_url = mock_request.call_args[0][1]
|
|||
|
|
assert "hello" in called_url
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_execute_http_post_with_body(self, registry):
|
|||
|
|
config = {
|
|||
|
|
"url": "https://api.example.com/submit",
|
|||
|
|
"method": "POST",
|
|||
|
|
"headers": {},
|
|||
|
|
"body_template": {"name": "{name}", "age": "{age}"},
|
|||
|
|
"timeout": 10,
|
|||
|
|
"_type": "http",
|
|||
|
|
}
|
|||
|
|
registry._custom_tool_configs["test_post"] = config
|
|||
|
|
|
|||
|
|
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
|
|||
|
|
mock_response = AsyncMock()
|
|||
|
|
mock_response.status_code = 201
|
|||
|
|
mock_response.text = '{"id": 1}'
|
|||
|
|
mock_request.return_value = mock_response
|
|||
|
|
|
|||
|
|
result = await registry.execute_tool("test_post", {"name": "Alice", "age": 30})
|
|||
|
|
parsed = json.loads(result)
|
|||
|
|
assert parsed["status_code"] == 201
|
|||
|
|
# Verify POST method
|
|||
|
|
assert mock_request.call_args[0][0] == "POST"
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_http_no_url(self, registry):
|
|||
|
|
config = {"_type": "http"}
|
|||
|
|
registry._custom_tool_configs["bad_http"] = config
|
|||
|
|
result = await registry.execute_tool("bad_http", {})
|
|||
|
|
assert "error" in result
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestToolRegistryCode:
|
|||
|
|
"""代码沙箱执行测试"""
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_execute_code_simple(self, registry):
|
|||
|
|
config = {
|
|||
|
|
"source": "def run(args):\n return {'sum': args['a'] + args['b']}",
|
|||
|
|
"_type": "code",
|
|||
|
|
}
|
|||
|
|
registry._custom_tool_configs["calc"] = config
|
|||
|
|
result = await registry.execute_tool("calc", {"a": 10, "b": 20})
|
|||
|
|
parsed = json.loads(result)
|
|||
|
|
assert parsed["sum"] == 30
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_execute_code_text_stats(self, registry):
|
|||
|
|
config = {
|
|||
|
|
"source": "def run(args):\n text = args.get('text', '')\n return {'len': len(text), 'words': len(text.split())}",
|
|||
|
|
"_type": "code",
|
|||
|
|
}
|
|||
|
|
registry._custom_tool_configs["stats"] = config
|
|||
|
|
result = await registry.execute_tool("stats", {"text": "hello world test"})
|
|||
|
|
parsed = json.loads(result)
|
|||
|
|
assert parsed["len"] == 16
|
|||
|
|
assert parsed["words"] == 3
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_code_source_missing(self, registry):
|
|||
|
|
config = {"_type": "code"} # no source
|
|||
|
|
registry._custom_tool_configs["no_source"] = config
|
|||
|
|
result = await registry.execute_tool("no_source", {})
|
|||
|
|
assert "error" in result
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_code_no_run_function(self, registry):
|
|||
|
|
config = {"source": "x = 1", "_type": "code"}
|
|||
|
|
registry._custom_tool_configs["no_run"] = config
|
|||
|
|
result = await registry.execute_tool("no_run", {})
|
|||
|
|
assert "run" in result
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_code_runtime_error(self, registry):
|
|||
|
|
config = {
|
|||
|
|
"source": "def run(args):\n raise ValueError('test error')",
|
|||
|
|
"_type": "code",
|
|||
|
|
}
|
|||
|
|
registry._custom_tool_configs["err"] = config
|
|||
|
|
result = await registry.execute_tool("err", {})
|
|||
|
|
assert "error" in result
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_code_sandbox_restriction(self, registry):
|
|||
|
|
"""验证 __builtins__ 被禁用"""
|
|||
|
|
config = {
|
|||
|
|
"source": "def run(args):\n import os\n return os.name",
|
|||
|
|
"_type": "code",
|
|||
|
|
}
|
|||
|
|
registry._custom_tool_configs["unsafe"] = config
|
|||
|
|
result = await registry.execute_tool("unsafe", {})
|
|||
|
|
# import 在沙箱中应该失败
|
|||
|
|
assert "error" in result
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_code_sandbox_no_file_access(self, registry):
|
|||
|
|
"""验证无法访问文件系统"""
|
|||
|
|
config = {
|
|||
|
|
"source": "def run(args):\n open('/etc/passwd')\n return 'ok'",
|
|||
|
|
"_type": "code",
|
|||
|
|
}
|
|||
|
|
registry._custom_tool_configs["file_access"] = config
|
|||
|
|
result = await registry.execute_tool("file_access", {})
|
|||
|
|
assert "error" in result
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestToolRegistryTestHelpers:
|
|||
|
|
"""测试工具(不保存到 DB)"""
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_test_code_tool_success(self, registry):
|
|||
|
|
source = "def run(args):\n return {'result': args['x'] * 2}"
|
|||
|
|
result = await registry.test_code_tool(source, {"x": 5})
|
|||
|
|
assert result["success"] is True
|
|||
|
|
assert result["result"]["result"] == 10
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_test_code_tool_compile_error(self, registry):
|
|||
|
|
source = "def run(args):\n invalid syntax{{{"
|
|||
|
|
result = await registry.test_code_tool(source, {})
|
|||
|
|
assert result["success"] is False
|
|||
|
|
assert "error" in result
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_test_http_tool(self, registry):
|
|||
|
|
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
|
|||
|
|
mock_response = AsyncMock()
|
|||
|
|
mock_response.status_code = 200
|
|||
|
|
mock_response.text = '{"ip": "8.8.8.8"}'
|
|||
|
|
mock_request.return_value = mock_response
|
|||
|
|
|
|||
|
|
from app.services.tool_registry import _CODE_SAFE_GLOBALS
|
|||
|
|
result = await registry.test_http_tool(
|
|||
|
|
url="https://httpbin.org/get?ip={ip}",
|
|||
|
|
method="GET",
|
|||
|
|
headers={},
|
|||
|
|
body=None,
|
|||
|
|
args={"ip": "8.8.8.8"},
|
|||
|
|
timeout=5,
|
|||
|
|
)
|
|||
|
|
assert result["success"] is True
|
|||
|
|
assert result["status_code"] == 200
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestToolRegistryExecute:
|
|||
|
|
"""execute_tool 整体流程"""
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_execute_builtin(self, registry):
|
|||
|
|
def hello(**kwargs):
|
|||
|
|
return f"Hello, {kwargs.get('name', 'world')}!"
|
|||
|
|
|
|||
|
|
registry.register_builtin_tool("hello", hello, {"function": {"name": "hello"}})
|
|||
|
|
result = await registry.execute_tool("hello", {"name": "Test"})
|
|||
|
|
assert "Hello, Test!" in result
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_execute_nonexistent(self, registry):
|
|||
|
|
result = await registry.execute_tool("no_such_tool", {})
|
|||
|
|
assert "error" in result
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_execute_unsupported_type(self, registry):
|
|||
|
|
config = {"_type": "unsupported"}
|
|||
|
|
registry._custom_tool_configs["weird"] = config
|
|||
|
|
result = await registry.execute_tool("weird", {})
|
|||
|
|
assert "error" in result
|
|||
|
|
|
|||
|
|
@pytest.mark.unit
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_execute_workflow_not_supported(self, registry):
|
|||
|
|
config = {"_type": "workflow"}
|
|||
|
|
registry._custom_tool_configs["wf"] = config
|
|||
|
|
result = await registry.execute_tool("wf", {})
|
|||
|
|
assert "暂不支持" in result
|