- Fix delete agent 500: clean up FK records (agent_llm_logs, permissions, schedules, executions, team_members) and unbind goals/tasks before delete - Remove hardcoded personality templates in Android, replace with dynamic system prompt generation from name + description - Set promptSectionsEnabled=false to bypass PromptComposer for personality - Add Tencent Cloud Linux deployment guide (Docker Compose) - Accumulated backend service updates, frontend UI fixes, Android app changes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
389 lines
14 KiB
Python
389 lines
14 KiB
Python
"""
|
||
工具注册表单元测试
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import pytest
|
||
from unittest.mock import patch, AsyncMock
|
||
|
||
from app.services.tool_registry import ToolRegistry, _CODE_SAFE_GLOBALS
|
||
|
||
|
||
@pytest.fixture
|
||
def registry():
|
||
r = ToolRegistry()
|
||
return r
|
||
|
||
|
||
class TestToolRegistryBuiltin:
|
||
"""内置工具注册与查询"""
|
||
|
||
@pytest.mark.unit
|
||
def test_register_and_get(self, registry):
|
||
def my_tool(**kwargs):
|
||
return {"result": kwargs.get("x", 0) + kwargs.get("y", 0)}
|
||
|
||
schema = {
|
||
"type": "function",
|
||
"function": {
|
||
"name": "add",
|
||
"description": "加法",
|
||
"parameters": {"type": "object", "properties": {"x": {"type": "number"}, "y": {"type": "number"}}},
|
||
},
|
||
}
|
||
registry.register_builtin_tool("add", my_tool, schema)
|
||
assert registry.get_tool_schema("add") == schema
|
||
assert registry.get_tool_function("add") == my_tool
|
||
assert registry.builtin_tool_count() == 1
|
||
assert "add" in registry.builtin_tool_names()
|
||
|
||
@pytest.mark.unit
|
||
def test_get_missing_tool(self, registry):
|
||
assert registry.get_tool_schema("nonexistent") is None
|
||
assert registry.get_tool_function("nonexistent") is None
|
||
|
||
@pytest.mark.unit
|
||
def test_get_all_schemas(self, registry):
|
||
schema1 = {"type": "function", "function": {"name": "tool1"}}
|
||
schema2 = {"type": "function", "function": {"name": "tool2"}}
|
||
registry.register_builtin_tool("tool1", lambda: None, schema1)
|
||
registry.register_builtin_tool("tool2", lambda: None, schema2)
|
||
assert len(registry.get_all_tool_schemas()) == 2
|
||
|
||
@pytest.mark.unit
|
||
def test_get_tools_by_names(self, registry):
|
||
registry.register_builtin_tool("a", lambda: None, {"function": {"name": "a"}})
|
||
registry.register_builtin_tool("b", lambda: None, {"function": {"name": "b"}})
|
||
tools = registry.get_tools_by_names(["a", "b", "c"])
|
||
assert len(tools) == 2
|
||
assert tools[0]["function"]["name"] == "a"
|
||
|
||
@pytest.mark.unit
|
||
def test_sync_function_execution(self, registry):
|
||
def sync_func(x=0, y=0):
|
||
return x + y
|
||
|
||
registry.register_builtin_tool("add", sync_func, {"function": {"name": "add"}})
|
||
result = registry._run_function(sync_func, "add", {"x": 3, "y": 4})
|
||
import asyncio
|
||
result = asyncio.run(registry._run_function(sync_func, "add", {"x": 3, "y": 4}))
|
||
parsed = json.loads(result)
|
||
assert parsed == 7
|
||
|
||
|
||
@pytest.mark.usefixtures("registry")
|
||
class TestToolRegistryHTTP:
|
||
"""HTTP 工具执行测试(mock httpx)"""
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_http(self, registry):
|
||
config = {
|
||
"url": "https://api.example.com/data?q={query}",
|
||
"method": "GET",
|
||
"headers": {"Authorization": "Bearer token"},
|
||
"timeout": 10,
|
||
"_type": "http",
|
||
}
|
||
registry._custom_tool_configs["test_http"] = config
|
||
|
||
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
|
||
mock_response = AsyncMock()
|
||
mock_response.status_code = 200
|
||
mock_response.text = '{"result": "ok"}'
|
||
mock_request.return_value = mock_response
|
||
|
||
result = await registry.execute_tool("test_http", {"query": "hello"})
|
||
parsed = json.loads(result)
|
||
assert parsed["status_code"] == 200
|
||
# 验证 URL 模板替换
|
||
called_url = mock_request.call_args[0][1]
|
||
assert "hello" in called_url
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_http_post_with_body(self, registry):
|
||
config = {
|
||
"url": "https://api.example.com/submit",
|
||
"method": "POST",
|
||
"headers": {},
|
||
"body_template": {"name": "{name}", "age": "{age}"},
|
||
"timeout": 10,
|
||
"_type": "http",
|
||
}
|
||
registry._custom_tool_configs["test_post"] = config
|
||
|
||
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
|
||
mock_response = AsyncMock()
|
||
mock_response.status_code = 201
|
||
mock_response.text = '{"id": 1}'
|
||
mock_request.return_value = mock_response
|
||
|
||
result = await registry.execute_tool("test_post", {"name": "Alice", "age": 30})
|
||
parsed = json.loads(result)
|
||
assert parsed["status_code"] == 201
|
||
# Verify POST method
|
||
assert mock_request.call_args[0][0] == "POST"
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_http_no_url(self, registry):
|
||
config = {"_type": "http"}
|
||
registry._custom_tool_configs["bad_http"] = config
|
||
result = await registry.execute_tool("bad_http", {})
|
||
assert "error" in result
|
||
|
||
|
||
class TestToolRegistryCode:
|
||
"""代码沙箱执行测试"""
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_code_simple(self, registry):
|
||
config = {
|
||
"source": "def run(args):\n return {'sum': args['a'] + args['b']}",
|
||
"_type": "code",
|
||
}
|
||
registry._custom_tool_configs["calc"] = config
|
||
result = await registry.execute_tool("calc", {"a": 10, "b": 20})
|
||
parsed = json.loads(result)
|
||
assert parsed["sum"] == 30
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_code_text_stats(self, registry):
|
||
config = {
|
||
"source": "def run(args):\n text = args.get('text', '')\n return {'len': len(text), 'words': len(text.split())}",
|
||
"_type": "code",
|
||
}
|
||
registry._custom_tool_configs["stats"] = config
|
||
result = await registry.execute_tool("stats", {"text": "hello world test"})
|
||
parsed = json.loads(result)
|
||
assert parsed["len"] == 16
|
||
assert parsed["words"] == 3
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_code_source_missing(self, registry):
|
||
config = {"_type": "code"} # no source
|
||
registry._custom_tool_configs["no_source"] = config
|
||
result = await registry.execute_tool("no_source", {})
|
||
assert "error" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_code_no_run_function(self, registry):
|
||
config = {"source": "x = 1", "_type": "code"}
|
||
registry._custom_tool_configs["no_run"] = config
|
||
result = await registry.execute_tool("no_run", {})
|
||
assert "run" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_code_runtime_error(self, registry):
|
||
config = {
|
||
"source": "def run(args):\n raise ValueError('test error')",
|
||
"_type": "code",
|
||
}
|
||
registry._custom_tool_configs["err"] = config
|
||
result = await registry.execute_tool("err", {})
|
||
assert "error" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_code_sandbox_restriction(self, registry):
|
||
"""验证 __builtins__ 被禁用"""
|
||
config = {
|
||
"source": "def run(args):\n import os\n return os.name",
|
||
"_type": "code",
|
||
}
|
||
registry._custom_tool_configs["unsafe"] = config
|
||
result = await registry.execute_tool("unsafe", {})
|
||
# import 在沙箱中应该失败
|
||
assert "error" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_code_sandbox_no_file_access(self, registry):
|
||
"""验证无法访问文件系统"""
|
||
config = {
|
||
"source": "def run(args):\n open('/etc/passwd')\n return 'ok'",
|
||
"_type": "code",
|
||
}
|
||
registry._custom_tool_configs["file_access"] = config
|
||
result = await registry.execute_tool("file_access", {})
|
||
assert "error" in result
|
||
|
||
|
||
class TestToolRegistryTestHelpers:
|
||
"""测试工具(不保存到 DB)"""
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_test_code_tool_success(self, registry):
|
||
source = "def run(args):\n return {'result': args['x'] * 2}"
|
||
result = await registry.test_code_tool(source, {"x": 5})
|
||
assert result["success"] is True
|
||
assert result["result"]["result"] == 10
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_test_code_tool_compile_error(self, registry):
|
||
source = "def run(args):\n invalid syntax{{{"
|
||
result = await registry.test_code_tool(source, {})
|
||
assert result["success"] is False
|
||
assert "error" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_test_http_tool(self, registry):
|
||
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
|
||
mock_response = AsyncMock()
|
||
mock_response.status_code = 200
|
||
mock_response.text = '{"ip": "8.8.8.8"}'
|
||
mock_request.return_value = mock_response
|
||
|
||
from app.services.tool_registry import _CODE_SAFE_GLOBALS
|
||
result = await registry.test_http_tool(
|
||
url="https://httpbin.org/get?ip={ip}",
|
||
method="GET",
|
||
headers={},
|
||
body=None,
|
||
args={"ip": "8.8.8.8"},
|
||
timeout=5,
|
||
)
|
||
assert result["success"] is True
|
||
assert result["status_code"] == 200
|
||
|
||
|
||
class TestToolRegistryExecute:
|
||
"""execute_tool 整体流程"""
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_builtin(self, registry):
|
||
def hello(name="world"):
|
||
return f"Hello, {name}!"
|
||
|
||
registry.register_builtin_tool("hello", hello, {"function": {"name": "hello"}})
|
||
result = await registry.execute_tool("hello", {"name": "Test"})
|
||
assert "Hello, Test!" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_nonexistent(self, registry):
|
||
result = await registry.execute_tool("no_such_tool", {})
|
||
assert "error" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_unsupported_type(self, registry):
|
||
config = {"_type": "unsupported"}
|
||
registry._custom_tool_configs["weird"] = config
|
||
result = await registry.execute_tool("weird", {})
|
||
assert "error" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_workflow_not_supported(self, registry):
|
||
config = {"_type": "workflow"}
|
||
registry._custom_tool_configs["wf"] = config
|
||
result = await registry.execute_tool("wf", {})
|
||
assert "暂不支持" in result
|
||
|
||
|
||
class TestToolRegistryParameterFiltering:
|
||
"""工具参数过滤测试 —— LLM 生成错误参数名 → 过滤 → 正常执行"""
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_unknown_params_filtered(self, registry):
|
||
"""未知参数被过滤,工具正常执行"""
|
||
def greet(name="world"):
|
||
return f"Hello, {name}!"
|
||
|
||
registry.register_builtin_tool("greet", greet, {"function": {"name": "greet"}})
|
||
# LLM 生成了未知参数 extra_field
|
||
result = await registry._run_function(greet, "greet", {
|
||
"name": "Test", "extra_field": "should_be_filtered"
|
||
})
|
||
assert "Hello, Test!" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_missing_params_use_defaults(self, registry):
|
||
"""缺少参数时使用函数默认值"""
|
||
def greet(name="world"):
|
||
return f"Hello, {name}!"
|
||
|
||
registry.register_builtin_tool("greet", greet, {"function": {"name": "greet"}})
|
||
result = await registry._run_function(greet, "greet", {})
|
||
assert "Hello, world!" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_no_valid_params_does_not_crash(self, registry):
|
||
"""全部参数都无效时不崩溃"""
|
||
def greet(name="world"):
|
||
return f"Hello, {name}!"
|
||
|
||
registry.register_builtin_tool("greet", greet, {"function": {"name": "greet"}})
|
||
# 所有参数都不匹配
|
||
result = await registry._run_function(greet, "greet", {
|
||
"bad_param_1": 1, "bad_param_2": 2
|
||
})
|
||
assert "Hello, world!" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_mixed_valid_and_invalid_params(self, registry):
|
||
"""混合有效/无效参数时只使用有效参数"""
|
||
def calc(x=0, y=0):
|
||
return str(x + y)
|
||
|
||
registry.register_builtin_tool("calc", calc, {"function": {"name": "calc"}})
|
||
result = await registry._run_function(calc, "calc", {
|
||
"x": 3, "y": 4, "z": 100, "unused": "ignored"
|
||
})
|
||
assert "7" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_type_coercion_handled_gracefully(self, registry):
|
||
"""类型不匹配时使用 Python 默认行为"""
|
||
def add(x=0, y=0):
|
||
return str(x + y)
|
||
|
||
registry.register_builtin_tool("add", add, {"function": {"name": "add"}})
|
||
# 字符串参数可能导致 TypeError,但不应崩溃
|
||
result = await registry._run_function(add, "add", {
|
||
"x": "hello", "y": "world"
|
||
})
|
||
# 要么成功(字符串拼接),要么失败但不崩溃
|
||
assert isinstance(result, str)
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_empty_args_dict(self, registry):
|
||
"""空参数字典正常执行"""
|
||
def status():
|
||
return "ok"
|
||
|
||
registry.register_builtin_tool("status", status, {"function": {"name": "status"}})
|
||
result = await registry._run_function(status, "status", {})
|
||
assert "ok" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_extra_positional_not_in_sig(self, registry):
|
||
"""参数名不在函数签名中时被过滤"""
|
||
def read_file(path=""):
|
||
return f"read: {path}"
|
||
|
||
registry.register_builtin_tool("read_file", read_file, {"function": {"name": "read_file"}})
|
||
result = await registry._run_function(read_file, "read_file", {
|
||
"path": "/data.txt", "encoding": "utf-8", "mode": "r"
|
||
})
|
||
assert "read: /data.txt" in result
|