Files
aiagent/backend/app/api/workflows.py

669 lines
22 KiB
Python
Raw Normal View History

2026-01-19 00:09:36 +08:00
"""
工作流API
"""
from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
2026-01-19 00:09:36 +08:00
from sqlalchemy.orm import Session
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from datetime import datetime
import logging
from app.core.database import get_db
from app.models.workflow import Workflow
from app.models.workflow_version import WorkflowVersion
from app.api.auth import get_current_user, UserResponse
from app.models.user import User
from app.core.exceptions import NotFoundError, ValidationError, ConflictError
from app.services.workflow_validator import validate_workflow
from app.services.workflow_templates import list_templates, create_from_template, get_template
from app.services.permission_service import check_workflow_permission
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/workflows", tags=["workflows"])
2026-03-06 22:31:41 +08:00
def _workflow_to_response(workflow: Workflow) -> dict:
"""将Workflow对象转换为响应格式"""
return {
"id": workflow.id,
"name": workflow.name,
"description": workflow.description,
"nodes": workflow.nodes,
"edges": workflow.edges,
"version": workflow.version,
"status": workflow.status,
"user_id": workflow.user_id if workflow.user_id else None,
"created_at": workflow.created_at if workflow.created_at else datetime.now(),
"updated_at": workflow.updated_at if workflow.updated_at else datetime.now()
}
2026-01-19 00:09:36 +08:00
class WorkflowCreate(BaseModel):
"""工作流创建模型"""
name: str
description: Optional[str] = None
nodes: List[Dict[str, Any]]
edges: List[Dict[str, Any]]
class WorkflowUpdate(BaseModel):
"""工作流更新模型"""
name: Optional[str] = None
description: Optional[str] = None
nodes: Optional[List[Dict[str, Any]]] = None
edges: Optional[List[Dict[str, Any]]] = None
status: Optional[str] = None
class WorkflowResponse(BaseModel):
"""工作流响应模型"""
id: str
name: str
description: Optional[str]
nodes: List[Dict[str, Any]]
edges: List[Dict[str, Any]]
version: int
status: str
2026-03-06 22:31:41 +08:00
user_id: Optional[str] # 允许为None
2026-01-19 00:09:36 +08:00
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
@router.post("/validate", status_code=status.HTTP_200_OK)
async def validate_workflow_endpoint(
workflow_data: WorkflowCreate,
current_user: User = Depends(get_current_user)
):
"""验证工作流(不保存)"""
validation_result = validate_workflow(workflow_data.nodes, workflow_data.edges)
return validation_result
@router.get("/templates", status_code=status.HTTP_200_OK)
async def get_workflow_templates(
current_user: User = Depends(get_current_user)
):
"""获取工作流模板列表"""
templates = list_templates()
return templates
@router.get("/templates/{template_id}", status_code=status.HTTP_200_OK)
async def get_workflow_template(
template_id: str,
current_user: User = Depends(get_current_user)
):
"""获取工作流模板详情"""
template = get_template(template_id)
if not template:
raise NotFoundError("模板", template_id)
return template
@router.post("/templates/{template_id}/create", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED)
async def create_workflow_from_template(
template_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""从模板创建工作流"""
try:
workflow_data = create_from_template(template_id, name, description)
except ValueError as e:
raise NotFoundError("模板", template_id)
# 验证工作流
validation_result = validate_workflow(workflow_data["nodes"], workflow_data["edges"])
if not validation_result["valid"]:
raise ValidationError(f"模板工作流验证失败: {', '.join(validation_result['errors'])}")
# 创建工作流
workflow = Workflow(
name=workflow_data["name"],
description=workflow_data["description"],
nodes=workflow_data["nodes"],
edges=workflow_data["edges"],
user_id=current_user.id
)
db.add(workflow)
db.commit()
db.refresh(workflow)
2026-03-06 22:31:41 +08:00
return _workflow_to_response(workflow)
2026-01-19 00:09:36 +08:00
@router.get("", response_model=List[WorkflowResponse])
async def get_workflows(
skip: int = 0,
limit: int = 100,
search: Optional[str] = None,
status: Optional[str] = None,
workspace_id: Optional[str] = Query(None, description="工作区ID筛选"),
2026-01-19 00:09:36 +08:00
sort_by: Optional[str] = "created_at",
sort_order: Optional[str] = "desc",
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取工作流列表(支持搜索、筛选、排序、工作区筛选)"""
2026-01-19 00:09:36 +08:00
# 管理员可以看到所有工作流普通用户只能看到自己拥有的或有read权限的
if current_user.role == "admin":
query = db.query(Workflow)
else:
# 获取用户拥有或有read权限的工作流
from sqlalchemy import or_
from app.models.permission import WorkflowPermission
2026-01-19 00:09:36 +08:00
# 用户拥有的工作流
owned_workflows = db.query(Workflow.id).filter(Workflow.user_id == current_user.id).subquery()
2026-01-19 00:09:36 +08:00
# 用户有read权限的工作流通过用户ID或角色
user_permissions = db.query(WorkflowPermission.workflow_id).filter(
WorkflowPermission.permission_type == "read",
or_(
WorkflowPermission.user_id == current_user.id,
WorkflowPermission.role_id.in_([r.id for r in current_user.roles])
)
).subquery()
2026-01-19 00:09:36 +08:00
query = db.query(Workflow).filter(
or_(
Workflow.id.in_(db.query(owned_workflows.c.id)),
Workflow.id.in_(db.query(user_permissions.c.workflow_id))
)
)
# 工作区筛选
if workspace_id:
query = query.filter(Workflow.workspace_id == workspace_id)
2026-01-19 00:09:36 +08:00
# 搜索:按名称或描述搜索
if search:
search_pattern = f"%{search}%"
query = query.filter(
(Workflow.name.ilike(search_pattern)) |
(Workflow.description.ilike(search_pattern))
)
# 筛选:按状态筛选
if status:
query = query.filter(Workflow.status == status)
# 排序
if sort_by == "name":
order_by = Workflow.name
elif sort_by == "created_at":
order_by = Workflow.created_at
elif sort_by == "updated_at":
order_by = Workflow.updated_at
else:
order_by = Workflow.created_at
if sort_order == "asc":
query = query.order_by(order_by.asc())
else:
query = query.order_by(order_by.desc())
workflows = query.offset(skip).limit(limit).all()
2026-03-06 22:31:41 +08:00
# 转换为响应格式确保user_id和日期时间字段正确处理
return [_workflow_to_response(w) for w in workflows]
2026-01-19 00:09:36 +08:00
@router.post("", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED)
async def create_workflow(
workflow_data: WorkflowCreate,
request: Request,
2026-01-19 00:09:36 +08:00
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建工作流"""
# 从 JWT 提取当前工作区 ID
from app.core.security import decode_access_token
ws_id = None
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
payload = decode_access_token(auth_header[7:])
if payload:
ws_id = payload.get("ws") or None
2026-01-19 00:09:36 +08:00
# 验证工作流
validation_result = validate_workflow(workflow_data.nodes, workflow_data.edges)
if not validation_result["valid"]:
raise ValidationError(f"工作流验证失败: {', '.join(validation_result['errors'])}")
2026-01-19 00:09:36 +08:00
# 如果有警告,记录日志
if validation_result["warnings"]:
logger.warning(f"工作流创建警告: {', '.join(validation_result['warnings'])}")
2026-01-19 00:09:36 +08:00
workflow = Workflow(
name=workflow_data.name,
description=workflow_data.description,
nodes=workflow_data.nodes,
edges=workflow_data.edges,
user_id=current_user.id,
workspace_id=ws_id,
2026-01-19 00:09:36 +08:00
)
db.add(workflow)
db.commit()
db.refresh(workflow)
2026-03-06 22:31:41 +08:00
return _workflow_to_response(workflow)
2026-01-19 00:09:36 +08:00
@router.get("/{workflow_id}", response_model=WorkflowResponse)
async def get_workflow(
workflow_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取工作流详情"""
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 检查权限read权限
if not check_workflow_permission(db, current_user, workflow, "read"):
raise HTTPException(status_code=403, detail="无权访问此工作流")
2026-03-06 22:31:41 +08:00
return _workflow_to_response(workflow)
2026-01-19 00:09:36 +08:00
@router.put("/{workflow_id}", response_model=WorkflowResponse)
async def update_workflow(
workflow_id: str,
workflow_data: WorkflowUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新工作流(自动保存版本)"""
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 检查权限write权限
if not check_workflow_permission(db, current_user, workflow, "write"):
raise HTTPException(status_code=403, detail="无权修改此工作流")
# 如果更新了节点或边,需要验证
nodes_to_validate = workflow_data.nodes if workflow_data.nodes is not None else workflow.nodes
edges_to_validate = workflow_data.edges if workflow_data.edges is not None else workflow.edges
validation_result = validate_workflow(nodes_to_validate, edges_to_validate)
if not validation_result["valid"]:
raise ValidationError(f"工作流验证失败: {', '.join(validation_result['errors'])}")
# 如果有警告,记录日志
if validation_result["warnings"]:
logger.warning(f"工作流更新警告: {', '.join(validation_result['warnings'])}")
# 保存当前版本到版本历史表(如果表存在)
try:
version = WorkflowVersion(
workflow_id=workflow.id,
version=workflow.version,
name=workflow.name,
description=workflow.description,
nodes=workflow.nodes,
edges=workflow.edges,
status=workflow.status,
created_by=current_user.id
)
db.add(version)
except Exception as e:
# 如果表不存在,记录警告但不影响更新操作
logger.warning(f"保存版本历史失败: {str(e)},继续执行更新")
# 更新工作流
if workflow_data.name is not None:
workflow.name = workflow_data.name
if workflow_data.description is not None:
workflow.description = workflow_data.description
if workflow_data.nodes is not None:
workflow.nodes = workflow_data.nodes
if workflow_data.edges is not None:
workflow.edges = workflow_data.edges
if workflow_data.status is not None:
workflow.status = workflow_data.status
workflow.version += 1
db.commit()
db.refresh(workflow)
2026-03-06 22:31:41 +08:00
return _workflow_to_response(workflow)
2026-01-19 00:09:36 +08:00
@router.delete("/{workflow_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_workflow(
workflow_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除工作流(只有所有者可以删除)"""
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 只有工作流所有者可以删除
if workflow.user_id != current_user.id and current_user.role != "admin":
raise HTTPException(status_code=403, detail="无权删除此工作流")
db.delete(workflow)
db.commit()
return None
@router.get("/{workflow_id}/export", status_code=status.HTTP_200_OK)
async def export_workflow(
workflow_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""导出工作流JSON格式"""
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 检查权限read权限
if not check_workflow_permission(db, current_user, workflow, "read"):
raise HTTPException(status_code=403, detail="无权导出此工作流")
return {
"id": str(workflow.id),
"name": workflow.name,
"description": workflow.description,
"nodes": workflow.nodes,
"edges": workflow.edges,
"version": workflow.version,
"status": workflow.status,
"exported_at": datetime.utcnow().isoformat()
}
@router.post("/import", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED)
async def import_workflow(
workflow_data: Dict[str, Any],
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""导入工作流JSON格式"""
# 提取工作流数据
name = workflow_data.get("name", "导入的工作流")
description = workflow_data.get("description")
nodes = workflow_data.get("nodes", [])
edges = workflow_data.get("edges", [])
# 验证工作流
validation_result = validate_workflow(nodes, edges)
if not validation_result["valid"]:
raise ValidationError(f"导入的工作流验证失败: {', '.join(validation_result['errors'])}")
# 重新生成节点ID避免ID冲突
node_id_mapping = {}
for node in nodes:
old_id = node["id"]
new_id = f"node_{len(node_id_mapping)}_{old_id}"
node_id_mapping[old_id] = new_id
node["id"] = new_id
# 更新边的源节点和目标节点ID
for edge in edges:
if edge.get("source") in node_id_mapping:
edge["source"] = node_id_mapping[edge["source"]]
if edge.get("target") in node_id_mapping:
edge["target"] = node_id_mapping[edge["target"]]
# 创建工作流
workflow = Workflow(
name=name,
description=description,
nodes=nodes,
edges=edges,
user_id=current_user.id
)
db.add(workflow)
db.commit()
db.refresh(workflow)
2026-03-06 22:31:41 +08:00
return _workflow_to_response(workflow)
2026-01-19 00:09:36 +08:00
@router.post("/{workflow_id}/execute")
async def execute_workflow(
workflow_id: str,
input_data: Dict[str, Any],
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""执行工作流"""
workflow = db.query(Workflow).filter(Workflow.id == workflow_id).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 检查权限execute权限
if not check_workflow_permission(db, current_user, workflow, "execute"):
raise HTTPException(status_code=403, detail="无权执行此工作流")
# 导入executions API的创建函数
from app.api.executions import create_execution, ExecutionCreate
execution_data = ExecutionCreate(
workflow_id=workflow_id,
input_data=input_data
)
return await create_execution(execution_data, db, current_user)
# 版本管理API
class WorkflowVersionResponse(BaseModel):
"""工作流版本响应模型"""
id: str
workflow_id: str
version: int
name: str
description: Optional[str]
nodes: List[Dict[str, Any]]
edges: List[Dict[str, Any]]
status: str
created_by: Optional[str]
created_at: datetime
comment: Optional[str]
class Config:
from_attributes = True
@router.get("/{workflow_id}/versions", response_model=List[WorkflowVersionResponse])
async def get_workflow_versions(
workflow_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取工作流版本列表"""
# 验证工作流是否存在且属于当前用户
workflow = db.query(Workflow).filter(
Workflow.id == workflow_id,
Workflow.user_id == current_user.id
).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 获取所有版本(包括当前版本)
# 如果表不存在,只返回当前版本
try:
versions = db.query(WorkflowVersion).filter(
WorkflowVersion.workflow_id == workflow_id
).order_by(WorkflowVersion.version.desc()).all()
except Exception as e:
# 如果表不存在或其他数据库错误,只返回当前版本
logger.warning(f"查询版本历史失败: {str(e)},仅返回当前版本")
versions = []
# 添加当前版本到列表
current_version = WorkflowVersionResponse(
id=workflow.id,
workflow_id=workflow.id,
version=workflow.version,
name=workflow.name,
description=workflow.description,
nodes=workflow.nodes,
edges=workflow.edges,
status=workflow.status,
created_by=workflow.user_id,
created_at=workflow.updated_at or workflow.created_at,
comment="当前版本"
)
result = [current_version]
# 转换历史版本
for v in versions:
version_dict = {
"id": v.id,
"workflow_id": v.workflow_id,
"version": v.version,
"name": v.name,
"description": v.description,
"nodes": v.nodes,
"edges": v.edges,
"status": v.status,
"created_by": v.created_by,
"created_at": v.created_at,
"comment": v.comment
}
result.append(WorkflowVersionResponse(**version_dict))
return result
@router.get("/{workflow_id}/versions/{version}", response_model=WorkflowVersionResponse)
async def get_workflow_version(
workflow_id: str,
version: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取工作流特定版本"""
# 验证工作流是否存在且属于当前用户
workflow = db.query(Workflow).filter(
Workflow.id == workflow_id,
Workflow.user_id == current_user.id
).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 如果是当前版本
if version == workflow.version:
return WorkflowVersionResponse(
id=workflow.id,
workflow_id=workflow.id,
version=workflow.version,
name=workflow.name,
description=workflow.description,
nodes=workflow.nodes,
edges=workflow.edges,
status=workflow.status,
created_by=workflow.user_id,
created_at=workflow.updated_at or workflow.created_at,
comment="当前版本"
)
# 查找历史版本
workflow_version = db.query(WorkflowVersion).filter(
WorkflowVersion.workflow_id == workflow_id,
WorkflowVersion.version == version
).first()
if not workflow_version:
raise NotFoundError("工作流版本", f"{workflow_id} v{version}")
return WorkflowVersionResponse(
id=workflow_version.id,
workflow_id=workflow_version.workflow_id,
version=workflow_version.version,
name=workflow_version.name,
description=workflow_version.description,
nodes=workflow_version.nodes,
edges=workflow_version.edges,
status=workflow_version.status,
created_by=workflow_version.created_by,
created_at=workflow_version.created_at,
comment=workflow_version.comment
)
class WorkflowVersionRollback(BaseModel):
"""工作流版本回滚模型"""
comment: Optional[str] = None
@router.post("/{workflow_id}/versions/{version}/rollback", response_model=WorkflowResponse)
async def rollback_workflow_version(
workflow_id: str,
version: int,
rollback_data: Optional[WorkflowVersionRollback] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""回滚工作流到指定版本"""
# 验证工作流是否存在且属于当前用户
workflow = db.query(Workflow).filter(
Workflow.id == workflow_id,
Workflow.user_id == current_user.id
).first()
if not workflow:
raise NotFoundError("工作流", workflow_id)
# 如果是当前版本,不需要回滚
if version == workflow.version:
raise ValidationError("不能回滚到当前版本")
# 查找要回滚的版本
workflow_version = db.query(WorkflowVersion).filter(
WorkflowVersion.workflow_id == workflow_id,
WorkflowVersion.version == version
).first()
if not workflow_version:
raise NotFoundError("工作流版本", f"{workflow_id} v{version}")
# 保存当前版本到版本历史表
try:
current_version = WorkflowVersion(
workflow_id=workflow.id,
version=workflow.version,
name=workflow.name,
description=workflow.description,
nodes=workflow.nodes,
edges=workflow.edges,
status=workflow.status,
created_by=current_user.id,
comment="回滚前保存"
)
db.add(current_version)
except Exception as e:
logger.warning(f"保存版本历史失败: {str(e)},继续执行回滚")
# 回滚到指定版本
workflow.name = workflow_version.name
workflow.description = workflow_version.description
workflow.nodes = workflow_version.nodes
workflow.edges = workflow_version.edges
workflow.status = workflow_version.status
workflow.version += 1
db.commit()
db.refresh(workflow)
logger.info(f"工作流 {workflow_id} 已回滚到版本 {version}")
2026-03-06 22:31:41 +08:00
return _workflow_to_response(workflow)