Files
aiagent/backend/app/services/plugin_loader.py

220 lines
7.1 KiB
Python
Raw Permalink Normal View History

"""
插件加载器 加载校验沙箱执行第三方节点插件
插件规范 (manifest.json):
{
"name": "my-plugin",
"version": "1.0.0",
"description": "...",
"author": "...",
"node_type": "custom_action",
"node_label": "自定义操作",
"category": "custom",
"entry": "execute.py",
"inputs_schema": {"type": "object", "properties": {...}},
"outputs_schema": {"type": "object", "properties": {...}}
}
执行函数签名:
async def execute(inputs: dict, context: dict) -> dict
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import subprocess
import sys
import tempfile
import traceback
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
# 插件存储根目录
PLUGINS_DIR = Path(__file__).resolve().parent.parent.parent / "data" / "plugins"
PLUGINS_DIR.mkdir(parents=True, exist_ok=True)
def validate_manifest(manifest: dict) -> tuple[bool, str]:
"""校验 manifest.json 是否合法。"""
required = ["name", "version", "node_type"]
for key in required:
if key not in manifest:
return False, f"缺少必填字段: {key}"
if not isinstance(manifest.get("name"), str) or not manifest["name"].strip():
return False, "name 必须是非空字符串"
if not isinstance(manifest.get("node_type"), str) or not manifest["node_type"].strip():
return False, "node_type 必须是非空字符串"
# node_type 必须是合法的标识符
import re
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', manifest["node_type"]):
return False, "node_type 必须是合法的 Python 标识符"
return True, "ok"
def load_plugin_code(plugin_id: str, code: str, node_type: str) -> str:
"""将插件代码写入磁盘并返回文件路径。"""
plugin_dir = PLUGINS_DIR / plugin_id
plugin_dir.mkdir(parents=True, exist_ok=True)
file_path = plugin_dir / f"{node_type}.py"
file_path.write_text(code, encoding="utf-8")
return str(file_path)
def unload_plugin_code(plugin_id: str):
"""从磁盘删除插件代码。"""
import shutil
plugin_dir = PLUGINS_DIR / plugin_id
if plugin_dir.exists():
shutil.rmtree(plugin_dir)
def _make_sandbox_globals():
"""构建受限的内建函数集。"""
safe_builtins = {
"True": True, "False": False, "None": None,
"abs": abs, "all": all, "any": any, "bool": bool,
"dict": dict, "enumerate": enumerate, "filter": filter,
"float": float, "int": int, "isinstance": isinstance,
"len": len, "list": list, "map": map, "max": max,
"min": min, "range": range, "round": round, "set": set,
"sorted": sorted, "str": str, "sum": sum, "tuple": tuple,
"type": type, "zip": zip,
"print": print, "json": json,
"Exception": Exception, "ValueError": ValueError,
"TypeError": TypeError, "KeyError": KeyError,
}
return {"__builtins__": safe_builtins}
async def execute_plugin_sandbox(
code: str,
inputs: Dict[str, Any],
context: Optional[Dict[str, Any]] = None,
timeout_seconds: int = 30,
) -> Dict[str, Any]:
"""在沙箱中异步执行插件代码。
使用 subprocess 隔离执行超时自动终止防止恶意代码影响主进程
"""
context = context or {}
# 将代码和输入写入临时脚本
wrapper_code = f'''
import json
import sys
import traceback
# 用户插件代码
{code}
# 准备输入
_inputs = json.loads(sys.stdin.read())
_ctx = _inputs.get("__context__", {{}})
_user_inputs = _inputs.get("__inputs__", {{}})
# 查找 execute 函数
if "execute" not in dir():
print(json.dumps({{"ok": False, "error": "插件缺少 execute 函数"}}))
sys.exit(1)
try:
import asyncio
result = asyncio.run(execute(_user_inputs, _ctx))
print(json.dumps({{"ok": True, "result": result}}, ensure_ascii=False, default=str))
except Exception as e:
print(json.dumps({{"ok": False, "error": str(e), "traceback": traceback.format_exc()}}, ensure_ascii=False))
'''
try:
proc = await asyncio.wait_for(
asyncio.create_subprocess_exec(
sys.executable, "-c", wrapper_code,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
),
timeout=5,
)
input_data = json.dumps({"__inputs__": inputs, "__context__": context})
stdout, stderr = await asyncio.wait_for(
proc.communicate(input_data.encode("utf-8")),
timeout=timeout_seconds,
)
if stderr:
logger.warning("插件沙箱 stderr: %s", stderr.decode("utf-8", errors="replace")[:500])
result = json.loads(stdout.decode("utf-8"))
if result.get("ok"):
return {"success": True, "result": result["result"]}
else:
return {"success": False, "error": result.get("error", "未知错误")}
except asyncio.TimeoutError:
return {"success": False, "error": f"插件执行超时({timeout_seconds}秒)"}
except json.JSONDecodeError as e:
return {"success": False, "error": f"插件返回解析失败: {e}"}
except Exception as e:
return {"success": False, "error": f"插件执行异常: {e}"}
def list_plugin_node_types() -> List[Dict[str, Any]]:
"""列出当前已加载的插件节点类型,供工作流编辑器使用。"""
from app.core.database import SessionLocal
from app.models.plugin import NodePlugin
db = SessionLocal()
try:
plugins = db.query(NodePlugin).filter(NodePlugin.enabled == True).all()
return [
{
"id": p.id,
"node_type": p.node_type,
"node_label": p.node_label or p.name,
"category": p.category,
"description": p.description,
"inputs_schema": p.inputs_schema,
"outputs_schema": p.outputs_schema,
"icon": p.icon,
}
for p in plugins
]
finally:
db.close()
def register_plugin_node_type(plugin) -> bool:
"""将插件注册到工具注册表,使工作流编辑器可用。"""
from app.services.tool_registry import tool_registry
node_type = plugin.node_type
node_label = plugin.node_label or plugin.name
schema = {
"name": node_type,
"description": plugin.description or f"自定义节点: {node_label}",
"parameters": {
"type": "object",
"properties": (plugin.inputs_schema or {}).get("properties", {}),
"required": (plugin.inputs_schema or {}).get("required", []),
},
}
try:
tool_registry.register_external_tool(
name=node_type,
description=schema["description"],
parameters=schema["parameters"],
category=f"plugin:{plugin.category}",
)
logger.info("插件节点类型已注册: %s", node_type)
return True
except Exception as e:
logger.warning("插件节点类型注册失败 [%s]: %s", node_type, e)
return False