Files
aiagent/backend/tests/test_tool_registry.py
renjianbo beff3fac8d fix: delete agent 500 error + dynamic personality + deployment guide
- Fix delete agent 500: clean up FK records (agent_llm_logs, permissions,
  schedules, executions, team_members) and unbind goals/tasks before delete
- Remove hardcoded personality templates in Android, replace with dynamic
  system prompt generation from name + description
- Set promptSectionsEnabled=false to bypass PromptComposer for personality
- Add Tencent Cloud Linux deployment guide (Docker Compose)
- Accumulated backend service updates, frontend UI fixes, Android app changes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-06-29 01:17:21 +08:00

389 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工具注册表单元测试
"""
from __future__ import annotations
import json
import pytest
from unittest.mock import patch, AsyncMock
from app.services.tool_registry import ToolRegistry, _CODE_SAFE_GLOBALS
@pytest.fixture
def registry():
r = ToolRegistry()
return r
class TestToolRegistryBuiltin:
"""内置工具注册与查询"""
@pytest.mark.unit
def test_register_and_get(self, registry):
def my_tool(**kwargs):
return {"result": kwargs.get("x", 0) + kwargs.get("y", 0)}
schema = {
"type": "function",
"function": {
"name": "add",
"description": "加法",
"parameters": {"type": "object", "properties": {"x": {"type": "number"}, "y": {"type": "number"}}},
},
}
registry.register_builtin_tool("add", my_tool, schema)
assert registry.get_tool_schema("add") == schema
assert registry.get_tool_function("add") == my_tool
assert registry.builtin_tool_count() == 1
assert "add" in registry.builtin_tool_names()
@pytest.mark.unit
def test_get_missing_tool(self, registry):
assert registry.get_tool_schema("nonexistent") is None
assert registry.get_tool_function("nonexistent") is None
@pytest.mark.unit
def test_get_all_schemas(self, registry):
schema1 = {"type": "function", "function": {"name": "tool1"}}
schema2 = {"type": "function", "function": {"name": "tool2"}}
registry.register_builtin_tool("tool1", lambda: None, schema1)
registry.register_builtin_tool("tool2", lambda: None, schema2)
assert len(registry.get_all_tool_schemas()) == 2
@pytest.mark.unit
def test_get_tools_by_names(self, registry):
registry.register_builtin_tool("a", lambda: None, {"function": {"name": "a"}})
registry.register_builtin_tool("b", lambda: None, {"function": {"name": "b"}})
tools = registry.get_tools_by_names(["a", "b", "c"])
assert len(tools) == 2
assert tools[0]["function"]["name"] == "a"
@pytest.mark.unit
def test_sync_function_execution(self, registry):
def sync_func(x=0, y=0):
return x + y
registry.register_builtin_tool("add", sync_func, {"function": {"name": "add"}})
result = registry._run_function(sync_func, "add", {"x": 3, "y": 4})
import asyncio
result = asyncio.run(registry._run_function(sync_func, "add", {"x": 3, "y": 4}))
parsed = json.loads(result)
assert parsed == 7
@pytest.mark.usefixtures("registry")
class TestToolRegistryHTTP:
"""HTTP 工具执行测试mock httpx"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_http(self, registry):
config = {
"url": "https://api.example.com/data?q={query}",
"method": "GET",
"headers": {"Authorization": "Bearer token"},
"timeout": 10,
"_type": "http",
}
registry._custom_tool_configs["test_http"] = config
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.text = '{"result": "ok"}'
mock_request.return_value = mock_response
result = await registry.execute_tool("test_http", {"query": "hello"})
parsed = json.loads(result)
assert parsed["status_code"] == 200
# 验证 URL 模板替换
called_url = mock_request.call_args[0][1]
assert "hello" in called_url
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_http_post_with_body(self, registry):
config = {
"url": "https://api.example.com/submit",
"method": "POST",
"headers": {},
"body_template": {"name": "{name}", "age": "{age}"},
"timeout": 10,
"_type": "http",
}
registry._custom_tool_configs["test_post"] = config
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
mock_response = AsyncMock()
mock_response.status_code = 201
mock_response.text = '{"id": 1}'
mock_request.return_value = mock_response
result = await registry.execute_tool("test_post", {"name": "Alice", "age": 30})
parsed = json.loads(result)
assert parsed["status_code"] == 201
# Verify POST method
assert mock_request.call_args[0][0] == "POST"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_http_no_url(self, registry):
config = {"_type": "http"}
registry._custom_tool_configs["bad_http"] = config
result = await registry.execute_tool("bad_http", {})
assert "error" in result
class TestToolRegistryCode:
"""代码沙箱执行测试"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_code_simple(self, registry):
config = {
"source": "def run(args):\n return {'sum': args['a'] + args['b']}",
"_type": "code",
}
registry._custom_tool_configs["calc"] = config
result = await registry.execute_tool("calc", {"a": 10, "b": 20})
parsed = json.loads(result)
assert parsed["sum"] == 30
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_code_text_stats(self, registry):
config = {
"source": "def run(args):\n text = args.get('text', '')\n return {'len': len(text), 'words': len(text.split())}",
"_type": "code",
}
registry._custom_tool_configs["stats"] = config
result = await registry.execute_tool("stats", {"text": "hello world test"})
parsed = json.loads(result)
assert parsed["len"] == 16
assert parsed["words"] == 3
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_source_missing(self, registry):
config = {"_type": "code"} # no source
registry._custom_tool_configs["no_source"] = config
result = await registry.execute_tool("no_source", {})
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_no_run_function(self, registry):
config = {"source": "x = 1", "_type": "code"}
registry._custom_tool_configs["no_run"] = config
result = await registry.execute_tool("no_run", {})
assert "run" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_runtime_error(self, registry):
config = {
"source": "def run(args):\n raise ValueError('test error')",
"_type": "code",
}
registry._custom_tool_configs["err"] = config
result = await registry.execute_tool("err", {})
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_sandbox_restriction(self, registry):
"""验证 __builtins__ 被禁用"""
config = {
"source": "def run(args):\n import os\n return os.name",
"_type": "code",
}
registry._custom_tool_configs["unsafe"] = config
result = await registry.execute_tool("unsafe", {})
# import 在沙箱中应该失败
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_sandbox_no_file_access(self, registry):
"""验证无法访问文件系统"""
config = {
"source": "def run(args):\n open('/etc/passwd')\n return 'ok'",
"_type": "code",
}
registry._custom_tool_configs["file_access"] = config
result = await registry.execute_tool("file_access", {})
assert "error" in result
class TestToolRegistryTestHelpers:
"""测试工具(不保存到 DB"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_test_code_tool_success(self, registry):
source = "def run(args):\n return {'result': args['x'] * 2}"
result = await registry.test_code_tool(source, {"x": 5})
assert result["success"] is True
assert result["result"]["result"] == 10
@pytest.mark.unit
@pytest.mark.asyncio
async def test_test_code_tool_compile_error(self, registry):
source = "def run(args):\n invalid syntax{{{"
result = await registry.test_code_tool(source, {})
assert result["success"] is False
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_test_http_tool(self, registry):
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.text = '{"ip": "8.8.8.8"}'
mock_request.return_value = mock_response
from app.services.tool_registry import _CODE_SAFE_GLOBALS
result = await registry.test_http_tool(
url="https://httpbin.org/get?ip={ip}",
method="GET",
headers={},
body=None,
args={"ip": "8.8.8.8"},
timeout=5,
)
assert result["success"] is True
assert result["status_code"] == 200
class TestToolRegistryExecute:
"""execute_tool 整体流程"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_builtin(self, registry):
def hello(name="world"):
return f"Hello, {name}!"
registry.register_builtin_tool("hello", hello, {"function": {"name": "hello"}})
result = await registry.execute_tool("hello", {"name": "Test"})
assert "Hello, Test!" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_nonexistent(self, registry):
result = await registry.execute_tool("no_such_tool", {})
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_unsupported_type(self, registry):
config = {"_type": "unsupported"}
registry._custom_tool_configs["weird"] = config
result = await registry.execute_tool("weird", {})
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_workflow_not_supported(self, registry):
config = {"_type": "workflow"}
registry._custom_tool_configs["wf"] = config
result = await registry.execute_tool("wf", {})
assert "暂不支持" in result
class TestToolRegistryParameterFiltering:
"""工具参数过滤测试 —— LLM 生成错误参数名 → 过滤 → 正常执行"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_unknown_params_filtered(self, registry):
"""未知参数被过滤,工具正常执行"""
def greet(name="world"):
return f"Hello, {name}!"
registry.register_builtin_tool("greet", greet, {"function": {"name": "greet"}})
# LLM 生成了未知参数 extra_field
result = await registry._run_function(greet, "greet", {
"name": "Test", "extra_field": "should_be_filtered"
})
assert "Hello, Test!" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_missing_params_use_defaults(self, registry):
"""缺少参数时使用函数默认值"""
def greet(name="world"):
return f"Hello, {name}!"
registry.register_builtin_tool("greet", greet, {"function": {"name": "greet"}})
result = await registry._run_function(greet, "greet", {})
assert "Hello, world!" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_no_valid_params_does_not_crash(self, registry):
"""全部参数都无效时不崩溃"""
def greet(name="world"):
return f"Hello, {name}!"
registry.register_builtin_tool("greet", greet, {"function": {"name": "greet"}})
# 所有参数都不匹配
result = await registry._run_function(greet, "greet", {
"bad_param_1": 1, "bad_param_2": 2
})
assert "Hello, world!" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_mixed_valid_and_invalid_params(self, registry):
"""混合有效/无效参数时只使用有效参数"""
def calc(x=0, y=0):
return str(x + y)
registry.register_builtin_tool("calc", calc, {"function": {"name": "calc"}})
result = await registry._run_function(calc, "calc", {
"x": 3, "y": 4, "z": 100, "unused": "ignored"
})
assert "7" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_type_coercion_handled_gracefully(self, registry):
"""类型不匹配时使用 Python 默认行为"""
def add(x=0, y=0):
return str(x + y)
registry.register_builtin_tool("add", add, {"function": {"name": "add"}})
# 字符串参数可能导致 TypeError但不应崩溃
result = await registry._run_function(add, "add", {
"x": "hello", "y": "world"
})
# 要么成功(字符串拼接),要么失败但不崩溃
assert isinstance(result, str)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_empty_args_dict(self, registry):
"""空参数字典正常执行"""
def status():
return "ok"
registry.register_builtin_tool("status", status, {"function": {"name": "status"}})
result = await registry._run_function(status, "status", {})
assert "ok" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_extra_positional_not_in_sig(self, registry):
"""参数名不在函数签名中时被过滤"""
def read_file(path=""):
return f"read: {path}"
registry.register_builtin_tool("read_file", read_file, {"function": {"name": "read_file"}})
result = await registry._run_function(read_file, "read_file", {
"path": "/data.txt", "encoding": "utf-8", "mode": "r"
})
assert "read: /data.txt" in result