Files
aiagent/backend/app/api/plugins.py

335 lines
10 KiB
Python
Raw Normal View History

"""
插件管理 API 上传管理市场执行第三方节点插件
"""
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from sqlalchemy import or_
from typing import Dict, List, Optional, Any
import json
import logging
from app.core.database import get_db
from app.api.auth import get_current_user
from app.models.user import User
from app.models.plugin import NodePlugin
from app.services.plugin_loader import (
validate_manifest, load_plugin_code, unload_plugin_code,
execute_plugin_sandbox, list_plugin_node_types, register_plugin_node_type,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/plugins", tags=["plugins"])
# ---------- Models ----------
class PluginManifest(BaseModel):
name: str
version: str = "1.0.0"
description: Optional[str] = None
author: Optional[str] = None
node_type: str
node_label: Optional[str] = None
category: str = "custom"
inputs_schema: Optional[Dict[str, Any]] = None
outputs_schema: Optional[Dict[str, Any]] = None
icon: Optional[str] = None
tags: Optional[List[str]] = None
class PluginCreate(BaseModel):
manifest: PluginManifest
code: str = Field(..., description="Python 插件代码(须包含 execute 函数)")
is_public: bool = False
class PluginUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
node_label: Optional[str] = None
category: Optional[str] = None
code: Optional[str] = None
inputs_schema: Optional[Dict[str, Any]] = None
outputs_schema: Optional[Dict[str, Any]] = None
enabled: Optional[bool] = None
is_public: Optional[bool] = None
icon: Optional[str] = None
tags: Optional[List[str]] = None
class PluginExecuteRequest(BaseModel):
inputs: Dict[str, Any] = Field(default_factory=dict)
context: Optional[Dict[str, Any]] = None
class PluginResponse(BaseModel):
id: str
name: str
version: str
description: Optional[str]
author: Optional[str]
node_type: str
node_label: Optional[str]
category: str
manifest: Optional[Dict[str, Any]]
inputs_schema: Optional[Dict[str, Any]]
outputs_schema: Optional[Dict[str, Any]]
enabled: bool
is_public: bool
install_count: int
rating_avg: int
icon: Optional[str]
tags: Optional[List[str]]
created_at: Any
updated_at: Any
class Config:
from_attributes = True
# ---------- 我的插件 ----------
@router.get("/my", response_model=List[PluginResponse])
async def get_my_plugins(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取当前用户上传的插件。"""
plugins = (
db.query(NodePlugin)
.filter(NodePlugin.user_id == current_user.id)
.order_by(NodePlugin.updated_at.desc())
.all()
)
return plugins
# ---------- 插件市场 ----------
@router.get("/market", response_model=List[PluginResponse])
async def get_market_plugins(
search: Optional[str] = None,
category: Optional[str] = None,
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
):
"""获取公开插件市场列表。"""
query = db.query(NodePlugin).filter(
NodePlugin.is_public == True,
NodePlugin.enabled == True,
)
if search:
query = query.filter(
or_(
NodePlugin.name.like(f"%{search}%"),
NodePlugin.description.like(f"%{search}%"),
NodePlugin.node_label.like(f"%{search}%"),
)
)
if category:
query = query.filter(NodePlugin.category == category)
plugins = query.order_by(NodePlugin.install_count.desc()).offset(skip).limit(limit).all()
return plugins
# ---------- 创建插件 ----------
@router.post("", response_model=PluginResponse, status_code=201)
async def create_plugin(
body: PluginCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""上传新插件。"""
manifest = body.manifest.dict()
# 校验 manifest
ok, err = validate_manifest(manifest)
if not ok:
raise HTTPException(status_code=400, detail=f"manifest 校验失败: {err}")
# 校验代码包含 execute 函数
if "def execute" not in body.code and "async def execute" not in body.code:
raise HTTPException(status_code=400, detail="代码中必须包含 execute(inputs, context) 函数")
# 检查名称唯一性
existing = db.query(NodePlugin).filter(NodePlugin.name == manifest["name"]).first()
if existing:
raise HTTPException(status_code=409, detail=f"插件名称 '{manifest['name']}' 已存在")
# 校验 node_type 唯一性
existing_type = db.query(NodePlugin).filter(NodePlugin.node_type == manifest["node_type"]).first()
if existing_type:
raise HTTPException(status_code=409, detail=f"节点类型 '{manifest['node_type']}' 已被使用")
plugin = NodePlugin(
name=manifest["name"],
version=manifest.get("version", "1.0.0"),
description=manifest.get("description"),
author=manifest.get("author") or current_user.username,
node_type=manifest["node_type"],
node_label=manifest.get("node_label", manifest["name"]),
category=manifest.get("category", "custom"),
manifest=manifest,
inputs_schema=manifest.get("inputs_schema"),
outputs_schema=manifest.get("outputs_schema"),
code=body.code,
icon=manifest.get("icon"),
tags=manifest.get("tags", []),
is_public=body.is_public,
user_id=current_user.id,
)
db.add(plugin)
db.flush()
# 写入磁盘并注册
try:
load_plugin_code(plugin.id, body.code, manifest["node_type"])
if plugin.enabled:
register_plugin_node_type(plugin)
except Exception as e:
logger.warning("插件注册警告: %s", e)
db.commit()
db.refresh(plugin)
logger.info("插件已创建: %s (node_type=%s)", plugin.name, plugin.node_type)
return plugin
# ---------- 获取单个插件 ----------
@router.get("/{plugin_id}", response_model=PluginResponse)
async def get_plugin(
plugin_id: str,
db: Session = Depends(get_db),
):
"""获取插件详情。"""
plugin = db.query(NodePlugin).filter(NodePlugin.id == plugin_id).first()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
return plugin
# ---------- 更新插件 ----------
@router.put("/{plugin_id}", response_model=PluginResponse)
async def update_plugin(
plugin_id: str,
body: PluginUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""更新插件。"""
plugin = db.query(NodePlugin).filter(NodePlugin.id == plugin_id).first()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
if plugin.user_id and plugin.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权修改此插件")
for field, value in body.dict(exclude_unset=True).items():
setattr(plugin, field, value)
# 如果代码或 schema 变动,重新加载
if body.code:
load_plugin_code(plugin.id, body.code, plugin.node_type)
if body.enabled is not None:
if body.enabled:
register_plugin_node_type(plugin)
db.commit()
db.refresh(plugin)
return plugin
# ---------- 删除插件 ----------
@router.delete("/{plugin_id}")
async def delete_plugin(
plugin_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""删除插件。"""
plugin = db.query(NodePlugin).filter(NodePlugin.id == plugin_id).first()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
if plugin.user_id and plugin.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权删除此插件")
unload_plugin_code(plugin.id)
db.delete(plugin)
db.commit()
return {"message": "插件已删除"}
# ---------- 启用/禁用 ----------
@router.post("/{plugin_id}/toggle")
async def toggle_plugin(
plugin_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""切换插件启用/禁用状态。"""
plugin = db.query(NodePlugin).filter(NodePlugin.id == plugin_id).first()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
plugin.enabled = not plugin.enabled
if plugin.enabled:
load_plugin_code(plugin.id, plugin.code or "", plugin.node_type)
register_plugin_node_type(plugin)
else:
unload_plugin_code(plugin.id)
db.commit()
return {"enabled": plugin.enabled, "message": f"插件已{'启用' if plugin.enabled else '禁用'}"}
# ---------- 安装插件(从市场) ----------
@router.post("/{plugin_id}/install")
async def install_plugin(
plugin_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""从插件市场安装插件(增加安装计数)。"""
plugin = db.query(NodePlugin).filter(NodePlugin.id == plugin_id).first()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
if not plugin.is_public:
raise HTTPException(status_code=403, detail="此插件未公开")
plugin.install_count += 1
if plugin.enabled:
register_plugin_node_type(plugin)
db.commit()
return {"message": "插件已安装", "install_count": plugin.install_count}
# ---------- 沙箱测试执行 ----------
@router.post("/{plugin_id}/test")
async def test_plugin(
plugin_id: str,
body: PluginExecuteRequest,
db: Session = Depends(get_db),
):
"""在沙箱中测试执行插件。"""
plugin = db.query(NodePlugin).filter(NodePlugin.id == plugin_id).first()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
if not plugin.code:
raise HTTPException(status_code=400, detail="插件无代码")
result = await execute_plugin_sandbox(
code=plugin.code,
inputs=body.inputs,
context=body.context,
timeout_seconds=30,
)
return result
# ---------- 插件节点类型列表(供工作流编辑器用) ----------
@router.get("/internal/node-types")
async def get_node_types():
"""获取所有已启用的插件节点类型。"""
return list_plugin_node_types()