Files
aiagent/backend/app/services/tool_registry.py
renjianbo 3c102ed5f9 feat: #27 插件系统 — 第三方节点扩展
- NodePlugin 模型: manifest规范(name/version/node_type/inputs_schema/outputs_schema)
- plugin_loader 服务: manifest校验、代码加载/卸载、沙箱执行(subprocess隔离+超时30s)
- plugins API: CRUD、启用/禁用、市场浏览、安装计数、沙箱测试执行
- PluginMarket.vue: 插件市场上传/浏览/安装/启用禁用/删除/测试
- 注册 register_external_tool 到 tool_registry,供工作流编辑器使用

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-06 21:44:45 +08:00

372 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工具注册表 — 管理所有可用工具(内置 / HTTP / 代码段)。
提供统一的工具注册、查找和执行接口。
"""
from __future__ import annotations
import json
import logging
import traceback
from typing import Any, Callable, Dict, List, Optional
from app.models.tool import Tool
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
# 代码工具的安全内置模块
_CODE_SAFE_GLOBALS = {
"json": json,
"dict": dict,
"list": list,
"str": str,
"int": int,
"float": float,
"bool": bool,
"len": len,
"range": range,
"enumerate": enumerate,
"zip": zip,
"map": map,
"filter": filter,
"sorted": sorted,
"min": min,
"max": max,
"sum": sum,
"abs": abs,
"round": round,
"isinstance": isinstance,
"type": type,
"True": True,
"False": False,
"None": None,
}
class ToolRegistry:
"""工具注册表 — 管理所有可用工具。"""
def __init__(self):
self._builtin_tools: Dict[str, Callable] = {}
self._tool_schemas: Dict[str, Dict[str, Any]] = {}
# 自定义工具配置(从 DB 加载的非 builtin 工具)
self._custom_tool_configs: Dict[str, Dict[str, Any]] = {}
# ─── 内置工具注册 ─────────────────────────────────────────
def register_builtin_tool(self, name: str, func: Callable, schema: Dict[str, Any]):
"""
注册内置工具。
Args:
name: 工具名称
func: 工具函数(同步或异步)
schema: OpenAI function calling 格式 schema
"""
self._builtin_tools[name] = func
self._tool_schemas[name] = schema
logger.debug("注册内置工具: %s", name)
def register_external_tool(self, name: str, description: str, parameters: dict, category: str = "plugin"):
"""注册外部/插件工具 schema无实际执行函数仅用于工作流编辑器展示"""
schema = {
"name": name,
"description": description,
"parameters": parameters,
}
self._tool_schemas[name] = schema
self._custom_tool_configs[name] = {"category": category, "description": description}
logger.debug("注册外部工具 schema: %s (category=%s)", name, category)
# ─── 工具信息查询 ─────────────────────────────────────────
def get_tool_schema(self, name: str) -> Optional[Dict[str, Any]]:
return self._tool_schemas.get(name)
def get_tool_function(self, name: str) -> Optional[Callable]:
return self._builtin_tools.get(name)
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
return list(self._tool_schemas.values())
def builtin_tool_count(self) -> int:
return len(self._builtin_tools)
def builtin_tool_names(self) -> List[str]:
return sorted(self._builtin_tools.keys())
def get_tools_by_names(self, tool_names: List[str]) -> List[Dict[str, Any]]:
tools = []
for name in tool_names:
schema = self.get_tool_schema(name)
if schema:
tools.append(schema)
else:
logger.warning("工具 %s 未找到", name)
return tools
# ─── 从数据库加载自定义工具 ──────────────────────────────
def load_tools_from_db(self, db: Session, tool_names: Optional[List[str]] = None):
"""
从数据库加载工具定义。
Args:
db: 数据库会话
tool_names: 指定名称列表None 则加载所有公开工具
"""
query = db.query(Tool).filter(Tool.is_public == True)
if tool_names:
query = query.filter(Tool.name.in_(tool_names))
tools = query.all()
count = 0
for tool in tools:
self._tool_schemas[tool.name] = tool.function_schema
if tool.implementation_type == "builtin":
if tool.name not in self._builtin_tools:
logger.warning("工具 %s 标记为 builtin 但未注册", tool.name)
else:
# 存储自定义工具配置
config = tool.implementation_config or {}
config["_type"] = tool.implementation_type
config["_db_id"] = tool.id
self._custom_tool_configs[tool.name] = config
count += 1
if count:
logger.info("从数据库加载了 %d 个自定义工具", count)
# ─── 工具执行 ─────────────────────────────────────────────
async def execute_tool(self, name: str, args: Dict[str, Any]) -> str:
"""
执行任意工具(内置 / HTTP / 代码段)。
Args:
name: 工具名称
args: 参数字典
Returns:
结果字符串
"""
# 1. 内置工具
func = self._builtin_tools.get(name)
if func:
return await self._run_function(func, name, args)
# 2. 自定义工具
config = self._custom_tool_configs.get(name)
if not config:
return json.dumps({"error": f"工具 '{name}' 不存在"}, ensure_ascii=False)
impl_type = config.get("_type", "")
try:
if impl_type == "http":
return await self._execute_http_tool(name, config, args)
elif impl_type == "code":
return await self._execute_code_tool(name, config, args)
elif impl_type == "workflow":
return json.dumps({"error": "工作流工具暂不支持动态执行"},
ensure_ascii=False)
else:
return json.dumps({"error": f"不支持的实现类型: {impl_type}"},
ensure_ascii=False)
except Exception as e:
logger.error("工具 %s 执行失败: %s", name, e, exc_info=True)
return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"},
ensure_ascii=False)
@staticmethod
async def _run_function(func: Callable, name: str, args: Dict[str, Any]) -> str:
"""执行内置工具函数。"""
import asyncio
try:
if asyncio.iscoroutinefunction(func):
result = await func(**args)
else:
result = func(**args)
if isinstance(result, (dict, list)):
return json.dumps(result, ensure_ascii=False)
return str(result)
except Exception as e:
logger.error("工具 '%s' 执行失败: %s", name, e, exc_info=True)
return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"},
ensure_ascii=False)
# ─── HTTP 工具 ────────────────────────────────────────────
async def _execute_http_tool(
self, name: str, config: Dict[str, Any], args: Dict[str, Any]
) -> str:
"""
执行 HTTP 工具。
implementation_config 格式:
{
"url": "https://api.example.com/{path_param}",
"method": "GET",
"headers": {"Authorization": "Bearer xxx"},
"body_template": {"key": "{param}"}, # 可选
"timeout": 30
}
URL 和 body_template 中的 {param} 会被 args 替换。
"""
import httpx
url = config.get("url", "")
method = (config.get("method") or "GET").upper()
headers = config.get("headers") or {}
body_template = config.get("body_template")
timeout = config.get("timeout", 30)
if not url:
return json.dumps({"error": "HTTP 工具未配置 URL"}, ensure_ascii=False)
# 模板替换:{param} → args["param"]
def _fmt(template: str) -> str:
try:
return template.format(**args)
except KeyError:
return template
url = _fmt(url)
body: Any = None
if body_template:
body_str = json.dumps(body_template)
body_str = _fmt(body_str)
try:
body = json.loads(body_str)
except json.JSONDecodeError:
body = body_str
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.request(method, url, headers=headers, json=body)
text = response.text[:10000] # 截断过长的响应
result = {
"status_code": response.status_code,
"body": text,
}
return json.dumps(result, ensure_ascii=False)
# ─── 代码工具 ─────────────────────────────────────────────
async def _execute_code_tool(
self, name: str, config: Dict[str, Any], args: Dict[str, Any]
) -> str:
"""
执行代码工具。
implementation_config 格式:
{
"source": "def run(args):\\n return {'sum': args['a'] + args['b']}",
"language": "python"
}
代码工具需定义一个 run(args) 函数作为入口。
source 在沙箱环境中执行,不可访问文件系统/网络。
"""
source = config.get("source", "")
if not source:
return json.dumps({"error": "代码工具未配置 source"}, ensure_ascii=False)
# 编译代码,限制可访问的全局变量
safe_globals = _CODE_SAFE_GLOBALS.copy()
safe_globals["__builtins__"] = {} # 禁用所有内置函数
try:
exec(source, safe_globals)
except Exception as e:
return json.dumps({
"error": f"代码编译失败: {e}",
"traceback": traceback.format_exc(),
}, ensure_ascii=False)
run_func = safe_globals.get("run")
if not run_func or not callable(run_func):
return json.dumps({
"error": "代码工具必须定义一个 run(args) 函数"
}, ensure_ascii=False)
try:
result = run_func(args)
if isinstance(result, (dict, list)):
return json.dumps(result, ensure_ascii=False)
return str(result)
except Exception as e:
return json.dumps({
"error": f"代码执行失败: {e}",
"traceback": traceback.format_exc(),
}, ensure_ascii=False)
# ─── 测试工具(不保存,直接执行)─────────────────────────
async def test_http_tool(
self, url: str, method: str, headers: Dict[str, str],
body: Optional[Dict[str, Any]], args: Dict[str, Any],
timeout: int = 30,
) -> Dict[str, Any]:
"""测试 HTTP 工具配置(不保存到 DB"""
import httpx
def _fmt(template: str) -> str:
try:
return template.format(**args)
except KeyError:
return template
url = _fmt(url)
body_sent = None
if body:
body_str = json.dumps(body)
body_str = _fmt(body_str)
try:
body_sent = json.loads(body_str)
except json.JSONDecodeError:
body_sent = body_str
start = __import__("time").time()
try:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.request(method.upper(), url,
headers=headers, json=body_sent)
elapsed_ms = int((__import__("time").time() - start) * 1000)
return {
"success": True,
"status_code": response.status_code,
"elapsed_ms": elapsed_ms,
"body": response.text[:10000],
}
except Exception as e:
return {"success": False, "error": str(e)}
async def test_code_tool(
self, source: str, args: Dict[str, Any]
) -> Dict[str, Any]:
"""测试代码工具(不保存到 DB"""
safe_globals = _CODE_SAFE_GLOBALS.copy()
safe_globals["__builtins__"] = {}
try:
exec(source, safe_globals)
except Exception as e:
return {"success": False, "error": f"编译失败: {e}"}
run_func = safe_globals.get("run")
if not run_func:
return {"success": False, "error": "代码须定义 run(args) 函数"}
try:
start = __import__("time").time()
result = run_func(args)
elapsed_ms = int((__import__("time").time() - start) * 1000)
return {"success": True, "elapsed_ms": elapsed_ms, "result": result}
except Exception as e:
return {"success": False, "error": f"执行失败: {e}"}
# 全局单例
tool_registry = ToolRegistry()