203 lines
6.0 KiB
Python
203 lines
6.0 KiB
Python
"""
|
||
批量操作API
|
||
支持批量执行、批量导出等功能
|
||
"""
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from sqlalchemy.orm import Session
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional, Dict, Any
|
||
import logging
|
||
from app.core.database import get_db
|
||
from app.models.workflow import Workflow
|
||
from app.models.execution import Execution
|
||
from app.api.auth import get_current_user
|
||
from app.models.user import User
|
||
from app.core.exceptions import NotFoundError
|
||
from app.tasks.workflow_tasks import execute_workflow_task
|
||
import json
|
||
from datetime import datetime
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/api/v1/batch", tags=["batch"])
|
||
|
||
|
||
class BatchExecuteRequest(BaseModel):
|
||
"""批量执行请求模型"""
|
||
workflow_ids: List[str]
|
||
input_data: Dict[str, Any] = {}
|
||
|
||
|
||
class BatchExecuteResponse(BaseModel):
|
||
"""批量执行响应模型"""
|
||
total: int
|
||
success: int
|
||
failed: int
|
||
executions: List[Dict[str, Any]]
|
||
|
||
|
||
class BatchExportRequest(BaseModel):
|
||
"""批量导出请求模型"""
|
||
workflow_ids: List[str]
|
||
|
||
|
||
@router.post("/execute", response_model=BatchExecuteResponse, status_code=status.HTTP_200_OK)
|
||
async def batch_execute(
|
||
request: BatchExecuteRequest,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""批量执行工作流"""
|
||
if not request.workflow_ids:
|
||
raise HTTPException(status_code=400, detail="工作流ID列表不能为空")
|
||
|
||
if len(request.workflow_ids) > 50:
|
||
raise HTTPException(status_code=400, detail="一次最多执行50个工作流")
|
||
|
||
executions = []
|
||
success_count = 0
|
||
failed_count = 0
|
||
|
||
for workflow_id in request.workflow_ids:
|
||
try:
|
||
# 验证工作流是否存在且属于当前用户
|
||
workflow = db.query(Workflow).filter(
|
||
Workflow.id == workflow_id,
|
||
Workflow.user_id == current_user.id
|
||
).first()
|
||
|
||
if not workflow:
|
||
executions.append({
|
||
"workflow_id": workflow_id,
|
||
"status": "failed",
|
||
"error": "工作流不存在或无权限"
|
||
})
|
||
failed_count += 1
|
||
continue
|
||
|
||
# 创建执行记录
|
||
execution = Execution(
|
||
workflow_id=workflow_id,
|
||
input_data=request.input_data,
|
||
status="pending"
|
||
)
|
||
db.add(execution)
|
||
db.commit()
|
||
db.refresh(execution)
|
||
|
||
# 异步执行工作流
|
||
workflow_data = {
|
||
'nodes': workflow.nodes,
|
||
'edges': workflow.edges
|
||
}
|
||
task = execute_workflow_task.delay(
|
||
str(execution.id),
|
||
workflow_id,
|
||
workflow_data,
|
||
request.input_data
|
||
)
|
||
|
||
# 更新执行记录的task_id
|
||
execution.task_id = task.id
|
||
db.commit()
|
||
db.refresh(execution)
|
||
|
||
executions.append({
|
||
"workflow_id": workflow_id,
|
||
"workflow_name": workflow.name,
|
||
"execution_id": str(execution.id),
|
||
"status": "pending",
|
||
"task_id": task.id
|
||
})
|
||
success_count += 1
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量执行工作流失败: {workflow_id} - {str(e)}")
|
||
executions.append({
|
||
"workflow_id": workflow_id,
|
||
"status": "failed",
|
||
"error": str(e)
|
||
})
|
||
failed_count += 1
|
||
|
||
return BatchExecuteResponse(
|
||
total=len(request.workflow_ids),
|
||
success=success_count,
|
||
failed=failed_count,
|
||
executions=executions
|
||
)
|
||
|
||
|
||
@router.post("/export", status_code=status.HTTP_200_OK)
|
||
async def batch_export(
|
||
request: BatchExportRequest,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""批量导出工作流(JSON格式)"""
|
||
if not request.workflow_ids:
|
||
raise HTTPException(status_code=400, detail="工作流ID列表不能为空")
|
||
|
||
if len(request.workflow_ids) > 100:
|
||
raise HTTPException(status_code=400, detail="一次最多导出100个工作流")
|
||
|
||
workflows = db.query(Workflow).filter(
|
||
Workflow.id.in_(request.workflow_ids),
|
||
Workflow.user_id == current_user.id
|
||
).all()
|
||
|
||
if len(workflows) != len(request.workflow_ids):
|
||
raise HTTPException(status_code=403, detail="部分工作流不存在或无权限")
|
||
|
||
# 构建导出数据
|
||
export_data = {
|
||
"exported_at": datetime.now().isoformat(),
|
||
"total": len(workflows),
|
||
"workflows": []
|
||
}
|
||
|
||
for workflow in workflows:
|
||
export_data["workflows"].append({
|
||
"id": str(workflow.id),
|
||
"name": workflow.name,
|
||
"description": workflow.description,
|
||
"nodes": workflow.nodes,
|
||
"edges": workflow.edges,
|
||
"version": workflow.version,
|
||
"status": workflow.status
|
||
})
|
||
|
||
return export_data
|
||
|
||
|
||
@router.post("/delete", status_code=status.HTTP_200_OK)
|
||
async def batch_delete(
|
||
workflow_ids: List[str],
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""批量删除工作流"""
|
||
if not workflow_ids:
|
||
raise HTTPException(status_code=400, detail="工作流ID列表不能为空")
|
||
|
||
if len(workflow_ids) > 100:
|
||
raise HTTPException(status_code=400, detail="一次最多删除100个工作流")
|
||
|
||
workflows = db.query(Workflow).filter(
|
||
Workflow.id.in_(workflow_ids),
|
||
Workflow.user_id == current_user.id
|
||
).all()
|
||
|
||
deleted_count = 0
|
||
for workflow in workflows:
|
||
db.delete(workflow)
|
||
deleted_count += 1
|
||
|
||
db.commit()
|
||
|
||
return {
|
||
"message": f"成功删除 {deleted_count} 个工作流",
|
||
"deleted_count": deleted_count,
|
||
"total_requested": len(workflow_ids)
|
||
}
|