Files
aiagent/backend/app/services/workflow_validator.py
2026-01-22 09:59:02 +08:00

269 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工作流验证服务
验证工作流的节点连接、数据流、循环检测等
"""
from typing import Dict, Any, List, Optional, Tuple
from collections import defaultdict, deque
import logging
logger = logging.getLogger(__name__)
class WorkflowValidator:
"""工作流验证器"""
def __init__(self, nodes: List[Dict[str, Any]], edges: List[Dict[str, Any]]):
"""
初始化验证器
Args:
nodes: 节点列表
edges: 边列表
"""
self.nodes = {node['id']: node for node in nodes}
self.edges = edges
self.errors = []
self.warnings = []
def validate(self) -> Tuple[bool, List[str], List[str]]:
"""
执行完整验证
Returns:
(是否有效, 错误列表, 警告列表)
"""
self.errors = []
self.warnings = []
# 基础验证
self._validate_nodes()
self._validate_edges()
# 结构验证
self._validate_has_start_node()
self._validate_has_end_node()
self._validate_no_cycles()
self._validate_all_nodes_reachable()
# 连接验证
self._validate_node_connections()
self._validate_condition_branches()
# 配置验证
self._validate_node_configs()
return len(self.errors) == 0, self.errors, self.warnings
def _validate_nodes(self):
"""验证节点基础信息"""
if not self.nodes:
self.errors.append("工作流必须包含至少一个节点")
return
node_ids = set()
for node_id, node in self.nodes.items():
# 检查节点ID唯一性
if node_id in node_ids:
self.errors.append(f"节点ID重复: {node_id}")
node_ids.add(node_id)
# 检查节点类型
node_type = node.get('type')
if not node_type:
self.errors.append(f"节点 {node_id} 缺少类型")
elif node_type not in ['start', 'input', 'llm', 'condition', 'transform', 'output', 'end', 'default', 'loop', 'foreach', 'loop_end', 'agent', 'http', 'request', 'database', 'db', 'file', 'file_operation', 'schedule', 'delay', 'timer', 'webhook', 'email', 'mail', 'message_queue', 'mq', 'rabbitmq', 'kafka', 'switch', 'merge', 'wait', 'json', 'text', 'cache', 'vector_db', 'log', 'error_handler', 'csv', 'object_storage', 'slack', 'dingtalk', 'dingding', 'wechat_work', 'wecom', 'sms', 'pdf', 'image', 'excel', 'subworkflow', 'code', 'oauth', 'validator', 'batch']:
self.warnings.append(f"节点 {node_id} 使用了未知类型: {node_type}")
def _validate_edges(self):
"""验证边的基础信息"""
for edge in self.edges:
source = edge.get('source')
target = edge.get('target')
if not source or not target:
self.errors.append(f"边缺少源节点或目标节点: {edge.get('id', 'unknown')}")
continue
# 检查源节点是否存在
if source not in self.nodes:
self.errors.append(f"边的源节点不存在: {source}")
# 检查目标节点是否存在
if target not in self.nodes:
self.errors.append(f"边的目标节点不存在: {target}")
# 检查自环
if source == target:
self.errors.append(f"节点 {source} 不能连接到自身")
def _validate_has_start_node(self):
"""验证是否有开始节点"""
start_nodes = [node for node in self.nodes.values() if node.get('type') == 'start']
if not start_nodes:
self.errors.append("工作流必须包含至少一个开始节点")
elif len(start_nodes) > 1:
self.warnings.append(f"工作流包含多个开始节点: {len(start_nodes)}")
def _validate_has_end_node(self):
"""验证是否有结束节点"""
end_nodes = [node for node in self.nodes.values() if node.get('type') == 'end']
if not end_nodes:
self.warnings.append("工作流建议包含至少一个结束节点")
def _validate_no_cycles(self):
"""验证工作流中是否有循环使用DFS"""
# 构建邻接表
graph = defaultdict(list)
for edge in self.edges:
source = edge.get('source')
target = edge.get('target')
if source and target:
graph[source].append(target)
# DFS检测循环
visited = set()
rec_stack = set()
def has_cycle(node_id: str) -> bool:
visited.add(node_id)
rec_stack.add(node_id)
for neighbor in graph.get(node_id, []):
if neighbor not in visited:
if has_cycle(neighbor):
return True
elif neighbor in rec_stack:
# 找到循环
self.errors.append(f"检测到循环: {node_id} -> {neighbor}")
return True
rec_stack.remove(node_id)
return False
for node_id in self.nodes.keys():
if node_id not in visited:
has_cycle(node_id)
def _validate_all_nodes_reachable(self):
"""验证所有节点是否可达"""
# 找到所有开始节点
start_nodes = [node_id for node_id, node in self.nodes.items() if node.get('type') == 'start']
if not start_nodes:
return # 如果没有开始节点,跳过此验证
# 从开始节点BFS遍历
reachable = set()
queue = deque(start_nodes)
while queue:
node_id = queue.popleft()
if node_id in reachable:
continue
reachable.add(node_id)
# 添加所有可达的节点
for edge in self.edges:
if edge.get('source') == node_id:
target = edge.get('target')
if target and target not in reachable:
queue.append(target)
# 检查未达节点
unreachable = set(self.nodes.keys()) - reachable
if unreachable:
self.warnings.append(f"以下节点不可达: {', '.join(unreachable)}")
def _validate_node_connections(self):
"""验证节点连接的正确性"""
# 检查开始节点是否有入边
for node_id, node in self.nodes.items():
if node.get('type') == 'start':
has_incoming = any(edge.get('target') == node_id for edge in self.edges)
if has_incoming:
self.warnings.append(f"开始节点 {node_id} 不应该有入边")
# 检查结束节点是否有出边
for node_id, node in self.nodes.items():
if node.get('type') == 'end':
has_outgoing = any(edge.get('source') == node_id for edge in self.edges)
if has_outgoing:
self.warnings.append(f"结束节点 {node_id} 不应该有出边")
def _validate_condition_branches(self):
"""验证条件节点的分支"""
for node_id, node in self.nodes.items():
if node.get('type') == 'condition':
# 检查是否有条件表达式
condition = node.get('data', {}).get('condition', '')
if not condition:
self.warnings.append(f"条件节点 {node_id} 没有配置条件表达式")
# 检查是否有true和false分支
true_edges = [e for e in self.edges if e.get('source') == node_id and e.get('sourceHandle') == 'true']
false_edges = [e for e in self.edges if e.get('source') == node_id and e.get('sourceHandle') == 'false']
if not true_edges and not false_edges:
self.warnings.append(f"条件节点 {node_id} 没有配置分支连接")
elif not true_edges:
self.warnings.append(f"条件节点 {node_id} 缺少True分支")
elif not false_edges:
self.warnings.append(f"条件节点 {node_id} 缺少False分支")
def _validate_node_configs(self):
"""验证节点配置"""
for node_id, node in self.nodes.items():
node_type = node.get('type')
node_data = node.get('data', {})
# LLM节点验证
if node_type == 'llm':
prompt = node_data.get('prompt', '')
if not prompt:
self.warnings.append(f"LLM节点 {node_id} 没有配置提示词")
provider = node_data.get('provider', 'openai')
model = node_data.get('model')
if not model:
self.warnings.append(f"LLM节点 {node_id} 没有配置模型")
# 转换节点验证
elif node_type == 'transform' or node_type == 'data':
mode = node_data.get('mode', 'mapping')
mapping = node_data.get('mapping', {})
filter_rules = node_data.get('filter_rules', [])
compute_rules = node_data.get('compute_rules', {})
if mode == 'mapping' and not mapping:
self.warnings.append(f"转换节点 {node_id} 选择了映射模式但没有配置映射规则")
elif mode == 'filter' and not filter_rules:
self.warnings.append(f"转换节点 {node_id} 选择了过滤模式但没有配置过滤规则")
elif mode == 'compute' and not compute_rules:
self.warnings.append(f"转换节点 {node_id} 选择了计算模式但没有配置计算规则")
def validate_workflow(nodes: List[Dict[str, Any]], edges: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
验证工作流
Args:
nodes: 节点列表
edges: 边列表
Returns:
验证结果字典
{
"valid": bool,
"errors": List[str],
"warnings": List[str]
}
"""
validator = WorkflowValidator(nodes, edges)
valid, errors, warnings = validator.validate()
return {
"valid": valid,
"errors": errors,
"warnings": warnings
}