- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser - 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行) - 新增 agent_vector_memory / knowledge_base 模型及对应数据库表 - 实现 SSE 流式响应与 Agent 预算控制 - AgentChat.vue 集成 MainLayout 导航布局 - 完善测试体系:7 个新测试文件共 110 个测试覆盖 - 修复 conftest.py SQLite 内存数据库连接隔离问题 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
294 lines
10 KiB
Python
294 lines
10 KiB
Python
"""
|
||
工具注册表单元测试
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import pytest
|
||
from unittest.mock import patch, AsyncMock
|
||
|
||
from app.services.tool_registry import ToolRegistry, _CODE_SAFE_GLOBALS
|
||
|
||
|
||
@pytest.fixture
|
||
def registry():
|
||
r = ToolRegistry()
|
||
return r
|
||
|
||
|
||
class TestToolRegistryBuiltin:
|
||
"""内置工具注册与查询"""
|
||
|
||
@pytest.mark.unit
|
||
def test_register_and_get(self, registry):
|
||
def my_tool(**kwargs):
|
||
return {"result": kwargs.get("x", 0) + kwargs.get("y", 0)}
|
||
|
||
schema = {
|
||
"type": "function",
|
||
"function": {
|
||
"name": "add",
|
||
"description": "加法",
|
||
"parameters": {"type": "object", "properties": {"x": {"type": "number"}, "y": {"type": "number"}}},
|
||
},
|
||
}
|
||
registry.register_builtin_tool("add", my_tool, schema)
|
||
assert registry.get_tool_schema("add") == schema
|
||
assert registry.get_tool_function("add") == my_tool
|
||
assert registry.builtin_tool_count() == 1
|
||
assert "add" in registry.builtin_tool_names()
|
||
|
||
@pytest.mark.unit
|
||
def test_get_missing_tool(self, registry):
|
||
assert registry.get_tool_schema("nonexistent") is None
|
||
assert registry.get_tool_function("nonexistent") is None
|
||
|
||
@pytest.mark.unit
|
||
def test_get_all_schemas(self, registry):
|
||
schema1 = {"type": "function", "function": {"name": "tool1"}}
|
||
schema2 = {"type": "function", "function": {"name": "tool2"}}
|
||
registry.register_builtin_tool("tool1", lambda: None, schema1)
|
||
registry.register_builtin_tool("tool2", lambda: None, schema2)
|
||
assert len(registry.get_all_tool_schemas()) == 2
|
||
|
||
@pytest.mark.unit
|
||
def test_get_tools_by_names(self, registry):
|
||
registry.register_builtin_tool("a", lambda: None, {"function": {"name": "a"}})
|
||
registry.register_builtin_tool("b", lambda: None, {"function": {"name": "b"}})
|
||
tools = registry.get_tools_by_names(["a", "b", "c"])
|
||
assert len(tools) == 2
|
||
assert tools[0]["function"]["name"] == "a"
|
||
|
||
@pytest.mark.unit
|
||
def test_sync_function_execution(self, registry):
|
||
def sync_func(x=0, y=0):
|
||
return x + y
|
||
|
||
registry.register_builtin_tool("add", sync_func, {"function": {"name": "add"}})
|
||
result = registry._run_function(sync_func, "add", {"x": 3, "y": 4})
|
||
import asyncio
|
||
result = asyncio.run(registry._run_function(sync_func, "add", {"x": 3, "y": 4}))
|
||
parsed = json.loads(result)
|
||
assert parsed == 7
|
||
|
||
|
||
@pytest.mark.usefixtures("registry")
|
||
class TestToolRegistryHTTP:
|
||
"""HTTP 工具执行测试(mock httpx)"""
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_http(self, registry):
|
||
config = {
|
||
"url": "https://api.example.com/data?q={query}",
|
||
"method": "GET",
|
||
"headers": {"Authorization": "Bearer token"},
|
||
"timeout": 10,
|
||
"_type": "http",
|
||
}
|
||
registry._custom_tool_configs["test_http"] = config
|
||
|
||
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
|
||
mock_response = AsyncMock()
|
||
mock_response.status_code = 200
|
||
mock_response.text = '{"result": "ok"}'
|
||
mock_request.return_value = mock_response
|
||
|
||
result = await registry.execute_tool("test_http", {"query": "hello"})
|
||
parsed = json.loads(result)
|
||
assert parsed["status_code"] == 200
|
||
# 验证 URL 模板替换
|
||
called_url = mock_request.call_args[0][1]
|
||
assert "hello" in called_url
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_http_post_with_body(self, registry):
|
||
config = {
|
||
"url": "https://api.example.com/submit",
|
||
"method": "POST",
|
||
"headers": {},
|
||
"body_template": {"name": "{name}", "age": "{age}"},
|
||
"timeout": 10,
|
||
"_type": "http",
|
||
}
|
||
registry._custom_tool_configs["test_post"] = config
|
||
|
||
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
|
||
mock_response = AsyncMock()
|
||
mock_response.status_code = 201
|
||
mock_response.text = '{"id": 1}'
|
||
mock_request.return_value = mock_response
|
||
|
||
result = await registry.execute_tool("test_post", {"name": "Alice", "age": 30})
|
||
parsed = json.loads(result)
|
||
assert parsed["status_code"] == 201
|
||
# Verify POST method
|
||
assert mock_request.call_args[0][0] == "POST"
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_http_no_url(self, registry):
|
||
config = {"_type": "http"}
|
||
registry._custom_tool_configs["bad_http"] = config
|
||
result = await registry.execute_tool("bad_http", {})
|
||
assert "error" in result
|
||
|
||
|
||
class TestToolRegistryCode:
|
||
"""代码沙箱执行测试"""
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_code_simple(self, registry):
|
||
config = {
|
||
"source": "def run(args):\n return {'sum': args['a'] + args['b']}",
|
||
"_type": "code",
|
||
}
|
||
registry._custom_tool_configs["calc"] = config
|
||
result = await registry.execute_tool("calc", {"a": 10, "b": 20})
|
||
parsed = json.loads(result)
|
||
assert parsed["sum"] == 30
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_code_text_stats(self, registry):
|
||
config = {
|
||
"source": "def run(args):\n text = args.get('text', '')\n return {'len': len(text), 'words': len(text.split())}",
|
||
"_type": "code",
|
||
}
|
||
registry._custom_tool_configs["stats"] = config
|
||
result = await registry.execute_tool("stats", {"text": "hello world test"})
|
||
parsed = json.loads(result)
|
||
assert parsed["len"] == 16
|
||
assert parsed["words"] == 3
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_code_source_missing(self, registry):
|
||
config = {"_type": "code"} # no source
|
||
registry._custom_tool_configs["no_source"] = config
|
||
result = await registry.execute_tool("no_source", {})
|
||
assert "error" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_code_no_run_function(self, registry):
|
||
config = {"source": "x = 1", "_type": "code"}
|
||
registry._custom_tool_configs["no_run"] = config
|
||
result = await registry.execute_tool("no_run", {})
|
||
assert "run" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_code_runtime_error(self, registry):
|
||
config = {
|
||
"source": "def run(args):\n raise ValueError('test error')",
|
||
"_type": "code",
|
||
}
|
||
registry._custom_tool_configs["err"] = config
|
||
result = await registry.execute_tool("err", {})
|
||
assert "error" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_code_sandbox_restriction(self, registry):
|
||
"""验证 __builtins__ 被禁用"""
|
||
config = {
|
||
"source": "def run(args):\n import os\n return os.name",
|
||
"_type": "code",
|
||
}
|
||
registry._custom_tool_configs["unsafe"] = config
|
||
result = await registry.execute_tool("unsafe", {})
|
||
# import 在沙箱中应该失败
|
||
assert "error" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_code_sandbox_no_file_access(self, registry):
|
||
"""验证无法访问文件系统"""
|
||
config = {
|
||
"source": "def run(args):\n open('/etc/passwd')\n return 'ok'",
|
||
"_type": "code",
|
||
}
|
||
registry._custom_tool_configs["file_access"] = config
|
||
result = await registry.execute_tool("file_access", {})
|
||
assert "error" in result
|
||
|
||
|
||
class TestToolRegistryTestHelpers:
|
||
"""测试工具(不保存到 DB)"""
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_test_code_tool_success(self, registry):
|
||
source = "def run(args):\n return {'result': args['x'] * 2}"
|
||
result = await registry.test_code_tool(source, {"x": 5})
|
||
assert result["success"] is True
|
||
assert result["result"]["result"] == 10
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_test_code_tool_compile_error(self, registry):
|
||
source = "def run(args):\n invalid syntax{{{"
|
||
result = await registry.test_code_tool(source, {})
|
||
assert result["success"] is False
|
||
assert "error" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_test_http_tool(self, registry):
|
||
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
|
||
mock_response = AsyncMock()
|
||
mock_response.status_code = 200
|
||
mock_response.text = '{"ip": "8.8.8.8"}'
|
||
mock_request.return_value = mock_response
|
||
|
||
from app.services.tool_registry import _CODE_SAFE_GLOBALS
|
||
result = await registry.test_http_tool(
|
||
url="https://httpbin.org/get?ip={ip}",
|
||
method="GET",
|
||
headers={},
|
||
body=None,
|
||
args={"ip": "8.8.8.8"},
|
||
timeout=5,
|
||
)
|
||
assert result["success"] is True
|
||
assert result["status_code"] == 200
|
||
|
||
|
||
class TestToolRegistryExecute:
|
||
"""execute_tool 整体流程"""
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_builtin(self, registry):
|
||
def hello(**kwargs):
|
||
return f"Hello, {kwargs.get('name', 'world')}!"
|
||
|
||
registry.register_builtin_tool("hello", hello, {"function": {"name": "hello"}})
|
||
result = await registry.execute_tool("hello", {"name": "Test"})
|
||
assert "Hello, Test!" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_nonexistent(self, registry):
|
||
result = await registry.execute_tool("no_such_tool", {})
|
||
assert "error" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_unsupported_type(self, registry):
|
||
config = {"_type": "unsupported"}
|
||
registry._custom_tool_configs["weird"] = config
|
||
result = await registry.execute_tool("weird", {})
|
||
assert "error" in result
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_execute_workflow_not_supported(self, registry):
|
||
config = {"_type": "workflow"}
|
||
registry._custom_tool_configs["wf"] = config
|
||
result = await registry.execute_tool("wf", {})
|
||
assert "暂不支持" in result
|