Files
aiagent/backend/tests/test_tool_registry.py
renjianbo 7b9e0826de feat: 向量记忆 RAG、工具市场、SSE 流式响应、前端集成与测试覆盖
- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser
- 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行)
- 新增 agent_vector_memory / knowledge_base 模型及对应数据库表
- 实现 SSE 流式响应与 Agent 预算控制
- AgentChat.vue 集成 MainLayout 导航布局
- 完善测试体系:7 个新测试文件共 110 个测试覆盖
- 修复 conftest.py SQLite 内存数据库连接隔离问题

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-01 22:30:46 +08:00

294 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工具注册表单元测试
"""
from __future__ import annotations
import json
import pytest
from unittest.mock import patch, AsyncMock
from app.services.tool_registry import ToolRegistry, _CODE_SAFE_GLOBALS
@pytest.fixture
def registry():
r = ToolRegistry()
return r
class TestToolRegistryBuiltin:
"""内置工具注册与查询"""
@pytest.mark.unit
def test_register_and_get(self, registry):
def my_tool(**kwargs):
return {"result": kwargs.get("x", 0) + kwargs.get("y", 0)}
schema = {
"type": "function",
"function": {
"name": "add",
"description": "加法",
"parameters": {"type": "object", "properties": {"x": {"type": "number"}, "y": {"type": "number"}}},
},
}
registry.register_builtin_tool("add", my_tool, schema)
assert registry.get_tool_schema("add") == schema
assert registry.get_tool_function("add") == my_tool
assert registry.builtin_tool_count() == 1
assert "add" in registry.builtin_tool_names()
@pytest.mark.unit
def test_get_missing_tool(self, registry):
assert registry.get_tool_schema("nonexistent") is None
assert registry.get_tool_function("nonexistent") is None
@pytest.mark.unit
def test_get_all_schemas(self, registry):
schema1 = {"type": "function", "function": {"name": "tool1"}}
schema2 = {"type": "function", "function": {"name": "tool2"}}
registry.register_builtin_tool("tool1", lambda: None, schema1)
registry.register_builtin_tool("tool2", lambda: None, schema2)
assert len(registry.get_all_tool_schemas()) == 2
@pytest.mark.unit
def test_get_tools_by_names(self, registry):
registry.register_builtin_tool("a", lambda: None, {"function": {"name": "a"}})
registry.register_builtin_tool("b", lambda: None, {"function": {"name": "b"}})
tools = registry.get_tools_by_names(["a", "b", "c"])
assert len(tools) == 2
assert tools[0]["function"]["name"] == "a"
@pytest.mark.unit
def test_sync_function_execution(self, registry):
def sync_func(x=0, y=0):
return x + y
registry.register_builtin_tool("add", sync_func, {"function": {"name": "add"}})
result = registry._run_function(sync_func, "add", {"x": 3, "y": 4})
import asyncio
result = asyncio.run(registry._run_function(sync_func, "add", {"x": 3, "y": 4}))
parsed = json.loads(result)
assert parsed == 7
@pytest.mark.usefixtures("registry")
class TestToolRegistryHTTP:
"""HTTP 工具执行测试mock httpx"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_http(self, registry):
config = {
"url": "https://api.example.com/data?q={query}",
"method": "GET",
"headers": {"Authorization": "Bearer token"},
"timeout": 10,
"_type": "http",
}
registry._custom_tool_configs["test_http"] = config
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.text = '{"result": "ok"}'
mock_request.return_value = mock_response
result = await registry.execute_tool("test_http", {"query": "hello"})
parsed = json.loads(result)
assert parsed["status_code"] == 200
# 验证 URL 模板替换
called_url = mock_request.call_args[0][1]
assert "hello" in called_url
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_http_post_with_body(self, registry):
config = {
"url": "https://api.example.com/submit",
"method": "POST",
"headers": {},
"body_template": {"name": "{name}", "age": "{age}"},
"timeout": 10,
"_type": "http",
}
registry._custom_tool_configs["test_post"] = config
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
mock_response = AsyncMock()
mock_response.status_code = 201
mock_response.text = '{"id": 1}'
mock_request.return_value = mock_response
result = await registry.execute_tool("test_post", {"name": "Alice", "age": 30})
parsed = json.loads(result)
assert parsed["status_code"] == 201
# Verify POST method
assert mock_request.call_args[0][0] == "POST"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_http_no_url(self, registry):
config = {"_type": "http"}
registry._custom_tool_configs["bad_http"] = config
result = await registry.execute_tool("bad_http", {})
assert "error" in result
class TestToolRegistryCode:
"""代码沙箱执行测试"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_code_simple(self, registry):
config = {
"source": "def run(args):\n return {'sum': args['a'] + args['b']}",
"_type": "code",
}
registry._custom_tool_configs["calc"] = config
result = await registry.execute_tool("calc", {"a": 10, "b": 20})
parsed = json.loads(result)
assert parsed["sum"] == 30
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_code_text_stats(self, registry):
config = {
"source": "def run(args):\n text = args.get('text', '')\n return {'len': len(text), 'words': len(text.split())}",
"_type": "code",
}
registry._custom_tool_configs["stats"] = config
result = await registry.execute_tool("stats", {"text": "hello world test"})
parsed = json.loads(result)
assert parsed["len"] == 16
assert parsed["words"] == 3
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_source_missing(self, registry):
config = {"_type": "code"} # no source
registry._custom_tool_configs["no_source"] = config
result = await registry.execute_tool("no_source", {})
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_no_run_function(self, registry):
config = {"source": "x = 1", "_type": "code"}
registry._custom_tool_configs["no_run"] = config
result = await registry.execute_tool("no_run", {})
assert "run" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_runtime_error(self, registry):
config = {
"source": "def run(args):\n raise ValueError('test error')",
"_type": "code",
}
registry._custom_tool_configs["err"] = config
result = await registry.execute_tool("err", {})
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_sandbox_restriction(self, registry):
"""验证 __builtins__ 被禁用"""
config = {
"source": "def run(args):\n import os\n return os.name",
"_type": "code",
}
registry._custom_tool_configs["unsafe"] = config
result = await registry.execute_tool("unsafe", {})
# import 在沙箱中应该失败
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_code_sandbox_no_file_access(self, registry):
"""验证无法访问文件系统"""
config = {
"source": "def run(args):\n open('/etc/passwd')\n return 'ok'",
"_type": "code",
}
registry._custom_tool_configs["file_access"] = config
result = await registry.execute_tool("file_access", {})
assert "error" in result
class TestToolRegistryTestHelpers:
"""测试工具(不保存到 DB"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_test_code_tool_success(self, registry):
source = "def run(args):\n return {'result': args['x'] * 2}"
result = await registry.test_code_tool(source, {"x": 5})
assert result["success"] is True
assert result["result"]["result"] == 10
@pytest.mark.unit
@pytest.mark.asyncio
async def test_test_code_tool_compile_error(self, registry):
source = "def run(args):\n invalid syntax{{{"
result = await registry.test_code_tool(source, {})
assert result["success"] is False
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_test_http_tool(self, registry):
with patch("httpx.AsyncClient.request", new=AsyncMock()) as mock_request:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.text = '{"ip": "8.8.8.8"}'
mock_request.return_value = mock_response
from app.services.tool_registry import _CODE_SAFE_GLOBALS
result = await registry.test_http_tool(
url="https://httpbin.org/get?ip={ip}",
method="GET",
headers={},
body=None,
args={"ip": "8.8.8.8"},
timeout=5,
)
assert result["success"] is True
assert result["status_code"] == 200
class TestToolRegistryExecute:
"""execute_tool 整体流程"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_builtin(self, registry):
def hello(**kwargs):
return f"Hello, {kwargs.get('name', 'world')}!"
registry.register_builtin_tool("hello", hello, {"function": {"name": "hello"}})
result = await registry.execute_tool("hello", {"name": "Test"})
assert "Hello, Test!" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_nonexistent(self, registry):
result = await registry.execute_tool("no_such_tool", {})
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_unsupported_type(self, registry):
config = {"_type": "unsupported"}
registry._custom_tool_configs["weird"] = config
result = await registry.execute_tool("weird", {})
assert "error" in result
@pytest.mark.unit
@pytest.mark.asyncio
async def test_execute_workflow_not_supported(self, registry):
config = {"_type": "workflow"}
registry._custom_tool_configs["wf"] = config
result = await registry.execute_tool("wf", {})
assert "暂不支持" in result