""" 工作流API """ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from pydantic import BaseModel from typing import List, Optional, Dict, Any from datetime import datetime import logging from app.core.database import get_db from app.models.workflow import Workflow from app.models.workflow_version import WorkflowVersion from app.api.auth import get_current_user, UserResponse from app.models.user import User from app.core.exceptions import NotFoundError, ValidationError, ConflictError from app.services.workflow_validator import validate_workflow from app.services.workflow_templates import list_templates, create_from_template, get_template from app.services.permission_service import check_workflow_permission logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/workflows", tags=["workflows"]) class WorkflowCreate(BaseModel): """工作流创建模型""" name: str description: Optional[str] = None nodes: List[Dict[str, Any]] edges: List[Dict[str, Any]] class WorkflowUpdate(BaseModel): """工作流更新模型""" name: Optional[str] = None description: Optional[str] = None nodes: Optional[List[Dict[str, Any]]] = None edges: Optional[List[Dict[str, Any]]] = None status: Optional[str] = None class WorkflowResponse(BaseModel): """工作流响应模型""" id: str name: str description: Optional[str] nodes: List[Dict[str, Any]] edges: List[Dict[str, Any]] version: int status: str user_id: str created_at: datetime updated_at: datetime class Config: from_attributes = True @router.post("/validate", status_code=status.HTTP_200_OK) async def validate_workflow_endpoint( workflow_data: WorkflowCreate, current_user: User = Depends(get_current_user) ): """验证工作流(不保存)""" validation_result = validate_workflow(workflow_data.nodes, workflow_data.edges) return validation_result @router.get("/templates", status_code=status.HTTP_200_OK) async def get_workflow_templates( current_user: User = Depends(get_current_user) ): """获取工作流模板列表""" templates = list_templates() return templates @router.get("/templates/{template_id}", status_code=status.HTTP_200_OK) async def get_workflow_template( template_id: str, current_user: User = Depends(get_current_user) ): """获取工作流模板详情""" template = get_template(template_id) if not template: raise NotFoundError("模板", template_id) return template @router.post("/templates/{template_id}/create", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED) async def create_workflow_from_template( template_id: str, name: Optional[str] = None, description: Optional[str] = None, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """从模板创建工作流""" try: workflow_data = create_from_template(template_id, name, description) except ValueError as e: raise NotFoundError("模板", template_id) # 验证工作流 validation_result = validate_workflow(workflow_data["nodes"], workflow_data["edges"]) if not validation_result["valid"]: raise ValidationError(f"模板工作流验证失败: {', '.join(validation_result['errors'])}") # 创建工作流 workflow = Workflow( name=workflow_data["name"], description=workflow_data["description"], nodes=workflow_data["nodes"], edges=workflow_data["edges"], user_id=current_user.id ) db.add(workflow) db.commit() db.refresh(workflow) return workflow @router.get("", response_model=List[WorkflowResponse]) async def get_workflows( skip: int = 0, limit: int = 100, search: Optional[str] = None, status: Optional[str] = None, sort_by: Optional[str] = "created_at", sort_order: Optional[str] = "desc", db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """获取工作流列表(支持搜索、筛选、排序)""" # 管理员可以看到所有工作流,普通用户只能看到自己拥有的或有read权限的 if current_user.role == "admin": query = db.query(Workflow) else: # 获取用户拥有或有read权限的工作流 from sqlalchemy import or_ from app.models.permission import WorkflowPermission # 用户拥有的工作流 owned_workflows = db.query(Workflow.id).filter(Workflow.user_id == current_user.id).subquery() # 用户有read权限的工作流(通过用户ID或角色) user_permissions = db.query(WorkflowPermission.workflow_id).filter( WorkflowPermission.permission_type == "read", or_( WorkflowPermission.user_id == current_user.id, WorkflowPermission.role_id.in_([r.id for r in current_user.roles]) ) ).subquery() query = db.query(Workflow).filter( or_( Workflow.id.in_(db.query(owned_workflows.c.id)), Workflow.id.in_(db.query(user_permissions.c.workflow_id)) ) ) # 搜索:按名称或描述搜索 if search: search_pattern = f"%{search}%" query = query.filter( (Workflow.name.ilike(search_pattern)) | (Workflow.description.ilike(search_pattern)) ) # 筛选:按状态筛选 if status: query = query.filter(Workflow.status == status) # 排序 if sort_by == "name": order_by = Workflow.name elif sort_by == "created_at": order_by = Workflow.created_at elif sort_by == "updated_at": order_by = Workflow.updated_at else: order_by = Workflow.created_at if sort_order == "asc": query = query.order_by(order_by.asc()) else: query = query.order_by(order_by.desc()) workflows = query.offset(skip).limit(limit).all() return workflows @router.post("", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED) async def create_workflow( workflow_data: WorkflowCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """创建工作流""" # 验证工作流 validation_result = validate_workflow(workflow_data.nodes, workflow_data.edges) if not validation_result["valid"]: raise ValidationError(f"工作流验证失败: {', '.join(validation_result['errors'])}") # 如果有警告,记录日志 if validation_result["warnings"]: logger.warning(f"工作流创建警告: {', '.join(validation_result['warnings'])}") workflow = Workflow( name=workflow_data.name, description=workflow_data.description, nodes=workflow_data.nodes, edges=workflow_data.edges, user_id=current_user.id ) db.add(workflow) db.commit() db.refresh(workflow) return workflow @router.get("/{workflow_id}", response_model=WorkflowResponse) async def get_workflow( workflow_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """获取工作流详情""" workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first() if not workflow: raise NotFoundError("工作流", workflow_id) # 检查权限:read权限 if not check_workflow_permission(db, current_user, workflow, "read"): raise HTTPException(status_code=403, detail="无权访问此工作流") return workflow @router.put("/{workflow_id}", response_model=WorkflowResponse) async def update_workflow( workflow_id: str, workflow_data: WorkflowUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """更新工作流(自动保存版本)""" workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first() if not workflow: raise NotFoundError("工作流", workflow_id) # 检查权限:write权限 if not check_workflow_permission(db, current_user, workflow, "write"): raise HTTPException(status_code=403, detail="无权修改此工作流") # 如果更新了节点或边,需要验证 nodes_to_validate = workflow_data.nodes if workflow_data.nodes is not None else workflow.nodes edges_to_validate = workflow_data.edges if workflow_data.edges is not None else workflow.edges validation_result = validate_workflow(nodes_to_validate, edges_to_validate) if not validation_result["valid"]: raise ValidationError(f"工作流验证失败: {', '.join(validation_result['errors'])}") # 如果有警告,记录日志 if validation_result["warnings"]: logger.warning(f"工作流更新警告: {', '.join(validation_result['warnings'])}") # 保存当前版本到版本历史表(如果表存在) try: version = WorkflowVersion( workflow_id=workflow.id, version=workflow.version, name=workflow.name, description=workflow.description, nodes=workflow.nodes, edges=workflow.edges, status=workflow.status, created_by=current_user.id ) db.add(version) except Exception as e: # 如果表不存在,记录警告但不影响更新操作 logger.warning(f"保存版本历史失败: {str(e)},继续执行更新") # 更新工作流 if workflow_data.name is not None: workflow.name = workflow_data.name if workflow_data.description is not None: workflow.description = workflow_data.description if workflow_data.nodes is not None: workflow.nodes = workflow_data.nodes if workflow_data.edges is not None: workflow.edges = workflow_data.edges if workflow_data.status is not None: workflow.status = workflow_data.status workflow.version += 1 db.commit() db.refresh(workflow) return workflow @router.delete("/{workflow_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_workflow( workflow_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """删除工作流(只有所有者可以删除)""" workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first() if not workflow: raise NotFoundError("工作流", workflow_id) # 只有工作流所有者可以删除 if workflow.user_id != current_user.id and current_user.role != "admin": raise HTTPException(status_code=403, detail="无权删除此工作流") db.delete(workflow) db.commit() return None @router.get("/{workflow_id}/export", status_code=status.HTTP_200_OK) async def export_workflow( workflow_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """导出工作流(JSON格式)""" workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first() if not workflow: raise NotFoundError("工作流", workflow_id) # 检查权限:read权限 if not check_workflow_permission(db, current_user, workflow, "read"): raise HTTPException(status_code=403, detail="无权导出此工作流") return { "id": str(workflow.id), "name": workflow.name, "description": workflow.description, "nodes": workflow.nodes, "edges": workflow.edges, "version": workflow.version, "status": workflow.status, "exported_at": datetime.utcnow().isoformat() } @router.post("/import", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED) async def import_workflow( workflow_data: Dict[str, Any], db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """导入工作流(JSON格式)""" # 提取工作流数据 name = workflow_data.get("name", "导入的工作流") description = workflow_data.get("description") nodes = workflow_data.get("nodes", []) edges = workflow_data.get("edges", []) # 验证工作流 validation_result = validate_workflow(nodes, edges) if not validation_result["valid"]: raise ValidationError(f"导入的工作流验证失败: {', '.join(validation_result['errors'])}") # 重新生成节点ID(避免ID冲突) node_id_mapping = {} for node in nodes: old_id = node["id"] new_id = f"node_{len(node_id_mapping)}_{old_id}" node_id_mapping[old_id] = new_id node["id"] = new_id # 更新边的源节点和目标节点ID for edge in edges: if edge.get("source") in node_id_mapping: edge["source"] = node_id_mapping[edge["source"]] if edge.get("target") in node_id_mapping: edge["target"] = node_id_mapping[edge["target"]] # 创建工作流 workflow = Workflow( name=name, description=description, nodes=nodes, edges=edges, user_id=current_user.id ) db.add(workflow) db.commit() db.refresh(workflow) return workflow @router.post("/{workflow_id}/execute") async def execute_workflow( workflow_id: str, input_data: Dict[str, Any], db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """执行工作流""" workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first() if not workflow: raise NotFoundError("工作流", workflow_id) # 检查权限:execute权限 if not check_workflow_permission(db, current_user, workflow, "execute"): raise HTTPException(status_code=403, detail="无权执行此工作流") # 导入executions API的创建函数 from app.api.executions import create_execution, ExecutionCreate execution_data = ExecutionCreate( workflow_id=workflow_id, input_data=input_data ) return await create_execution(execution_data, db, current_user) # 版本管理API class WorkflowVersionResponse(BaseModel): """工作流版本响应模型""" id: str workflow_id: str version: int name: str description: Optional[str] nodes: List[Dict[str, Any]] edges: List[Dict[str, Any]] status: str created_by: Optional[str] created_at: datetime comment: Optional[str] class Config: from_attributes = True @router.get("/{workflow_id}/versions", response_model=List[WorkflowVersionResponse]) async def get_workflow_versions( workflow_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """获取工作流版本列表""" # 验证工作流是否存在且属于当前用户 workflow = db.query(Workflow).filter( Workflow.id == workflow_id, Workflow.user_id == current_user.id ).first() if not workflow: raise NotFoundError("工作流", workflow_id) # 获取所有版本(包括当前版本) # 如果表不存在,只返回当前版本 try: versions = db.query(WorkflowVersion).filter( WorkflowVersion.workflow_id == workflow_id ).order_by(WorkflowVersion.version.desc()).all() except Exception as e: # 如果表不存在或其他数据库错误,只返回当前版本 logger.warning(f"查询版本历史失败: {str(e)},仅返回当前版本") versions = [] # 添加当前版本到列表 current_version = WorkflowVersionResponse( id=workflow.id, workflow_id=workflow.id, version=workflow.version, name=workflow.name, description=workflow.description, nodes=workflow.nodes, edges=workflow.edges, status=workflow.status, created_by=workflow.user_id, created_at=workflow.updated_at or workflow.created_at, comment="当前版本" ) result = [current_version] # 转换历史版本 for v in versions: version_dict = { "id": v.id, "workflow_id": v.workflow_id, "version": v.version, "name": v.name, "description": v.description, "nodes": v.nodes, "edges": v.edges, "status": v.status, "created_by": v.created_by, "created_at": v.created_at, "comment": v.comment } result.append(WorkflowVersionResponse(**version_dict)) return result @router.get("/{workflow_id}/versions/{version}", response_model=WorkflowVersionResponse) async def get_workflow_version( workflow_id: str, version: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """获取工作流特定版本""" # 验证工作流是否存在且属于当前用户 workflow = db.query(Workflow).filter( Workflow.id == workflow_id, Workflow.user_id == current_user.id ).first() if not workflow: raise NotFoundError("工作流", workflow_id) # 如果是当前版本 if version == workflow.version: return WorkflowVersionResponse( id=workflow.id, workflow_id=workflow.id, version=workflow.version, name=workflow.name, description=workflow.description, nodes=workflow.nodes, edges=workflow.edges, status=workflow.status, created_by=workflow.user_id, created_at=workflow.updated_at or workflow.created_at, comment="当前版本" ) # 查找历史版本 workflow_version = db.query(WorkflowVersion).filter( WorkflowVersion.workflow_id == workflow_id, WorkflowVersion.version == version ).first() if not workflow_version: raise NotFoundError("工作流版本", f"{workflow_id} v{version}") return WorkflowVersionResponse( id=workflow_version.id, workflow_id=workflow_version.workflow_id, version=workflow_version.version, name=workflow_version.name, description=workflow_version.description, nodes=workflow_version.nodes, edges=workflow_version.edges, status=workflow_version.status, created_by=workflow_version.created_by, created_at=workflow_version.created_at, comment=workflow_version.comment ) class WorkflowVersionRollback(BaseModel): """工作流版本回滚模型""" comment: Optional[str] = None @router.post("/{workflow_id}/versions/{version}/rollback", response_model=WorkflowResponse) async def rollback_workflow_version( workflow_id: str, version: int, rollback_data: Optional[WorkflowVersionRollback] = None, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """回滚工作流到指定版本""" # 验证工作流是否存在且属于当前用户 workflow = db.query(Workflow).filter( Workflow.id == workflow_id, Workflow.user_id == current_user.id ).first() if not workflow: raise NotFoundError("工作流", workflow_id) # 如果是当前版本,不需要回滚 if version == workflow.version: raise ValidationError("不能回滚到当前版本") # 查找要回滚的版本 workflow_version = db.query(WorkflowVersion).filter( WorkflowVersion.workflow_id == workflow_id, WorkflowVersion.version == version ).first() if not workflow_version: raise NotFoundError("工作流版本", f"{workflow_id} v{version}") # 保存当前版本到版本历史表 try: current_version = WorkflowVersion( workflow_id=workflow.id, version=workflow.version, name=workflow.name, description=workflow.description, nodes=workflow.nodes, edges=workflow.edges, status=workflow.status, created_by=current_user.id, comment="回滚前保存" ) db.add(current_version) except Exception as e: logger.warning(f"保存版本历史失败: {str(e)},继续执行回滚") # 回滚到指定版本 workflow.name = workflow_version.name workflow.description = workflow_version.description workflow.nodes = workflow_version.nodes workflow.edges = workflow_version.edges workflow.status = workflow_version.status workflow.version += 1 db.commit() db.refresh(workflow) logger.info(f"工作流 {workflow_id} 已回滚到版本 {version}") return workflow