Files
aiagent/backend/app/services/workflow_validator.py

269 lines
10 KiB
Python
Raw Normal View History

2026-01-19 00:09:36 +08:00
"""
工作流验证服务
验证工作流的节点连接数据流循环检测等
"""
from typing import Dict, Any, List, Optional, Tuple
from collections import defaultdict, deque
import logging
logger = logging.getLogger(__name__)
class WorkflowValidator:
"""工作流验证器"""
def __init__(self, nodes: List[Dict[str, Any]], edges: List[Dict[str, Any]]):
"""
初始化验证器
Args:
nodes: 节点列表
edges: 边列表
"""
self.nodes = {node['id']: node for node in nodes}
self.edges = edges
self.errors = []
self.warnings = []
def validate(self) -> Tuple[bool, List[str], List[str]]:
"""
执行完整验证
Returns:
(是否有效, 错误列表, 警告列表)
"""
self.errors = []
self.warnings = []
# 基础验证
self._validate_nodes()
self._validate_edges()
# 结构验证
self._validate_has_start_node()
self._validate_has_end_node()
self._validate_no_cycles()
self._validate_all_nodes_reachable()
# 连接验证
self._validate_node_connections()
self._validate_condition_branches()
# 配置验证
self._validate_node_configs()
return len(self.errors) == 0, self.errors, self.warnings
def _validate_nodes(self):
"""验证节点基础信息"""
if not self.nodes:
self.errors.append("工作流必须包含至少一个节点")
return
node_ids = set()
for node_id, node in self.nodes.items():
# 检查节点ID唯一性
if node_id in node_ids:
self.errors.append(f"节点ID重复: {node_id}")
node_ids.add(node_id)
# 检查节点类型
node_type = node.get('type')
if not node_type:
self.errors.append(f"节点 {node_id} 缺少类型")
2026-01-22 09:59:02 +08:00
elif node_type not in ['start', 'input', 'llm', 'condition', 'transform', 'output', 'end', 'default', 'loop', 'foreach', 'loop_end', 'agent', 'http', 'request', 'database', 'db', 'file', 'file_operation', 'schedule', 'delay', 'timer', 'webhook', 'email', 'mail', 'message_queue', 'mq', 'rabbitmq', 'kafka', 'switch', 'merge', 'wait', 'json', 'text', 'cache', 'vector_db', 'log', 'error_handler', 'csv', 'object_storage', 'slack', 'dingtalk', 'dingding', 'wechat_work', 'wecom', 'sms', 'pdf', 'image', 'excel', 'subworkflow', 'code', 'oauth', 'validator', 'batch']:
2026-01-19 00:09:36 +08:00
self.warnings.append(f"节点 {node_id} 使用了未知类型: {node_type}")
def _validate_edges(self):
"""验证边的基础信息"""
for edge in self.edges:
source = edge.get('source')
target = edge.get('target')
if not source or not target:
self.errors.append(f"边缺少源节点或目标节点: {edge.get('id', 'unknown')}")
continue
# 检查源节点是否存在
if source not in self.nodes:
self.errors.append(f"边的源节点不存在: {source}")
# 检查目标节点是否存在
if target not in self.nodes:
self.errors.append(f"边的目标节点不存在: {target}")
# 检查自环
if source == target:
self.errors.append(f"节点 {source} 不能连接到自身")
def _validate_has_start_node(self):
"""验证是否有开始节点"""
start_nodes = [node for node in self.nodes.values() if node.get('type') == 'start']
if not start_nodes:
self.errors.append("工作流必须包含至少一个开始节点")
elif len(start_nodes) > 1:
self.warnings.append(f"工作流包含多个开始节点: {len(start_nodes)}")
def _validate_has_end_node(self):
"""验证是否有结束节点"""
end_nodes = [node for node in self.nodes.values() if node.get('type') == 'end']
if not end_nodes:
self.warnings.append("工作流建议包含至少一个结束节点")
def _validate_no_cycles(self):
"""验证工作流中是否有循环使用DFS"""
# 构建邻接表
graph = defaultdict(list)
for edge in self.edges:
source = edge.get('source')
target = edge.get('target')
if source and target:
graph[source].append(target)
# DFS检测循环
visited = set()
rec_stack = set()
def has_cycle(node_id: str) -> bool:
visited.add(node_id)
rec_stack.add(node_id)
for neighbor in graph.get(node_id, []):
if neighbor not in visited:
if has_cycle(neighbor):
return True
elif neighbor in rec_stack:
# 找到循环
self.errors.append(f"检测到循环: {node_id} -> {neighbor}")
return True
rec_stack.remove(node_id)
return False
for node_id in self.nodes.keys():
if node_id not in visited:
has_cycle(node_id)
def _validate_all_nodes_reachable(self):
"""验证所有节点是否可达"""
# 找到所有开始节点
start_nodes = [node_id for node_id, node in self.nodes.items() if node.get('type') == 'start']
if not start_nodes:
return # 如果没有开始节点,跳过此验证
# 从开始节点BFS遍历
reachable = set()
queue = deque(start_nodes)
while queue:
node_id = queue.popleft()
if node_id in reachable:
continue
reachable.add(node_id)
# 添加所有可达的节点
for edge in self.edges:
if edge.get('source') == node_id:
target = edge.get('target')
if target and target not in reachable:
queue.append(target)
# 检查未达节点
unreachable = set(self.nodes.keys()) - reachable
if unreachable:
self.warnings.append(f"以下节点不可达: {', '.join(unreachable)}")
def _validate_node_connections(self):
"""验证节点连接的正确性"""
# 检查开始节点是否有入边
for node_id, node in self.nodes.items():
if node.get('type') == 'start':
has_incoming = any(edge.get('target') == node_id for edge in self.edges)
if has_incoming:
self.warnings.append(f"开始节点 {node_id} 不应该有入边")
# 检查结束节点是否有出边
for node_id, node in self.nodes.items():
if node.get('type') == 'end':
has_outgoing = any(edge.get('source') == node_id for edge in self.edges)
if has_outgoing:
self.warnings.append(f"结束节点 {node_id} 不应该有出边")
def _validate_condition_branches(self):
"""验证条件节点的分支"""
for node_id, node in self.nodes.items():
if node.get('type') == 'condition':
# 检查是否有条件表达式
condition = node.get('data', {}).get('condition', '')
if not condition:
self.warnings.append(f"条件节点 {node_id} 没有配置条件表达式")
# 检查是否有true和false分支
true_edges = [e for e in self.edges if e.get('source') == node_id and e.get('sourceHandle') == 'true']
false_edges = [e for e in self.edges if e.get('source') == node_id and e.get('sourceHandle') == 'false']
if not true_edges and not false_edges:
self.warnings.append(f"条件节点 {node_id} 没有配置分支连接")
elif not true_edges:
self.warnings.append(f"条件节点 {node_id} 缺少True分支")
elif not false_edges:
self.warnings.append(f"条件节点 {node_id} 缺少False分支")
def _validate_node_configs(self):
"""验证节点配置"""
for node_id, node in self.nodes.items():
node_type = node.get('type')
node_data = node.get('data', {})
# LLM节点验证
if node_type == 'llm':
prompt = node_data.get('prompt', '')
if not prompt:
self.warnings.append(f"LLM节点 {node_id} 没有配置提示词")
provider = node_data.get('provider', 'openai')
model = node_data.get('model')
if not model:
self.warnings.append(f"LLM节点 {node_id} 没有配置模型")
# 转换节点验证
elif node_type == 'transform' or node_type == 'data':
mode = node_data.get('mode', 'mapping')
mapping = node_data.get('mapping', {})
filter_rules = node_data.get('filter_rules', [])
compute_rules = node_data.get('compute_rules', {})
if mode == 'mapping' and not mapping:
self.warnings.append(f"转换节点 {node_id} 选择了映射模式但没有配置映射规则")
elif mode == 'filter' and not filter_rules:
self.warnings.append(f"转换节点 {node_id} 选择了过滤模式但没有配置过滤规则")
elif mode == 'compute' and not compute_rules:
self.warnings.append(f"转换节点 {node_id} 选择了计算模式但没有配置计算规则")
def validate_workflow(nodes: List[Dict[str, Any]], edges: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
验证工作流
Args:
nodes: 节点列表
edges: 边列表
Returns:
验证结果字典
{
"valid": bool,
"errors": List[str],
"warnings": List[str]
}
"""
validator = WorkflowValidator(nodes, edges)
valid, errors, warnings = validator.validate()
return {
"valid": valid,
"errors": errors,
"warnings": warnings
}