1667 lines
82 KiB
Python
1667 lines
82 KiB
Python
|
|
"""
|
|||
|
|
工作流执行引擎
|
|||
|
|
"""
|
|||
|
|
from typing import Dict, Any, List, Optional
|
|||
|
|
import asyncio
|
|||
|
|
from collections import defaultdict, deque
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
from app.services.llm_service import llm_service
|
|||
|
|
from app.services.condition_parser import condition_parser
|
|||
|
|
from app.services.data_transformer import data_transformer
|
|||
|
|
from app.core.exceptions import WorkflowExecutionError
|
|||
|
|
from app.core.database import SessionLocal
|
|||
|
|
from app.models.agent import Agent
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class WorkflowEngine:
|
|||
|
|
"""工作流执行引擎"""
|
|||
|
|
|
|||
|
|
def __init__(self, workflow_id: str, workflow_data: Dict[str, Any], logger=None, db=None):
|
|||
|
|
"""
|
|||
|
|
初始化工作流引擎
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
workflow_id: 工作流ID
|
|||
|
|
workflow_data: 工作流数据(包含nodes和edges)
|
|||
|
|
logger: 执行日志记录器(可选)
|
|||
|
|
db: 数据库会话(可选,用于Agent节点加载Agent配置)
|
|||
|
|
"""
|
|||
|
|
self.workflow_id = workflow_id
|
|||
|
|
self.nodes = {node['id']: node for node in workflow_data.get('nodes', [])}
|
|||
|
|
self.edges = workflow_data.get('edges', [])
|
|||
|
|
self.execution_graph = None
|
|||
|
|
self.node_outputs = {}
|
|||
|
|
self.logger = logger
|
|||
|
|
self.db = db
|
|||
|
|
|
|||
|
|
def build_execution_graph(self, active_edges: Optional[List[Dict[str, Any]]] = None) -> List[str]:
|
|||
|
|
"""
|
|||
|
|
构建执行图(DAG)并返回拓扑排序结果
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
active_edges: 活跃的边列表(用于条件分支过滤)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
拓扑排序后的节点ID列表
|
|||
|
|
"""
|
|||
|
|
# 使用活跃的边,如果没有提供则使用所有边
|
|||
|
|
edges_to_use = active_edges if active_edges is not None else self.edges
|
|||
|
|
|
|||
|
|
# 构建邻接表和入度表
|
|||
|
|
graph = defaultdict(list)
|
|||
|
|
in_degree = defaultdict(int)
|
|||
|
|
|
|||
|
|
# 初始化所有节点的入度
|
|||
|
|
for node_id in self.nodes.keys():
|
|||
|
|
in_degree[node_id] = 0
|
|||
|
|
|
|||
|
|
# 构建图
|
|||
|
|
for edge in edges_to_use:
|
|||
|
|
source = edge['source']
|
|||
|
|
target = edge['target']
|
|||
|
|
graph[source].append(target)
|
|||
|
|
in_degree[target] += 1
|
|||
|
|
|
|||
|
|
# 拓扑排序(Kahn算法)
|
|||
|
|
queue = deque()
|
|||
|
|
result = []
|
|||
|
|
|
|||
|
|
# 找到所有入度为0的节点(起始节点)
|
|||
|
|
for node_id in self.nodes.keys():
|
|||
|
|
if in_degree[node_id] == 0:
|
|||
|
|
queue.append(node_id)
|
|||
|
|
|
|||
|
|
while queue:
|
|||
|
|
node_id = queue.popleft()
|
|||
|
|
result.append(node_id)
|
|||
|
|
|
|||
|
|
# 处理该节点的所有出边
|
|||
|
|
for neighbor in graph[node_id]:
|
|||
|
|
in_degree[neighbor] -= 1
|
|||
|
|
if in_degree[neighbor] == 0:
|
|||
|
|
queue.append(neighbor)
|
|||
|
|
|
|||
|
|
# 检查是否有环(只检查可达节点)
|
|||
|
|
reachable_nodes = set(result)
|
|||
|
|
if len(reachable_nodes) < len(self.nodes):
|
|||
|
|
# 有些节点不可达,这是正常的(条件分支)
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
self.execution_graph = result
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def get_node_input(self, node_id: str, node_outputs: Dict[str, Any], active_edges: Optional[List[Dict[str, Any]]] = None) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
获取节点的输入数据
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
node_id: 节点ID
|
|||
|
|
node_outputs: 所有节点的输出数据
|
|||
|
|
active_edges: 活跃的边列表(用于条件分支过滤)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
节点的输入数据
|
|||
|
|
"""
|
|||
|
|
# 使用活跃的边,如果没有提供则使用所有边
|
|||
|
|
edges_to_use = active_edges if active_edges is not None else self.edges
|
|||
|
|
|
|||
|
|
# 找到所有指向该节点的边
|
|||
|
|
input_data = {}
|
|||
|
|
|
|||
|
|
for edge in edges_to_use:
|
|||
|
|
if edge['target'] == node_id:
|
|||
|
|
source_id = edge['source']
|
|||
|
|
source_output = node_outputs.get(source_id, {})
|
|||
|
|
logger.debug(f"[rjb] 获取节点输入: target={node_id}, source={source_id}, source_output={source_output}, sourceHandle={edge.get('sourceHandle')}")
|
|||
|
|
|
|||
|
|
# 如果有sourceHandle,使用它作为key
|
|||
|
|
if 'sourceHandle' in edge and edge['sourceHandle']:
|
|||
|
|
input_data[edge['sourceHandle']] = source_output
|
|||
|
|
else:
|
|||
|
|
# 否则合并所有输入
|
|||
|
|
if isinstance(source_output, dict):
|
|||
|
|
input_data.update(source_output)
|
|||
|
|
else:
|
|||
|
|
input_data['input'] = source_output
|
|||
|
|
|
|||
|
|
logger.debug(f"[rjb] 节点输入结果: node_id={node_id}, input_data={input_data}")
|
|||
|
|
return input_data
|
|||
|
|
|
|||
|
|
def _get_nested_value(self, data: Dict[str, Any], path: str) -> Any:
|
|||
|
|
"""
|
|||
|
|
从嵌套字典中获取值(支持点号路径和数组索引)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
data: 数据字典
|
|||
|
|
path: 路径,如 "user.name" 或 "items[0].price"
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
路径对应的值
|
|||
|
|
"""
|
|||
|
|
if not path:
|
|||
|
|
return data
|
|||
|
|
|
|||
|
|
parts = path.split('.')
|
|||
|
|
result = data
|
|||
|
|
|
|||
|
|
for part in parts:
|
|||
|
|
if '[' in part and ']' in part:
|
|||
|
|
# 处理数组索引,如 "items[0]"
|
|||
|
|
key = part[:part.index('[')]
|
|||
|
|
index_str = part[part.index('[') + 1:part.index(']')]
|
|||
|
|
|
|||
|
|
if isinstance(result, dict):
|
|||
|
|
result = result.get(key)
|
|||
|
|
elif isinstance(result, list):
|
|||
|
|
try:
|
|||
|
|
result = result[int(index_str)]
|
|||
|
|
except (ValueError, IndexError):
|
|||
|
|
return None
|
|||
|
|
else:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
if result is None:
|
|||
|
|
return None
|
|||
|
|
else:
|
|||
|
|
# 普通键访问
|
|||
|
|
if isinstance(result, dict):
|
|||
|
|
result = result.get(part)
|
|||
|
|
else:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
if result is None:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
async def _execute_loop_body(self, loop_node_id: str, loop_input: Dict[str, Any], iteration_index: int) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
执行循环体
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
loop_node_id: 循环节点ID
|
|||
|
|
loop_input: 循环体的输入数据
|
|||
|
|
iteration_index: 当前迭代索引
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
循环体的执行结果
|
|||
|
|
"""
|
|||
|
|
# 找到循环节点的直接子节点(循环体开始节点)
|
|||
|
|
loop_body_start_nodes = []
|
|||
|
|
for edge in self.edges:
|
|||
|
|
if edge.get('source') == loop_node_id:
|
|||
|
|
target_id = edge.get('target')
|
|||
|
|
if target_id and target_id in self.nodes:
|
|||
|
|
loop_body_start_nodes.append(target_id)
|
|||
|
|
|
|||
|
|
if not loop_body_start_nodes:
|
|||
|
|
# 如果没有子节点,直接返回输入数据
|
|||
|
|
return {'output': loop_input, 'status': 'success'}
|
|||
|
|
|
|||
|
|
# 执行循环体:从循环体开始节点执行到循环结束节点或没有更多节点
|
|||
|
|
# 简化处理:只执行第一个子节点链
|
|||
|
|
executed_in_loop = set()
|
|||
|
|
loop_results = {}
|
|||
|
|
current_node_id = loop_body_start_nodes[0] # 简化:只执行第一个子节点链
|
|||
|
|
|
|||
|
|
# 执行循环体内的节点(简化版本:只执行直接连接的子节点)
|
|||
|
|
max_iterations = 100 # 防止无限循环
|
|||
|
|
iteration = 0
|
|||
|
|
|
|||
|
|
while current_node_id and iteration < max_iterations:
|
|||
|
|
iteration += 1
|
|||
|
|
|
|||
|
|
if current_node_id in executed_in_loop:
|
|||
|
|
break # 避免循环体内部循环
|
|||
|
|
|
|||
|
|
if current_node_id not in self.nodes:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
node = self.nodes[current_node_id]
|
|||
|
|
executed_in_loop.add(current_node_id)
|
|||
|
|
|
|||
|
|
# 如果是循环结束节点,停止执行
|
|||
|
|
if node.get('type') == 'loop_end' or node.get('type') == 'end':
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 执行节点
|
|||
|
|
result = await self.execute_node(node, loop_input)
|
|||
|
|
loop_results[current_node_id] = result
|
|||
|
|
|
|||
|
|
if result.get('status') != 'success':
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
# 更新输入数据为当前节点的输出
|
|||
|
|
if result.get('output'):
|
|||
|
|
if isinstance(result.get('output'), dict):
|
|||
|
|
loop_input = {**loop_input, **result.get('output')}
|
|||
|
|
else:
|
|||
|
|
loop_input = {**loop_input, 'result': result.get('output')}
|
|||
|
|
|
|||
|
|
# 找到下一个节点(简化:只找第一个子节点)
|
|||
|
|
next_node_id = None
|
|||
|
|
for edge in self.edges:
|
|||
|
|
if edge.get('source') == current_node_id:
|
|||
|
|
target_id = edge.get('target')
|
|||
|
|
if target_id and target_id in self.nodes and target_id not in executed_in_loop:
|
|||
|
|
# 跳过循环节点本身
|
|||
|
|
if target_id != loop_node_id:
|
|||
|
|
next_node_id = target_id
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
current_node_id = next_node_id
|
|||
|
|
|
|||
|
|
# 返回最后一个节点的输出
|
|||
|
|
if loop_results:
|
|||
|
|
last_result = list(loop_results.values())[-1]
|
|||
|
|
return last_result
|
|||
|
|
|
|||
|
|
return {'output': loop_input, 'status': 'success'}
|
|||
|
|
|
|||
|
|
def _mark_loop_body_executed(self, node_id: str, executed_nodes: set, active_edges: List[Dict[str, Any]]):
|
|||
|
|
"""
|
|||
|
|
递归标记循环体内的节点为已执行
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
node_id: 当前节点ID
|
|||
|
|
executed_nodes: 已执行节点集合
|
|||
|
|
active_edges: 活跃的边列表
|
|||
|
|
"""
|
|||
|
|
if node_id in executed_nodes:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
executed_nodes.add(node_id)
|
|||
|
|
|
|||
|
|
# 查找所有子节点
|
|||
|
|
for edge in active_edges:
|
|||
|
|
if edge.get('source') == node_id:
|
|||
|
|
target_id = edge.get('target')
|
|||
|
|
if target_id in self.nodes:
|
|||
|
|
target_node = self.nodes[target_id]
|
|||
|
|
# 如果是循环结束节点,停止递归
|
|||
|
|
if target_node.get('type') in ['loop_end', 'end']:
|
|||
|
|
continue
|
|||
|
|
# 递归标记子节点
|
|||
|
|
self._mark_loop_body_executed(target_id, executed_nodes, active_edges)
|
|||
|
|
|
|||
|
|
async def execute_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
执行单个节点
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
node: 节点配置
|
|||
|
|
input_data: 输入数据
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
节点执行结果
|
|||
|
|
"""
|
|||
|
|
# 确保可以访问全局的 json 模块
|
|||
|
|
import json as json_module
|
|||
|
|
|
|||
|
|
node_type = node.get('type', 'unknown')
|
|||
|
|
node_id = node.get('id')
|
|||
|
|
import time
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
# 记录节点开始执行
|
|||
|
|
if self.logger:
|
|||
|
|
self.logger.log_node_start(node_id, node_type, input_data)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if node_type == 'start':
|
|||
|
|
# 起始节点:返回输入数据
|
|||
|
|
logger.debug(f"[rjb] 开始节点执行: node_id={node_id}, input_data={input_data}")
|
|||
|
|
result = {'output': input_data, 'status': 'success'}
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
|||
|
|
logger.debug(f"[rjb] 开始节点输出: node_id={node_id}, output={result.get('output')}")
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
elif node_type == 'input':
|
|||
|
|
# 输入节点:处理输入数据
|
|||
|
|
result = {'output': input_data, 'status': 'success'}
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
elif node_type == 'llm' or node_type == 'template':
|
|||
|
|
# LLM节点:调用AI模型
|
|||
|
|
node_data = node.get('data', {})
|
|||
|
|
logger.debug(f"[rjb] LLM节点执行: node_id={node_id}, input_data={input_data}, input_data type={type(input_data)}")
|
|||
|
|
logger.debug(f"[rjb] LLM节点数据: node_id={node_id}, node_data keys={list(node_data.keys())}, api_key={'已配置' if node_data.get('api_key') else '未配置'}")
|
|||
|
|
prompt = node_data.get('prompt', '')
|
|||
|
|
|
|||
|
|
# 如果prompt为空,使用默认提示词
|
|||
|
|
if not prompt:
|
|||
|
|
prompt = "请处理以下输入数据:\n{input}"
|
|||
|
|
|
|||
|
|
# 格式化prompt,替换变量
|
|||
|
|
try:
|
|||
|
|
# 将input_data转换为字符串用于格式化
|
|||
|
|
if isinstance(input_data, dict):
|
|||
|
|
# 如果prompt中包含变量,尝试格式化
|
|||
|
|
if '{' in prompt and '}' in prompt:
|
|||
|
|
# 尝试格式化所有input_data中的键
|
|||
|
|
formatted_prompt = prompt
|
|||
|
|
for key, value in input_data.items():
|
|||
|
|
placeholder = f'{{{key}}}'
|
|||
|
|
if placeholder in formatted_prompt:
|
|||
|
|
formatted_prompt = formatted_prompt.replace(
|
|||
|
|
placeholder,
|
|||
|
|
json_module.dumps(value, ensure_ascii=False) if isinstance(value, (dict, list)) else str(value)
|
|||
|
|
)
|
|||
|
|
# 如果还有{input}占位符,替换为整个input_data
|
|||
|
|
if '{input}' in formatted_prompt:
|
|||
|
|
formatted_prompt = formatted_prompt.replace(
|
|||
|
|
'{input}',
|
|||
|
|
json_module.dumps(input_data, ensure_ascii=False)
|
|||
|
|
)
|
|||
|
|
prompt = formatted_prompt
|
|||
|
|
else:
|
|||
|
|
# 如果没有占位符,将input_data作为JSON附加到prompt
|
|||
|
|
prompt = f"{prompt}\n\n输入数据:\n{json_module.dumps(input_data, ensure_ascii=False)}"
|
|||
|
|
else:
|
|||
|
|
# 如果input_data不是dict,直接转换为字符串
|
|||
|
|
if '{input}' in prompt:
|
|||
|
|
prompt = prompt.replace('{input}', str(input_data))
|
|||
|
|
else:
|
|||
|
|
prompt = f"{prompt}\n\n输入:{str(input_data)}"
|
|||
|
|
except Exception as e:
|
|||
|
|
# 格式化失败,使用原始prompt和input_data
|
|||
|
|
try:
|
|||
|
|
prompt = f"{prompt}\n\n输入数据:\n{json_module.dumps(input_data, ensure_ascii=False)}"
|
|||
|
|
except:
|
|||
|
|
prompt = f"{prompt}\n\n输入数据:{str(input_data)}"
|
|||
|
|
|
|||
|
|
# 获取LLM配置
|
|||
|
|
provider = node_data.get('provider', 'openai')
|
|||
|
|
model = node_data.get('model', 'gpt-3.5-turbo')
|
|||
|
|
temperature = node_data.get('temperature', 0.7)
|
|||
|
|
max_tokens = node_data.get('max_tokens')
|
|||
|
|
# 不传递 api_key 和 base_url,让 LLM 服务使用系统默认配置(与节点测试保持一致)
|
|||
|
|
api_key = None
|
|||
|
|
base_url = None
|
|||
|
|
|
|||
|
|
# 调用LLM服务
|
|||
|
|
try:
|
|||
|
|
if self.logger:
|
|||
|
|
logger.debug(f"[rjb] LLM节点配置: provider={provider}, model={model}, 使用系统默认API Key配置")
|
|||
|
|
self.logger.info(f"调用LLM服务: {provider}/{model}", node_id=node_id, node_type=node_type)
|
|||
|
|
result = await llm_service.call_llm(
|
|||
|
|
prompt=prompt,
|
|||
|
|
provider=provider,
|
|||
|
|
model=model,
|
|||
|
|
temperature=temperature,
|
|||
|
|
max_tokens=max_tokens
|
|||
|
|
# 不传递 api_key 和 base_url,使用系统默认配置
|
|||
|
|
)
|
|||
|
|
exec_result = {'output': result, 'status': 'success'}
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_complete(node_id, node_type, result, duration)
|
|||
|
|
return exec_result
|
|||
|
|
except Exception as e:
|
|||
|
|
# LLM调用失败,返回错误
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_error(node_id, node_type, e, duration)
|
|||
|
|
return {
|
|||
|
|
'output': None,
|
|||
|
|
'status': 'failed',
|
|||
|
|
'error': f'LLM调用失败: {str(e)}'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
elif node_type == 'condition':
|
|||
|
|
# 条件节点:判断分支
|
|||
|
|
condition = node.get('data', {}).get('condition', '')
|
|||
|
|
|
|||
|
|
if not condition:
|
|||
|
|
# 如果没有条件表达式,默认返回False
|
|||
|
|
return {
|
|||
|
|
'output': False,
|
|||
|
|
'status': 'success',
|
|||
|
|
'branch': 'false'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 使用条件解析器评估表达式
|
|||
|
|
try:
|
|||
|
|
result = condition_parser.evaluate_condition(condition, input_data)
|
|||
|
|
exec_result = {
|
|||
|
|
'output': result,
|
|||
|
|
'status': 'success',
|
|||
|
|
'branch': 'true' if result else 'false'
|
|||
|
|
}
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_complete(node_id, node_type, {'result': result, 'branch': exec_result['branch']}, duration)
|
|||
|
|
return exec_result
|
|||
|
|
except Exception as e:
|
|||
|
|
# 条件评估失败
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_error(node_id, node_type, e, duration)
|
|||
|
|
return {
|
|||
|
|
'output': False,
|
|||
|
|
'status': 'failed',
|
|||
|
|
'error': f'条件评估失败: {str(e)}',
|
|||
|
|
'branch': 'false'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
elif node_type == 'data' or node_type == 'transform':
|
|||
|
|
# 数据转换节点
|
|||
|
|
node_data = node.get('data', {})
|
|||
|
|
mapping = node_data.get('mapping', {})
|
|||
|
|
filter_rules = node_data.get('filter_rules', [])
|
|||
|
|
compute_rules = node_data.get('compute_rules', {})
|
|||
|
|
mode = node_data.get('mode', 'mapping')
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
result = data_transformer.transform_data(
|
|||
|
|
input_data=input_data,
|
|||
|
|
mapping=mapping,
|
|||
|
|
filter_rules=filter_rules,
|
|||
|
|
compute_rules=compute_rules,
|
|||
|
|
mode=mode
|
|||
|
|
)
|
|||
|
|
exec_result = {'output': result, 'status': 'success'}
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_complete(node_id, node_type, result, duration)
|
|||
|
|
return exec_result
|
|||
|
|
except Exception as e:
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_error(node_id, node_type, e, duration)
|
|||
|
|
return {
|
|||
|
|
'output': None,
|
|||
|
|
'status': 'failed',
|
|||
|
|
'error': f'数据转换失败: {str(e)}'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
elif node_type == 'loop' or node_type == 'foreach':
|
|||
|
|
# 循环节点:对数组进行循环处理
|
|||
|
|
node_data = node.get('data', {})
|
|||
|
|
items_path = node_data.get('items_path', 'items') # 数组数据路径
|
|||
|
|
item_variable = node_data.get('item_variable', 'item') # 循环变量名
|
|||
|
|
|
|||
|
|
# 从输入数据中获取数组
|
|||
|
|
items = self._get_nested_value(input_data, items_path)
|
|||
|
|
|
|||
|
|
if not isinstance(items, list):
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_error(node_id, node_type,
|
|||
|
|
ValueError(f"路径 {items_path} 的值不是数组"), duration)
|
|||
|
|
return {
|
|||
|
|
'output': None,
|
|||
|
|
'status': 'failed',
|
|||
|
|
'error': f'路径 {items_path} 的值不是数组,当前类型: {type(items).__name__}'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if self.logger:
|
|||
|
|
self.logger.info(f"循环节点开始处理 {len(items)} 个元素",
|
|||
|
|
node_id=node_id, node_type=node_type,
|
|||
|
|
data={"items_count": len(items)})
|
|||
|
|
|
|||
|
|
# 执行循环:对每个元素执行循环体
|
|||
|
|
loop_results = []
|
|||
|
|
for index, item in enumerate(items):
|
|||
|
|
if self.logger:
|
|||
|
|
self.logger.info(f"循环迭代 {index + 1}/{len(items)}",
|
|||
|
|
node_id=node_id, node_type=node_type,
|
|||
|
|
data={"index": index, "item": item})
|
|||
|
|
|
|||
|
|
# 准备循环体的输入数据
|
|||
|
|
loop_input = {
|
|||
|
|
**input_data, # 保留原始输入数据
|
|||
|
|
item_variable: item, # 当前循环项
|
|||
|
|
f'{item_variable}_index': index, # 索引
|
|||
|
|
f'{item_variable}_total': len(items) # 总数
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 执行循环体(获取循环节点的子节点)
|
|||
|
|
loop_body_result = await self._execute_loop_body(
|
|||
|
|
node_id, loop_input, index
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if loop_body_result.get('status') == 'success':
|
|||
|
|
loop_results.append(loop_body_result.get('output', item))
|
|||
|
|
else:
|
|||
|
|
# 如果循环体执行失败,可以选择继续或停止
|
|||
|
|
error_handling = node_data.get('error_handling', 'continue')
|
|||
|
|
if error_handling == 'stop':
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_error(node_id, node_type,
|
|||
|
|
Exception(f"循环体执行失败,停止循环: {loop_body_result.get('error')}"), duration)
|
|||
|
|
return {
|
|||
|
|
'output': None,
|
|||
|
|
'status': 'failed',
|
|||
|
|
'error': f'循环体执行失败: {loop_body_result.get("error")}',
|
|||
|
|
'completed_items': index,
|
|||
|
|
'results': loop_results
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
# continue: 继续执行,记录错误
|
|||
|
|
if self.logger:
|
|||
|
|
self.logger.warn(f"循环迭代 {index + 1} 失败,继续执行",
|
|||
|
|
node_id=node_id, node_type=node_type,
|
|||
|
|
data={"error": loop_body_result.get('error')})
|
|||
|
|
loop_results.append(None)
|
|||
|
|
|
|||
|
|
exec_result = {
|
|||
|
|
'output': loop_results,
|
|||
|
|
'status': 'success',
|
|||
|
|
'items_processed': len(items),
|
|||
|
|
'results_count': len(loop_results)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_complete(node_id, node_type,
|
|||
|
|
{'results_count': len(loop_results)}, duration)
|
|||
|
|
|
|||
|
|
return exec_result
|
|||
|
|
|
|||
|
|
elif node_type == 'http' or node_type == 'request':
|
|||
|
|
# HTTP请求节点:发送HTTP请求
|
|||
|
|
node_data = node.get('data', {})
|
|||
|
|
url = node_data.get('url', '')
|
|||
|
|
method = node_data.get('method', 'GET').upper()
|
|||
|
|
headers = node_data.get('headers', {})
|
|||
|
|
params = node_data.get('params', {})
|
|||
|
|
body = node_data.get('body', {})
|
|||
|
|
timeout = node_data.get('timeout', 30)
|
|||
|
|
|
|||
|
|
# 如果URL、headers、params、body中包含变量,从input_data中替换
|
|||
|
|
import re
|
|||
|
|
def replace_variables(text: str, data: Dict[str, Any]) -> str:
|
|||
|
|
"""替换字符串中的变量占位符"""
|
|||
|
|
if not isinstance(text, str):
|
|||
|
|
return text
|
|||
|
|
# 支持 {key} 或 ${key} 格式
|
|||
|
|
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
|
|||
|
|
def replacer(match):
|
|||
|
|
key = match.group(1) or match.group(2)
|
|||
|
|
value = self._get_nested_value(data, key)
|
|||
|
|
return str(value) if value is not None else match.group(0)
|
|||
|
|
return re.sub(pattern, replacer, text)
|
|||
|
|
|
|||
|
|
# 替换URL中的变量
|
|||
|
|
if url:
|
|||
|
|
url = replace_variables(url, input_data)
|
|||
|
|
|
|||
|
|
# 替换headers中的变量
|
|||
|
|
if isinstance(headers, dict):
|
|||
|
|
headers = {k: replace_variables(str(v), input_data) for k, v in headers.items()}
|
|||
|
|
elif isinstance(headers, str):
|
|||
|
|
try:
|
|||
|
|
headers = json.loads(replace_variables(headers, input_data))
|
|||
|
|
except:
|
|||
|
|
headers = {}
|
|||
|
|
|
|||
|
|
# 替换params中的变量
|
|||
|
|
if isinstance(params, dict):
|
|||
|
|
params = {k: replace_variables(str(v), input_data) if isinstance(v, str) else v
|
|||
|
|
for k, v in params.items()}
|
|||
|
|
elif isinstance(params, str):
|
|||
|
|
try:
|
|||
|
|
params = json.loads(replace_variables(params, input_data))
|
|||
|
|
except:
|
|||
|
|
params = {}
|
|||
|
|
|
|||
|
|
# 替换body中的变量
|
|||
|
|
if isinstance(body, dict):
|
|||
|
|
# 递归替换字典中的变量
|
|||
|
|
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
|
|||
|
|
result = {}
|
|||
|
|
for k, v in d.items():
|
|||
|
|
new_k = replace_variables(k, data)
|
|||
|
|
if isinstance(v, dict):
|
|||
|
|
result[new_k] = replace_dict_vars(v, data)
|
|||
|
|
elif isinstance(v, str):
|
|||
|
|
result[new_k] = replace_variables(v, data)
|
|||
|
|
else:
|
|||
|
|
result[new_k] = v
|
|||
|
|
return result
|
|||
|
|
body = replace_dict_vars(body, input_data)
|
|||
|
|
elif isinstance(body, str):
|
|||
|
|
body = replace_variables(body, input_data)
|
|||
|
|
try:
|
|||
|
|
body = json.loads(body)
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import httpx
|
|||
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|||
|
|
if method == 'GET':
|
|||
|
|
response = await client.get(url, params=params, headers=headers)
|
|||
|
|
elif method == 'POST':
|
|||
|
|
response = await client.post(url, json=body, params=params, headers=headers)
|
|||
|
|
elif method == 'PUT':
|
|||
|
|
response = await client.put(url, json=body, params=params, headers=headers)
|
|||
|
|
elif method == 'DELETE':
|
|||
|
|
response = await client.delete(url, params=params, headers=headers)
|
|||
|
|
elif method == 'PATCH':
|
|||
|
|
response = await client.patch(url, json=body, params=params, headers=headers)
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"不支持的HTTP方法: {method}")
|
|||
|
|
|
|||
|
|
# 尝试解析JSON响应
|
|||
|
|
try:
|
|||
|
|
response_data = response.json()
|
|||
|
|
except:
|
|||
|
|
response_data = response.text
|
|||
|
|
|
|||
|
|
result = {
|
|||
|
|
'output': {
|
|||
|
|
'status_code': response.status_code,
|
|||
|
|
'headers': dict(response.headers),
|
|||
|
|
'data': response_data
|
|||
|
|
},
|
|||
|
|
'status': 'success'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
|||
|
|
return result
|
|||
|
|
except Exception as e:
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_error(node_id, node_type, e, duration)
|
|||
|
|
return {
|
|||
|
|
'output': None,
|
|||
|
|
'status': 'failed',
|
|||
|
|
'error': f'HTTP请求失败: {str(e)}'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
elif node_type == 'database' or node_type == 'db':
|
|||
|
|
# 数据库操作节点:执行数据库操作
|
|||
|
|
node_data = node.get('data', {})
|
|||
|
|
data_source_id = node_data.get('data_source_id')
|
|||
|
|
operation = node_data.get('operation', 'query') # query/insert/update/delete
|
|||
|
|
sql = node_data.get('sql', '')
|
|||
|
|
table = node_data.get('table', '')
|
|||
|
|
data = node_data.get('data', {})
|
|||
|
|
where = node_data.get('where', {})
|
|||
|
|
|
|||
|
|
# 如果SQL中包含变量,从input_data中替换
|
|||
|
|
if sql and isinstance(sql, str):
|
|||
|
|
import re
|
|||
|
|
def replace_sql_vars(text: str, data: Dict[str, Any]) -> str:
|
|||
|
|
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
|
|||
|
|
def replacer(match):
|
|||
|
|
key = match.group(1) or match.group(2)
|
|||
|
|
value = self._get_nested_value(data, key)
|
|||
|
|
if value is None:
|
|||
|
|
return match.group(0)
|
|||
|
|
# 如果是字符串,需要转义SQL注入
|
|||
|
|
if isinstance(value, str):
|
|||
|
|
# 简单转义,实际应该使用参数化查询
|
|||
|
|
escaped_value = value.replace("'", "''")
|
|||
|
|
return f"'{escaped_value}'"
|
|||
|
|
return str(value)
|
|||
|
|
return re.sub(pattern, replacer, text)
|
|||
|
|
sql = replace_sql_vars(sql, input_data)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 从数据库加载数据源配置
|
|||
|
|
if not self.db:
|
|||
|
|
raise ValueError("数据库会话未提供,无法执行数据库操作")
|
|||
|
|
|
|||
|
|
from app.models.data_source import DataSource
|
|||
|
|
from app.services.data_source_connector import create_connector
|
|||
|
|
|
|||
|
|
data_source = self.db.query(DataSource).filter(
|
|||
|
|
DataSource.id == data_source_id
|
|||
|
|
).first()
|
|||
|
|
|
|||
|
|
if not data_source:
|
|||
|
|
raise ValueError(f"数据源不存在: {data_source_id}")
|
|||
|
|
|
|||
|
|
connector = create_connector(data_source.type, data_source.config)
|
|||
|
|
|
|||
|
|
if operation == 'query':
|
|||
|
|
# 查询操作
|
|||
|
|
if not sql:
|
|||
|
|
raise ValueError("查询操作需要提供SQL语句")
|
|||
|
|
query_params = {'query': sql}
|
|||
|
|
result_data = connector.query(query_params)
|
|||
|
|
result = {'output': result_data, 'status': 'success'}
|
|||
|
|
elif operation == 'insert':
|
|||
|
|
# 插入操作
|
|||
|
|
if not table:
|
|||
|
|
raise ValueError("插入操作需要提供表名")
|
|||
|
|
# 构建INSERT SQL
|
|||
|
|
columns = ', '.join(data.keys())
|
|||
|
|
# 处理字符串值,转义单引号
|
|||
|
|
def escape_value(v):
|
|||
|
|
if isinstance(v, str):
|
|||
|
|
escaped = v.replace("'", "''")
|
|||
|
|
return f"'{escaped}'"
|
|||
|
|
return str(v)
|
|||
|
|
values = ', '.join([escape_value(v) for v in data.values()])
|
|||
|
|
insert_sql = f"INSERT INTO {table} ({columns}) VALUES ({values})"
|
|||
|
|
query_params = {'query': insert_sql}
|
|||
|
|
result_data = connector.query(query_params)
|
|||
|
|
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
|
|||
|
|
elif operation == 'update':
|
|||
|
|
# 更新操作
|
|||
|
|
if not table or not where:
|
|||
|
|
raise ValueError("更新操作需要提供表名和WHERE条件")
|
|||
|
|
set_clause = ', '.join([f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in data.items()])
|
|||
|
|
where_clause = ' AND '.join([f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in where.items()])
|
|||
|
|
update_sql = f"UPDATE {table} SET {set_clause} WHERE {where_clause}"
|
|||
|
|
query_params = {'query': update_sql}
|
|||
|
|
result_data = connector.query(query_params)
|
|||
|
|
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
|
|||
|
|
elif operation == 'delete':
|
|||
|
|
# 删除操作
|
|||
|
|
if not table or not where:
|
|||
|
|
raise ValueError("删除操作需要提供表名和WHERE条件")
|
|||
|
|
# 处理字符串值,转义单引号
|
|||
|
|
def escape_sql_value(k, v):
|
|||
|
|
if isinstance(v, str):
|
|||
|
|
escaped = v.replace("'", "''")
|
|||
|
|
return f"{k} = '{escaped}'"
|
|||
|
|
return f"{k} = {v}"
|
|||
|
|
where_clause = ' AND '.join([escape_sql_value(k, v) for k, v in where.items()])
|
|||
|
|
delete_sql = f"DELETE FROM {table} WHERE {where_clause}"
|
|||
|
|
query_params = {'query': delete_sql}
|
|||
|
|
result_data = connector.query(query_params)
|
|||
|
|
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"不支持的数据库操作: {operation}")
|
|||
|
|
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
|||
|
|
return result
|
|||
|
|
except Exception as e:
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_error(node_id, node_type, e, duration)
|
|||
|
|
return {
|
|||
|
|
'output': None,
|
|||
|
|
'status': 'failed',
|
|||
|
|
'error': f'数据库操作失败: {str(e)}'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
elif node_type == 'file' or node_type == 'file_operation':
|
|||
|
|
# 文件操作节点:文件读取、写入、上传、下载
|
|||
|
|
node_data = node.get('data', {})
|
|||
|
|
operation = node_data.get('operation', 'read') # read/write/upload/download
|
|||
|
|
file_path = node_data.get('file_path', '')
|
|||
|
|
content = node_data.get('content', '')
|
|||
|
|
encoding = node_data.get('encoding', 'utf-8')
|
|||
|
|
|
|||
|
|
# 替换文件路径和内容中的变量
|
|||
|
|
import re
|
|||
|
|
def replace_variables(text: str, data: Dict[str, Any]) -> str:
|
|||
|
|
"""替换字符串中的变量占位符"""
|
|||
|
|
if not isinstance(text, str):
|
|||
|
|
return text
|
|||
|
|
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
|
|||
|
|
def replacer(match):
|
|||
|
|
key = match.group(1) or match.group(2)
|
|||
|
|
value = self._get_nested_value(data, key)
|
|||
|
|
return str(value) if value is not None else match.group(0)
|
|||
|
|
return re.sub(pattern, replacer, text)
|
|||
|
|
|
|||
|
|
if file_path:
|
|||
|
|
file_path = replace_variables(file_path, input_data)
|
|||
|
|
if isinstance(content, str):
|
|||
|
|
content = replace_variables(content, input_data)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import os
|
|||
|
|
import json
|
|||
|
|
import base64
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
if operation == 'read':
|
|||
|
|
# 读取文件
|
|||
|
|
if not file_path:
|
|||
|
|
raise ValueError("读取操作需要提供文件路径")
|
|||
|
|
|
|||
|
|
if not os.path.exists(file_path):
|
|||
|
|
raise FileNotFoundError(f"文件不存在: {file_path}")
|
|||
|
|
|
|||
|
|
# 根据文件扩展名决定读取方式
|
|||
|
|
file_ext = Path(file_path).suffix.lower()
|
|||
|
|
if file_ext == '.json':
|
|||
|
|
with open(file_path, 'r', encoding=encoding) as f:
|
|||
|
|
data = json.load(f)
|
|||
|
|
elif file_ext in ['.txt', '.md', '.log']:
|
|||
|
|
with open(file_path, 'r', encoding=encoding) as f:
|
|||
|
|
data = f.read()
|
|||
|
|
else:
|
|||
|
|
# 二进制文件,返回base64编码
|
|||
|
|
with open(file_path, 'rb') as f:
|
|||
|
|
data = base64.b64encode(f.read()).decode('utf-8')
|
|||
|
|
|
|||
|
|
result = {'output': data, 'status': 'success'}
|
|||
|
|
|
|||
|
|
elif operation == 'write':
|
|||
|
|
# 写入文件
|
|||
|
|
if not file_path:
|
|||
|
|
raise ValueError("写入操作需要提供文件路径")
|
|||
|
|
|
|||
|
|
# 确保目录存在
|
|||
|
|
os.makedirs(os.path.dirname(file_path) if os.path.dirname(file_path) else '.', exist_ok=True)
|
|||
|
|
|
|||
|
|
# 如果content是字典或列表,转换为JSON
|
|||
|
|
if isinstance(content, (dict, list)):
|
|||
|
|
content = json.dumps(content, ensure_ascii=False, indent=2)
|
|||
|
|
|
|||
|
|
# 根据文件扩展名决定写入方式
|
|||
|
|
file_ext = Path(file_path).suffix.lower()
|
|||
|
|
if file_ext == '.json':
|
|||
|
|
with open(file_path, 'w', encoding=encoding) as f:
|
|||
|
|
json.dump(json.loads(content) if isinstance(content, str) else content, f, ensure_ascii=False, indent=2)
|
|||
|
|
else:
|
|||
|
|
with open(file_path, 'w', encoding=encoding) as f:
|
|||
|
|
f.write(str(content))
|
|||
|
|
|
|||
|
|
result = {'output': {'file_path': file_path, 'message': '文件写入成功'}, 'status': 'success'}
|
|||
|
|
|
|||
|
|
elif operation == 'upload':
|
|||
|
|
# 文件上传(从base64或URL上传)
|
|||
|
|
upload_type = node_data.get('upload_type', 'base64') # base64/url
|
|||
|
|
target_path = node_data.get('target_path', '')
|
|||
|
|
|
|||
|
|
if upload_type == 'base64':
|
|||
|
|
# 从输入数据中获取base64编码的文件内容
|
|||
|
|
file_data = input_data.get('file_data') or input_data.get('content')
|
|||
|
|
if not file_data:
|
|||
|
|
raise ValueError("上传操作需要提供file_data或content字段")
|
|||
|
|
|
|||
|
|
# 解码base64
|
|||
|
|
if isinstance(file_data, str):
|
|||
|
|
file_bytes = base64.b64decode(file_data)
|
|||
|
|
else:
|
|||
|
|
file_bytes = file_data
|
|||
|
|
|
|||
|
|
# 写入目标路径
|
|||
|
|
if not target_path:
|
|||
|
|
raise ValueError("上传操作需要提供target_path")
|
|||
|
|
|
|||
|
|
os.makedirs(os.path.dirname(target_path) if os.path.dirname(target_path) else '.', exist_ok=True)
|
|||
|
|
with open(target_path, 'wb') as f:
|
|||
|
|
f.write(file_bytes)
|
|||
|
|
|
|||
|
|
result = {'output': {'file_path': target_path, 'message': '文件上传成功'}, 'status': 'success'}
|
|||
|
|
else:
|
|||
|
|
# URL上传(下载后保存)
|
|||
|
|
import httpx
|
|||
|
|
url = node_data.get('url', '')
|
|||
|
|
if not url:
|
|||
|
|
raise ValueError("URL上传需要提供url")
|
|||
|
|
|
|||
|
|
async with httpx.AsyncClient() as client:
|
|||
|
|
response = await client.get(url)
|
|||
|
|
response.raise_for_status()
|
|||
|
|
|
|||
|
|
if not target_path:
|
|||
|
|
# 从URL提取文件名
|
|||
|
|
target_path = os.path.basename(url) or 'downloaded_file'
|
|||
|
|
|
|||
|
|
os.makedirs(os.path.dirname(target_path) if os.path.dirname(target_path) else '.', exist_ok=True)
|
|||
|
|
with open(target_path, 'wb') as f:
|
|||
|
|
f.write(response.content)
|
|||
|
|
|
|||
|
|
result = {'output': {'file_path': target_path, 'message': '文件下载并保存成功'}, 'status': 'success'}
|
|||
|
|
|
|||
|
|
elif operation == 'download':
|
|||
|
|
# 文件下载(返回base64编码或文件URL)
|
|||
|
|
download_format = node_data.get('download_format', 'base64') # base64/url
|
|||
|
|
|
|||
|
|
if not file_path:
|
|||
|
|
raise ValueError("下载操作需要提供文件路径")
|
|||
|
|
|
|||
|
|
if not os.path.exists(file_path):
|
|||
|
|
raise FileNotFoundError(f"文件不存在: {file_path}")
|
|||
|
|
|
|||
|
|
if download_format == 'base64':
|
|||
|
|
# 返回base64编码
|
|||
|
|
with open(file_path, 'rb') as f:
|
|||
|
|
file_bytes = f.read()
|
|||
|
|
file_base64 = base64.b64encode(file_bytes).decode('utf-8')
|
|||
|
|
result = {'output': {'file_name': os.path.basename(file_path), 'content': file_base64, 'format': 'base64'}, 'status': 'success'}
|
|||
|
|
else:
|
|||
|
|
# 返回文件路径(实际应用中可能需要生成临时URL)
|
|||
|
|
result = {'output': {'file_path': file_path, 'format': 'path'}, 'status': 'success'}
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"不支持的文件操作: {operation}")
|
|||
|
|
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
|||
|
|
return result
|
|||
|
|
except Exception as e:
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_error(node_id, node_type, e, duration)
|
|||
|
|
return {
|
|||
|
|
'output': None,
|
|||
|
|
'status': 'failed',
|
|||
|
|
'error': f'文件操作失败: {str(e)}'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
elif node_type == 'webhook':
|
|||
|
|
# Webhook节点:发送Webhook请求到外部系统
|
|||
|
|
node_data = node.get('data', {})
|
|||
|
|
url = node_data.get('url', '')
|
|||
|
|
method = node_data.get('method', 'POST').upper()
|
|||
|
|
headers = node_data.get('headers', {})
|
|||
|
|
body = node_data.get('body', {})
|
|||
|
|
timeout = node_data.get('timeout', 30)
|
|||
|
|
|
|||
|
|
# 如果URL、headers、body中包含变量,从input_data中替换
|
|||
|
|
import re
|
|||
|
|
def replace_variables(text: str, data: Dict[str, Any]) -> str:
|
|||
|
|
"""替换字符串中的变量占位符"""
|
|||
|
|
if not isinstance(text, str):
|
|||
|
|
return text
|
|||
|
|
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
|
|||
|
|
def replacer(match):
|
|||
|
|
key = match.group(1) or match.group(2)
|
|||
|
|
value = self._get_nested_value(data, key)
|
|||
|
|
return str(value) if value is not None else match.group(0)
|
|||
|
|
return re.sub(pattern, replacer, text)
|
|||
|
|
|
|||
|
|
# 替换URL中的变量
|
|||
|
|
if url:
|
|||
|
|
url = replace_variables(url, input_data)
|
|||
|
|
|
|||
|
|
# 替换headers中的变量
|
|||
|
|
if isinstance(headers, dict):
|
|||
|
|
headers = {k: replace_variables(str(v), input_data) for k, v in headers.items()}
|
|||
|
|
elif isinstance(headers, str):
|
|||
|
|
try:
|
|||
|
|
headers = json.loads(replace_variables(headers, input_data))
|
|||
|
|
except:
|
|||
|
|
headers = {}
|
|||
|
|
|
|||
|
|
# 替换body中的变量
|
|||
|
|
if isinstance(body, dict):
|
|||
|
|
# 递归替换字典中的变量
|
|||
|
|
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
|
|||
|
|
result = {}
|
|||
|
|
for k, v in d.items():
|
|||
|
|
new_k = replace_variables(k, data)
|
|||
|
|
if isinstance(v, dict):
|
|||
|
|
result[new_k] = replace_dict_vars(v, data)
|
|||
|
|
elif isinstance(v, str):
|
|||
|
|
result[new_k] = replace_variables(v, data)
|
|||
|
|
else:
|
|||
|
|
result[new_k] = v
|
|||
|
|
return result
|
|||
|
|
body = replace_dict_vars(body, input_data)
|
|||
|
|
elif isinstance(body, str):
|
|||
|
|
body = replace_variables(body, input_data)
|
|||
|
|
try:
|
|||
|
|
body = json.loads(body)
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
# 如果没有配置body,默认使用input_data作为body
|
|||
|
|
if not body:
|
|||
|
|
body = input_data
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import httpx
|
|||
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|||
|
|
if method == 'GET':
|
|||
|
|
response = await client.get(url, headers=headers)
|
|||
|
|
elif method == 'POST':
|
|||
|
|
response = await client.post(url, json=body, headers=headers)
|
|||
|
|
elif method == 'PUT':
|
|||
|
|
response = await client.put(url, json=body, headers=headers)
|
|||
|
|
elif method == 'PATCH':
|
|||
|
|
response = await client.patch(url, json=body, headers=headers)
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"Webhook不支持HTTP方法: {method}")
|
|||
|
|
|
|||
|
|
# 尝试解析JSON响应
|
|||
|
|
try:
|
|||
|
|
response_data = response.json()
|
|||
|
|
except:
|
|||
|
|
response_data = response.text
|
|||
|
|
|
|||
|
|
result = {
|
|||
|
|
'output': {
|
|||
|
|
'status_code': response.status_code,
|
|||
|
|
'headers': dict(response.headers),
|
|||
|
|
'data': response_data
|
|||
|
|
},
|
|||
|
|
'status': 'success'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
|||
|
|
return result
|
|||
|
|
except Exception as e:
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_error(node_id, node_type, e, duration)
|
|||
|
|
return {
|
|||
|
|
'output': None,
|
|||
|
|
'status': 'failed',
|
|||
|
|
'error': f'Webhook请求失败: {str(e)}'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
elif node_type == 'schedule' or node_type == 'delay' or node_type == 'timer':
|
|||
|
|
# 定时任务节点:延迟执行或定时执行
|
|||
|
|
node_data = node.get('data', {})
|
|||
|
|
delay_type = node_data.get('delay_type', 'fixed') # fixed: 固定延迟, cron: cron表达式
|
|||
|
|
delay_value = node_data.get('delay_value', 0) # 延迟值(秒)
|
|||
|
|
delay_unit = node_data.get('delay_unit', 'seconds') # seconds, minutes, hours
|
|||
|
|
|
|||
|
|
# 计算实际延迟时间(毫秒)
|
|||
|
|
if delay_unit == 'seconds':
|
|||
|
|
delay_ms = int(delay_value * 1000)
|
|||
|
|
elif delay_unit == 'minutes':
|
|||
|
|
delay_ms = int(delay_value * 60 * 1000)
|
|||
|
|
elif delay_unit == 'hours':
|
|||
|
|
delay_ms = int(delay_value * 60 * 60 * 1000)
|
|||
|
|
else:
|
|||
|
|
delay_ms = int(delay_value * 1000)
|
|||
|
|
|
|||
|
|
# 如果延迟时间大于0,则等待
|
|||
|
|
if delay_ms > 0:
|
|||
|
|
if self.logger:
|
|||
|
|
self.logger.info(
|
|||
|
|
f"定时任务节点等待 {delay_value} {delay_unit}",
|
|||
|
|
node_id=node_id,
|
|||
|
|
node_type=node_type,
|
|||
|
|
data={'delay_ms': delay_ms, 'delay_value': delay_value, 'delay_unit': delay_unit}
|
|||
|
|
)
|
|||
|
|
await asyncio.sleep(delay_ms / 1000.0)
|
|||
|
|
|
|||
|
|
# 返回输入数据(定时节点只是延迟,不改变数据)
|
|||
|
|
result = {'output': input_data, 'status': 'success', 'delay_ms': delay_ms}
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
elif node_type == 'email' or node_type == 'mail':
|
|||
|
|
# 邮件节点:发送邮件通知
|
|||
|
|
node_data = node.get('data', {})
|
|||
|
|
smtp_host = node_data.get('smtp_host', '')
|
|||
|
|
smtp_port = node_data.get('smtp_port', 587)
|
|||
|
|
smtp_user = node_data.get('smtp_user', '')
|
|||
|
|
smtp_password = node_data.get('smtp_password', '')
|
|||
|
|
use_tls = node_data.get('use_tls', True)
|
|||
|
|
from_email = node_data.get('from_email', '')
|
|||
|
|
to_email = node_data.get('to_email', '')
|
|||
|
|
cc_email = node_data.get('cc_email', '')
|
|||
|
|
bcc_email = node_data.get('bcc_email', '')
|
|||
|
|
subject = node_data.get('subject', '')
|
|||
|
|
body = node_data.get('body', '')
|
|||
|
|
body_type = node_data.get('body_type', 'text') # text/html
|
|||
|
|
attachments = node_data.get('attachments', []) # 附件列表
|
|||
|
|
|
|||
|
|
# 替换变量
|
|||
|
|
import re
|
|||
|
|
def replace_variables(text: str, data: Dict[str, Any]) -> str:
|
|||
|
|
"""替换字符串中的变量占位符"""
|
|||
|
|
if not isinstance(text, str):
|
|||
|
|
return text
|
|||
|
|
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
|
|||
|
|
def replacer(match):
|
|||
|
|
key = match.group(1) or match.group(2)
|
|||
|
|
value = self._get_nested_value(data, key)
|
|||
|
|
return str(value) if value is not None else match.group(0)
|
|||
|
|
return re.sub(pattern, replacer, text)
|
|||
|
|
|
|||
|
|
# 替换所有配置中的变量
|
|||
|
|
smtp_host = replace_variables(smtp_host, input_data)
|
|||
|
|
smtp_user = replace_variables(smtp_user, input_data)
|
|||
|
|
smtp_password = replace_variables(smtp_password, input_data)
|
|||
|
|
from_email = replace_variables(from_email, input_data)
|
|||
|
|
to_email = replace_variables(to_email, input_data)
|
|||
|
|
cc_email = replace_variables(cc_email, input_data)
|
|||
|
|
bcc_email = replace_variables(bcc_email, input_data)
|
|||
|
|
subject = replace_variables(subject, input_data)
|
|||
|
|
body = replace_variables(body, input_data)
|
|||
|
|
|
|||
|
|
# 验证必需参数
|
|||
|
|
if not smtp_host:
|
|||
|
|
raise ValueError("邮件节点需要配置SMTP服务器地址")
|
|||
|
|
if not from_email:
|
|||
|
|
raise ValueError("邮件节点需要配置发件人邮箱")
|
|||
|
|
if not to_email:
|
|||
|
|
raise ValueError("邮件节点需要配置收件人邮箱")
|
|||
|
|
if not subject:
|
|||
|
|
raise ValueError("邮件节点需要配置邮件主题")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import aiosmtplib
|
|||
|
|
from email.mime.text import MIMEText
|
|||
|
|
from email.mime.multipart import MIMEMultipart
|
|||
|
|
from email.mime.base import MIMEBase
|
|||
|
|
from email import encoders
|
|||
|
|
import base64
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
# 创建邮件消息
|
|||
|
|
msg = MIMEMultipart('alternative')
|
|||
|
|
msg['From'] = from_email
|
|||
|
|
msg['To'] = to_email
|
|||
|
|
if cc_email:
|
|||
|
|
msg['Cc'] = cc_email
|
|||
|
|
msg['Subject'] = subject
|
|||
|
|
|
|||
|
|
# 添加邮件正文
|
|||
|
|
if body_type == 'html':
|
|||
|
|
msg.attach(MIMEText(body, 'html', 'utf-8'))
|
|||
|
|
else:
|
|||
|
|
msg.attach(MIMEText(body, 'plain', 'utf-8'))
|
|||
|
|
|
|||
|
|
# 处理附件
|
|||
|
|
for attachment in attachments:
|
|||
|
|
if isinstance(attachment, dict):
|
|||
|
|
file_path = attachment.get('file_path', '')
|
|||
|
|
file_name = attachment.get('file_name', '')
|
|||
|
|
file_content = attachment.get('file_content', '') # base64编码的内容
|
|||
|
|
|
|||
|
|
# 替换变量
|
|||
|
|
file_path = replace_variables(file_path, input_data)
|
|||
|
|
file_name = replace_variables(file_name, input_data)
|
|||
|
|
|
|||
|
|
if file_path and os.path.exists(file_path):
|
|||
|
|
# 从文件路径读取
|
|||
|
|
with open(file_path, 'rb') as f:
|
|||
|
|
file_data = f.read()
|
|||
|
|
if not file_name:
|
|||
|
|
file_name = os.path.basename(file_path)
|
|||
|
|
elif file_content:
|
|||
|
|
# 从base64内容读取
|
|||
|
|
file_data = base64.b64decode(file_content)
|
|||
|
|
if not file_name:
|
|||
|
|
file_name = 'attachment'
|
|||
|
|
else:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 添加附件
|
|||
|
|
part = MIMEBase('application', 'octet-stream')
|
|||
|
|
part.set_payload(file_data)
|
|||
|
|
encoders.encode_base64(part)
|
|||
|
|
part.add_header(
|
|||
|
|
'Content-Disposition',
|
|||
|
|
f'attachment; filename= {file_name}'
|
|||
|
|
)
|
|||
|
|
msg.attach(part)
|
|||
|
|
|
|||
|
|
# 发送邮件
|
|||
|
|
recipients = [to_email]
|
|||
|
|
if cc_email:
|
|||
|
|
recipients.extend([email.strip() for email in cc_email.split(',')])
|
|||
|
|
if bcc_email:
|
|||
|
|
recipients.extend([email.strip() for email in bcc_email.split(',')])
|
|||
|
|
|
|||
|
|
async with aiosmtplib.SMTP(hostname=smtp_host, port=smtp_port) as smtp:
|
|||
|
|
if use_tls:
|
|||
|
|
await smtp.starttls()
|
|||
|
|
if smtp_user and smtp_password:
|
|||
|
|
await smtp.login(smtp_user, smtp_password)
|
|||
|
|
await smtp.send_message(msg, recipients=recipients)
|
|||
|
|
|
|||
|
|
result = {
|
|||
|
|
'output': {
|
|||
|
|
'message': '邮件发送成功',
|
|||
|
|
'from': from_email,
|
|||
|
|
'to': to_email,
|
|||
|
|
'subject': subject,
|
|||
|
|
'recipients_count': len(recipients)
|
|||
|
|
},
|
|||
|
|
'status': 'success'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_error(node_id, node_type, e, duration)
|
|||
|
|
return {
|
|||
|
|
'output': None,
|
|||
|
|
'status': 'failed',
|
|||
|
|
'error': f'邮件发送失败: {str(e)}'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
elif node_type == 'message_queue' or node_type == 'mq' or node_type == 'rabbitmq' or node_type == 'kafka':
|
|||
|
|
# 消息队列节点:发送消息到RabbitMQ或Kafka
|
|||
|
|
node_data = node.get('data', {})
|
|||
|
|
queue_type = node_data.get('queue_type', 'rabbitmq') # rabbitmq/kafka
|
|||
|
|
|
|||
|
|
# 替换变量
|
|||
|
|
import re
|
|||
|
|
def replace_variables(text: str, data: Dict[str, Any]) -> str:
|
|||
|
|
"""替换字符串中的变量占位符"""
|
|||
|
|
if not isinstance(text, str):
|
|||
|
|
return text
|
|||
|
|
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
|
|||
|
|
def replacer(match):
|
|||
|
|
key = match.group(1) or match.group(2)
|
|||
|
|
value = self._get_nested_value(data, key)
|
|||
|
|
return str(value) if value is not None else match.group(0)
|
|||
|
|
return re.sub(pattern, replacer, text)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if queue_type == 'rabbitmq':
|
|||
|
|
# RabbitMQ实现
|
|||
|
|
import aio_pika
|
|||
|
|
import json
|
|||
|
|
|
|||
|
|
# 获取RabbitMQ配置
|
|||
|
|
host = replace_variables(node_data.get('host', 'localhost'), input_data)
|
|||
|
|
port = node_data.get('port', 5672)
|
|||
|
|
username = replace_variables(node_data.get('username', 'guest'), input_data)
|
|||
|
|
password = replace_variables(node_data.get('password', 'guest'), input_data)
|
|||
|
|
exchange = replace_variables(node_data.get('exchange', ''), input_data)
|
|||
|
|
routing_key = replace_variables(node_data.get('routing_key', ''), input_data)
|
|||
|
|
queue_name = replace_variables(node_data.get('queue_name', ''), input_data)
|
|||
|
|
message = node_data.get('message', input_data)
|
|||
|
|
|
|||
|
|
# 如果message是字符串,尝试替换变量
|
|||
|
|
if isinstance(message, str):
|
|||
|
|
message = replace_variables(message, input_data)
|
|||
|
|
try:
|
|||
|
|
message = json.loads(message)
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
elif isinstance(message, dict):
|
|||
|
|
# 递归替换字典中的变量
|
|||
|
|
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
|
|||
|
|
result = {}
|
|||
|
|
for k, v in d.items():
|
|||
|
|
new_k = replace_variables(k, data)
|
|||
|
|
if isinstance(v, dict):
|
|||
|
|
result[new_k] = replace_dict_vars(v, data)
|
|||
|
|
elif isinstance(v, str):
|
|||
|
|
result[new_k] = replace_variables(v, data)
|
|||
|
|
else:
|
|||
|
|
result[new_k] = v
|
|||
|
|
return result
|
|||
|
|
message = replace_dict_vars(message, input_data)
|
|||
|
|
|
|||
|
|
# 如果没有配置message,使用input_data
|
|||
|
|
if not message:
|
|||
|
|
message = input_data
|
|||
|
|
|
|||
|
|
# 连接RabbitMQ
|
|||
|
|
connection_url = f"amqp://{username}:{password}@{host}:{port}/"
|
|||
|
|
connection = await aio_pika.connect_robust(connection_url)
|
|||
|
|
channel = await connection.channel()
|
|||
|
|
|
|||
|
|
# 发送消息
|
|||
|
|
message_body = json.dumps(message, ensure_ascii=False).encode('utf-8')
|
|||
|
|
|
|||
|
|
if exchange:
|
|||
|
|
# 使用exchange和routing_key
|
|||
|
|
await channel.default_exchange.publish(
|
|||
|
|
aio_pika.Message(message_body),
|
|||
|
|
routing_key=routing_key or queue_name
|
|||
|
|
)
|
|||
|
|
elif queue_name:
|
|||
|
|
# 直接发送到队列
|
|||
|
|
queue = await channel.declare_queue(queue_name, durable=True)
|
|||
|
|
await channel.default_exchange.publish(
|
|||
|
|
aio_pika.Message(message_body),
|
|||
|
|
routing_key=queue_name
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
raise ValueError("RabbitMQ节点需要配置exchange或queue_name")
|
|||
|
|
|
|||
|
|
await connection.close()
|
|||
|
|
|
|||
|
|
result = {
|
|||
|
|
'output': {
|
|||
|
|
'message': '消息已发送到RabbitMQ',
|
|||
|
|
'queue_type': 'rabbitmq',
|
|||
|
|
'exchange': exchange,
|
|||
|
|
'routing_key': routing_key or queue_name,
|
|||
|
|
'queue_name': queue_name,
|
|||
|
|
'message_size': len(message_body)
|
|||
|
|
},
|
|||
|
|
'status': 'success'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
elif queue_type == 'kafka':
|
|||
|
|
# Kafka实现
|
|||
|
|
from kafka import KafkaProducer
|
|||
|
|
import json
|
|||
|
|
|
|||
|
|
# 获取Kafka配置
|
|||
|
|
bootstrap_servers = replace_variables(node_data.get('bootstrap_servers', 'localhost:9092'), input_data)
|
|||
|
|
topic = replace_variables(node_data.get('topic', ''), input_data)
|
|||
|
|
message = node_data.get('message', input_data)
|
|||
|
|
|
|||
|
|
# 如果message是字符串,尝试替换变量
|
|||
|
|
if isinstance(message, str):
|
|||
|
|
message = replace_variables(message, input_data)
|
|||
|
|
try:
|
|||
|
|
message = json.loads(message)
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
elif isinstance(message, dict):
|
|||
|
|
# 递归替换字典中的变量
|
|||
|
|
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
|
|||
|
|
result = {}
|
|||
|
|
for k, v in d.items():
|
|||
|
|
new_k = replace_variables(k, data)
|
|||
|
|
if isinstance(v, dict):
|
|||
|
|
result[new_k] = replace_dict_vars(v, data)
|
|||
|
|
elif isinstance(v, str):
|
|||
|
|
result[new_k] = replace_variables(v, data)
|
|||
|
|
else:
|
|||
|
|
result[new_k] = v
|
|||
|
|
return result
|
|||
|
|
message = replace_dict_vars(message, input_data)
|
|||
|
|
|
|||
|
|
# 如果没有配置message,使用input_data
|
|||
|
|
if not message:
|
|||
|
|
message = input_data
|
|||
|
|
|
|||
|
|
if not topic:
|
|||
|
|
raise ValueError("Kafka节点需要配置topic")
|
|||
|
|
|
|||
|
|
# 创建Kafka生产者(注意:kafka-python是同步的,需要在线程池中运行)
|
|||
|
|
import asyncio
|
|||
|
|
from concurrent.futures import ThreadPoolExecutor
|
|||
|
|
|
|||
|
|
def send_kafka_message():
|
|||
|
|
producer = KafkaProducer(
|
|||
|
|
bootstrap_servers=bootstrap_servers.split(','),
|
|||
|
|
value_serializer=lambda v: json.dumps(v, ensure_ascii=False).encode('utf-8')
|
|||
|
|
)
|
|||
|
|
future = producer.send(topic, message)
|
|||
|
|
record_metadata = future.get(timeout=10)
|
|||
|
|
producer.close()
|
|||
|
|
return record_metadata
|
|||
|
|
|
|||
|
|
# 在线程池中执行同步操作
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
with ThreadPoolExecutor() as executor:
|
|||
|
|
record_metadata = await loop.run_in_executor(executor, send_kafka_message)
|
|||
|
|
|
|||
|
|
result = {
|
|||
|
|
'output': {
|
|||
|
|
'message': '消息已发送到Kafka',
|
|||
|
|
'queue_type': 'kafka',
|
|||
|
|
'topic': topic,
|
|||
|
|
'partition': record_metadata.partition,
|
|||
|
|
'offset': record_metadata.offset
|
|||
|
|
},
|
|||
|
|
'status': 'success'
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"不支持的消息队列类型: {queue_type}")
|
|||
|
|
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_error(node_id, node_type, e, duration)
|
|||
|
|
return {
|
|||
|
|
'output': None,
|
|||
|
|
'status': 'failed',
|
|||
|
|
'error': f'消息队列发送失败: {str(e)}'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
elif node_type == 'output' or node_type == 'end':
|
|||
|
|
# 输出节点:返回最终结果
|
|||
|
|
# 对于人机交互场景,End节点应该返回纯文本字符串,而不是JSON
|
|||
|
|
logger.debug(f"[rjb] End节点处理: node_id={node_id}, input_data={input_data}, input_data type={type(input_data)}")
|
|||
|
|
final_output = input_data
|
|||
|
|
|
|||
|
|
# 递归解包,提取实际的文本内容
|
|||
|
|
if isinstance(input_data, dict):
|
|||
|
|
# 如果只有一个 key 且是 'input',提取其值
|
|||
|
|
if len(input_data) == 1 and 'input' in input_data:
|
|||
|
|
final_output = input_data['input']
|
|||
|
|
logger.debug(f"[rjb] End节点提取第一层: final_output={final_output}, type={type(final_output)}")
|
|||
|
|
# 如果提取的值仍然是字典且只有一个 'input' key,继续提取
|
|||
|
|
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
|
|||
|
|
final_output = final_output['input']
|
|||
|
|
logger.debug(f"[rjb] End节点提取第二层: final_output={final_output}, type={type(final_output)}")
|
|||
|
|
|
|||
|
|
# 确保最终输出是字符串(对于人机交互场景)
|
|||
|
|
# 如果是字典,尝试转换为字符串;如果是其他类型,也转换为字符串
|
|||
|
|
if not isinstance(final_output, str):
|
|||
|
|
if isinstance(final_output, dict):
|
|||
|
|
# 如果是字典,尝试提取文本内容或转换为JSON字符串
|
|||
|
|
# 优先查找常见的文本字段
|
|||
|
|
if 'text' in final_output:
|
|||
|
|
final_output = str(final_output['text'])
|
|||
|
|
elif 'content' in final_output:
|
|||
|
|
final_output = str(final_output['content'])
|
|||
|
|
elif 'message' in final_output:
|
|||
|
|
final_output = str(final_output['message'])
|
|||
|
|
elif 'response' in final_output:
|
|||
|
|
final_output = str(final_output['response'])
|
|||
|
|
elif len(final_output) == 1:
|
|||
|
|
# 如果只有一个key,直接使用其值
|
|||
|
|
final_output = str(list(final_output.values())[0])
|
|||
|
|
else:
|
|||
|
|
# 否则转换为JSON字符串
|
|||
|
|
final_output = json_module.dumps(final_output, ensure_ascii=False)
|
|||
|
|
else:
|
|||
|
|
final_output = str(final_output)
|
|||
|
|
|
|||
|
|
logger.debug(f"[rjb] End节点最终输出: final_output={final_output}, type={type(final_output)}")
|
|||
|
|
result = {'output': final_output, 'status': 'success'}
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_complete(node_id, node_type, final_output, duration)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
# 未知节点类型
|
|||
|
|
return {
|
|||
|
|
'output': input_data,
|
|||
|
|
'status': 'success',
|
|||
|
|
'message': f'节点类型 {node_type} 暂未实现'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"节点执行失败: {node_id} ({node_type}) - {str(e)}", exc_info=True)
|
|||
|
|
if self.logger:
|
|||
|
|
duration = int((time.time() - start_time) * 1000)
|
|||
|
|
self.logger.log_node_error(node_id, node_type, e, duration)
|
|||
|
|
return {
|
|||
|
|
'output': None,
|
|||
|
|
'status': 'failed',
|
|||
|
|
'error': str(e),
|
|||
|
|
'node_id': node_id,
|
|||
|
|
'node_type': node_type
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def execute(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
执行完整工作流
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
input_data: 初始输入数据
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
执行结果
|
|||
|
|
"""
|
|||
|
|
# 记录工作流开始执行
|
|||
|
|
if self.logger:
|
|||
|
|
self.logger.info("工作流开始执行", data={"input": input_data})
|
|||
|
|
|
|||
|
|
# 初始化节点输出
|
|||
|
|
self.node_outputs = {}
|
|||
|
|
active_edges = self.edges.copy() # 活跃的边列表
|
|||
|
|
executed_nodes = set() # 已执行的节点
|
|||
|
|
|
|||
|
|
# 按拓扑顺序执行节点(动态构建执行图)
|
|||
|
|
results = {}
|
|||
|
|
|
|||
|
|
while True:
|
|||
|
|
# 构建当前活跃的执行图
|
|||
|
|
execution_order = self.build_execution_graph(active_edges)
|
|||
|
|
|
|||
|
|
# 找到下一个要执行的节点(未执行且入度为0)
|
|||
|
|
next_node_id = None
|
|||
|
|
for node_id in execution_order:
|
|||
|
|
if node_id not in executed_nodes:
|
|||
|
|
# 检查所有前置节点是否已执行
|
|||
|
|
can_execute = True
|
|||
|
|
for edge in active_edges:
|
|||
|
|
if edge['target'] == node_id:
|
|||
|
|
if edge['source'] not in executed_nodes:
|
|||
|
|
can_execute = False
|
|||
|
|
break
|
|||
|
|
if can_execute:
|
|||
|
|
next_node_id = node_id
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not next_node_id:
|
|||
|
|
break # 没有更多节点可执行
|
|||
|
|
|
|||
|
|
node = self.nodes[next_node_id]
|
|||
|
|
executed_nodes.add(next_node_id)
|
|||
|
|
|
|||
|
|
# 调试:检查节点数据结构
|
|||
|
|
if node.get('type') == 'llm':
|
|||
|
|
logger.debug(f"[rjb] 执行LLM节点: node_id={next_node_id}, node keys={list(node.keys())}, data keys={list(node.get('data', {}).keys()) if node.get('data') else []}")
|
|||
|
|
|
|||
|
|
# 获取节点输入(使用活跃的边)
|
|||
|
|
node_input = self.get_node_input(next_node_id, self.node_outputs, active_edges)
|
|||
|
|
|
|||
|
|
# 如果是起始节点,使用初始输入
|
|||
|
|
if node.get('type') == 'start' and not node_input:
|
|||
|
|
node_input = input_data
|
|||
|
|
|
|||
|
|
# 调试:记录节点输入数据
|
|||
|
|
if node.get('type') == 'llm':
|
|||
|
|
logger.debug(f"[rjb] LLM节点输入: node_id={next_node_id}, node_input={node_input}, node_outputs keys={list(self.node_outputs.keys())}")
|
|||
|
|
|
|||
|
|
# 执行节点
|
|||
|
|
result = await self.execute_node(node, node_input)
|
|||
|
|
results[next_node_id] = result
|
|||
|
|
|
|||
|
|
# 保存节点输出
|
|||
|
|
if result.get('status') == 'success':
|
|||
|
|
self.node_outputs[next_node_id] = result.get('output', {})
|
|||
|
|
|
|||
|
|
# 如果是条件节点,根据分支结果过滤边
|
|||
|
|
if node.get('type') == 'condition':
|
|||
|
|
branch = result.get('branch', 'false')
|
|||
|
|
# 移除不符合条件的边
|
|||
|
|
active_edges = [
|
|||
|
|
edge for edge in active_edges
|
|||
|
|
if not (edge['source'] == next_node_id and edge.get('sourceHandle') != branch)
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 如果是循环节点,跳过循环体的节点(循环体已在节点内部执行)
|
|||
|
|
if node.get('type') in ['loop', 'foreach']:
|
|||
|
|
# 标记循环体的节点为已执行(简化处理)
|
|||
|
|
for edge in active_edges[:]: # 使用切片复制列表
|
|||
|
|
if edge.get('source') == next_node_id:
|
|||
|
|
target_id = edge.get('target')
|
|||
|
|
if target_id in self.nodes:
|
|||
|
|
# 检查是否是循环结束节点
|
|||
|
|
target_node = self.nodes[target_id]
|
|||
|
|
if target_node.get('type') not in ['loop_end', 'end']:
|
|||
|
|
# 标记为已执行(循环体已在循环节点内部执行)
|
|||
|
|
executed_nodes.add(target_id)
|
|||
|
|
# 继续查找循环体内的节点
|
|||
|
|
self._mark_loop_body_executed(target_id, executed_nodes, active_edges)
|
|||
|
|
else:
|
|||
|
|
# 执行失败,停止工作流
|
|||
|
|
error_msg = result.get('error', '未知错误')
|
|||
|
|
node_type = node.get('type', 'unknown')
|
|||
|
|
logger.error(f"工作流执行失败 - 节点: {next_node_id} ({node_type}), 错误: {error_msg}")
|
|||
|
|
raise WorkflowExecutionError(
|
|||
|
|
detail=error_msg,
|
|||
|
|
node_id=next_node_id
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 返回最终结果(最后一个执行的节点的输出)
|
|||
|
|
if executed_nodes:
|
|||
|
|
# 找到最后一个节点(没有出边的节点)
|
|||
|
|
last_node_id = None
|
|||
|
|
for node_id in executed_nodes:
|
|||
|
|
has_outgoing = any(edge['source'] == node_id for edge in active_edges)
|
|||
|
|
if not has_outgoing:
|
|||
|
|
last_node_id = node_id
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not last_node_id:
|
|||
|
|
# 如果没有找到,使用最后一个执行的节点
|
|||
|
|
last_node_id = list(executed_nodes)[-1]
|
|||
|
|
|
|||
|
|
# 获取最终结果
|
|||
|
|
final_output = self.node_outputs.get(last_node_id)
|
|||
|
|
|
|||
|
|
# 如果最终输出是字典且只有一个 'input' key,提取其值
|
|||
|
|
# 这样可以确保最终结果不是重复包装的格式
|
|||
|
|
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
|
|||
|
|
final_output = final_output['input']
|
|||
|
|
# 如果提取的值仍然是字典且只有一个 'input' key,继续提取
|
|||
|
|
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
|
|||
|
|
final_output = final_output['input']
|
|||
|
|
|
|||
|
|
# 确保最终结果是字符串(对于人机交互场景)
|
|||
|
|
# 如果是字典,尝试转换为字符串;如果是其他类型,也转换为字符串
|
|||
|
|
if not isinstance(final_output, str):
|
|||
|
|
if isinstance(final_output, dict):
|
|||
|
|
# 如果是字典,尝试提取文本内容或转换为JSON字符串
|
|||
|
|
# 优先查找常见的文本字段
|
|||
|
|
if 'text' in final_output:
|
|||
|
|
final_output = str(final_output['text'])
|
|||
|
|
elif 'content' in final_output:
|
|||
|
|
final_output = str(final_output['content'])
|
|||
|
|
elif 'message' in final_output:
|
|||
|
|
final_output = str(final_output['message'])
|
|||
|
|
elif 'response' in final_output:
|
|||
|
|
final_output = str(final_output['response'])
|
|||
|
|
elif len(final_output) == 1:
|
|||
|
|
# 如果只有一个key,直接使用其值
|
|||
|
|
final_output = str(list(final_output.values())[0])
|
|||
|
|
else:
|
|||
|
|
# 否则转换为JSON字符串
|
|||
|
|
final_output = json_module.dumps(final_output, ensure_ascii=False)
|
|||
|
|
else:
|
|||
|
|
final_output = str(final_output)
|
|||
|
|
|
|||
|
|
final_result = {
|
|||
|
|
'status': 'completed',
|
|||
|
|
'result': final_output,
|
|||
|
|
'node_results': results
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 记录工作流执行完成
|
|||
|
|
if self.logger:
|
|||
|
|
self.logger.info("工作流执行完成", data={"result": final_result.get('result')})
|
|||
|
|
|
|||
|
|
return final_result
|
|||
|
|
|
|||
|
|
if self.logger:
|
|||
|
|
self.logger.warn("工作流执行完成,但没有执行任何节点")
|
|||
|
|
|
|||
|
|
return {'status': 'completed', 'result': None}
|