""" 工具注册表单元测试 """ from __future__ import annotations import json import pytest from unittest.mock import patch, AsyncMock from app.services.tool_registry import ToolRegistry, _CODE_SAFE_GLOBALS @pytest.fixture def registry(): r = ToolRegistry() return r class TestToolRegistryBuiltin: """内置工具注册与查询""" @pytest.mark.unit def test_register_and_get(self, registry): def my_tool(**kwargs): return {"result": kwargs.get("x", 0) + kwargs.get("y", 0)} schema = { "type": "function", "function": { "name": "add", "description": "加法", "parameters": {"type": "object", "properties": {"x": {"type": "number"}, "y": {"type": "number"}}}, }, } registry.register_builtin_tool("add", my_tool, schema) assert registry.get_tool_schema("add") == schema assert registry.get_tool_function("add") == my_tool assert registry.builtin_tool_count() == 1 assert "add" in registry.builtin_tool_names() @pytest.mark.unit def test_get_missing_tool(self, registry): assert registry.get_tool_schema("nonexistent") is None assert registry.get_tool_function("nonexistent") is None @pytest.mark.unit def test_get_all_schemas(self, registry): schema1 = {"type": "function", "function": {"name": "tool1"}} schema2 = {"type": "function", "function": {"name": "tool2"}} registry.register_builtin_tool("tool1", lambda: None, schema1) registry.register_builtin_tool("tool2", lambda: None, schema2) assert len(registry.get_all_tool_schemas()) == 2 @pytest.mark.unit def test_get_tools_by_names(self, registry): registry.register_builtin_tool("a", lambda: None, {"function": {"name": "a"}}) registry.register_builtin_tool("b", lambda: None, {"function": {"name": "b"}}) tools = registry.get_tools_by_names(["a", "b", "c"]) assert len(tools) == 2 assert tools[0]["function"]["name"] == "a" @pytest.mark.unit def test_sync_function_execution(self, registry): def sync_func(x=0, y=0): return x + y registry.register_builtin_tool("add", sync_func, {"function": {"name": "add"}}) result = registry._run_function(sync_func, "add", {"x": 3, "y": 4}) import asyncio result = asyncio.run(registry._run_function(sync_func, "add", {"x": 3, "y": 4})) parsed = json.loads(result) assert parsed == 7 @pytest.mark.usefixtures("registry") class TestToolRegistryHTTP: """HTTP 工具执行测试(mock httpx)""" @pytest.mark.unit @pytest.mark.asyncio async def test_execute_http(self, registry): config = { "url": "https://api.example.com/data?q={query}", "method": "GET", "headers": {"Authorization": "Bearer token"}, "timeout": 10, "_type": "http", } registry._custom_tool_configs["test_http"] = config with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request: mock_response = AsyncMock() mock_response.status_code = 200 mock_response.text = '{"result": "ok"}' mock_request.return_value = mock_response result = await registry.execute_tool("test_http", {"query": "hello"}) parsed = json.loads(result) assert parsed["status_code"] == 200 # 验证 URL 模板替换 called_url = mock_request.call_args[0][1] assert "hello" in called_url @pytest.mark.unit @pytest.mark.asyncio async def test_execute_http_post_with_body(self, registry): config = { "url": "https://api.example.com/submit", "method": "POST", "headers": {}, "body_template": {"name": "{name}", "age": "{age}"}, "timeout": 10, "_type": "http", } registry._custom_tool_configs["test_post"] = config with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request: mock_response = AsyncMock() mock_response.status_code = 201 mock_response.text = '{"id": 1}' mock_request.return_value = mock_response result = await registry.execute_tool("test_post", {"name": "Alice", "age": 30}) parsed = json.loads(result) assert parsed["status_code"] == 201 # Verify POST method assert mock_request.call_args[0][0] == "POST" @pytest.mark.unit @pytest.mark.asyncio async def test_http_no_url(self, registry): config = {"_type": "http"} registry._custom_tool_configs["bad_http"] = config result = await registry.execute_tool("bad_http", {}) assert "error" in result class TestToolRegistryCode: """代码沙箱执行测试""" @pytest.mark.unit @pytest.mark.asyncio async def test_execute_code_simple(self, registry): config = { "source": "def run(args):\n return {'sum': args['a'] + args['b']}", "_type": "code", } registry._custom_tool_configs["calc"] = config result = await registry.execute_tool("calc", {"a": 10, "b": 20}) parsed = json.loads(result) assert parsed["sum"] == 30 @pytest.mark.unit @pytest.mark.asyncio async def test_execute_code_text_stats(self, registry): config = { "source": "def run(args):\n text = args.get('text', '')\n return {'len': len(text), 'words': len(text.split())}", "_type": "code", } registry._custom_tool_configs["stats"] = config result = await registry.execute_tool("stats", {"text": "hello world test"}) parsed = json.loads(result) assert parsed["len"] == 16 assert parsed["words"] == 3 @pytest.mark.unit @pytest.mark.asyncio async def test_code_source_missing(self, registry): config = {"_type": "code"} # no source registry._custom_tool_configs["no_source"] = config result = await registry.execute_tool("no_source", {}) assert "error" in result @pytest.mark.unit @pytest.mark.asyncio async def test_code_no_run_function(self, registry): config = {"source": "x = 1", "_type": "code"} registry._custom_tool_configs["no_run"] = config result = await registry.execute_tool("no_run", {}) assert "run" in result @pytest.mark.unit @pytest.mark.asyncio async def test_code_runtime_error(self, registry): config = { "source": "def run(args):\n raise ValueError('test error')", "_type": "code", } registry._custom_tool_configs["err"] = config result = await registry.execute_tool("err", {}) assert "error" in result @pytest.mark.unit @pytest.mark.asyncio async def test_code_sandbox_restriction(self, registry): """验证 __builtins__ 被禁用""" config = { "source": "def run(args):\n import os\n return os.name", "_type": "code", } registry._custom_tool_configs["unsafe"] = config result = await registry.execute_tool("unsafe", {}) # import 在沙箱中应该失败 assert "error" in result @pytest.mark.unit @pytest.mark.asyncio async def test_code_sandbox_no_file_access(self, registry): """验证无法访问文件系统""" config = { "source": "def run(args):\n open('/etc/passwd')\n return 'ok'", "_type": "code", } registry._custom_tool_configs["file_access"] = config result = await registry.execute_tool("file_access", {}) assert "error" in result class TestToolRegistryTestHelpers: """测试工具(不保存到 DB)""" @pytest.mark.unit @pytest.mark.asyncio async def test_test_code_tool_success(self, registry): source = "def run(args):\n return {'result': args['x'] * 2}" result = await registry.test_code_tool(source, {"x": 5}) assert result["success"] is True assert result["result"]["result"] == 10 @pytest.mark.unit @pytest.mark.asyncio async def test_test_code_tool_compile_error(self, registry): source = "def run(args):\n invalid syntax{{{" result = await registry.test_code_tool(source, {}) assert result["success"] is False assert "error" in result @pytest.mark.unit @pytest.mark.asyncio async def test_test_http_tool(self, registry): with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request: mock_response = AsyncMock() mock_response.status_code = 200 mock_response.text = '{"ip": "8.8.8.8"}' mock_request.return_value = mock_response from app.services.tool_registry import _CODE_SAFE_GLOBALS result = await registry.test_http_tool( url="https://httpbin.org/get?ip={ip}", method="GET", headers={}, body=None, args={"ip": "8.8.8.8"}, timeout=5, ) assert result["success"] is True assert result["status_code"] == 200 class TestToolRegistryExecute: """execute_tool 整体流程""" @pytest.mark.unit @pytest.mark.asyncio async def test_execute_builtin(self, registry): def hello(**kwargs): return f"Hello, {kwargs.get('name', 'world')}!" registry.register_builtin_tool("hello", hello, {"function": {"name": "hello"}}) result = await registry.execute_tool("hello", {"name": "Test"}) assert "Hello, Test!" in result @pytest.mark.unit @pytest.mark.asyncio async def test_execute_nonexistent(self, registry): result = await registry.execute_tool("no_such_tool", {}) assert "error" in result @pytest.mark.unit @pytest.mark.asyncio async def test_execute_unsupported_type(self, registry): config = {"_type": "unsupported"} registry._custom_tool_configs["weird"] = config result = await registry.execute_tool("weird", {}) assert "error" in result @pytest.mark.unit @pytest.mark.asyncio async def test_execute_workflow_not_supported(self, registry): config = {"_type": "workflow"} registry._custom_tool_configs["wf"] = config result = await registry.execute_tool("wf", {}) assert "暂不支持" in result