Files
aiagent/backend/app/services/tool_registry.py

372 lines
13 KiB
Python
Raw Normal View History

2026-01-23 09:49:45 +08:00
"""
工具注册表 管理所有可用工具内置 / HTTP / 代码段
提供统一的工具注册查找和执行接口
2026-01-23 09:49:45 +08:00
"""
from __future__ import annotations
2026-01-23 09:49:45 +08:00
import json
import logging
import traceback
from typing import Any, Callable, Dict, List, Optional
2026-01-23 09:49:45 +08:00
from app.models.tool import Tool
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
# 代码工具的安全内置模块
_CODE_SAFE_GLOBALS = {
"json": json,
"dict": dict,
"list": list,
"str": str,
"int": int,
"float": float,
"bool": bool,
"len": len,
"range": range,
"enumerate": enumerate,
"zip": zip,
"map": map,
"filter": filter,
"sorted": sorted,
"min": min,
"max": max,
"sum": sum,
"abs": abs,
"round": round,
"isinstance": isinstance,
"type": type,
"True": True,
"False": False,
"None": None,
}
2026-01-23 09:49:45 +08:00
class ToolRegistry:
"""工具注册表 — 管理所有可用工具。"""
2026-01-23 09:49:45 +08:00
def __init__(self):
self._builtin_tools: Dict[str, Callable] = {}
self._tool_schemas: Dict[str, Dict[str, Any]] = {}
# 自定义工具配置(从 DB 加载的非 builtin 工具)
self._custom_tool_configs: Dict[str, Dict[str, Any]] = {}
# ─── 内置工具注册 ─────────────────────────────────────────
2026-01-23 09:49:45 +08:00
def register_builtin_tool(self, name: str, func: Callable, schema: Dict[str, Any]):
"""
注册内置工具
2026-01-23 09:49:45 +08:00
Args:
name: 工具名称
func: 工具函数同步或异步
schema: OpenAI function calling 格式 schema
2026-01-23 09:49:45 +08:00
"""
self._builtin_tools[name] = func
self._tool_schemas[name] = schema
logger.debug("注册内置工具: %s", name)
def register_external_tool(self, name: str, description: str, parameters: dict, category: str = "plugin"):
"""注册外部/插件工具 schema无实际执行函数仅用于工作流编辑器展示"""
schema = {
"name": name,
"description": description,
"parameters": parameters,
}
self._tool_schemas[name] = schema
self._custom_tool_configs[name] = {"category": category, "description": description}
logger.debug("注册外部工具 schema: %s (category=%s)", name, category)
# ─── 工具信息查询 ─────────────────────────────────────────
2026-01-23 09:49:45 +08:00
def get_tool_schema(self, name: str) -> Optional[Dict[str, Any]]:
return self._tool_schemas.get(name)
2026-01-23 09:49:45 +08:00
def get_tool_function(self, name: str) -> Optional[Callable]:
return self._builtin_tools.get(name)
2026-01-23 09:49:45 +08:00
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
return list(self._tool_schemas.values())
def builtin_tool_count(self) -> int:
return len(self._builtin_tools)
def builtin_tool_names(self) -> List[str]:
return sorted(self._builtin_tools.keys())
def get_tools_by_names(self, tool_names: List[str]) -> List[Dict[str, Any]]:
tools = []
for name in tool_names:
schema = self.get_tool_schema(name)
if schema:
tools.append(schema)
else:
logger.warning("工具 %s 未找到", name)
return tools
# ─── 从数据库加载自定义工具 ──────────────────────────────
def load_tools_from_db(self, db: Session, tool_names: Optional[List[str]] = None):
2026-01-23 09:49:45 +08:00
"""
从数据库加载工具定义
2026-01-23 09:49:45 +08:00
Args:
db: 数据库会话
tool_names: 指定名称列表None 则加载所有公开工具
2026-01-23 09:49:45 +08:00
"""
query = db.query(Tool).filter(Tool.is_public == True)
if tool_names:
query = query.filter(Tool.name.in_(tool_names))
2026-01-23 09:49:45 +08:00
tools = query.all()
count = 0
2026-01-23 09:49:45 +08:00
for tool in tools:
self._tool_schemas[tool.name] = tool.function_schema
if tool.implementation_type == "builtin":
if tool.name not in self._builtin_tools:
logger.warning("工具 %s 标记为 builtin 但未注册", tool.name)
else:
# 存储自定义工具配置
config = tool.implementation_config or {}
config["_type"] = tool.implementation_type
config["_db_id"] = tool.id
self._custom_tool_configs[tool.name] = config
count += 1
if count:
logger.info("从数据库加载了 %d 个自定义工具", count)
# ─── 工具执行 ─────────────────────────────────────────────
async def execute_tool(self, name: str, args: Dict[str, Any]) -> str:
2026-01-23 09:49:45 +08:00
"""
执行任意工具内置 / HTTP / 代码段
2026-01-23 09:49:45 +08:00
Args:
name: 工具名称
args: 参数字典
2026-01-23 09:49:45 +08:00
Returns:
结果字符串
2026-01-23 09:49:45 +08:00
"""
# 1. 内置工具
func = self._builtin_tools.get(name)
if func:
return await self._run_function(func, name, args)
# 2. 自定义工具
config = self._custom_tool_configs.get(name)
if not config:
return json.dumps({"error": f"工具 '{name}' 不存在"}, ensure_ascii=False)
impl_type = config.get("_type", "")
try:
if impl_type == "http":
return await self._execute_http_tool(name, config, args)
elif impl_type == "code":
return await self._execute_code_tool(name, config, args)
elif impl_type == "workflow":
return json.dumps({"error": "工作流工具暂不支持动态执行"},
ensure_ascii=False)
2026-01-23 09:49:45 +08:00
else:
return json.dumps({"error": f"不支持的实现类型: {impl_type}"},
ensure_ascii=False)
except Exception as e:
logger.error("工具 %s 执行失败: %s", name, e, exc_info=True)
return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"},
ensure_ascii=False)
@staticmethod
async def _run_function(func: Callable, name: str, args: Dict[str, Any]) -> str:
"""执行内置工具函数。"""
import asyncio
try:
if asyncio.iscoroutinefunction(func):
result = await func(**args)
else:
result = func(**args)
if isinstance(result, (dict, list)):
return json.dumps(result, ensure_ascii=False)
return str(result)
except Exception as e:
logger.error("工具 '%s' 执行失败: %s", name, e, exc_info=True)
return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"},
ensure_ascii=False)
# ─── HTTP 工具 ────────────────────────────────────────────
async def _execute_http_tool(
self, name: str, config: Dict[str, Any], args: Dict[str, Any]
) -> str:
"""
执行 HTTP 工具
implementation_config 格式:
{
"url": "https://api.example.com/{path_param}",
"method": "GET",
"headers": {"Authorization": "Bearer xxx"},
"body_template": {"key": "{param}"}, # 可选
"timeout": 30
}
URL body_template 中的 {param} 会被 args 替换
"""
import httpx
url = config.get("url", "")
method = (config.get("method") or "GET").upper()
headers = config.get("headers") or {}
body_template = config.get("body_template")
timeout = config.get("timeout", 30)
if not url:
return json.dumps({"error": "HTTP 工具未配置 URL"}, ensure_ascii=False)
# 模板替换:{param} → args["param"]
def _fmt(template: str) -> str:
try:
return template.format(**args)
except KeyError:
return template
url = _fmt(url)
body: Any = None
if body_template:
body_str = json.dumps(body_template)
body_str = _fmt(body_str)
try:
body = json.loads(body_str)
except json.JSONDecodeError:
body = body_str
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.request(method, url, headers=headers, json=body)
text = response.text[:10000] # 截断过长的响应
result = {
"status_code": response.status_code,
"body": text,
}
return json.dumps(result, ensure_ascii=False)
# ─── 代码工具 ─────────────────────────────────────────────
async def _execute_code_tool(
self, name: str, config: Dict[str, Any], args: Dict[str, Any]
) -> str:
"""
执行代码工具
implementation_config 格式:
{
"source": "def run(args):\\n return {'sum': args['a'] + args['b']}",
"language": "python"
}
代码工具需定义一个 run(args) 函数作为入口
source 在沙箱环境中执行不可访问文件系统/网络
"""
source = config.get("source", "")
if not source:
return json.dumps({"error": "代码工具未配置 source"}, ensure_ascii=False)
# 编译代码,限制可访问的全局变量
safe_globals = _CODE_SAFE_GLOBALS.copy()
safe_globals["__builtins__"] = {} # 禁用所有内置函数
try:
exec(source, safe_globals)
except Exception as e:
return json.dumps({
"error": f"代码编译失败: {e}",
"traceback": traceback.format_exc(),
}, ensure_ascii=False)
run_func = safe_globals.get("run")
if not run_func or not callable(run_func):
return json.dumps({
"error": "代码工具必须定义一个 run(args) 函数"
}, ensure_ascii=False)
try:
result = run_func(args)
if isinstance(result, (dict, list)):
return json.dumps(result, ensure_ascii=False)
return str(result)
except Exception as e:
return json.dumps({
"error": f"代码执行失败: {e}",
"traceback": traceback.format_exc(),
}, ensure_ascii=False)
# ─── 测试工具(不保存,直接执行)─────────────────────────
async def test_http_tool(
self, url: str, method: str, headers: Dict[str, str],
body: Optional[Dict[str, Any]], args: Dict[str, Any],
timeout: int = 30,
) -> Dict[str, Any]:
"""测试 HTTP 工具配置(不保存到 DB"""
import httpx
def _fmt(template: str) -> str:
try:
return template.format(**args)
except KeyError:
return template
url = _fmt(url)
body_sent = None
if body:
body_str = json.dumps(body)
body_str = _fmt(body_str)
try:
body_sent = json.loads(body_str)
except json.JSONDecodeError:
body_sent = body_str
start = __import__("time").time()
try:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.request(method.upper(), url,
headers=headers, json=body_sent)
elapsed_ms = int((__import__("time").time() - start) * 1000)
return {
"success": True,
"status_code": response.status_code,
"elapsed_ms": elapsed_ms,
"body": response.text[:10000],
}
except Exception as e:
return {"success": False, "error": str(e)}
async def test_code_tool(
self, source: str, args: Dict[str, Any]
) -> Dict[str, Any]:
"""测试代码工具(不保存到 DB"""
safe_globals = _CODE_SAFE_GLOBALS.copy()
safe_globals["__builtins__"] = {}
try:
exec(source, safe_globals)
except Exception as e:
return {"success": False, "error": f"编译失败: {e}"}
run_func = safe_globals.get("run")
if not run_func:
return {"success": False, "error": "代码须定义 run(args) 函数"}
try:
start = __import__("time").time()
result = run_func(args)
elapsed_ms = int((__import__("time").time() - start) * 1000)
return {"success": True, "elapsed_ms": elapsed_ms, "result": result}
except Exception as e:
return {"success": False, "error": f"执行失败: {e}"}
2026-01-23 09:49:45 +08:00
# 全局单例
2026-01-23 09:49:45 +08:00
tool_registry = ToolRegistry()