2026-01-23 09:49:45 +08:00
|
|
|
|
"""
|
2026-05-01 22:30:46 +08:00
|
|
|
|
工具注册表 — 管理所有可用工具(内置 / HTTP / 代码段)。
|
|
|
|
|
|
|
|
|
|
|
|
提供统一的工具注册、查找和执行接口。
|
2026-01-23 09:49:45 +08:00
|
|
|
|
"""
|
2026-05-01 22:30:46 +08:00
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
2026-01-23 09:49:45 +08:00
|
|
|
|
import json
|
|
|
|
|
|
import logging
|
2026-05-01 22:30:46 +08:00
|
|
|
|
import traceback
|
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
|
|
|
2026-01-23 09:49:45 +08:00
|
|
|
|
from app.models.tool import Tool
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-05-01 22:30:46 +08:00
|
|
|
|
# 代码工具的安全内置模块
|
|
|
|
|
|
_CODE_SAFE_GLOBALS = {
|
|
|
|
|
|
"json": json,
|
|
|
|
|
|
"dict": dict,
|
|
|
|
|
|
"list": list,
|
|
|
|
|
|
"str": str,
|
|
|
|
|
|
"int": int,
|
|
|
|
|
|
"float": float,
|
|
|
|
|
|
"bool": bool,
|
|
|
|
|
|
"len": len,
|
|
|
|
|
|
"range": range,
|
|
|
|
|
|
"enumerate": enumerate,
|
|
|
|
|
|
"zip": zip,
|
|
|
|
|
|
"map": map,
|
|
|
|
|
|
"filter": filter,
|
|
|
|
|
|
"sorted": sorted,
|
|
|
|
|
|
"min": min,
|
|
|
|
|
|
"max": max,
|
|
|
|
|
|
"sum": sum,
|
|
|
|
|
|
"abs": abs,
|
|
|
|
|
|
"round": round,
|
|
|
|
|
|
"isinstance": isinstance,
|
|
|
|
|
|
"type": type,
|
|
|
|
|
|
"True": True,
|
|
|
|
|
|
"False": False,
|
|
|
|
|
|
"None": None,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-23 09:49:45 +08:00
|
|
|
|
class ToolRegistry:
|
2026-05-01 22:30:46 +08:00
|
|
|
|
"""工具注册表 — 管理所有可用工具。"""
|
|
|
|
|
|
|
2026-01-23 09:49:45 +08:00
|
|
|
|
def __init__(self):
|
|
|
|
|
|
self._builtin_tools: Dict[str, Callable] = {}
|
|
|
|
|
|
self._tool_schemas: Dict[str, Dict[str, Any]] = {}
|
2026-05-01 22:30:46 +08:00
|
|
|
|
# 自定义工具配置(从 DB 加载的非 builtin 工具)
|
|
|
|
|
|
self._custom_tool_configs: Dict[str, Dict[str, Any]] = {}
|
|
|
|
|
|
|
|
|
|
|
|
# ─── 内置工具注册 ─────────────────────────────────────────
|
|
|
|
|
|
|
2026-01-23 09:49:45 +08:00
|
|
|
|
def register_builtin_tool(self, name: str, func: Callable, schema: Dict[str, Any]):
|
|
|
|
|
|
"""
|
2026-05-01 22:30:46 +08:00
|
|
|
|
注册内置工具。
|
|
|
|
|
|
|
2026-01-23 09:49:45 +08:00
|
|
|
|
Args:
|
|
|
|
|
|
name: 工具名称
|
2026-05-01 22:30:46 +08:00
|
|
|
|
func: 工具函数(同步或异步)
|
|
|
|
|
|
schema: OpenAI function calling 格式 schema
|
2026-01-23 09:49:45 +08:00
|
|
|
|
"""
|
|
|
|
|
|
self._builtin_tools[name] = func
|
|
|
|
|
|
self._tool_schemas[name] = schema
|
2026-04-08 11:44:24 +08:00
|
|
|
|
logger.debug("注册内置工具: %s", name)
|
2026-05-01 22:30:46 +08:00
|
|
|
|
|
2026-05-06 21:44:45 +08:00
|
|
|
|
def register_external_tool(self, name: str, description: str, parameters: dict, category: str = "plugin"):
|
|
|
|
|
|
"""注册外部/插件工具 schema(无实际执行函数,仅用于工作流编辑器展示)。"""
|
|
|
|
|
|
schema = {
|
|
|
|
|
|
"name": name,
|
|
|
|
|
|
"description": description,
|
|
|
|
|
|
"parameters": parameters,
|
|
|
|
|
|
}
|
|
|
|
|
|
self._tool_schemas[name] = schema
|
|
|
|
|
|
self._custom_tool_configs[name] = {"category": category, "description": description}
|
|
|
|
|
|
logger.debug("注册外部工具 schema: %s (category=%s)", name, category)
|
|
|
|
|
|
|
2026-05-01 22:30:46 +08:00
|
|
|
|
# ─── 工具信息查询 ─────────────────────────────────────────
|
|
|
|
|
|
|
2026-01-23 09:49:45 +08:00
|
|
|
|
def get_tool_schema(self, name: str) -> Optional[Dict[str, Any]]:
|
|
|
|
|
|
return self._tool_schemas.get(name)
|
2026-05-01 22:30:46 +08:00
|
|
|
|
|
2026-01-23 09:49:45 +08:00
|
|
|
|
def get_tool_function(self, name: str) -> Optional[Callable]:
|
|
|
|
|
|
return self._builtin_tools.get(name)
|
2026-05-01 22:30:46 +08:00
|
|
|
|
|
2026-01-23 09:49:45 +08:00
|
|
|
|
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
|
|
|
|
|
|
return list(self._tool_schemas.values())
|
2026-04-08 11:44:24 +08:00
|
|
|
|
|
|
|
|
|
|
def builtin_tool_count(self) -> int:
|
|
|
|
|
|
return len(self._builtin_tools)
|
|
|
|
|
|
|
|
|
|
|
|
def builtin_tool_names(self) -> List[str]:
|
|
|
|
|
|
return sorted(self._builtin_tools.keys())
|
|
|
|
|
|
|
2026-05-01 22:30:46 +08:00
|
|
|
|
def get_tools_by_names(self, tool_names: List[str]) -> List[Dict[str, Any]]:
|
|
|
|
|
|
tools = []
|
|
|
|
|
|
for name in tool_names:
|
|
|
|
|
|
schema = self.get_tool_schema(name)
|
|
|
|
|
|
if schema:
|
|
|
|
|
|
tools.append(schema)
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.warning("工具 %s 未找到", name)
|
|
|
|
|
|
return tools
|
|
|
|
|
|
|
|
|
|
|
|
# ─── 从数据库加载自定义工具 ──────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
def load_tools_from_db(self, db: Session, tool_names: Optional[List[str]] = None):
|
2026-01-23 09:49:45 +08:00
|
|
|
|
"""
|
2026-05-01 22:30:46 +08:00
|
|
|
|
从数据库加载工具定义。
|
|
|
|
|
|
|
2026-01-23 09:49:45 +08:00
|
|
|
|
Args:
|
|
|
|
|
|
db: 数据库会话
|
2026-05-01 22:30:46 +08:00
|
|
|
|
tool_names: 指定名称列表;None 则加载所有公开工具
|
2026-01-23 09:49:45 +08:00
|
|
|
|
"""
|
|
|
|
|
|
query = db.query(Tool).filter(Tool.is_public == True)
|
|
|
|
|
|
if tool_names:
|
|
|
|
|
|
query = query.filter(Tool.name.in_(tool_names))
|
2026-05-01 22:30:46 +08:00
|
|
|
|
|
2026-01-23 09:49:45 +08:00
|
|
|
|
tools = query.all()
|
2026-05-01 22:30:46 +08:00
|
|
|
|
count = 0
|
2026-01-23 09:49:45 +08:00
|
|
|
|
for tool in tools:
|
|
|
|
|
|
self._tool_schemas[tool.name] = tool.function_schema
|
2026-05-01 22:30:46 +08:00
|
|
|
|
if tool.implementation_type == "builtin":
|
|
|
|
|
|
if tool.name not in self._builtin_tools:
|
|
|
|
|
|
logger.warning("工具 %s 标记为 builtin 但未注册", tool.name)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 存储自定义工具配置
|
|
|
|
|
|
config = tool.implementation_config or {}
|
|
|
|
|
|
config["_type"] = tool.implementation_type
|
|
|
|
|
|
config["_db_id"] = tool.id
|
|
|
|
|
|
self._custom_tool_configs[tool.name] = config
|
|
|
|
|
|
count += 1
|
|
|
|
|
|
|
|
|
|
|
|
if count:
|
|
|
|
|
|
logger.info("从数据库加载了 %d 个自定义工具", count)
|
|
|
|
|
|
|
|
|
|
|
|
# ─── 工具执行 ─────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
async def execute_tool(self, name: str, args: Dict[str, Any]) -> str:
|
2026-01-23 09:49:45 +08:00
|
|
|
|
"""
|
2026-05-01 22:30:46 +08:00
|
|
|
|
执行任意工具(内置 / HTTP / 代码段)。
|
|
|
|
|
|
|
2026-01-23 09:49:45 +08:00
|
|
|
|
Args:
|
2026-05-01 22:30:46 +08:00
|
|
|
|
name: 工具名称
|
|
|
|
|
|
args: 参数字典
|
|
|
|
|
|
|
2026-01-23 09:49:45 +08:00
|
|
|
|
Returns:
|
2026-05-01 22:30:46 +08:00
|
|
|
|
结果字符串
|
2026-01-23 09:49:45 +08:00
|
|
|
|
"""
|
2026-05-01 22:30:46 +08:00
|
|
|
|
# 1. 内置工具
|
|
|
|
|
|
func = self._builtin_tools.get(name)
|
|
|
|
|
|
if func:
|
|
|
|
|
|
return await self._run_function(func, name, args)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 自定义工具
|
|
|
|
|
|
config = self._custom_tool_configs.get(name)
|
|
|
|
|
|
if not config:
|
|
|
|
|
|
return json.dumps({"error": f"工具 '{name}' 不存在"}, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
impl_type = config.get("_type", "")
|
|
|
|
|
|
try:
|
|
|
|
|
|
if impl_type == "http":
|
|
|
|
|
|
return await self._execute_http_tool(name, config, args)
|
|
|
|
|
|
elif impl_type == "code":
|
|
|
|
|
|
return await self._execute_code_tool(name, config, args)
|
|
|
|
|
|
elif impl_type == "workflow":
|
|
|
|
|
|
return json.dumps({"error": "工作流工具暂不支持动态执行"},
|
|
|
|
|
|
ensure_ascii=False)
|
2026-01-23 09:49:45 +08:00
|
|
|
|
else:
|
2026-05-01 22:30:46 +08:00
|
|
|
|
return json.dumps({"error": f"不支持的实现类型: {impl_type}"},
|
|
|
|
|
|
ensure_ascii=False)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error("工具 %s 执行失败: %s", name, e, exc_info=True)
|
|
|
|
|
|
return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"},
|
|
|
|
|
|
ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
async def _run_function(func: Callable, name: str, args: Dict[str, Any]) -> str:
|
|
|
|
|
|
"""执行内置工具函数。"""
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
try:
|
|
|
|
|
|
if asyncio.iscoroutinefunction(func):
|
|
|
|
|
|
result = await func(**args)
|
|
|
|
|
|
else:
|
|
|
|
|
|
result = func(**args)
|
|
|
|
|
|
if isinstance(result, (dict, list)):
|
|
|
|
|
|
return json.dumps(result, ensure_ascii=False)
|
|
|
|
|
|
return str(result)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error("工具 '%s' 执行失败: %s", name, e, exc_info=True)
|
|
|
|
|
|
return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"},
|
|
|
|
|
|
ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
# ─── HTTP 工具 ────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_http_tool(
|
|
|
|
|
|
self, name: str, config: Dict[str, Any], args: Dict[str, Any]
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
执行 HTTP 工具。
|
|
|
|
|
|
|
|
|
|
|
|
implementation_config 格式:
|
|
|
|
|
|
{
|
|
|
|
|
|
"url": "https://api.example.com/{path_param}",
|
|
|
|
|
|
"method": "GET",
|
|
|
|
|
|
"headers": {"Authorization": "Bearer xxx"},
|
|
|
|
|
|
"body_template": {"key": "{param}"}, # 可选
|
|
|
|
|
|
"timeout": 30
|
|
|
|
|
|
}
|
|
|
|
|
|
URL 和 body_template 中的 {param} 会被 args 替换。
|
|
|
|
|
|
"""
|
|
|
|
|
|
import httpx
|
|
|
|
|
|
|
|
|
|
|
|
url = config.get("url", "")
|
|
|
|
|
|
method = (config.get("method") or "GET").upper()
|
|
|
|
|
|
headers = config.get("headers") or {}
|
|
|
|
|
|
body_template = config.get("body_template")
|
|
|
|
|
|
timeout = config.get("timeout", 30)
|
|
|
|
|
|
|
|
|
|
|
|
if not url:
|
|
|
|
|
|
return json.dumps({"error": "HTTP 工具未配置 URL"}, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
# 模板替换:{param} → args["param"]
|
|
|
|
|
|
def _fmt(template: str) -> str:
|
|
|
|
|
|
try:
|
|
|
|
|
|
return template.format(**args)
|
|
|
|
|
|
except KeyError:
|
|
|
|
|
|
return template
|
|
|
|
|
|
|
|
|
|
|
|
url = _fmt(url)
|
|
|
|
|
|
body: Any = None
|
|
|
|
|
|
if body_template:
|
|
|
|
|
|
body_str = json.dumps(body_template)
|
|
|
|
|
|
body_str = _fmt(body_str)
|
|
|
|
|
|
try:
|
|
|
|
|
|
body = json.loads(body_str)
|
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
|
body = body_str
|
|
|
|
|
|
|
|
|
|
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
|
|
|
|
response = await client.request(method, url, headers=headers, json=body)
|
|
|
|
|
|
text = response.text[:10000] # 截断过长的响应
|
|
|
|
|
|
|
|
|
|
|
|
result = {
|
|
|
|
|
|
"status_code": response.status_code,
|
|
|
|
|
|
"body": text,
|
|
|
|
|
|
}
|
|
|
|
|
|
return json.dumps(result, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
# ─── 代码工具 ─────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_code_tool(
|
|
|
|
|
|
self, name: str, config: Dict[str, Any], args: Dict[str, Any]
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
执行代码工具。
|
|
|
|
|
|
|
|
|
|
|
|
implementation_config 格式:
|
|
|
|
|
|
{
|
|
|
|
|
|
"source": "def run(args):\\n return {'sum': args['a'] + args['b']}",
|
|
|
|
|
|
"language": "python"
|
|
|
|
|
|
}
|
|
|
|
|
|
代码工具需定义一个 run(args) 函数作为入口。
|
|
|
|
|
|
source 在沙箱环境中执行,不可访问文件系统/网络。
|
|
|
|
|
|
"""
|
|
|
|
|
|
source = config.get("source", "")
|
|
|
|
|
|
if not source:
|
|
|
|
|
|
return json.dumps({"error": "代码工具未配置 source"}, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
# 编译代码,限制可访问的全局变量
|
|
|
|
|
|
safe_globals = _CODE_SAFE_GLOBALS.copy()
|
|
|
|
|
|
safe_globals["__builtins__"] = {} # 禁用所有内置函数
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
exec(source, safe_globals)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return json.dumps({
|
|
|
|
|
|
"error": f"代码编译失败: {e}",
|
|
|
|
|
|
"traceback": traceback.format_exc(),
|
|
|
|
|
|
}, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
run_func = safe_globals.get("run")
|
|
|
|
|
|
if not run_func or not callable(run_func):
|
|
|
|
|
|
return json.dumps({
|
|
|
|
|
|
"error": "代码工具必须定义一个 run(args) 函数"
|
|
|
|
|
|
}, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
result = run_func(args)
|
|
|
|
|
|
if isinstance(result, (dict, list)):
|
|
|
|
|
|
return json.dumps(result, ensure_ascii=False)
|
|
|
|
|
|
return str(result)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return json.dumps({
|
|
|
|
|
|
"error": f"代码执行失败: {e}",
|
|
|
|
|
|
"traceback": traceback.format_exc(),
|
|
|
|
|
|
}, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
# ─── 测试工具(不保存,直接执行)─────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
async def test_http_tool(
|
|
|
|
|
|
self, url: str, method: str, headers: Dict[str, str],
|
|
|
|
|
|
body: Optional[Dict[str, Any]], args: Dict[str, Any],
|
|
|
|
|
|
timeout: int = 30,
|
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
|
"""测试 HTTP 工具配置(不保存到 DB)。"""
|
|
|
|
|
|
import httpx
|
|
|
|
|
|
|
|
|
|
|
|
def _fmt(template: str) -> str:
|
|
|
|
|
|
try:
|
|
|
|
|
|
return template.format(**args)
|
|
|
|
|
|
except KeyError:
|
|
|
|
|
|
return template
|
|
|
|
|
|
|
|
|
|
|
|
url = _fmt(url)
|
|
|
|
|
|
body_sent = None
|
|
|
|
|
|
if body:
|
|
|
|
|
|
body_str = json.dumps(body)
|
|
|
|
|
|
body_str = _fmt(body_str)
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_sent = json.loads(body_str)
|
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
|
body_sent = body_str
|
|
|
|
|
|
|
|
|
|
|
|
start = __import__("time").time()
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
|
|
|
|
response = await client.request(method.upper(), url,
|
|
|
|
|
|
headers=headers, json=body_sent)
|
|
|
|
|
|
elapsed_ms = int((__import__("time").time() - start) * 1000)
|
|
|
|
|
|
return {
|
|
|
|
|
|
"success": True,
|
|
|
|
|
|
"status_code": response.status_code,
|
|
|
|
|
|
"elapsed_ms": elapsed_ms,
|
|
|
|
|
|
"body": response.text[:10000],
|
|
|
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return {"success": False, "error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
async def test_code_tool(
|
|
|
|
|
|
self, source: str, args: Dict[str, Any]
|
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
|
"""测试代码工具(不保存到 DB)。"""
|
|
|
|
|
|
safe_globals = _CODE_SAFE_GLOBALS.copy()
|
|
|
|
|
|
safe_globals["__builtins__"] = {}
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
exec(source, safe_globals)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return {"success": False, "error": f"编译失败: {e}"}
|
|
|
|
|
|
|
|
|
|
|
|
run_func = safe_globals.get("run")
|
|
|
|
|
|
if not run_func:
|
|
|
|
|
|
return {"success": False, "error": "代码须定义 run(args) 函数"}
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
start = __import__("time").time()
|
|
|
|
|
|
result = run_func(args)
|
|
|
|
|
|
elapsed_ms = int((__import__("time").time() - start) * 1000)
|
|
|
|
|
|
return {"success": True, "elapsed_ms": elapsed_ms, "result": result}
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return {"success": False, "error": f"执行失败: {e}"}
|
2026-01-23 09:49:45 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-05-01 22:30:46 +08:00
|
|
|
|
# 全局单例
|
2026-01-23 09:49:45 +08:00
|
|
|
|
tool_registry = ToolRegistry()
|