Files
aiagent/backend/app/services/workflow_engine.py
2026-01-20 09:40:16 +08:00

2034 lines
108 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工作流执行引擎
"""
from typing import Dict, Any, List, Optional
import asyncio
from collections import defaultdict, deque
import json
import logging
import re
from app.services.llm_service import llm_service
from app.services.condition_parser import condition_parser
from app.services.data_transformer import data_transformer
from app.core.exceptions import WorkflowExecutionError
from app.core.database import SessionLocal
from app.models.agent import Agent
logger = logging.getLogger(__name__)
class WorkflowEngine:
"""工作流执行引擎"""
def __init__(self, workflow_id: str, workflow_data: Dict[str, Any], logger=None, db=None):
"""
初始化工作流引擎
Args:
workflow_id: 工作流ID
workflow_data: 工作流数据包含nodes和edges
logger: 执行日志记录器(可选)
db: 数据库会话可选用于Agent节点加载Agent配置
"""
self.workflow_id = workflow_id
self.nodes = {node['id']: node for node in workflow_data.get('nodes', [])}
self.edges = workflow_data.get('edges', [])
self.execution_graph = None
self.node_outputs = {}
self.logger = logger
self.db = db
def build_execution_graph(self, active_edges: Optional[List[Dict[str, Any]]] = None) -> List[str]:
"""
构建执行图DAG并返回拓扑排序结果
Args:
active_edges: 活跃的边列表(用于条件分支过滤)
Returns:
拓扑排序后的节点ID列表
"""
# 使用活跃的边,如果没有提供则使用所有边
edges_to_use = active_edges if active_edges is not None else self.edges
# 构建邻接表和入度表
graph = defaultdict(list)
in_degree = defaultdict(int)
# 初始化所有节点的入度
for node_id in self.nodes.keys():
in_degree[node_id] = 0
# 构建图
for edge in edges_to_use:
source = edge['source']
target = edge['target']
graph[source].append(target)
in_degree[target] += 1
# 拓扑排序Kahn算法
queue = deque()
result = []
# 找到所有入度为0的节点起始节点
for node_id in self.nodes.keys():
if in_degree[node_id] == 0:
queue.append(node_id)
while queue:
node_id = queue.popleft()
result.append(node_id)
# 处理该节点的所有出边
for neighbor in graph[node_id]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
# 检查是否有环(只检查可达节点)
reachable_nodes = set(result)
if len(reachable_nodes) < len(self.nodes):
# 有些节点不可达,这是正常的(条件分支)
pass
self.execution_graph = result
return result
def get_node_input(self, node_id: str, node_outputs: Dict[str, Any], active_edges: Optional[List[Dict[str, Any]]] = None) -> Dict[str, Any]:
"""
获取节点的输入数据
Args:
node_id: 节点ID
node_outputs: 所有节点的输出数据
active_edges: 活跃的边列表(用于条件分支过滤)
Returns:
节点的输入数据
"""
# 使用活跃的边,如果没有提供则使用所有边
edges_to_use = active_edges if active_edges is not None else self.edges
# 找到所有指向该节点的边
input_data = {}
for edge in edges_to_use:
if edge['target'] == node_id:
source_id = edge['source']
source_output = node_outputs.get(source_id, {})
logger.info(f"[rjb] 获取节点输入: target={node_id}, source={source_id}, source_output={source_output}, source_output_type={type(source_output)}, sourceHandle={edge.get('sourceHandle')}")
# 如果有sourceHandle使用它作为key
if 'sourceHandle' in edge and edge['sourceHandle']:
input_data[edge['sourceHandle']] = source_output
else:
# 否则合并所有输入
if isinstance(source_output, dict):
# 如果source_output包含output字段展开它
if 'output' in source_output and isinstance(source_output['output'], dict):
# 将output中的内容展开到顶层
input_data.update(source_output['output'])
# 保留其他字段如status
for key, value in source_output.items():
if key != 'output':
input_data[key] = value
else:
# 直接展开source_output的内容
input_data.update(source_output)
logger.info(f"[rjb] 展开source_output后: input_data={input_data}")
else:
# 如果source_output不是字典包装到input字段
input_data['input'] = source_output
logger.info(f"[rjb] source_output不是字典包装到input字段: input_data={input_data}")
# 如果input_data中没有query字段尝试从所有已执行的节点中查找特别是start节点
if 'query' not in input_data:
# 优先查找start节点
for node_id_key in ['start-1', 'start']:
if node_id_key in node_outputs:
node_output = node_outputs[node_id_key]
if isinstance(node_output, dict):
# 检查顶层字段因为node_outputs存储的是output字段的内容
if 'query' in node_output:
input_data['query'] = node_output['query']
logger.debug(f"[rjb] 从节点 {node_id_key} 中获取query: {input_data['query']}")
break
# 检查output字段兼容性
elif 'output' in node_output and isinstance(node_output['output'], dict):
if 'query' in node_output['output']:
input_data['query'] = node_output['output']['query']
logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取query: {input_data['query']}")
break
# 如果还没找到,遍历所有节点
if 'query' not in input_data:
for node_id_key, node_output in node_outputs.items():
if isinstance(node_output, dict):
# 检查顶层字段
if 'query' in node_output:
input_data['query'] = node_output['query']
logger.debug(f"[rjb] 从节点 {node_id_key} 中获取query: {input_data['query']}")
break
# 检查output字段兼容性
elif 'output' in node_output and isinstance(node_output['output'], dict):
if 'query' in node_output['output']:
input_data['query'] = node_output['output']['query']
logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取query: {input_data['query']}")
break
# 如果input_data中没有requirement_analysis字段尝试从所有已执行的节点中查找
if 'requirement_analysis' not in input_data:
# 优先查找requirement-analysis节点
for node_id_key in ['llm-requirement-analysis', 'requirement-analysis']:
if node_id_key in node_outputs:
node_output = node_outputs[node_id_key]
if isinstance(node_output, dict):
# 检查顶层字段因为node_outputs存储的是output字段的内容
if 'requirement_analysis' in node_output:
input_data['requirement_analysis'] = node_output['requirement_analysis']
logger.debug(f"[rjb] 从节点 {node_id_key} 中获取requirement_analysis")
break
# 检查output字段兼容性
elif 'output' in node_output and isinstance(node_output['output'], dict):
if 'requirement_analysis' in node_output['output']:
input_data['requirement_analysis'] = node_output['output']['requirement_analysis']
logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取requirement_analysis")
break
# 如果还没找到,遍历所有节点
if 'requirement_analysis' not in input_data:
for node_id_key, node_output in node_outputs.items():
if isinstance(node_output, dict):
# 检查顶层字段
if 'requirement_analysis' in node_output:
input_data['requirement_analysis'] = node_output['requirement_analysis']
logger.debug(f"[rjb] 从节点 {node_id_key} 中获取requirement_analysis")
break
# 检查output字段兼容性
elif 'output' in node_output and isinstance(node_output['output'], dict):
if 'requirement_analysis' in node_output['output']:
input_data['requirement_analysis'] = node_output['output']['requirement_analysis']
logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取requirement_analysis")
break
logger.debug(f"[rjb] 节点输入结果: node_id={node_id}, input_data={input_data}")
return input_data
def _get_nested_value(self, data: Dict[str, Any], path: str) -> Any:
"""
从嵌套字典中获取值(支持点号路径和数组索引)
Args:
data: 数据字典
path: 路径,如 "user.name""items[0].price"
Returns:
路径对应的值
"""
if not path:
return data
parts = path.split('.')
result = data
for part in parts:
if '[' in part and ']' in part:
# 处理数组索引,如 "items[0]"
key = part[:part.index('[')]
index_str = part[part.index('[') + 1:part.index(']')]
if isinstance(result, dict):
result = result.get(key)
elif isinstance(result, list):
try:
result = result[int(index_str)]
except (ValueError, IndexError):
return None
else:
return None
if result is None:
return None
else:
# 普通键访问
if isinstance(result, dict):
result = result.get(part)
else:
return None
if result is None:
return None
return result
async def _execute_loop_body(self, loop_node_id: str, loop_input: Dict[str, Any], iteration_index: int) -> Dict[str, Any]:
"""
执行循环体
Args:
loop_node_id: 循环节点ID
loop_input: 循环体的输入数据
iteration_index: 当前迭代索引
Returns:
循环体的执行结果
"""
# 找到循环节点的直接子节点(循环体开始节点)
loop_body_start_nodes = []
for edge in self.edges:
if edge.get('source') == loop_node_id:
target_id = edge.get('target')
if target_id and target_id in self.nodes:
loop_body_start_nodes.append(target_id)
if not loop_body_start_nodes:
# 如果没有子节点,直接返回输入数据
return {'output': loop_input, 'status': 'success'}
# 执行循环体:从循环体开始节点执行到循环结束节点或没有更多节点
# 简化处理:只执行第一个子节点链
executed_in_loop = set()
loop_results = {}
current_node_id = loop_body_start_nodes[0] # 简化:只执行第一个子节点链
# 执行循环体内的节点(简化版本:只执行直接连接的子节点)
max_iterations = 100 # 防止无限循环
iteration = 0
while current_node_id and iteration < max_iterations:
iteration += 1
if current_node_id in executed_in_loop:
break # 避免循环体内部循环
if current_node_id not in self.nodes:
break
node = self.nodes[current_node_id]
executed_in_loop.add(current_node_id)
# 如果是循环结束节点,停止执行
if node.get('type') == 'loop_end' or node.get('type') == 'end':
break
# 执行节点
result = await self.execute_node(node, loop_input)
loop_results[current_node_id] = result
if result.get('status') != 'success':
return result
# 更新输入数据为当前节点的输出
if result.get('output'):
if isinstance(result.get('output'), dict):
loop_input = {**loop_input, **result.get('output')}
else:
loop_input = {**loop_input, 'result': result.get('output')}
# 找到下一个节点(简化:只找第一个子节点)
next_node_id = None
for edge in self.edges:
if edge.get('source') == current_node_id:
target_id = edge.get('target')
if target_id and target_id in self.nodes and target_id not in executed_in_loop:
# 跳过循环节点本身
if target_id != loop_node_id:
next_node_id = target_id
break
current_node_id = next_node_id
# 返回最后一个节点的输出
if loop_results:
last_result = list(loop_results.values())[-1]
return last_result
return {'output': loop_input, 'status': 'success'}
def _mark_loop_body_executed(self, node_id: str, executed_nodes: set, active_edges: List[Dict[str, Any]]):
"""
递归标记循环体内的节点为已执行
Args:
node_id: 当前节点ID
executed_nodes: 已执行节点集合
active_edges: 活跃的边列表
"""
if node_id in executed_nodes:
return
executed_nodes.add(node_id)
# 查找所有子节点
for edge in active_edges:
if edge.get('source') == node_id:
target_id = edge.get('target')
if target_id in self.nodes:
target_node = self.nodes[target_id]
# 如果是循环结束节点,停止递归
if target_node.get('type') in ['loop_end', 'end']:
continue
# 递归标记子节点
self._mark_loop_body_executed(target_id, executed_nodes, active_edges)
async def execute_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
"""
执行单个节点
Args:
node: 节点配置
input_data: 输入数据
Returns:
节点执行结果
"""
# 确保可以访问全局的 json 模块
import json as json_module
node_type = node.get('type', 'unknown')
node_id = node.get('id')
import time
start_time = time.time()
# 记录节点开始执行
if self.logger:
self.logger.log_node_start(node_id, node_type, input_data)
try:
if node_type == 'start':
# 起始节点:返回输入数据
logger.debug(f"[rjb] 开始节点执行: node_id={node_id}, input_data={input_data}")
result = {'output': input_data, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
logger.debug(f"[rjb] 开始节点输出: node_id={node_id}, output={result.get('output')}")
return result
elif node_type == 'input':
# 输入节点:处理输入数据
result = {'output': input_data, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
elif node_type == 'llm' or node_type == 'template':
# LLM节点调用AI模型
node_data = node.get('data', {})
logger.debug(f"[rjb] LLM节点执行: node_id={node_id}, input_data={input_data}, input_data type={type(input_data)}")
logger.debug(f"[rjb] LLM节点数据: node_id={node_id}, node_data keys={list(node_data.keys())}, api_key={'已配置' if node_data.get('api_key') else '未配置'}")
prompt = node_data.get('prompt', '')
# 如果prompt为空使用默认提示词
if not prompt:
prompt = "请处理以下输入数据:\n{input}"
# 格式化prompt替换变量
try:
# 将input_data转换为字符串用于格式化
if isinstance(input_data, dict):
# 支持两种格式的变量:{key} 和 {{key}}
formatted_prompt = prompt
has_unfilled_variables = False
has_any_placeholder = False
import re
# 检查是否有任何占位符
has_any_placeholder = bool(re.search(r'\{\{?\w+\}?\}', prompt))
# 首先处理 {{variable}} 格式(模板节点常用)
double_brace_vars = re.findall(r'\{\{(\w+)\}\}', prompt)
for var_name in double_brace_vars:
if var_name in input_data:
# 替换 {{variable}} 为实际值
value = input_data[var_name]
replacement = json_module.dumps(value, ensure_ascii=False) if isinstance(value, (dict, list)) else str(value)
formatted_prompt = formatted_prompt.replace(f'{{{{{var_name}}}}}', replacement)
else:
has_unfilled_variables = True
# 然后处理 {key} 格式
for key, value in input_data.items():
placeholder = f'{{{key}}}'
if placeholder in formatted_prompt:
formatted_prompt = formatted_prompt.replace(
placeholder,
json_module.dumps(value, ensure_ascii=False) if isinstance(value, (dict, list)) else str(value)
)
# 如果还有{input}占位符替换为整个input_data
if '{input}' in formatted_prompt:
formatted_prompt = formatted_prompt.replace(
'{input}',
json_module.dumps(input_data, ensure_ascii=False)
)
# 提取用户的实际查询内容(优先提取)
user_query = None
logger.info(f"[rjb] 开始提取user_query: input_data={input_data}, input_data_type={type(input_data)}")
if isinstance(input_data, dict):
# 首先检查是否有嵌套的input字段
nested_input = input_data.get('input')
logger.info(f"[rjb] 检查嵌套input: nested_input={nested_input}, nested_input_type={type(nested_input) if nested_input else None}")
if isinstance(nested_input, dict):
# 从嵌套的input中提取
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
if key in nested_input:
user_query = nested_input[key]
logger.info(f"[rjb] 从嵌套input中提取到user_query: key={key}, user_query={user_query}")
break
# 如果还没有,从顶层提取
if not user_query:
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
if key in input_data:
value = input_data[key]
logger.info(f"[rjb] 从顶层提取: key={key}, value={value}, value_type={type(value)}")
# 如果值是字符串,直接使用
if isinstance(value, str):
user_query = value
logger.info(f"[rjb] 提取到字符串user_query: {user_query}")
break
# 如果值是字典,尝试从中提取
elif isinstance(value, dict):
for sub_key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
if sub_key in value:
user_query = value[sub_key]
logger.info(f"[rjb] 从字典值中提取到user_query: sub_key={sub_key}, user_query={user_query}")
break
if user_query:
break
# 如果还是没有使用整个input_data但排除系统字段
if not user_query:
filtered_data = {k: v for k, v in input_data.items() if not k.startswith('_')}
logger.info(f"[rjb] 使用filtered_data: filtered_data={filtered_data}")
if filtered_data:
# 如果只有一个字段且是字符串,直接使用
if len(filtered_data) == 1:
single_value = list(filtered_data.values())[0]
if isinstance(single_value, str):
user_query = single_value
logger.info(f"[rjb] 从单个字符串字段提取到user_query: {user_query}")
elif isinstance(single_value, dict):
# 从字典中提取第一个字符串值
for v in single_value.values():
if isinstance(v, str):
user_query = v
logger.info(f"[rjb] 从字典的单个字段中提取到user_query: {user_query}")
break
if not user_query:
user_query = json_module.dumps(filtered_data, ensure_ascii=False) if len(filtered_data) > 1 else str(list(filtered_data.values())[0])
logger.info(f"[rjb] 使用JSON或字符串转换: user_query={user_query}")
logger.info(f"[rjb] 最终提取的user_query: {user_query}")
# 如果prompt中没有占位符或者仍有未填充的变量将用户输入附加到prompt
is_generic_instruction = False # 初始化变量
if not has_any_placeholder:
# 如果prompt中没有占位符将用户输入作为主要内容
if user_query:
# 判断是否是通用指令:简短且不包含具体任务描述
prompt_stripped = prompt.strip()
is_generic_instruction = (
len(prompt_stripped) < 30 or # 简短提示词
prompt_stripped in [
"请处理用户请求。", "请处理用户请求",
"请处理以下输入数据:", "请处理以下输入数据",
"请处理输入。", "请处理输入",
"处理用户请求", "处理请求",
"请回答用户问题", "请回答用户问题。",
"请帮助用户", "请帮助用户。"
] or
# 检查是否只包含通用指令关键词
(len(prompt_stripped) < 50 and any(keyword in prompt_stripped for keyword in [
"请处理", "处理", "请回答", "回答", "请帮助", "帮助", "请执行", "执行"
]) and not any(specific in prompt_stripped for specific in [
"翻译", "生成", "分析", "总结", "提取", "转换", "计算"
]))
)
if is_generic_instruction:
# 如果是通用指令直接使用用户输入作为prompt
formatted_prompt = str(user_query)
logger.info(f"[rjb] 检测到通用指令直接使用用户输入作为prompt: {user_query[:50] if user_query else 'None'}")
else:
# 否则将用户输入附加到prompt
formatted_prompt = f"{formatted_prompt}\n\n{user_query}"
logger.info(f"[rjb] 非通用指令将用户输入附加到prompt")
else:
# 如果没有提取到用户查询附加整个input_data
formatted_prompt = f"{formatted_prompt}\n\n{json_module.dumps(input_data, ensure_ascii=False)}"
elif has_unfilled_variables or re.search(r'\{\{(\w+)\}\}', formatted_prompt):
# 如果有占位符但未填充,附加用户需求说明
if user_query:
formatted_prompt = f"{formatted_prompt}\n\n用户需求:{user_query}\n\n请根据以上用户需求,忽略未填充的变量占位符(如{{{{variable}}}}),直接基于用户需求来完成任务。"
logger.info(f"[rjb] LLM节点prompt格式化: node_id={node_id}, original_prompt='{prompt[:50] if len(prompt) > 50 else prompt}', has_any_placeholder={has_any_placeholder}, user_query={user_query}, is_generic_instruction={is_generic_instruction}, final_prompt前200字符='{formatted_prompt[:200] if len(formatted_prompt) > 200 else formatted_prompt}'")
prompt = formatted_prompt
else:
# 如果input_data不是dict直接转换为字符串
if '{input}' in prompt:
prompt = prompt.replace('{input}', str(input_data))
else:
prompt = f"{prompt}\n\n输入:{str(input_data)}"
except Exception as e:
# 格式化失败使用原始prompt和input_data
logger.warning(f"[rjb] Prompt格式化失败: {str(e)}")
try:
prompt = f"{prompt}\n\n输入数据:\n{json_module.dumps(input_data, ensure_ascii=False)}"
except:
prompt = f"{prompt}\n\n输入数据:{str(input_data)}"
# 获取LLM配置
provider = node_data.get('provider', 'openai')
model = node_data.get('model', 'gpt-3.5-turbo')
# 确保temperature是浮点数节点模板中可能是字符串
temperature_raw = node_data.get('temperature', 0.7)
if isinstance(temperature_raw, str):
try:
temperature = float(temperature_raw)
except (ValueError, TypeError):
temperature = 0.7
else:
temperature = float(temperature_raw) if temperature_raw is not None else 0.7
# 确保max_tokens是整数节点模板中可能是字符串
max_tokens_raw = node_data.get('max_tokens')
if max_tokens_raw is not None:
if isinstance(max_tokens_raw, str):
try:
max_tokens = int(max_tokens_raw)
except (ValueError, TypeError):
max_tokens = None
else:
max_tokens = int(max_tokens_raw) if max_tokens_raw is not None else None
else:
max_tokens = None
# 不传递 api_key 和 base_url让 LLM 服务使用系统默认配置(与节点测试保持一致)
api_key = None
base_url = None
# 记录实际发送给LLM的prompt
logger.info(f"[rjb] 准备调用LLM: node_id={node_id}, provider={provider}, model={model}, prompt前200字符='{prompt[:200] if len(prompt) > 200 else prompt}'")
# 调用LLM服务
try:
if self.logger:
logger.debug(f"[rjb] LLM节点配置: provider={provider}, model={model}, 使用系统默认API Key配置")
self.logger.info(f"调用LLM服务: {provider}/{model}", node_id=node_id, node_type=node_type)
result = await llm_service.call_llm(
prompt=prompt,
provider=provider,
model=model,
temperature=temperature,
max_tokens=max_tokens
# 不传递 api_key 和 base_url使用系统默认配置
)
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
# LLM调用失败返回错误
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'LLM调用失败: {str(e)}'
}
elif node_type == 'condition':
# 条件节点:判断分支
condition = node.get('data', {}).get('condition', '')
if not condition:
# 如果没有条件表达式默认返回False
return {
'output': False,
'status': 'success',
'branch': 'false'
}
# 使用条件解析器评估表达式
try:
result = condition_parser.evaluate_condition(condition, input_data)
exec_result = {
'output': result,
'status': 'success',
'branch': 'true' if result else 'false'
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, {'result': result, 'branch': exec_result['branch']}, duration)
return exec_result
except Exception as e:
# 条件评估失败
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': False,
'status': 'failed',
'error': f'条件评估失败: {str(e)}',
'branch': 'false'
}
elif node_type == 'data' or node_type == 'transform':
# 数据转换节点
node_data = node.get('data', {})
mapping = node_data.get('mapping', {})
filter_rules = node_data.get('filter_rules', [])
compute_rules = node_data.get('compute_rules', {})
mode = node_data.get('mode', 'mapping')
try:
# 处理mapping中的{{variable}}格式从input_data中提取值
# 首先如果input_data包含output字段需要展开它
expanded_input = input_data.copy()
if 'output' in input_data and isinstance(input_data['output'], dict):
# 将output中的内容展开到顶层但保留output字段
expanded_input.update(input_data['output'])
processed_mapping = {}
import re
for target_key, source_expr in mapping.items():
if isinstance(source_expr, str):
# 支持{{variable}}格式
double_brace_vars = re.findall(r'\{\{(\w+)\}\}', source_expr)
if double_brace_vars:
# 从expanded_input中获取变量值
var_value = None
for var_name in double_brace_vars:
# 尝试从expanded_input中获取支持嵌套路径
var_value = self._get_nested_value(expanded_input, var_name)
if var_value is not None:
break
if var_value is not None:
# 如果只有一个变量,直接使用值;否则替换表达式
if len(double_brace_vars) == 1:
processed_mapping[target_key] = var_value
else:
# 多个变量,替换表达式
processed_expr = source_expr
for var_name in double_brace_vars:
var_val = self._get_nested_value(expanded_input, var_name)
if var_val is not None:
replacement = json_module.dumps(var_val, ensure_ascii=False) if isinstance(var_val, (dict, list)) else str(var_val)
processed_expr = processed_expr.replace(f'{{{{{var_name}}}}}', replacement)
processed_mapping[target_key] = processed_expr
else:
# 变量不存在,保持原表达式
processed_mapping[target_key] = source_expr
else:
# 不是{{variable}}格式,直接使用
processed_mapping[target_key] = source_expr
else:
# 不是字符串,直接使用
processed_mapping[target_key] = source_expr
# 如果mode是merge需要合并所有输入数据
if mode == 'merge':
# 合并所有上游节点的输出(使用展开后的数据)
result = expanded_input.copy()
# 添加mapping的结果
for key, value in processed_mapping.items():
result[key] = value
else:
# 使用处理后的mapping进行转换使用展开后的数据
result = data_transformer.transform_data(
input_data=expanded_input,
mapping=processed_mapping,
filter_rules=filter_rules,
compute_rules=compute_rules,
mode=mode
)
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
logger.error(f"[rjb] Transform节点执行失败: {str(e)}", exc_info=True)
return {
'output': None,
'status': 'failed',
'error': f'数据转换失败: {str(e)}'
}
elif node_type == 'loop' or node_type == 'foreach':
# 循环节点:对数组进行循环处理
node_data = node.get('data', {})
items_path = node_data.get('items_path', 'items') # 数组数据路径
item_variable = node_data.get('item_variable', 'item') # 循环变量名
# 从输入数据中获取数组
items = self._get_nested_value(input_data, items_path)
if not isinstance(items, list):
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type,
ValueError(f"路径 {items_path} 的值不是数组"), duration)
return {
'output': None,
'status': 'failed',
'error': f'路径 {items_path} 的值不是数组,当前类型: {type(items).__name__}'
}
if self.logger:
self.logger.info(f"循环节点开始处理 {len(items)} 个元素",
node_id=node_id, node_type=node_type,
data={"items_count": len(items)})
# 执行循环:对每个元素执行循环体
loop_results = []
for index, item in enumerate(items):
if self.logger:
self.logger.info(f"循环迭代 {index + 1}/{len(items)}",
node_id=node_id, node_type=node_type,
data={"index": index, "item": item})
# 准备循环体的输入数据
loop_input = {
**input_data, # 保留原始输入数据
item_variable: item, # 当前循环项
f'{item_variable}_index': index, # 索引
f'{item_variable}_total': len(items) # 总数
}
# 执行循环体(获取循环节点的子节点)
loop_body_result = await self._execute_loop_body(
node_id, loop_input, index
)
if loop_body_result.get('status') == 'success':
loop_results.append(loop_body_result.get('output', item))
else:
# 如果循环体执行失败,可以选择继续或停止
error_handling = node_data.get('error_handling', 'continue')
if error_handling == 'stop':
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type,
Exception(f"循环体执行失败,停止循环: {loop_body_result.get('error')}"), duration)
return {
'output': None,
'status': 'failed',
'error': f'循环体执行失败: {loop_body_result.get("error")}',
'completed_items': index,
'results': loop_results
}
else:
# continue: 继续执行,记录错误
if self.logger:
self.logger.warn(f"循环迭代 {index + 1} 失败,继续执行",
node_id=node_id, node_type=node_type,
data={"error": loop_body_result.get('error')})
loop_results.append(None)
exec_result = {
'output': loop_results,
'status': 'success',
'items_processed': len(items),
'results_count': len(loop_results)
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type,
{'results_count': len(loop_results)}, duration)
return exec_result
elif node_type == 'http' or node_type == 'request':
# HTTP请求节点发送HTTP请求
node_data = node.get('data', {})
url = node_data.get('url', '')
method = node_data.get('method', 'GET').upper()
headers = node_data.get('headers', {})
params = node_data.get('params', {})
body = node_data.get('body', {})
timeout = node_data.get('timeout', 30)
# 如果URL、headers、params、body中包含变量从input_data中替换
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
# 支持 {key} 或 ${key} 格式
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
# 替换URL中的变量
if url:
url = replace_variables(url, input_data)
# 替换headers中的变量
if isinstance(headers, dict):
headers = {k: replace_variables(str(v), input_data) for k, v in headers.items()}
elif isinstance(headers, str):
try:
headers = json.loads(replace_variables(headers, input_data))
except:
headers = {}
# 替换params中的变量
if isinstance(params, dict):
params = {k: replace_variables(str(v), input_data) if isinstance(v, str) else v
for k, v in params.items()}
elif isinstance(params, str):
try:
params = json.loads(replace_variables(params, input_data))
except:
params = {}
# 替换body中的变量
if isinstance(body, dict):
# 递归替换字典中的变量
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
new_k = replace_variables(k, data)
if isinstance(v, dict):
result[new_k] = replace_dict_vars(v, data)
elif isinstance(v, str):
result[new_k] = replace_variables(v, data)
else:
result[new_k] = v
return result
body = replace_dict_vars(body, input_data)
elif isinstance(body, str):
body = replace_variables(body, input_data)
try:
body = json.loads(body)
except:
pass
try:
import httpx
async with httpx.AsyncClient(timeout=timeout) as client:
if method == 'GET':
response = await client.get(url, params=params, headers=headers)
elif method == 'POST':
response = await client.post(url, json=body, params=params, headers=headers)
elif method == 'PUT':
response = await client.put(url, json=body, params=params, headers=headers)
elif method == 'DELETE':
response = await client.delete(url, params=params, headers=headers)
elif method == 'PATCH':
response = await client.patch(url, json=body, params=params, headers=headers)
else:
raise ValueError(f"不支持的HTTP方法: {method}")
# 尝试解析JSON响应
try:
response_data = response.json()
except:
response_data = response.text
result = {
'output': {
'status_code': response.status_code,
'headers': dict(response.headers),
'data': response_data
},
'status': 'success'
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'HTTP请求失败: {str(e)}'
}
elif node_type == 'database' or node_type == 'db':
# 数据库操作节点:执行数据库操作
node_data = node.get('data', {})
data_source_id = node_data.get('data_source_id')
operation = node_data.get('operation', 'query') # query/insert/update/delete
sql = node_data.get('sql', '')
table = node_data.get('table', '')
data = node_data.get('data', {})
where = node_data.get('where', {})
# 如果SQL中包含变量从input_data中替换
if sql and isinstance(sql, str):
import re
def replace_sql_vars(text: str, data: Dict[str, Any]) -> str:
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
if value is None:
return match.group(0)
# 如果是字符串需要转义SQL注入
if isinstance(value, str):
# 简单转义,实际应该使用参数化查询
escaped_value = value.replace("'", "''")
return f"'{escaped_value}'"
return str(value)
return re.sub(pattern, replacer, text)
sql = replace_sql_vars(sql, input_data)
try:
# 从数据库加载数据源配置
if not self.db:
raise ValueError("数据库会话未提供,无法执行数据库操作")
from app.models.data_source import DataSource
from app.services.data_source_connector import create_connector
data_source = self.db.query(DataSource).filter(
DataSource.id == data_source_id
).first()
if not data_source:
raise ValueError(f"数据源不存在: {data_source_id}")
connector = create_connector(data_source.type, data_source.config)
if operation == 'query':
# 查询操作
if not sql:
raise ValueError("查询操作需要提供SQL语句")
query_params = {'query': sql}
result_data = connector.query(query_params)
result = {'output': result_data, 'status': 'success'}
elif operation == 'insert':
# 插入操作
if not table:
raise ValueError("插入操作需要提供表名")
# 构建INSERT SQL
columns = ', '.join(data.keys())
# 处理字符串值,转义单引号
def escape_value(v):
if isinstance(v, str):
escaped = v.replace("'", "''")
return f"'{escaped}'"
return str(v)
values = ', '.join([escape_value(v) for v in data.values()])
insert_sql = f"INSERT INTO {table} ({columns}) VALUES ({values})"
query_params = {'query': insert_sql}
result_data = connector.query(query_params)
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
elif operation == 'update':
# 更新操作
if not table or not where:
raise ValueError("更新操作需要提供表名和WHERE条件")
set_clause = ', '.join([f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in data.items()])
where_clause = ' AND '.join([f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in where.items()])
update_sql = f"UPDATE {table} SET {set_clause} WHERE {where_clause}"
query_params = {'query': update_sql}
result_data = connector.query(query_params)
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
elif operation == 'delete':
# 删除操作
if not table or not where:
raise ValueError("删除操作需要提供表名和WHERE条件")
# 处理字符串值,转义单引号
def escape_sql_value(k, v):
if isinstance(v, str):
escaped = v.replace("'", "''")
return f"{k} = '{escaped}'"
return f"{k} = {v}"
where_clause = ' AND '.join([escape_sql_value(k, v) for k, v in where.items()])
delete_sql = f"DELETE FROM {table} WHERE {where_clause}"
query_params = {'query': delete_sql}
result_data = connector.query(query_params)
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
else:
raise ValueError(f"不支持的数据库操作: {operation}")
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'数据库操作失败: {str(e)}'
}
elif node_type == 'file' or node_type == 'file_operation':
# 文件操作节点:文件读取、写入、上传、下载
node_data = node.get('data', {})
operation = node_data.get('operation', 'read') # read/write/upload/download
file_path = node_data.get('file_path', '')
content = node_data.get('content', '')
encoding = node_data.get('encoding', 'utf-8')
# 替换文件路径和内容中的变量
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
if file_path:
file_path = replace_variables(file_path, input_data)
if isinstance(content, str):
content = replace_variables(content, input_data)
try:
import os
import json
import base64
from pathlib import Path
if operation == 'read':
# 读取文件
if not file_path:
raise ValueError("读取操作需要提供文件路径")
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
# 根据文件扩展名决定读取方式
file_ext = Path(file_path).suffix.lower()
if file_ext == '.json':
with open(file_path, 'r', encoding=encoding) as f:
data = json.load(f)
elif file_ext in ['.txt', '.md', '.log']:
with open(file_path, 'r', encoding=encoding) as f:
data = f.read()
else:
# 二进制文件返回base64编码
with open(file_path, 'rb') as f:
data = base64.b64encode(f.read()).decode('utf-8')
result = {'output': data, 'status': 'success'}
elif operation == 'write':
# 写入文件
if not file_path:
raise ValueError("写入操作需要提供文件路径")
# 确保目录存在
os.makedirs(os.path.dirname(file_path) if os.path.dirname(file_path) else '.', exist_ok=True)
# 如果content是字典或列表转换为JSON
if isinstance(content, (dict, list)):
content = json.dumps(content, ensure_ascii=False, indent=2)
# 根据文件扩展名决定写入方式
file_ext = Path(file_path).suffix.lower()
if file_ext == '.json':
with open(file_path, 'w', encoding=encoding) as f:
json.dump(json.loads(content) if isinstance(content, str) else content, f, ensure_ascii=False, indent=2)
else:
with open(file_path, 'w', encoding=encoding) as f:
f.write(str(content))
result = {'output': {'file_path': file_path, 'message': '文件写入成功'}, 'status': 'success'}
elif operation == 'upload':
# 文件上传从base64或URL上传
upload_type = node_data.get('upload_type', 'base64') # base64/url
target_path = node_data.get('target_path', '')
if upload_type == 'base64':
# 从输入数据中获取base64编码的文件内容
file_data = input_data.get('file_data') or input_data.get('content')
if not file_data:
raise ValueError("上传操作需要提供file_data或content字段")
# 解码base64
if isinstance(file_data, str):
file_bytes = base64.b64decode(file_data)
else:
file_bytes = file_data
# 写入目标路径
if not target_path:
raise ValueError("上传操作需要提供target_path")
os.makedirs(os.path.dirname(target_path) if os.path.dirname(target_path) else '.', exist_ok=True)
with open(target_path, 'wb') as f:
f.write(file_bytes)
result = {'output': {'file_path': target_path, 'message': '文件上传成功'}, 'status': 'success'}
else:
# URL上传下载后保存
import httpx
url = node_data.get('url', '')
if not url:
raise ValueError("URL上传需要提供url")
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
if not target_path:
# 从URL提取文件名
target_path = os.path.basename(url) or 'downloaded_file'
os.makedirs(os.path.dirname(target_path) if os.path.dirname(target_path) else '.', exist_ok=True)
with open(target_path, 'wb') as f:
f.write(response.content)
result = {'output': {'file_path': target_path, 'message': '文件下载并保存成功'}, 'status': 'success'}
elif operation == 'download':
# 文件下载返回base64编码或文件URL
download_format = node_data.get('download_format', 'base64') # base64/url
if not file_path:
raise ValueError("下载操作需要提供文件路径")
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
if download_format == 'base64':
# 返回base64编码
with open(file_path, 'rb') as f:
file_bytes = f.read()
file_base64 = base64.b64encode(file_bytes).decode('utf-8')
result = {'output': {'file_name': os.path.basename(file_path), 'content': file_base64, 'format': 'base64'}, 'status': 'success'}
else:
# 返回文件路径实际应用中可能需要生成临时URL
result = {'output': {'file_path': file_path, 'format': 'path'}, 'status': 'success'}
else:
raise ValueError(f"不支持的文件操作: {operation}")
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'文件操作失败: {str(e)}'
}
elif node_type == 'webhook':
# Webhook节点发送Webhook请求到外部系统
node_data = node.get('data', {})
url = node_data.get('url', '')
method = node_data.get('method', 'POST').upper()
headers = node_data.get('headers', {})
body = node_data.get('body', {})
timeout = node_data.get('timeout', 30)
# 如果URL、headers、body中包含变量从input_data中替换
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
# 替换URL中的变量
if url:
url = replace_variables(url, input_data)
# 替换headers中的变量
if isinstance(headers, dict):
headers = {k: replace_variables(str(v), input_data) for k, v in headers.items()}
elif isinstance(headers, str):
try:
headers = json.loads(replace_variables(headers, input_data))
except:
headers = {}
# 替换body中的变量
if isinstance(body, dict):
# 递归替换字典中的变量
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
new_k = replace_variables(k, data)
if isinstance(v, dict):
result[new_k] = replace_dict_vars(v, data)
elif isinstance(v, str):
result[new_k] = replace_variables(v, data)
else:
result[new_k] = v
return result
body = replace_dict_vars(body, input_data)
elif isinstance(body, str):
body = replace_variables(body, input_data)
try:
body = json.loads(body)
except:
pass
# 如果没有配置body默认使用input_data作为body
if not body:
body = input_data
try:
import httpx
async with httpx.AsyncClient(timeout=timeout) as client:
if method == 'GET':
response = await client.get(url, headers=headers)
elif method == 'POST':
response = await client.post(url, json=body, headers=headers)
elif method == 'PUT':
response = await client.put(url, json=body, headers=headers)
elif method == 'PATCH':
response = await client.patch(url, json=body, headers=headers)
else:
raise ValueError(f"Webhook不支持HTTP方法: {method}")
# 尝试解析JSON响应
try:
response_data = response.json()
except:
response_data = response.text
result = {
'output': {
'status_code': response.status_code,
'headers': dict(response.headers),
'data': response_data
},
'status': 'success'
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'Webhook请求失败: {str(e)}'
}
elif node_type == 'schedule' or node_type == 'delay' or node_type == 'timer':
# 定时任务节点:延迟执行或定时执行
node_data = node.get('data', {})
delay_type = node_data.get('delay_type', 'fixed') # fixed: 固定延迟, cron: cron表达式
delay_value = node_data.get('delay_value', 0) # 延迟值(秒)
delay_unit = node_data.get('delay_unit', 'seconds') # seconds, minutes, hours
# 计算实际延迟时间(毫秒)
if delay_unit == 'seconds':
delay_ms = int(delay_value * 1000)
elif delay_unit == 'minutes':
delay_ms = int(delay_value * 60 * 1000)
elif delay_unit == 'hours':
delay_ms = int(delay_value * 60 * 60 * 1000)
else:
delay_ms = int(delay_value * 1000)
# 如果延迟时间大于0则等待
if delay_ms > 0:
if self.logger:
self.logger.info(
f"定时任务节点等待 {delay_value} {delay_unit}",
node_id=node_id,
node_type=node_type,
data={'delay_ms': delay_ms, 'delay_value': delay_value, 'delay_unit': delay_unit}
)
await asyncio.sleep(delay_ms / 1000.0)
# 返回输入数据(定时节点只是延迟,不改变数据)
result = {'output': input_data, 'status': 'success', 'delay_ms': delay_ms}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
elif node_type == 'email' or node_type == 'mail':
# 邮件节点:发送邮件通知
node_data = node.get('data', {})
smtp_host = node_data.get('smtp_host', '')
smtp_port = node_data.get('smtp_port', 587)
smtp_user = node_data.get('smtp_user', '')
smtp_password = node_data.get('smtp_password', '')
use_tls = node_data.get('use_tls', True)
from_email = node_data.get('from_email', '')
to_email = node_data.get('to_email', '')
cc_email = node_data.get('cc_email', '')
bcc_email = node_data.get('bcc_email', '')
subject = node_data.get('subject', '')
body = node_data.get('body', '')
body_type = node_data.get('body_type', 'text') # text/html
attachments = node_data.get('attachments', []) # 附件列表
# 替换变量
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
# 替换所有配置中的变量
smtp_host = replace_variables(smtp_host, input_data)
smtp_user = replace_variables(smtp_user, input_data)
smtp_password = replace_variables(smtp_password, input_data)
from_email = replace_variables(from_email, input_data)
to_email = replace_variables(to_email, input_data)
cc_email = replace_variables(cc_email, input_data)
bcc_email = replace_variables(bcc_email, input_data)
subject = replace_variables(subject, input_data)
body = replace_variables(body, input_data)
# 验证必需参数
if not smtp_host:
raise ValueError("邮件节点需要配置SMTP服务器地址")
if not from_email:
raise ValueError("邮件节点需要配置发件人邮箱")
if not to_email:
raise ValueError("邮件节点需要配置收件人邮箱")
if not subject:
raise ValueError("邮件节点需要配置邮件主题")
try:
import aiosmtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.mime.base import MIMEBase
from email import encoders
import base64
import os
# 创建邮件消息
msg = MIMEMultipart('alternative')
msg['From'] = from_email
msg['To'] = to_email
if cc_email:
msg['Cc'] = cc_email
msg['Subject'] = subject
# 添加邮件正文
if body_type == 'html':
msg.attach(MIMEText(body, 'html', 'utf-8'))
else:
msg.attach(MIMEText(body, 'plain', 'utf-8'))
# 处理附件
for attachment in attachments:
if isinstance(attachment, dict):
file_path = attachment.get('file_path', '')
file_name = attachment.get('file_name', '')
file_content = attachment.get('file_content', '') # base64编码的内容
# 替换变量
file_path = replace_variables(file_path, input_data)
file_name = replace_variables(file_name, input_data)
if file_path and os.path.exists(file_path):
# 从文件路径读取
with open(file_path, 'rb') as f:
file_data = f.read()
if not file_name:
file_name = os.path.basename(file_path)
elif file_content:
# 从base64内容读取
file_data = base64.b64decode(file_content)
if not file_name:
file_name = 'attachment'
else:
continue
# 添加附件
part = MIMEBase('application', 'octet-stream')
part.set_payload(file_data)
encoders.encode_base64(part)
part.add_header(
'Content-Disposition',
f'attachment; filename= {file_name}'
)
msg.attach(part)
# 发送邮件
recipients = [to_email]
if cc_email:
recipients.extend([email.strip() for email in cc_email.split(',')])
if bcc_email:
recipients.extend([email.strip() for email in bcc_email.split(',')])
async with aiosmtplib.SMTP(hostname=smtp_host, port=smtp_port) as smtp:
if use_tls:
await smtp.starttls()
if smtp_user and smtp_password:
await smtp.login(smtp_user, smtp_password)
await smtp.send_message(msg, recipients=recipients)
result = {
'output': {
'message': '邮件发送成功',
'from': from_email,
'to': to_email,
'subject': subject,
'recipients_count': len(recipients)
},
'status': 'success'
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'邮件发送失败: {str(e)}'
}
elif node_type == 'message_queue' or node_type == 'mq' or node_type == 'rabbitmq' or node_type == 'kafka':
# 消息队列节点发送消息到RabbitMQ或Kafka
node_data = node.get('data', {})
queue_type = node_data.get('queue_type', 'rabbitmq') # rabbitmq/kafka
# 替换变量
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
try:
if queue_type == 'rabbitmq':
# RabbitMQ实现
import aio_pika
import json
# 获取RabbitMQ配置
host = replace_variables(node_data.get('host', 'localhost'), input_data)
port = node_data.get('port', 5672)
username = replace_variables(node_data.get('username', 'guest'), input_data)
password = replace_variables(node_data.get('password', 'guest'), input_data)
exchange = replace_variables(node_data.get('exchange', ''), input_data)
routing_key = replace_variables(node_data.get('routing_key', ''), input_data)
queue_name = replace_variables(node_data.get('queue_name', ''), input_data)
message = node_data.get('message', input_data)
# 如果message是字符串尝试替换变量
if isinstance(message, str):
message = replace_variables(message, input_data)
try:
message = json.loads(message)
except:
pass
elif isinstance(message, dict):
# 递归替换字典中的变量
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
new_k = replace_variables(k, data)
if isinstance(v, dict):
result[new_k] = replace_dict_vars(v, data)
elif isinstance(v, str):
result[new_k] = replace_variables(v, data)
else:
result[new_k] = v
return result
message = replace_dict_vars(message, input_data)
# 如果没有配置message使用input_data
if not message:
message = input_data
# 连接RabbitMQ
connection_url = f"amqp://{username}:{password}@{host}:{port}/"
connection = await aio_pika.connect_robust(connection_url)
channel = await connection.channel()
# 发送消息
message_body = json.dumps(message, ensure_ascii=False).encode('utf-8')
if exchange:
# 使用exchange和routing_key
await channel.default_exchange.publish(
aio_pika.Message(message_body),
routing_key=routing_key or queue_name
)
elif queue_name:
# 直接发送到队列
queue = await channel.declare_queue(queue_name, durable=True)
await channel.default_exchange.publish(
aio_pika.Message(message_body),
routing_key=queue_name
)
else:
raise ValueError("RabbitMQ节点需要配置exchange或queue_name")
await connection.close()
result = {
'output': {
'message': '消息已发送到RabbitMQ',
'queue_type': 'rabbitmq',
'exchange': exchange,
'routing_key': routing_key or queue_name,
'queue_name': queue_name,
'message_size': len(message_body)
},
'status': 'success'
}
elif queue_type == 'kafka':
# Kafka实现
from kafka import KafkaProducer
import json
# 获取Kafka配置
bootstrap_servers = replace_variables(node_data.get('bootstrap_servers', 'localhost:9092'), input_data)
topic = replace_variables(node_data.get('topic', ''), input_data)
message = node_data.get('message', input_data)
# 如果message是字符串尝试替换变量
if isinstance(message, str):
message = replace_variables(message, input_data)
try:
message = json.loads(message)
except:
pass
elif isinstance(message, dict):
# 递归替换字典中的变量
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
new_k = replace_variables(k, data)
if isinstance(v, dict):
result[new_k] = replace_dict_vars(v, data)
elif isinstance(v, str):
result[new_k] = replace_variables(v, data)
else:
result[new_k] = v
return result
message = replace_dict_vars(message, input_data)
# 如果没有配置message使用input_data
if not message:
message = input_data
if not topic:
raise ValueError("Kafka节点需要配置topic")
# 创建Kafka生产者注意kafka-python是同步的需要在线程池中运行
import asyncio
from concurrent.futures import ThreadPoolExecutor
def send_kafka_message():
producer = KafkaProducer(
bootstrap_servers=bootstrap_servers.split(','),
value_serializer=lambda v: json.dumps(v, ensure_ascii=False).encode('utf-8')
)
future = producer.send(topic, message)
record_metadata = future.get(timeout=10)
producer.close()
return record_metadata
# 在线程池中执行同步操作
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
record_metadata = await loop.run_in_executor(executor, send_kafka_message)
result = {
'output': {
'message': '消息已发送到Kafka',
'queue_type': 'kafka',
'topic': topic,
'partition': record_metadata.partition,
'offset': record_metadata.offset
},
'status': 'success'
}
else:
raise ValueError(f"不支持的消息队列类型: {queue_type}")
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'消息队列发送失败: {str(e)}'
}
elif node_type == 'output' or node_type == 'end':
# 输出节点:返回最终结果
# 读取节点配置中的输出格式设置
node_data = node.get('data', {})
output_format = node_data.get('output_format', 'text') # 默认纯文本
logger.debug(f"[rjb] End节点处理: node_id={node_id}, output_format={output_format}, input_data={input_data}, input_data type={type(input_data)}")
final_output = input_data
# 如果配置为JSON格式直接返回原始数据或格式化的JSON
if output_format == 'json':
# 如果是字典直接返回JSON格式
if isinstance(input_data, dict):
final_output = json_module.dumps(input_data, ensure_ascii=False, indent=2)
elif isinstance(input_data, str):
# 尝试解析为JSON如果成功则格式化否则直接返回
try:
parsed = json_module.loads(input_data)
final_output = json_module.dumps(parsed, ensure_ascii=False, indent=2)
except:
final_output = input_data
else:
final_output = json_module.dumps({'output': input_data}, ensure_ascii=False, indent=2)
else:
# 默认纯文本格式:递归解包,提取实际的文本内容
if isinstance(input_data, dict):
# 优先提取 'output' 字段LLM节点的标准输出格式
if 'output' in input_data and isinstance(input_data['output'], str):
final_output = input_data['output']
# 如果只有一个 key 且是 'input',提取其值
elif len(input_data) == 1 and 'input' in input_data:
final_output = input_data['input']
# 如果包含 'solution' 字段,提取其值
elif 'solution' in input_data and isinstance(input_data['solution'], str):
final_output = input_data['solution']
# 如果input_data是字符串类型的字典JSON字符串尝试解析
elif isinstance(input_data, str):
try:
parsed = json_module.loads(input_data)
if isinstance(parsed, dict) and 'output' in parsed:
final_output = parsed['output']
elif isinstance(parsed, str):
final_output = parsed
except:
final_output = input_data
logger.debug(f"[rjb] End节点提取第一层: final_output={final_output}, type={type(final_output)}")
# 如果提取的值仍然是字典且只有一个 'input' key继续提取
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
final_output = final_output['input']
logger.debug(f"[rjb] End节点提取第二层: final_output={final_output}, type={type(final_output)}")
# 确保最终输出是字符串(对于人机交互场景)
# 如果是字典,尝试转换为字符串;如果是其他类型,也转换为字符串
if not isinstance(final_output, str):
if isinstance(final_output, dict):
# 如果是字典尝试提取文本内容或转换为JSON字符串
# 优先查找常见的文本字段
if 'text' in final_output:
final_output = str(final_output['text'])
elif 'content' in final_output:
final_output = str(final_output['content'])
elif 'message' in final_output:
final_output = str(final_output['message'])
elif 'response' in final_output:
final_output = str(final_output['response'])
elif len(final_output) == 1:
# 如果只有一个key直接使用其值
final_output = str(list(final_output.values())[0])
else:
# 否则转换为纯文本不是JSON
# 尝试提取所有文本字段并组合,但排除系统字段和用户查询字段
text_parts = []
exclude_keys = {'status', 'error', 'timestamp', 'node_id', 'execution_time', 'query', 'USER_INPUT', 'user_input', 'user_query'}
# 优先使用input字段LLM的实际输出
if 'input' in final_output and isinstance(final_output['input'], str):
final_output = final_output['input']
else:
for key, value in final_output.items():
if key in exclude_keys:
continue
if isinstance(value, str) and value.strip():
# 如果值本身已经包含 "key: " 格式,直接使用
if value.strip().startswith(f"{key}:"):
text_parts.append(value.strip())
else:
text_parts.append(value.strip())
elif isinstance(value, (int, float, bool)):
text_parts.append(f"{key}: {value}")
if text_parts:
final_output = "\n".join(text_parts)
else:
final_output = str(final_output)
else:
final_output = str(final_output)
# 清理输出文本:移除常见的字段前缀(如 "input: ", "query: " 等)
if isinstance(final_output, str):
import re
# 移除行首的 "input: ", "query: ", "output: " 等前缀
lines = final_output.split('\n')
cleaned_lines = []
for line in lines:
# 匹配行首的 "字段名: " 格式并移除
# 但保留内容本身
line = re.sub(r'^(input|query|output|result|response|message|content|text):\s*', '', line, flags=re.IGNORECASE)
if line.strip(): # 只保留非空行
cleaned_lines.append(line)
# 如果清理后还有内容,使用清理后的版本
if cleaned_lines:
final_output = '\n'.join(cleaned_lines)
# 如果清理后为空,使用原始输出(避免丢失所有内容)
elif final_output.strip():
# 如果原始输出不为空,但清理后为空,说明可能格式特殊,尝试更宽松的清理
# 只移除明显的 "input: " 和 "query: " 前缀,保留其他内容
final_output = re.sub(r'^(input|query):\s*', '', final_output, flags=re.IGNORECASE | re.MULTILINE)
if not final_output.strip():
final_output = str(input_data) # 如果还是空,使用原始输入
logger.debug(f"[rjb] End节点最终输出: output_format={output_format}, final_output={final_output[:100] if isinstance(final_output, str) else type(final_output)}")
result = {'output': final_output, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, final_output, duration)
return result
else:
# 未知节点类型
return {
'output': input_data,
'status': 'success',
'message': f'节点类型 {node_type} 暂未实现'
}
except Exception as e:
logger.error(f"节点执行失败: {node_id} ({node_type}) - {str(e)}", exc_info=True)
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': str(e),
'node_id': node_id,
'node_type': node_type
}
async def execute(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
"""
执行完整工作流
Args:
input_data: 初始输入数据
Returns:
执行结果
"""
# 记录工作流开始执行
if self.logger:
self.logger.info("工作流开始执行", data={"input": input_data})
# 初始化节点输出
self.node_outputs = {}
active_edges = self.edges.copy() # 活跃的边列表
executed_nodes = set() # 已执行的节点
# 按拓扑顺序执行节点(动态构建执行图)
results = {}
while True:
# 构建当前活跃的执行图
execution_order = self.build_execution_graph(active_edges)
# 找到下一个要执行的节点未执行且入度为0
next_node_id = None
for node_id in execution_order:
if node_id not in executed_nodes:
# 检查所有前置节点是否已执行
can_execute = True
for edge in active_edges:
if edge['target'] == node_id:
if edge['source'] not in executed_nodes:
can_execute = False
break
if can_execute:
next_node_id = node_id
break
if not next_node_id:
break # 没有更多节点可执行
node = self.nodes[next_node_id]
executed_nodes.add(next_node_id)
# 调试:检查节点数据结构
if node.get('type') == 'llm':
logger.debug(f"[rjb] 执行LLM节点: node_id={next_node_id}, node keys={list(node.keys())}, data keys={list(node.get('data', {}).keys()) if node.get('data') else []}")
# 获取节点输入(使用活跃的边)
node_input = self.get_node_input(next_node_id, self.node_outputs, active_edges)
# 如果是起始节点,使用初始输入
if node.get('type') == 'start' and not node_input:
node_input = input_data
logger.info(f"[rjb] Start节点使用初始输入: node_id={next_node_id}, node_input={node_input}")
# 调试:记录节点输入数据
if node.get('type') == 'llm':
logger.info(f"[rjb] LLM节点输入: node_id={next_node_id}, node_input={node_input}, node_outputs keys={list(self.node_outputs.keys())}")
if 'start-1' in self.node_outputs:
logger.info(f"[rjb] Start节点输出内容: {self.node_outputs['start-1']}")
# 执行节点
result = await self.execute_node(node, node_input)
results[next_node_id] = result
# 保存节点输出
if result.get('status') == 'success':
output_value = result.get('output', {})
self.node_outputs[next_node_id] = output_value
if node.get('type') == 'start':
logger.info(f"[rjb] Start节点输出已保存: node_id={next_node_id}, output={output_value}, output_type={type(output_value)}")
# 如果是条件节点,根据分支结果过滤边
if node.get('type') == 'condition':
branch = result.get('branch', 'false')
# 移除不符合条件的边
active_edges = [
edge for edge in active_edges
if not (edge['source'] == next_node_id and edge.get('sourceHandle') != branch)
]
# 如果是循环节点,跳过循环体的节点(循环体已在节点内部执行)
if node.get('type') in ['loop', 'foreach']:
# 标记循环体的节点为已执行(简化处理)
for edge in active_edges[:]: # 使用切片复制列表
if edge.get('source') == next_node_id:
target_id = edge.get('target')
if target_id in self.nodes:
# 检查是否是循环结束节点
target_node = self.nodes[target_id]
if target_node.get('type') not in ['loop_end', 'end']:
# 标记为已执行(循环体已在循环节点内部执行)
executed_nodes.add(target_id)
# 继续查找循环体内的节点
self._mark_loop_body_executed(target_id, executed_nodes, active_edges)
else:
# 执行失败,停止工作流
error_msg = result.get('error', '未知错误')
node_type = node.get('type', 'unknown')
logger.error(f"工作流执行失败 - 节点: {next_node_id} ({node_type}), 错误: {error_msg}")
raise WorkflowExecutionError(
detail=error_msg,
node_id=next_node_id
)
# 返回最终结果(最后一个执行的节点的输出)
if executed_nodes:
# 找到最后一个节点(没有出边的节点)
last_node_id = None
for node_id in executed_nodes:
has_outgoing = any(edge['source'] == node_id for edge in active_edges)
if not has_outgoing:
last_node_id = node_id
break
if not last_node_id:
# 如果没有找到,使用最后一个执行的节点
last_node_id = list(executed_nodes)[-1]
# 获取最终结果
final_output = self.node_outputs.get(last_node_id)
# 如果最终输出是字典且只有一个 'input' key提取其值
# 这样可以确保最终结果不是重复包装的格式
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
final_output = final_output['input']
# 如果提取的值仍然是字典且只有一个 'input' key继续提取
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
final_output = final_output['input']
# 确保最终结果是字符串(对于人机交互场景)
# 如果是字典,尝试转换为字符串;如果是其他类型,也转换为字符串
if not isinstance(final_output, str):
if isinstance(final_output, dict):
# 如果是字典尝试提取文本内容或转换为JSON字符串
# 优先查找常见的文本字段
if 'text' in final_output:
final_output = str(final_output['text'])
elif 'content' in final_output:
final_output = str(final_output['content'])
elif 'message' in final_output:
final_output = str(final_output['message'])
elif 'response' in final_output:
final_output = str(final_output['response'])
elif len(final_output) == 1:
# 如果只有一个key直接使用其值
final_output = str(list(final_output.values())[0])
else:
# 否则转换为JSON字符串
import json as json_module
final_output = json_module.dumps(final_output, ensure_ascii=False)
else:
final_output = str(final_output)
final_result = {
'status': 'completed',
'result': final_output,
'node_results': results
}
# 记录工作流执行完成
if self.logger:
self.logger.info("工作流执行完成", data={"result": final_result.get('result')})
return final_result
if self.logger:
self.logger.warn("工作流执行完成,但没有执行任何节点")
return {'status': 'completed', 'result': None}