""" 工具注册表 — 管理所有可用工具(内置 / HTTP / 代码段)。 提供统一的工具注册、查找和执行接口。 """ from __future__ import annotations import json import logging import traceback from typing import Any, Callable, Dict, List, Optional from app.models.tool import Tool from sqlalchemy.orm import Session logger = logging.getLogger(__name__) # 代码工具的安全内置模块 _CODE_SAFE_GLOBALS = { "json": json, "dict": dict, "list": list, "str": str, "int": int, "float": float, "bool": bool, "len": len, "range": range, "enumerate": enumerate, "zip": zip, "map": map, "filter": filter, "sorted": sorted, "min": min, "max": max, "sum": sum, "abs": abs, "round": round, "isinstance": isinstance, "type": type, "True": True, "False": False, "None": None, } class ToolRegistry: """工具注册表 — 管理所有可用工具。""" def __init__(self): self._builtin_tools: Dict[str, Callable] = {} self._tool_schemas: Dict[str, Dict[str, Any]] = {} # 自定义工具配置(从 DB 加载的非 builtin 工具) self._custom_tool_configs: Dict[str, Dict[str, Any]] = {} # ─── 内置工具注册 ───────────────────────────────────────── def register_builtin_tool(self, name: str, func: Callable, schema: Dict[str, Any]): """ 注册内置工具。 Args: name: 工具名称 func: 工具函数(同步或异步) schema: OpenAI function calling 格式 schema """ self._builtin_tools[name] = func self._tool_schemas[name] = schema logger.debug("注册内置工具: %s", name) def register_external_tool(self, name: str, description: str, parameters: dict, category: str = "plugin"): """注册外部/插件工具 schema(无实际执行函数,仅用于工作流编辑器展示)。""" schema = { "name": name, "description": description, "parameters": parameters, } self._tool_schemas[name] = schema self._custom_tool_configs[name] = {"category": category, "description": description} logger.debug("注册外部工具 schema: %s (category=%s)", name, category) # ─── 工具信息查询 ───────────────────────────────────────── def get_tool_schema(self, name: str) -> Optional[Dict[str, Any]]: return self._tool_schemas.get(name) def get_tool_function(self, name: str) -> Optional[Callable]: return self._builtin_tools.get(name) def get_all_tool_schemas(self) -> List[Dict[str, Any]]: return list(self._tool_schemas.values()) def builtin_tool_count(self) -> int: return len(self._builtin_tools) def builtin_tool_names(self) -> List[str]: return sorted(self._builtin_tools.keys()) def get_tools_by_names(self, tool_names: List[str]) -> List[Dict[str, Any]]: tools = [] for name in tool_names: schema = self.get_tool_schema(name) if schema: tools.append(schema) else: logger.warning("工具 %s 未找到", name) return tools # ─── 从数据库加载自定义工具 ────────────────────────────── def load_tools_from_db(self, db: Session, tool_names: Optional[List[str]] = None): """ 从数据库加载工具定义。 Args: db: 数据库会话 tool_names: 指定名称列表;None 则加载所有公开工具 """ query = db.query(Tool).filter(Tool.is_public == True) if tool_names: query = query.filter(Tool.name.in_(tool_names)) tools = query.all() count = 0 for tool in tools: self._tool_schemas[tool.name] = tool.function_schema if tool.implementation_type == "builtin": if tool.name not in self._builtin_tools: logger.warning("工具 %s 标记为 builtin 但未注册", tool.name) else: # 存储自定义工具配置 config = tool.implementation_config or {} config["_type"] = tool.implementation_type config["_db_id"] = tool.id self._custom_tool_configs[tool.name] = config count += 1 if count: logger.info("从数据库加载了 %d 个自定义工具", count) # ─── 工具执行 ───────────────────────────────────────────── async def execute_tool(self, name: str, args: Dict[str, Any]) -> str: """ 执行任意工具(内置 / HTTP / 代码段)。 Args: name: 工具名称 args: 参数字典 Returns: 结果字符串 """ # 1. 内置工具 func = self._builtin_tools.get(name) if func: return await self._run_function(func, name, args) # 2. 自定义工具 config = self._custom_tool_configs.get(name) if not config: return json.dumps({"error": f"工具 '{name}' 不存在"}, ensure_ascii=False) impl_type = config.get("_type", "") try: if impl_type == "http": return await self._execute_http_tool(name, config, args) elif impl_type == "code": return await self._execute_code_tool(name, config, args) elif impl_type == "workflow": return json.dumps({"error": "工作流工具暂不支持动态执行"}, ensure_ascii=False) else: return json.dumps({"error": f"不支持的实现类型: {impl_type}"}, ensure_ascii=False) except Exception as e: logger.error("工具 %s 执行失败: %s", name, e, exc_info=True) return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"}, ensure_ascii=False) @staticmethod async def _run_function(func: Callable, name: str, args: Dict[str, Any]) -> str: """执行内置工具函数。""" import asyncio try: if asyncio.iscoroutinefunction(func): result = await func(**args) else: result = func(**args) if isinstance(result, (dict, list)): return json.dumps(result, ensure_ascii=False) return str(result) except Exception as e: logger.error("工具 '%s' 执行失败: %s", name, e, exc_info=True) return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"}, ensure_ascii=False) # ─── HTTP 工具 ──────────────────────────────────────────── async def _execute_http_tool( self, name: str, config: Dict[str, Any], args: Dict[str, Any] ) -> str: """ 执行 HTTP 工具。 implementation_config 格式: { "url": "https://api.example.com/{path_param}", "method": "GET", "headers": {"Authorization": "Bearer xxx"}, "body_template": {"key": "{param}"}, # 可选 "timeout": 30 } URL 和 body_template 中的 {param} 会被 args 替换。 """ import httpx url = config.get("url", "") method = (config.get("method") or "GET").upper() headers = config.get("headers") or {} body_template = config.get("body_template") timeout = config.get("timeout", 30) if not url: return json.dumps({"error": "HTTP 工具未配置 URL"}, ensure_ascii=False) # 模板替换:{param} → args["param"] def _fmt(template: str) -> str: try: return template.format(**args) except KeyError: return template url = _fmt(url) body: Any = None if body_template: body_str = json.dumps(body_template) body_str = _fmt(body_str) try: body = json.loads(body_str) except json.JSONDecodeError: body = body_str async with httpx.AsyncClient(timeout=timeout) as client: response = await client.request(method, url, headers=headers, json=body) text = response.text[:10000] # 截断过长的响应 result = { "status_code": response.status_code, "body": text, } return json.dumps(result, ensure_ascii=False) # ─── 代码工具 ───────────────────────────────────────────── async def _execute_code_tool( self, name: str, config: Dict[str, Any], args: Dict[str, Any] ) -> str: """ 执行代码工具。 implementation_config 格式: { "source": "def run(args):\\n return {'sum': args['a'] + args['b']}", "language": "python" } 代码工具需定义一个 run(args) 函数作为入口。 source 在沙箱环境中执行,不可访问文件系统/网络。 """ source = config.get("source", "") if not source: return json.dumps({"error": "代码工具未配置 source"}, ensure_ascii=False) # 编译代码,限制可访问的全局变量 safe_globals = _CODE_SAFE_GLOBALS.copy() safe_globals["__builtins__"] = {} # 禁用所有内置函数 try: exec(source, safe_globals) except Exception as e: return json.dumps({ "error": f"代码编译失败: {e}", "traceback": traceback.format_exc(), }, ensure_ascii=False) run_func = safe_globals.get("run") if not run_func or not callable(run_func): return json.dumps({ "error": "代码工具必须定义一个 run(args) 函数" }, ensure_ascii=False) try: result = run_func(args) if isinstance(result, (dict, list)): return json.dumps(result, ensure_ascii=False) return str(result) except Exception as e: return json.dumps({ "error": f"代码执行失败: {e}", "traceback": traceback.format_exc(), }, ensure_ascii=False) # ─── 测试工具(不保存,直接执行)───────────────────────── async def test_http_tool( self, url: str, method: str, headers: Dict[str, str], body: Optional[Dict[str, Any]], args: Dict[str, Any], timeout: int = 30, ) -> Dict[str, Any]: """测试 HTTP 工具配置(不保存到 DB)。""" import httpx def _fmt(template: str) -> str: try: return template.format(**args) except KeyError: return template url = _fmt(url) body_sent = None if body: body_str = json.dumps(body) body_str = _fmt(body_str) try: body_sent = json.loads(body_str) except json.JSONDecodeError: body_sent = body_str start = __import__("time").time() try: async with httpx.AsyncClient(timeout=timeout) as client: response = await client.request(method.upper(), url, headers=headers, json=body_sent) elapsed_ms = int((__import__("time").time() - start) * 1000) return { "success": True, "status_code": response.status_code, "elapsed_ms": elapsed_ms, "body": response.text[:10000], } except Exception as e: return {"success": False, "error": str(e)} async def test_code_tool( self, source: str, args: Dict[str, Any] ) -> Dict[str, Any]: """测试代码工具(不保存到 DB)。""" safe_globals = _CODE_SAFE_GLOBALS.copy() safe_globals["__builtins__"] = {} try: exec(source, safe_globals) except Exception as e: return {"success": False, "error": f"编译失败: {e}"} run_func = safe_globals.get("run") if not run_func: return {"success": False, "error": "代码须定义 run(args) 函数"} try: start = __import__("time").time() result = run_func(args) elapsed_ms = int((__import__("time").time() - start) * 1000) return {"success": True, "elapsed_ms": elapsed_ms, "result": result} except Exception as e: return {"success": False, "error": f"执行失败: {e}"} # 全局单例 tool_registry = ToolRegistry()