635 lines
20 KiB
Python
635 lines
20 KiB
Python
"""
|
||
工作流API
|
||
"""
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from sqlalchemy.orm import Session
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional, Dict, Any
|
||
from datetime import datetime
|
||
import logging
|
||
from app.core.database import get_db
|
||
from app.models.workflow import Workflow
|
||
from app.models.workflow_version import WorkflowVersion
|
||
from app.api.auth import get_current_user, UserResponse
|
||
from app.models.user import User
|
||
from app.core.exceptions import NotFoundError, ValidationError, ConflictError
|
||
from app.services.workflow_validator import validate_workflow
|
||
from app.services.workflow_templates import list_templates, create_from_template, get_template
|
||
from app.services.permission_service import check_workflow_permission
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/api/v1/workflows", tags=["workflows"])
|
||
|
||
|
||
class WorkflowCreate(BaseModel):
|
||
"""工作流创建模型"""
|
||
name: str
|
||
description: Optional[str] = None
|
||
nodes: List[Dict[str, Any]]
|
||
edges: List[Dict[str, Any]]
|
||
|
||
|
||
class WorkflowUpdate(BaseModel):
|
||
"""工作流更新模型"""
|
||
name: Optional[str] = None
|
||
description: Optional[str] = None
|
||
nodes: Optional[List[Dict[str, Any]]] = None
|
||
edges: Optional[List[Dict[str, Any]]] = None
|
||
status: Optional[str] = None
|
||
|
||
|
||
class WorkflowResponse(BaseModel):
|
||
"""工作流响应模型"""
|
||
id: str
|
||
name: str
|
||
description: Optional[str]
|
||
nodes: List[Dict[str, Any]]
|
||
edges: List[Dict[str, Any]]
|
||
version: int
|
||
status: str
|
||
user_id: str
|
||
created_at: datetime
|
||
updated_at: datetime
|
||
|
||
class Config:
|
||
from_attributes = True
|
||
|
||
|
||
@router.post("/validate", status_code=status.HTTP_200_OK)
|
||
async def validate_workflow_endpoint(
|
||
workflow_data: WorkflowCreate,
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""验证工作流(不保存)"""
|
||
validation_result = validate_workflow(workflow_data.nodes, workflow_data.edges)
|
||
return validation_result
|
||
|
||
|
||
@router.get("/templates", status_code=status.HTTP_200_OK)
|
||
async def get_workflow_templates(
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""获取工作流模板列表"""
|
||
templates = list_templates()
|
||
return templates
|
||
|
||
|
||
@router.get("/templates/{template_id}", status_code=status.HTTP_200_OK)
|
||
async def get_workflow_template(
|
||
template_id: str,
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""获取工作流模板详情"""
|
||
template = get_template(template_id)
|
||
if not template:
|
||
raise NotFoundError("模板", template_id)
|
||
return template
|
||
|
||
|
||
@router.post("/templates/{template_id}/create", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED)
|
||
async def create_workflow_from_template(
|
||
template_id: str,
|
||
name: Optional[str] = None,
|
||
description: Optional[str] = None,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""从模板创建工作流"""
|
||
try:
|
||
workflow_data = create_from_template(template_id, name, description)
|
||
except ValueError as e:
|
||
raise NotFoundError("模板", template_id)
|
||
|
||
# 验证工作流
|
||
validation_result = validate_workflow(workflow_data["nodes"], workflow_data["edges"])
|
||
if not validation_result["valid"]:
|
||
raise ValidationError(f"模板工作流验证失败: {', '.join(validation_result['errors'])}")
|
||
|
||
# 创建工作流
|
||
workflow = Workflow(
|
||
name=workflow_data["name"],
|
||
description=workflow_data["description"],
|
||
nodes=workflow_data["nodes"],
|
||
edges=workflow_data["edges"],
|
||
user_id=current_user.id
|
||
)
|
||
db.add(workflow)
|
||
db.commit()
|
||
db.refresh(workflow)
|
||
return workflow
|
||
|
||
|
||
@router.get("", response_model=List[WorkflowResponse])
|
||
async def get_workflows(
|
||
skip: int = 0,
|
||
limit: int = 100,
|
||
search: Optional[str] = None,
|
||
status: Optional[str] = None,
|
||
sort_by: Optional[str] = "created_at",
|
||
sort_order: Optional[str] = "desc",
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""获取工作流列表(支持搜索、筛选、排序)"""
|
||
# 管理员可以看到所有工作流,普通用户只能看到自己拥有的或有read权限的
|
||
if current_user.role == "admin":
|
||
query = db.query(Workflow)
|
||
else:
|
||
# 获取用户拥有或有read权限的工作流
|
||
from sqlalchemy import or_
|
||
from app.models.permission import WorkflowPermission
|
||
|
||
# 用户拥有的工作流
|
||
owned_workflows = db.query(Workflow.id).filter(Workflow.user_id == current_user.id).subquery()
|
||
|
||
# 用户有read权限的工作流(通过用户ID或角色)
|
||
user_permissions = db.query(WorkflowPermission.workflow_id).filter(
|
||
WorkflowPermission.permission_type == "read",
|
||
or_(
|
||
WorkflowPermission.user_id == current_user.id,
|
||
WorkflowPermission.role_id.in_([r.id for r in current_user.roles])
|
||
)
|
||
).subquery()
|
||
|
||
query = db.query(Workflow).filter(
|
||
or_(
|
||
Workflow.id.in_(db.query(owned_workflows.c.id)),
|
||
Workflow.id.in_(db.query(user_permissions.c.workflow_id))
|
||
)
|
||
)
|
||
|
||
# 搜索:按名称或描述搜索
|
||
if search:
|
||
search_pattern = f"%{search}%"
|
||
query = query.filter(
|
||
(Workflow.name.ilike(search_pattern)) |
|
||
(Workflow.description.ilike(search_pattern))
|
||
)
|
||
|
||
# 筛选:按状态筛选
|
||
if status:
|
||
query = query.filter(Workflow.status == status)
|
||
|
||
# 排序
|
||
if sort_by == "name":
|
||
order_by = Workflow.name
|
||
elif sort_by == "created_at":
|
||
order_by = Workflow.created_at
|
||
elif sort_by == "updated_at":
|
||
order_by = Workflow.updated_at
|
||
else:
|
||
order_by = Workflow.created_at
|
||
|
||
if sort_order == "asc":
|
||
query = query.order_by(order_by.asc())
|
||
else:
|
||
query = query.order_by(order_by.desc())
|
||
|
||
workflows = query.offset(skip).limit(limit).all()
|
||
return workflows
|
||
|
||
|
||
@router.post("", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED)
|
||
async def create_workflow(
|
||
workflow_data: WorkflowCreate,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""创建工作流"""
|
||
# 验证工作流
|
||
validation_result = validate_workflow(workflow_data.nodes, workflow_data.edges)
|
||
if not validation_result["valid"]:
|
||
raise ValidationError(f"工作流验证失败: {', '.join(validation_result['errors'])}")
|
||
|
||
# 如果有警告,记录日志
|
||
if validation_result["warnings"]:
|
||
logger.warning(f"工作流创建警告: {', '.join(validation_result['warnings'])}")
|
||
|
||
workflow = Workflow(
|
||
name=workflow_data.name,
|
||
description=workflow_data.description,
|
||
nodes=workflow_data.nodes,
|
||
edges=workflow_data.edges,
|
||
user_id=current_user.id
|
||
)
|
||
db.add(workflow)
|
||
db.commit()
|
||
db.refresh(workflow)
|
||
return workflow
|
||
|
||
|
||
@router.get("/{workflow_id}", response_model=WorkflowResponse)
|
||
async def get_workflow(
|
||
workflow_id: str,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""获取工作流详情"""
|
||
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
|
||
|
||
if not workflow:
|
||
raise NotFoundError("工作流", workflow_id)
|
||
|
||
# 检查权限:read权限
|
||
if not check_workflow_permission(db, current_user, workflow, "read"):
|
||
raise HTTPException(status_code=403, detail="无权访问此工作流")
|
||
|
||
return workflow
|
||
|
||
|
||
@router.put("/{workflow_id}", response_model=WorkflowResponse)
|
||
async def update_workflow(
|
||
workflow_id: str,
|
||
workflow_data: WorkflowUpdate,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""更新工作流(自动保存版本)"""
|
||
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
|
||
|
||
if not workflow:
|
||
raise NotFoundError("工作流", workflow_id)
|
||
|
||
# 检查权限:write权限
|
||
if not check_workflow_permission(db, current_user, workflow, "write"):
|
||
raise HTTPException(status_code=403, detail="无权修改此工作流")
|
||
|
||
# 如果更新了节点或边,需要验证
|
||
nodes_to_validate = workflow_data.nodes if workflow_data.nodes is not None else workflow.nodes
|
||
edges_to_validate = workflow_data.edges if workflow_data.edges is not None else workflow.edges
|
||
|
||
validation_result = validate_workflow(nodes_to_validate, edges_to_validate)
|
||
if not validation_result["valid"]:
|
||
raise ValidationError(f"工作流验证失败: {', '.join(validation_result['errors'])}")
|
||
|
||
# 如果有警告,记录日志
|
||
if validation_result["warnings"]:
|
||
logger.warning(f"工作流更新警告: {', '.join(validation_result['warnings'])}")
|
||
|
||
# 保存当前版本到版本历史表(如果表存在)
|
||
try:
|
||
version = WorkflowVersion(
|
||
workflow_id=workflow.id,
|
||
version=workflow.version,
|
||
name=workflow.name,
|
||
description=workflow.description,
|
||
nodes=workflow.nodes,
|
||
edges=workflow.edges,
|
||
status=workflow.status,
|
||
created_by=current_user.id
|
||
)
|
||
db.add(version)
|
||
except Exception as e:
|
||
# 如果表不存在,记录警告但不影响更新操作
|
||
logger.warning(f"保存版本历史失败: {str(e)},继续执行更新")
|
||
|
||
# 更新工作流
|
||
if workflow_data.name is not None:
|
||
workflow.name = workflow_data.name
|
||
if workflow_data.description is not None:
|
||
workflow.description = workflow_data.description
|
||
if workflow_data.nodes is not None:
|
||
workflow.nodes = workflow_data.nodes
|
||
if workflow_data.edges is not None:
|
||
workflow.edges = workflow_data.edges
|
||
if workflow_data.status is not None:
|
||
workflow.status = workflow_data.status
|
||
|
||
workflow.version += 1
|
||
db.commit()
|
||
db.refresh(workflow)
|
||
return workflow
|
||
|
||
|
||
@router.delete("/{workflow_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||
async def delete_workflow(
|
||
workflow_id: str,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""删除工作流(只有所有者可以删除)"""
|
||
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
|
||
|
||
if not workflow:
|
||
raise NotFoundError("工作流", workflow_id)
|
||
|
||
# 只有工作流所有者可以删除
|
||
if workflow.user_id != current_user.id and current_user.role != "admin":
|
||
raise HTTPException(status_code=403, detail="无权删除此工作流")
|
||
|
||
db.delete(workflow)
|
||
db.commit()
|
||
return None
|
||
|
||
|
||
@router.get("/{workflow_id}/export", status_code=status.HTTP_200_OK)
|
||
async def export_workflow(
|
||
workflow_id: str,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""导出工作流(JSON格式)"""
|
||
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
|
||
|
||
if not workflow:
|
||
raise NotFoundError("工作流", workflow_id)
|
||
|
||
# 检查权限:read权限
|
||
if not check_workflow_permission(db, current_user, workflow, "read"):
|
||
raise HTTPException(status_code=403, detail="无权导出此工作流")
|
||
|
||
return {
|
||
"id": str(workflow.id),
|
||
"name": workflow.name,
|
||
"description": workflow.description,
|
||
"nodes": workflow.nodes,
|
||
"edges": workflow.edges,
|
||
"version": workflow.version,
|
||
"status": workflow.status,
|
||
"exported_at": datetime.utcnow().isoformat()
|
||
}
|
||
|
||
|
||
@router.post("/import", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED)
|
||
async def import_workflow(
|
||
workflow_data: Dict[str, Any],
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""导入工作流(JSON格式)"""
|
||
# 提取工作流数据
|
||
name = workflow_data.get("name", "导入的工作流")
|
||
description = workflow_data.get("description")
|
||
nodes = workflow_data.get("nodes", [])
|
||
edges = workflow_data.get("edges", [])
|
||
|
||
# 验证工作流
|
||
validation_result = validate_workflow(nodes, edges)
|
||
if not validation_result["valid"]:
|
||
raise ValidationError(f"导入的工作流验证失败: {', '.join(validation_result['errors'])}")
|
||
|
||
# 重新生成节点ID(避免ID冲突)
|
||
node_id_mapping = {}
|
||
for node in nodes:
|
||
old_id = node["id"]
|
||
new_id = f"node_{len(node_id_mapping)}_{old_id}"
|
||
node_id_mapping[old_id] = new_id
|
||
node["id"] = new_id
|
||
|
||
# 更新边的源节点和目标节点ID
|
||
for edge in edges:
|
||
if edge.get("source") in node_id_mapping:
|
||
edge["source"] = node_id_mapping[edge["source"]]
|
||
if edge.get("target") in node_id_mapping:
|
||
edge["target"] = node_id_mapping[edge["target"]]
|
||
|
||
# 创建工作流
|
||
workflow = Workflow(
|
||
name=name,
|
||
description=description,
|
||
nodes=nodes,
|
||
edges=edges,
|
||
user_id=current_user.id
|
||
)
|
||
db.add(workflow)
|
||
db.commit()
|
||
db.refresh(workflow)
|
||
return workflow
|
||
|
||
|
||
@router.post("/{workflow_id}/execute")
|
||
async def execute_workflow(
|
||
workflow_id: str,
|
||
input_data: Dict[str, Any],
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""执行工作流"""
|
||
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
|
||
|
||
if not workflow:
|
||
raise NotFoundError("工作流", workflow_id)
|
||
|
||
# 检查权限:execute权限
|
||
if not check_workflow_permission(db, current_user, workflow, "execute"):
|
||
raise HTTPException(status_code=403, detail="无权执行此工作流")
|
||
|
||
# 导入executions API的创建函数
|
||
from app.api.executions import create_execution, ExecutionCreate
|
||
|
||
execution_data = ExecutionCreate(
|
||
workflow_id=workflow_id,
|
||
input_data=input_data
|
||
)
|
||
|
||
return await create_execution(execution_data, db, current_user)
|
||
|
||
|
||
# 版本管理API
|
||
class WorkflowVersionResponse(BaseModel):
|
||
"""工作流版本响应模型"""
|
||
id: str
|
||
workflow_id: str
|
||
version: int
|
||
name: str
|
||
description: Optional[str]
|
||
nodes: List[Dict[str, Any]]
|
||
edges: List[Dict[str, Any]]
|
||
status: str
|
||
created_by: Optional[str]
|
||
created_at: datetime
|
||
comment: Optional[str]
|
||
|
||
class Config:
|
||
from_attributes = True
|
||
|
||
|
||
@router.get("/{workflow_id}/versions", response_model=List[WorkflowVersionResponse])
|
||
async def get_workflow_versions(
|
||
workflow_id: str,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""获取工作流版本列表"""
|
||
# 验证工作流是否存在且属于当前用户
|
||
workflow = db.query(Workflow).filter(
|
||
Workflow.id == workflow_id,
|
||
Workflow.user_id == current_user.id
|
||
).first()
|
||
|
||
if not workflow:
|
||
raise NotFoundError("工作流", workflow_id)
|
||
|
||
# 获取所有版本(包括当前版本)
|
||
# 如果表不存在,只返回当前版本
|
||
try:
|
||
versions = db.query(WorkflowVersion).filter(
|
||
WorkflowVersion.workflow_id == workflow_id
|
||
).order_by(WorkflowVersion.version.desc()).all()
|
||
except Exception as e:
|
||
# 如果表不存在或其他数据库错误,只返回当前版本
|
||
logger.warning(f"查询版本历史失败: {str(e)},仅返回当前版本")
|
||
versions = []
|
||
|
||
# 添加当前版本到列表
|
||
current_version = WorkflowVersionResponse(
|
||
id=workflow.id,
|
||
workflow_id=workflow.id,
|
||
version=workflow.version,
|
||
name=workflow.name,
|
||
description=workflow.description,
|
||
nodes=workflow.nodes,
|
||
edges=workflow.edges,
|
||
status=workflow.status,
|
||
created_by=workflow.user_id,
|
||
created_at=workflow.updated_at or workflow.created_at,
|
||
comment="当前版本"
|
||
)
|
||
|
||
result = [current_version]
|
||
# 转换历史版本
|
||
for v in versions:
|
||
version_dict = {
|
||
"id": v.id,
|
||
"workflow_id": v.workflow_id,
|
||
"version": v.version,
|
||
"name": v.name,
|
||
"description": v.description,
|
||
"nodes": v.nodes,
|
||
"edges": v.edges,
|
||
"status": v.status,
|
||
"created_by": v.created_by,
|
||
"created_at": v.created_at,
|
||
"comment": v.comment
|
||
}
|
||
result.append(WorkflowVersionResponse(**version_dict))
|
||
|
||
return result
|
||
|
||
|
||
@router.get("/{workflow_id}/versions/{version}", response_model=WorkflowVersionResponse)
|
||
async def get_workflow_version(
|
||
workflow_id: str,
|
||
version: int,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""获取工作流特定版本"""
|
||
# 验证工作流是否存在且属于当前用户
|
||
workflow = db.query(Workflow).filter(
|
||
Workflow.id == workflow_id,
|
||
Workflow.user_id == current_user.id
|
||
).first()
|
||
|
||
if not workflow:
|
||
raise NotFoundError("工作流", workflow_id)
|
||
|
||
# 如果是当前版本
|
||
if version == workflow.version:
|
||
return WorkflowVersionResponse(
|
||
id=workflow.id,
|
||
workflow_id=workflow.id,
|
||
version=workflow.version,
|
||
name=workflow.name,
|
||
description=workflow.description,
|
||
nodes=workflow.nodes,
|
||
edges=workflow.edges,
|
||
status=workflow.status,
|
||
created_by=workflow.user_id,
|
||
created_at=workflow.updated_at or workflow.created_at,
|
||
comment="当前版本"
|
||
)
|
||
|
||
# 查找历史版本
|
||
workflow_version = db.query(WorkflowVersion).filter(
|
||
WorkflowVersion.workflow_id == workflow_id,
|
||
WorkflowVersion.version == version
|
||
).first()
|
||
|
||
if not workflow_version:
|
||
raise NotFoundError("工作流版本", f"{workflow_id} v{version}")
|
||
|
||
return WorkflowVersionResponse(
|
||
id=workflow_version.id,
|
||
workflow_id=workflow_version.workflow_id,
|
||
version=workflow_version.version,
|
||
name=workflow_version.name,
|
||
description=workflow_version.description,
|
||
nodes=workflow_version.nodes,
|
||
edges=workflow_version.edges,
|
||
status=workflow_version.status,
|
||
created_by=workflow_version.created_by,
|
||
created_at=workflow_version.created_at,
|
||
comment=workflow_version.comment
|
||
)
|
||
|
||
|
||
class WorkflowVersionRollback(BaseModel):
|
||
"""工作流版本回滚模型"""
|
||
comment: Optional[str] = None
|
||
|
||
|
||
@router.post("/{workflow_id}/versions/{version}/rollback", response_model=WorkflowResponse)
|
||
async def rollback_workflow_version(
|
||
workflow_id: str,
|
||
version: int,
|
||
rollback_data: Optional[WorkflowVersionRollback] = None,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""回滚工作流到指定版本"""
|
||
# 验证工作流是否存在且属于当前用户
|
||
workflow = db.query(Workflow).filter(
|
||
Workflow.id == workflow_id,
|
||
Workflow.user_id == current_user.id
|
||
).first()
|
||
|
||
if not workflow:
|
||
raise NotFoundError("工作流", workflow_id)
|
||
|
||
# 如果是当前版本,不需要回滚
|
||
if version == workflow.version:
|
||
raise ValidationError("不能回滚到当前版本")
|
||
|
||
# 查找要回滚的版本
|
||
workflow_version = db.query(WorkflowVersion).filter(
|
||
WorkflowVersion.workflow_id == workflow_id,
|
||
WorkflowVersion.version == version
|
||
).first()
|
||
|
||
if not workflow_version:
|
||
raise NotFoundError("工作流版本", f"{workflow_id} v{version}")
|
||
|
||
# 保存当前版本到版本历史表
|
||
try:
|
||
current_version = WorkflowVersion(
|
||
workflow_id=workflow.id,
|
||
version=workflow.version,
|
||
name=workflow.name,
|
||
description=workflow.description,
|
||
nodes=workflow.nodes,
|
||
edges=workflow.edges,
|
||
status=workflow.status,
|
||
created_by=current_user.id,
|
||
comment="回滚前保存"
|
||
)
|
||
db.add(current_version)
|
||
except Exception as e:
|
||
logger.warning(f"保存版本历史失败: {str(e)},继续执行回滚")
|
||
|
||
# 回滚到指定版本
|
||
workflow.name = workflow_version.name
|
||
workflow.description = workflow_version.description
|
||
workflow.nodes = workflow_version.nodes
|
||
workflow.edges = workflow_version.edges
|
||
workflow.status = workflow_version.status
|
||
workflow.version += 1
|
||
|
||
db.commit()
|
||
db.refresh(workflow)
|
||
|
||
logger.info(f"工作流 {workflow_id} 已回滚到版本 {version}")
|
||
|
||
return workflow
|