Files
aiagent/backend/app/api/workflows.py
2026-01-19 00:09:36 +08:00

635 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工作流API
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from datetime import datetime
import logging
from app.core.database import get_db
from app.models.workflow import Workflow
from app.models.workflow_version import WorkflowVersion
from app.api.auth import get_current_user, UserResponse
from app.models.user import User
from app.core.exceptions import NotFoundError, ValidationError, ConflictError
from app.services.workflow_validator import validate_workflow
from app.services.workflow_templates import list_templates, create_from_template, get_template
from app.services.permission_service import check_workflow_permission
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/workflows", tags=["workflows"])
class WorkflowCreate(BaseModel):
"""工作流创建模型"""
name: str
description: Optional[str] = None
nodes: List[Dict[str, Any]]
edges: List[Dict[str, Any]]
class WorkflowUpdate(BaseModel):
"""工作流更新模型"""
name: Optional[str] = None
description: Optional[str] = None
nodes: Optional[List[Dict[str, Any]]] = None
edges: Optional[List[Dict[str, Any]]] = None
status: Optional[str] = None
class WorkflowResponse(BaseModel):
"""工作流响应模型"""
id: str
name: str
description: Optional[str]
nodes: List[Dict[str, Any]]
edges: List[Dict[str, Any]]
version: int
status: str
user_id: str
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
@router.post("/validate", status_code=status.HTTP_200_OK)
async def validate_workflow_endpoint(
workflow_data: WorkflowCreate,
current_user: User = Depends(get_current_user)
):
"""验证工作流(不保存)"""
validation_result = validate_workflow(workflow_data.nodes, workflow_data.edges)
return validation_result
@router.get("/templates", status_code=status.HTTP_200_OK)
async def get_workflow_templates(
current_user: User = Depends(get_current_user)
):
"""获取工作流模板列表"""
templates = list_templates()
return templates
@router.get("/templates/{template_id}", status_code=status.HTTP_200_OK)
async def get_workflow_template(
template_id: str,
current_user: User = Depends(get_current_user)
):
"""获取工作流模板详情"""
template = get_template(template_id)
if not template:
raise NotFoundError("模板", template_id)
return template
@router.post("/templates/{template_id}/create", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED)
async def create_workflow_from_template(
template_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""从模板创建工作流"""
try:
workflow_data = create_from_template(template_id, name, description)
except ValueError as e:
raise NotFoundError("模板", template_id)
# 验证工作流
validation_result = validate_workflow(workflow_data["nodes"], workflow_data["edges"])
if not validation_result["valid"]:
raise ValidationError(f"模板工作流验证失败: {', '.join(validation_result['errors'])}")
# 创建工作流
workflow = Workflow(
name=workflow_data["name"],
description=workflow_data["description"],
nodes=workflow_data["nodes"],
edges=workflow_data["edges"],
user_id=current_user.id
)
db.add(workflow)
db.commit()
db.refresh(workflow)
return workflow
@router.get("", response_model=List[WorkflowResponse])
async def get_workflows(
skip: int = 0,
limit: int = 100,
search: Optional[str] = None,
status: Optional[str] = None,
sort_by: Optional[str] = "created_at",
sort_order: Optional[str] = "desc",
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取工作流列表(支持搜索、筛选、排序)"""
# 管理员可以看到所有工作流普通用户只能看到自己拥有的或有read权限的
if current_user.role == "admin":
query = db.query(Workflow)
else:
# 获取用户拥有或有read权限的工作流
from sqlalchemy import or_
from app.models.permission import WorkflowPermission
# 用户拥有的工作流
owned_workflows = db.query(Workflow.id).filter(Workflow.user_id == current_user.id).subquery()
# 用户有read权限的工作流通过用户ID或角色
user_permissions = db.query(WorkflowPermission.workflow_id).filter(
WorkflowPermission.permission_type == "read",
or_(
WorkflowPermission.user_id == current_user.id,
WorkflowPermission.role_id.in_([r.id for r in current_user.roles])
)
).subquery()
query = db.query(Workflow).filter(
or_(
Workflow.id.in_(db.query(owned_workflows.c.id)),
Workflow.id.in_(db.query(user_permissions.c.workflow_id))
)
)
# 搜索:按名称或描述搜索
if search:
search_pattern = f"%{search}%"
query = query.filter(
(Workflow.name.ilike(search_pattern)) |
(Workflow.description.ilike(search_pattern))
)
# 筛选:按状态筛选
if status:
query = query.filter(Workflow.status == status)
# 排序
if sort_by == "name":
order_by = Workflow.name
elif sort_by == "created_at":
order_by = Workflow.created_at
elif sort_by == "updated_at":
order_by = Workflow.updated_at
else:
order_by = Workflow.created_at
if sort_order == "asc":
query = query.order_by(order_by.asc())
else:
query = query.order_by(order_by.desc())
workflows = query.offset(skip).limit(limit).all()
return workflows
@router.post("", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED)
async def create_workflow(
workflow_data: WorkflowCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建工作流"""
# 验证工作流
validation_result = validate_workflow(workflow_data.nodes, workflow_data.edges)
if not validation_result["valid"]:
raise ValidationError(f"工作流验证失败: {', '.join(validation_result['errors'])}")
# 如果有警告,记录日志
if validation_result["warnings"]:
logger.warning(f"工作流创建警告: {', '.join(validation_result['warnings'])}")
workflow = Workflow(
name=workflow_data.name,
description=workflow_data.description,
nodes=workflow_data.nodes,
edges=workflow_data.edges,
user_id=current_user.id
)
db.add(workflow)
db.commit()
db.refresh(workflow)
return workflow
@router.get("/{workflow_id}", response_model=WorkflowResponse)
async def get_workflow(
workflow_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取工作流详情"""
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 检查权限read权限
if not check_workflow_permission(db, current_user, workflow, "read"):
raise HTTPException(status_code=403, detail="无权访问此工作流")
return workflow
@router.put("/{workflow_id}", response_model=WorkflowResponse)
async def update_workflow(
workflow_id: str,
workflow_data: WorkflowUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新工作流(自动保存版本)"""
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 检查权限write权限
if not check_workflow_permission(db, current_user, workflow, "write"):
raise HTTPException(status_code=403, detail="无权修改此工作流")
# 如果更新了节点或边,需要验证
nodes_to_validate = workflow_data.nodes if workflow_data.nodes is not None else workflow.nodes
edges_to_validate = workflow_data.edges if workflow_data.edges is not None else workflow.edges
validation_result = validate_workflow(nodes_to_validate, edges_to_validate)
if not validation_result["valid"]:
raise ValidationError(f"工作流验证失败: {', '.join(validation_result['errors'])}")
# 如果有警告,记录日志
if validation_result["warnings"]:
logger.warning(f"工作流更新警告: {', '.join(validation_result['warnings'])}")
# 保存当前版本到版本历史表(如果表存在)
try:
version = WorkflowVersion(
workflow_id=workflow.id,
version=workflow.version,
name=workflow.name,
description=workflow.description,
nodes=workflow.nodes,
edges=workflow.edges,
status=workflow.status,
created_by=current_user.id
)
db.add(version)
except Exception as e:
# 如果表不存在,记录警告但不影响更新操作
logger.warning(f"保存版本历史失败: {str(e)},继续执行更新")
# 更新工作流
if workflow_data.name is not None:
workflow.name = workflow_data.name
if workflow_data.description is not None:
workflow.description = workflow_data.description
if workflow_data.nodes is not None:
workflow.nodes = workflow_data.nodes
if workflow_data.edges is not None:
workflow.edges = workflow_data.edges
if workflow_data.status is not None:
workflow.status = workflow_data.status
workflow.version += 1
db.commit()
db.refresh(workflow)
return workflow
@router.delete("/{workflow_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_workflow(
workflow_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除工作流(只有所有者可以删除)"""
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 只有工作流所有者可以删除
if workflow.user_id != current_user.id and current_user.role != "admin":
raise HTTPException(status_code=403, detail="无权删除此工作流")
db.delete(workflow)
db.commit()
return None
@router.get("/{workflow_id}/export", status_code=status.HTTP_200_OK)
async def export_workflow(
workflow_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""导出工作流JSON格式"""
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 检查权限read权限
if not check_workflow_permission(db, current_user, workflow, "read"):
raise HTTPException(status_code=403, detail="无权导出此工作流")
return {
"id": str(workflow.id),
"name": workflow.name,
"description": workflow.description,
"nodes": workflow.nodes,
"edges": workflow.edges,
"version": workflow.version,
"status": workflow.status,
"exported_at": datetime.utcnow().isoformat()
}
@router.post("/import", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED)
async def import_workflow(
workflow_data: Dict[str, Any],
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""导入工作流JSON格式"""
# 提取工作流数据
name = workflow_data.get("name", "导入的工作流")
description = workflow_data.get("description")
nodes = workflow_data.get("nodes", [])
edges = workflow_data.get("edges", [])
# 验证工作流
validation_result = validate_workflow(nodes, edges)
if not validation_result["valid"]:
raise ValidationError(f"导入的工作流验证失败: {', '.join(validation_result['errors'])}")
# 重新生成节点ID避免ID冲突
node_id_mapping = {}
for node in nodes:
old_id = node["id"]
new_id = f"node_{len(node_id_mapping)}_{old_id}"
node_id_mapping[old_id] = new_id
node["id"] = new_id
# 更新边的源节点和目标节点ID
for edge in edges:
if edge.get("source") in node_id_mapping:
edge["source"] = node_id_mapping[edge["source"]]
if edge.get("target") in node_id_mapping:
edge["target"] = node_id_mapping[edge["target"]]
# 创建工作流
workflow = Workflow(
name=name,
description=description,
nodes=nodes,
edges=edges,
user_id=current_user.id
)
db.add(workflow)
db.commit()
db.refresh(workflow)
return workflow
@router.post("/{workflow_id}/execute")
async def execute_workflow(
workflow_id: str,
input_data: Dict[str, Any],
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""执行工作流"""
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 检查权限execute权限
if not check_workflow_permission(db, current_user, workflow, "execute"):
raise HTTPException(status_code=403, detail="无权执行此工作流")
# 导入executions API的创建函数
from app.api.executions import create_execution, ExecutionCreate
execution_data = ExecutionCreate(
workflow_id=workflow_id,
input_data=input_data
)
return await create_execution(execution_data, db, current_user)
# 版本管理API
class WorkflowVersionResponse(BaseModel):
"""工作流版本响应模型"""
id: str
workflow_id: str
version: int
name: str
description: Optional[str]
nodes: List[Dict[str, Any]]
edges: List[Dict[str, Any]]
status: str
created_by: Optional[str]
created_at: datetime
comment: Optional[str]
class Config:
from_attributes = True
@router.get("/{workflow_id}/versions", response_model=List[WorkflowVersionResponse])
async def get_workflow_versions(
workflow_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取工作流版本列表"""
# 验证工作流是否存在且属于当前用户
workflow = db.query(Workflow).filter(
Workflow.id == workflow_id,
Workflow.user_id == current_user.id
).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 获取所有版本(包括当前版本)
# 如果表不存在,只返回当前版本
try:
versions = db.query(WorkflowVersion).filter(
WorkflowVersion.workflow_id == workflow_id
).order_by(WorkflowVersion.version.desc()).all()
except Exception as e:
# 如果表不存在或其他数据库错误,只返回当前版本
logger.warning(f"查询版本历史失败: {str(e)},仅返回当前版本")
versions = []
# 添加当前版本到列表
current_version = WorkflowVersionResponse(
id=workflow.id,
workflow_id=workflow.id,
version=workflow.version,
name=workflow.name,
description=workflow.description,
nodes=workflow.nodes,
edges=workflow.edges,
status=workflow.status,
created_by=workflow.user_id,
created_at=workflow.updated_at or workflow.created_at,
comment="当前版本"
)
result = [current_version]
# 转换历史版本
for v in versions:
version_dict = {
"id": v.id,
"workflow_id": v.workflow_id,
"version": v.version,
"name": v.name,
"description": v.description,
"nodes": v.nodes,
"edges": v.edges,
"status": v.status,
"created_by": v.created_by,
"created_at": v.created_at,
"comment": v.comment
}
result.append(WorkflowVersionResponse(**version_dict))
return result
@router.get("/{workflow_id}/versions/{version}", response_model=WorkflowVersionResponse)
async def get_workflow_version(
workflow_id: str,
version: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取工作流特定版本"""
# 验证工作流是否存在且属于当前用户
workflow = db.query(Workflow).filter(
Workflow.id == workflow_id,
Workflow.user_id == current_user.id
).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 如果是当前版本
if version == workflow.version:
return WorkflowVersionResponse(
id=workflow.id,
workflow_id=workflow.id,
version=workflow.version,
name=workflow.name,
description=workflow.description,
nodes=workflow.nodes,
edges=workflow.edges,
status=workflow.status,
created_by=workflow.user_id,
created_at=workflow.updated_at or workflow.created_at,
comment="当前版本"
)
# 查找历史版本
workflow_version = db.query(WorkflowVersion).filter(
WorkflowVersion.workflow_id == workflow_id,
WorkflowVersion.version == version
).first()
if not workflow_version:
raise NotFoundError("工作流版本", f"{workflow_id} v{version}")
return WorkflowVersionResponse(
id=workflow_version.id,
workflow_id=workflow_version.workflow_id,
version=workflow_version.version,
name=workflow_version.name,
description=workflow_version.description,
nodes=workflow_version.nodes,
edges=workflow_version.edges,
status=workflow_version.status,
created_by=workflow_version.created_by,
created_at=workflow_version.created_at,
comment=workflow_version.comment
)
class WorkflowVersionRollback(BaseModel):
"""工作流版本回滚模型"""
comment: Optional[str] = None
@router.post("/{workflow_id}/versions/{version}/rollback", response_model=WorkflowResponse)
async def rollback_workflow_version(
workflow_id: str,
version: int,
rollback_data: Optional[WorkflowVersionRollback] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""回滚工作流到指定版本"""
# 验证工作流是否存在且属于当前用户
workflow = db.query(Workflow).filter(
Workflow.id == workflow_id,
Workflow.user_id == current_user.id
).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 如果是当前版本,不需要回滚
if version == workflow.version:
raise ValidationError("不能回滚到当前版本")
# 查找要回滚的版本
workflow_version = db.query(WorkflowVersion).filter(
WorkflowVersion.workflow_id == workflow_id,
WorkflowVersion.version == version
).first()
if not workflow_version:
raise NotFoundError("工作流版本", f"{workflow_id} v{version}")
# 保存当前版本到版本历史表
try:
current_version = WorkflowVersion(
workflow_id=workflow.id,
version=workflow.version,
name=workflow.name,
description=workflow.description,
nodes=workflow.nodes,
edges=workflow.edges,
status=workflow.status,
created_by=current_user.id,
comment="回滚前保存"
)
db.add(current_version)
except Exception as e:
logger.warning(f"保存版本历史失败: {str(e)},继续执行回滚")
# 回滚到指定版本
workflow.name = workflow_version.name
workflow.description = workflow_version.description
workflow.nodes = workflow_version.nodes
workflow.edges = workflow_version.edges
workflow.status = workflow_version.status
workflow.version += 1
db.commit()
db.refresh(workflow)
logger.info(f"工作流 {workflow_id} 已回滚到版本 {version}")
return workflow