""" 工作流执行引擎 """ from typing import Dict, Any, List, Optional import asyncio from collections import defaultdict, deque import json import logging import re import time from app.services.llm_service import llm_service from app.services.condition_parser import condition_parser from app.services.data_transformer import data_transformer from app.core.exceptions import WorkflowExecutionError from app.core.database import SessionLocal from app.models.agent import Agent from app.core.config import settings logger = logging.getLogger(__name__) class WorkflowEngine: """工作流执行引擎""" def __init__(self, workflow_id: str, workflow_data: Dict[str, Any], logger=None, db=None): """ 初始化工作流引擎 Args: workflow_id: 工作流ID workflow_data: 工作流数据(包含nodes和edges) logger: 执行日志记录器(可选) db: 数据库会话(可选,用于Agent节点加载Agent配置) """ self.workflow_id = workflow_id self.nodes = {node['id']: node for node in workflow_data.get('nodes', [])} self.edges = workflow_data.get('edges', []) self.execution_graph = None self.node_outputs = {} self.logger = logger self.db = db def build_execution_graph(self, active_edges: Optional[List[Dict[str, Any]]] = None) -> List[str]: """ 构建执行图(DAG)并返回拓扑排序结果 Args: active_edges: 活跃的边列表(用于条件分支过滤) Returns: 拓扑排序后的节点ID列表 """ # 使用活跃的边,如果没有提供则使用所有边 edges_to_use = active_edges if active_edges is not None else self.edges # 构建邻接表和入度表 graph = defaultdict(list) in_degree = defaultdict(int) # 初始化所有节点的入度 for node_id in self.nodes.keys(): in_degree[node_id] = 0 # 构建图 for edge in edges_to_use: source = edge['source'] target = edge['target'] graph[source].append(target) in_degree[target] += 1 # 拓扑排序(Kahn算法) queue = deque() result = [] # 找到所有入度为0的节点(起始节点) for node_id in self.nodes.keys(): if in_degree[node_id] == 0: queue.append(node_id) while queue: node_id = queue.popleft() result.append(node_id) # 处理该节点的所有出边 for neighbor in graph[node_id]: in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) # 检查是否有环(只检查可达节点) reachable_nodes = set(result) if len(reachable_nodes) < len(self.nodes): # 有些节点不可达,这是正常的(条件分支) pass self.execution_graph = result return result def get_node_input(self, node_id: str, node_outputs: Dict[str, Any], active_edges: Optional[List[Dict[str, Any]]] = None) -> Dict[str, Any]: """ 获取节点的输入数据 Args: node_id: 节点ID node_outputs: 所有节点的输出数据 active_edges: 活跃的边列表(用于条件分支过滤) Returns: 节点的输入数据 """ # 使用活跃的边,如果没有提供则使用所有边 edges_to_use = active_edges if active_edges is not None else self.edges # 找到所有指向该节点的边 input_data = {} for edge in edges_to_use: if edge['target'] == node_id: source_id = edge['source'] source_output = node_outputs.get(source_id, {}) logger.info(f"[rjb] 获取节点输入: target={node_id}, source={source_id}, source_output={source_output}, source_output_type={type(source_output)}, sourceHandle={edge.get('sourceHandle')}") # 如果有sourceHandle,使用它作为key if 'sourceHandle' in edge and edge['sourceHandle']: input_data[edge['sourceHandle']] = source_output # 重要:即使有sourceHandle,也要保留记忆相关字段(conversation_history、user_profile、context) # 这些字段应该始终传递到下游节点 if isinstance(source_output, dict): memory_fields = ['conversation_history', 'user_profile', 'context', 'memory'] for field in memory_fields: if field in source_output: input_data[field] = source_output[field] logger.info(f"[rjb] 保留记忆字段 {field} 到节点 {node_id} 的输入") else: # 否则合并所有输入 if isinstance(source_output, dict): # 如果source_output包含output字段,展开它 if 'output' in source_output and isinstance(source_output['output'], dict): # 将output中的内容展开到顶层 input_data.update(source_output['output']) # 保留其他字段(如status) for key, value in source_output.items(): if key != 'output': input_data[key] = value else: # 直接展开source_output的内容 input_data.update(source_output) logger.info(f"[rjb] 展开source_output后: input_data={input_data}") else: # 如果source_output不是字典,包装到input字段 input_data['input'] = source_output logger.info(f"[rjb] source_output不是字典,包装到input字段: input_data={input_data}") # 重要:对于LLM节点和cache节点,如果输入中没有memory字段,尝试从所有已执行的节点中查找并合并记忆字段 # 这样可以确保即使上游节点没有传递记忆信息,这些节点也能访问到记忆 node_type = None node = self.nodes.get(node_id) if node: node_type = node.get('type') # 对于LLM节点和cache节点(特别是cache-update),需要memory字段 if node_type in ['llm', 'cache'] and 'memory' not in input_data: # 从所有已执行的节点中查找memory字段 for executed_node_id, node_output in self.node_outputs.items(): if isinstance(node_output, dict): # 检查是否有memory字段 if 'memory' in node_output: input_data['memory'] = node_output['memory'] logger.info(f"[rjb] 为{node_type}节点 {node_id} 从节点 {executed_node_id} 获取memory字段") break # 或者检查是否有conversation_history等记忆字段 elif 'conversation_history' in node_output: # 构建memory对象 memory = {} for field in ['conversation_history', 'user_profile', 'context']: if field in node_output: memory[field] = node_output[field] if memory: input_data['memory'] = memory logger.info(f"[rjb] 为{node_type}节点 {node_id} 从节点 {executed_node_id} 构建memory对象: {list(memory.keys())}") break # 如果input_data中没有query字段,尝试从所有已执行的节点中查找(特别是start节点) if 'query' not in input_data: # 优先查找start节点 for node_id_key in ['start-1', 'start']: if node_id_key in node_outputs: node_output = node_outputs[node_id_key] if isinstance(node_output, dict): # 检查顶层字段(因为node_outputs存储的是output字段的内容) if 'query' in node_output: input_data['query'] = node_output['query'] logger.debug(f"[rjb] 从节点 {node_id_key} 中获取query: {input_data['query']}") break # 检查output字段(兼容性) elif 'output' in node_output and isinstance(node_output['output'], dict): if 'query' in node_output['output']: input_data['query'] = node_output['output']['query'] logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取query: {input_data['query']}") break # 如果还没找到,遍历所有节点 if 'query' not in input_data: for node_id_key, node_output in node_outputs.items(): if isinstance(node_output, dict): # 检查顶层字段 if 'query' in node_output: input_data['query'] = node_output['query'] logger.debug(f"[rjb] 从节点 {node_id_key} 中获取query: {input_data['query']}") break # 检查output字段(兼容性) elif 'output' in node_output and isinstance(node_output['output'], dict): if 'query' in node_output['output']: input_data['query'] = node_output['output']['query'] logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取query: {input_data['query']}") break # 如果input_data中没有requirement_analysis字段,尝试从所有已执行的节点中查找 if 'requirement_analysis' not in input_data: # 优先查找requirement-analysis节点 for node_id_key in ['llm-requirement-analysis', 'requirement-analysis']: if node_id_key in node_outputs: node_output = node_outputs[node_id_key] if isinstance(node_output, dict): # 检查顶层字段(因为node_outputs存储的是output字段的内容) if 'requirement_analysis' in node_output: input_data['requirement_analysis'] = node_output['requirement_analysis'] logger.debug(f"[rjb] 从节点 {node_id_key} 中获取requirement_analysis") break # 检查output字段(兼容性) elif 'output' in node_output and isinstance(node_output['output'], dict): if 'requirement_analysis' in node_output['output']: input_data['requirement_analysis'] = node_output['output']['requirement_analysis'] logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取requirement_analysis") break # 如果还没找到,遍历所有节点 if 'requirement_analysis' not in input_data: for node_id_key, node_output in node_outputs.items(): if isinstance(node_output, dict): # 检查顶层字段 if 'requirement_analysis' in node_output: input_data['requirement_analysis'] = node_output['requirement_analysis'] logger.debug(f"[rjb] 从节点 {node_id_key} 中获取requirement_analysis") break # 检查output字段(兼容性) elif 'output' in node_output and isinstance(node_output['output'], dict): if 'requirement_analysis' in node_output['output']: input_data['requirement_analysis'] = node_output['output']['requirement_analysis'] logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取requirement_analysis") break logger.debug(f"[rjb] 节点输入结果: node_id={node_id}, input_data={input_data}") return input_data def _get_nested_value(self, data: Dict[str, Any], path: str) -> Any: """ 从嵌套字典中获取值(支持点号路径和数组索引) Args: data: 数据字典 path: 路径,如 "user.name" 或 "items[0].price" Returns: 路径对应的值 """ if not path: return data parts = path.split('.') result = data for part in parts: if '[' in part and ']' in part: # 处理数组索引,如 "items[0]" key = part[:part.index('[')] index_str = part[part.index('[') + 1:part.index(']')] if isinstance(result, dict): result = result.get(key) elif isinstance(result, list): try: result = result[int(index_str)] except (ValueError, IndexError): return None else: return None if result is None: return None else: # 普通键访问 if isinstance(result, dict): result = result.get(part) else: return None if result is None: return None return result async def _execute_loop_body(self, loop_node_id: str, loop_input: Dict[str, Any], iteration_index: int) -> Dict[str, Any]: """ 执行循环体 Args: loop_node_id: 循环节点ID loop_input: 循环体的输入数据 iteration_index: 当前迭代索引 Returns: 循环体的执行结果 """ # 找到循环节点的直接子节点(循环体开始节点) loop_body_start_nodes = [] for edge in self.edges: if edge.get('source') == loop_node_id: target_id = edge.get('target') if target_id and target_id in self.nodes: loop_body_start_nodes.append(target_id) if not loop_body_start_nodes: # 如果没有子节点,直接返回输入数据 return {'output': loop_input, 'status': 'success'} # 执行循环体:从循环体开始节点执行到循环结束节点或没有更多节点 # 简化处理:只执行第一个子节点链 executed_in_loop = set() loop_results = {} current_node_id = loop_body_start_nodes[0] # 简化:只执行第一个子节点链 # 执行循环体内的节点(简化版本:只执行直接连接的子节点) max_iterations = 100 # 防止无限循环 iteration = 0 while current_node_id and iteration < max_iterations: iteration += 1 if current_node_id in executed_in_loop: break # 避免循环体内部循环 if current_node_id not in self.nodes: break node = self.nodes[current_node_id] executed_in_loop.add(current_node_id) # 如果是循环结束节点,停止执行 if node.get('type') == 'loop_end' or node.get('type') == 'end': break # 执行节点 result = await self.execute_node(node, loop_input) loop_results[current_node_id] = result if result.get('status') != 'success': return result # 更新输入数据为当前节点的输出 if result.get('output'): if isinstance(result.get('output'), dict): loop_input = {**loop_input, **result.get('output')} else: loop_input = {**loop_input, 'result': result.get('output')} # 找到下一个节点(简化:只找第一个子节点) next_node_id = None for edge in self.edges: if edge.get('source') == current_node_id: target_id = edge.get('target') if target_id and target_id in self.nodes and target_id not in executed_in_loop: # 跳过循环节点本身 if target_id != loop_node_id: next_node_id = target_id break current_node_id = next_node_id # 返回最后一个节点的输出 if loop_results: last_result = list(loop_results.values())[-1] return last_result return {'output': loop_input, 'status': 'success'} def _mark_loop_body_executed(self, node_id: str, executed_nodes: set, active_edges: List[Dict[str, Any]]): """ 递归标记循环体内的节点为已执行 Args: node_id: 当前节点ID executed_nodes: 已执行节点集合 active_edges: 活跃的边列表 """ if node_id in executed_nodes: return executed_nodes.add(node_id) # 查找所有子节点 for edge in active_edges: if edge.get('source') == node_id: target_id = edge.get('target') if target_id in self.nodes: target_node = self.nodes[target_id] # 如果是循环结束节点,停止递归 if target_node.get('type') in ['loop_end', 'end']: continue # 递归标记子节点 self._mark_loop_body_executed(target_id, executed_nodes, active_edges) async def execute_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]: """ 执行单个节点 Args: node: 节点配置 input_data: 输入数据 Returns: 节点执行结果 """ # 确保可以访问全局的 json 模块 import json as json_module node_type = node.get('type', 'unknown') node_id = node.get('id') start_time = time.time() # 记录节点开始执行 if self.logger: self.logger.log_node_start(node_id, node_type, input_data) try: if node_type == 'start': # 起始节点:返回输入数据 logger.debug(f"[rjb] 开始节点执行: node_id={node_id}, input_data={input_data}") result = {'output': input_data, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result.get('output'), duration) logger.debug(f"[rjb] 开始节点输出: node_id={node_id}, output={result.get('output')}") return result elif node_type == 'input': # 输入节点:处理输入数据 result = {'output': input_data, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result.get('output'), duration) return result elif node_type == 'llm' or node_type == 'template': # LLM节点:调用AI模型 node_data = node.get('data', {}) logger.debug(f"[rjb] LLM节点执行: node_id={node_id}, input_data={input_data}, input_data type={type(input_data)}") logger.debug(f"[rjb] LLM节点数据: node_id={node_id}, node_data keys={list(node_data.keys())}, api_key={'已配置' if node_data.get('api_key') else '未配置'}") prompt = node_data.get('prompt', '') # 如果prompt为空,使用默认提示词 if not prompt: prompt = "请处理以下输入数据:\n{input}" # 格式化prompt,替换变量 try: # 将input_data转换为字符串用于格式化 if isinstance(input_data, dict): # 支持两种格式的变量:{key} 和 {{key}} formatted_prompt = prompt has_unfilled_variables = False has_any_placeholder = False import re # 检查是否有任何占位符 has_any_placeholder = bool(re.search(r'\{\{?\w+\}?\}', prompt)) # 首先处理 {{variable}} 和 {{variable.path}} 格式(模板节点常用) # 支持嵌套路径,如 {{memory.conversation_history}} double_brace_vars = re.findall(r'\{\{([^}]+)\}\}', prompt) for var_path in double_brace_vars: # 尝试从input_data中获取值(支持嵌套路径) value = self._get_nested_value(input_data, var_path) # 如果变量未找到,尝试常见的别名映射 if value is None: # user_input 可以映射到 query、input、USER_INPUT 等字段 if var_path == 'user_input': for alias in ['query', 'input', 'USER_INPUT', 'user_input', 'text', 'message', 'content']: value = input_data.get(alias) if value is not None: # 如果值是字典,尝试从中提取字符串值 if isinstance(value, dict): for sub_key in ['query', 'input', 'text', 'message', 'content']: if sub_key in value: value = value[sub_key] break break # output 可以映射到 right 字段(LLM节点的输出通常存储在right字段中) elif var_path == 'output': # 尝试从right字段中提取 right_value = input_data.get('right') logger.info(f"[rjb] LLM节点查找output变量: right_value类型={type(right_value)}, right_value={str(right_value)[:100] if right_value else None}") if right_value is not None: # 如果right是字符串,直接使用 if isinstance(right_value, str): value = right_value logger.info(f"[rjb] LLM节点从right字段(字符串)提取output: {value[:100]}") # 如果right是字典,尝试递归查找字符串值 elif isinstance(right_value, dict): # 尝试从right.right.right...中提取(处理嵌套的right字段) current = right_value depth = 0 while isinstance(current, dict) and depth < 10: if 'right' in current: current = current['right'] depth += 1 if isinstance(current, str): value = current logger.info(f"[rjb] LLM节点从right字段(嵌套{depth}层)提取output: {value[:100]}") break else: # 如果没有right字段,尝试其他可能的字段 for key in ['content', 'text', 'message', 'output']: if key in current and isinstance(current[key], str): value = current[key] logger.info(f"[rjb] LLM节点从right字段中找到{key}字段: {value[:100]}") break if value is not None: break break if value is None: logger.warning(f"[rjb] LLM节点无法从right字段中提取output,right结构: {str(right_value)[:200]}") if value is not None: # 替换 {{variable}} 或 {{variable.path}} 为实际值 # 特殊处理:如果是memory.conversation_history,格式化为易读的对话格式 if var_path == 'memory.conversation_history' and isinstance(value, list): # 将对话历史格式化为易读的文本格式 formatted_history = [] for msg in value: role = msg.get('role', 'unknown') content = msg.get('content', '') if role == 'user': formatted_history.append(f"用户:{content}") elif role == 'assistant': formatted_history.append(f"助手:{content}") else: formatted_history.append(f"{role}:{content}") replacement = '\n'.join(formatted_history) if formatted_history else '(暂无对话历史)' else: # 其他情况使用JSON格式 replacement = json_module.dumps(value, ensure_ascii=False) if isinstance(value, (dict, list)) else str(value) formatted_prompt = formatted_prompt.replace(f'{{{{{var_path}}}}}', replacement) # 对于conversation_history,显示完整内容以便调试 if var_path == 'memory.conversation_history': logger.info(f"[rjb] LLM节点替换变量: {var_path} = {replacement[:500] if len(replacement) > 500 else replacement}") else: logger.info(f"[rjb] LLM节点替换变量: {var_path} = {str(replacement)[:200]}") else: has_unfilled_variables = True logger.warning(f"[rjb] LLM节点变量未找到: {var_path}, input_data keys: {list(input_data.keys()) if isinstance(input_data, dict) else 'not dict'}") # 然后处理 {key} 格式 for key, value in input_data.items(): placeholder = f'{{{key}}}' if placeholder in formatted_prompt: formatted_prompt = formatted_prompt.replace( placeholder, json_module.dumps(value, ensure_ascii=False) if isinstance(value, (dict, list)) else str(value) ) # 如果还有{input}占位符,替换为整个input_data if '{input}' in formatted_prompt: formatted_prompt = formatted_prompt.replace( '{input}', json_module.dumps(input_data, ensure_ascii=False) ) # 提取用户的实际查询内容(优先提取) user_query = None logger.info(f"[rjb] 开始提取user_query: input_data={input_data}, input_data_type={type(input_data)}") if isinstance(input_data, dict): # 首先检查是否有嵌套的input字段 nested_input = input_data.get('input') logger.info(f"[rjb] 检查嵌套input: nested_input={nested_input}, nested_input_type={type(nested_input) if nested_input else None}") if isinstance(nested_input, dict): # 从嵌套的input中提取 for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']: if key in nested_input: user_query = nested_input[key] logger.info(f"[rjb] 从嵌套input中提取到user_query: key={key}, user_query={user_query}") break # 如果还没有,从顶层提取 if not user_query: for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']: if key in input_data: value = input_data[key] logger.info(f"[rjb] 从顶层提取: key={key}, value={value}, value_type={type(value)}") # 如果值是字符串,直接使用 if isinstance(value, str): user_query = value logger.info(f"[rjb] 提取到字符串user_query: {user_query}") break # 如果值是字典,尝试从中提取 elif isinstance(value, dict): for sub_key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']: if sub_key in value: user_query = value[sub_key] logger.info(f"[rjb] 从字典值中提取到user_query: sub_key={sub_key}, user_query={user_query}") break if user_query: break # 如果还是没有,使用整个input_data(但排除系统字段) if not user_query: filtered_data = {k: v for k, v in input_data.items() if not k.startswith('_')} logger.info(f"[rjb] 使用filtered_data: filtered_data={filtered_data}") if filtered_data: # 如果只有一个字段且是字符串,直接使用 if len(filtered_data) == 1: single_value = list(filtered_data.values())[0] if isinstance(single_value, str): user_query = single_value logger.info(f"[rjb] 从单个字符串字段提取到user_query: {user_query}") elif isinstance(single_value, dict): # 从字典中提取第一个字符串值 for v in single_value.values(): if isinstance(v, str): user_query = v logger.info(f"[rjb] 从字典的单个字段中提取到user_query: {user_query}") break if not user_query: user_query = json_module.dumps(filtered_data, ensure_ascii=False) if len(filtered_data) > 1 else str(list(filtered_data.values())[0]) logger.info(f"[rjb] 使用JSON或字符串转换: user_query={user_query}") logger.info(f"[rjb] 最终提取的user_query: {user_query}") # 如果prompt中没有占位符,或者仍有未填充的变量,将用户输入附加到prompt is_generic_instruction = False # 初始化变量 if not has_any_placeholder: # 如果prompt中没有占位符,将用户输入作为主要内容 if user_query: # 判断是否是通用指令:简短且不包含具体任务描述 prompt_stripped = prompt.strip() is_generic_instruction = ( len(prompt_stripped) < 30 or # 简短提示词 prompt_stripped in [ "请处理用户请求。", "请处理用户请求", "请处理以下输入数据:", "请处理以下输入数据", "请处理输入。", "请处理输入", "处理用户请求", "处理请求", "请回答用户问题", "请回答用户问题。", "请帮助用户", "请帮助用户。" ] or # 检查是否只包含通用指令关键词 (len(prompt_stripped) < 50 and any(keyword in prompt_stripped for keyword in [ "请处理", "处理", "请回答", "回答", "请帮助", "帮助", "请执行", "执行" ]) and not any(specific in prompt_stripped for specific in [ "翻译", "生成", "分析", "总结", "提取", "转换", "计算" ])) ) if is_generic_instruction: # 如果是通用指令,直接使用用户输入作为prompt formatted_prompt = str(user_query) logger.info(f"[rjb] 检测到通用指令,直接使用用户输入作为prompt: {user_query[:50] if user_query else 'None'}") else: # 否则,将用户输入附加到prompt formatted_prompt = f"{formatted_prompt}\n\n{user_query}" logger.info(f"[rjb] 非通用指令,将用户输入附加到prompt") else: # 如果没有提取到用户查询,附加整个input_data formatted_prompt = f"{formatted_prompt}\n\n{json_module.dumps(input_data, ensure_ascii=False)}" elif has_unfilled_variables or re.search(r'\{\{(\w+)\}\}', formatted_prompt): # 如果有占位符但未填充,附加用户需求说明 if user_query: formatted_prompt = f"{formatted_prompt}\n\n用户需求:{user_query}\n\n请根据以上用户需求,忽略未填充的变量占位符(如{{{{variable}}}}),直接基于用户需求来完成任务。" logger.info(f"[rjb] LLM节点prompt格式化: node_id={node_id}, original_prompt='{prompt[:50] if len(prompt) > 50 else prompt}', has_any_placeholder={has_any_placeholder}, user_query={user_query}, is_generic_instruction={is_generic_instruction}, final_prompt前200字符='{formatted_prompt[:200] if len(formatted_prompt) > 200 else formatted_prompt}'") prompt = formatted_prompt else: # 如果input_data不是dict,直接转换为字符串 if '{input}' in prompt: prompt = prompt.replace('{input}', str(input_data)) else: prompt = f"{prompt}\n\n输入:{str(input_data)}" except Exception as e: # 格式化失败,使用原始prompt和input_data logger.warning(f"[rjb] Prompt格式化失败: {str(e)}") try: prompt = f"{prompt}\n\n输入数据:\n{json_module.dumps(input_data, ensure_ascii=False)}" except: prompt = f"{prompt}\n\n输入数据:{str(input_data)}" # 获取LLM配置 provider = node_data.get('provider', 'openai') model = node_data.get('model', 'gpt-3.5-turbo') # 确保temperature是浮点数(节点模板中可能是字符串) temperature_raw = node_data.get('temperature', 0.7) if isinstance(temperature_raw, str): try: temperature = float(temperature_raw) except (ValueError, TypeError): temperature = 0.7 else: temperature = float(temperature_raw) if temperature_raw is not None else 0.7 # 确保max_tokens是整数(节点模板中可能是字符串) max_tokens_raw = node_data.get('max_tokens') if max_tokens_raw is not None: if isinstance(max_tokens_raw, str): try: max_tokens = int(max_tokens_raw) except (ValueError, TypeError): max_tokens = None else: max_tokens = int(max_tokens_raw) if max_tokens_raw is not None else None else: max_tokens = None # 不传递 api_key 和 base_url,让 LLM 服务使用系统默认配置(与节点测试保持一致) api_key = None base_url = None # 记录实际发送给LLM的prompt logger.info(f"[rjb] 准备调用LLM: node_id={node_id}, provider={provider}, model={model}, prompt前200字符='{prompt[:200] if len(prompt) > 200 else prompt}'") # 调用LLM服务 try: if self.logger: logger.debug(f"[rjb] LLM节点配置: provider={provider}, model={model}, 使用系统默认API Key配置") self.logger.info(f"调用LLM服务: {provider}/{model}", node_id=node_id, node_type=node_type) result = await llm_service.call_llm( prompt=prompt, provider=provider, model=model, temperature=temperature, max_tokens=max_tokens # 不传递 api_key 和 base_url,使用系统默认配置 ) exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: # LLM调用失败,返回错误 if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'LLM调用失败: {str(e)}' } elif node_type == 'condition': # 条件节点:判断分支 condition = node.get('data', {}).get('condition', '') if not condition: # 如果没有条件表达式,默认返回False return { 'output': False, 'status': 'success', 'branch': 'false' } # 使用条件解析器评估表达式 try: result = condition_parser.evaluate_condition(condition, input_data) exec_result = { 'output': result, 'status': 'success', 'branch': 'true' if result else 'false' } if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, {'result': result, 'branch': exec_result['branch']}, duration) return exec_result except Exception as e: # 条件评估失败 if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': False, 'status': 'failed', 'error': f'条件评估失败: {str(e)}', 'branch': 'false' } elif node_type == 'data' or node_type == 'transform': # 数据转换节点 node_data = node.get('data', {}) mapping = node_data.get('mapping', {}) filter_rules = node_data.get('filter_rules', []) compute_rules = node_data.get('compute_rules', {}) mode = node_data.get('mode', 'mapping') try: # 处理mapping中的{{variable}}格式,从input_data中提取值 # 首先,如果input_data包含output字段,需要展开它 expanded_input = input_data.copy() if 'output' in input_data and isinstance(input_data['output'], dict): # 将output中的内容展开到顶层,但保留output字段 expanded_input.update(input_data['output']) processed_mapping = {} import re for target_key, source_expr in mapping.items(): if isinstance(source_expr, str): # 支持{{variable}}格式 double_brace_vars = re.findall(r'\{\{(\w+)\}\}', source_expr) if double_brace_vars: # 从expanded_input中获取变量值 var_value = None for var_name in double_brace_vars: # 尝试从expanded_input中获取,支持嵌套路径 var_value = self._get_nested_value(expanded_input, var_name) if var_value is not None: break if var_value is not None: # 如果只有一个变量,直接使用值;否则替换表达式 if len(double_brace_vars) == 1: processed_mapping[target_key] = var_value else: # 多个变量,替换表达式 processed_expr = source_expr for var_name in double_brace_vars: var_val = self._get_nested_value(expanded_input, var_name) if var_val is not None: replacement = json_module.dumps(var_val, ensure_ascii=False) if isinstance(var_val, (dict, list)) else str(var_val) processed_expr = processed_expr.replace(f'{{{{{var_name}}}}}', replacement) processed_mapping[target_key] = processed_expr else: # 变量不存在,保持原表达式 processed_mapping[target_key] = source_expr else: # 不是{{variable}}格式,直接使用 processed_mapping[target_key] = source_expr else: # 不是字符串,直接使用 processed_mapping[target_key] = source_expr # 如果mode是merge,需要合并所有输入数据 if mode == 'merge': # 合并所有上游节点的输出(使用展开后的数据) result = expanded_input.copy() # 重要:如果输入数据中包含conversation_history、user_profile、context等记忆字段,先构建memory对象 memory_fields = ['conversation_history', 'user_profile', 'context'] memory_data = {} for field in memory_fields: if field in expanded_input: memory_data[field] = expanded_input[field] # 如果构建了memory对象,添加到result中 if memory_data: result['memory'] = memory_data logger.info(f"[rjb] Transform节点 {node_id} 构建memory对象: {list(memory_data.keys())}") # 添加mapping的结果(mapping可能会覆盖memory字段) for key, value in processed_mapping.items(): # 如果mapping中的value是None或空字符串,且key是memory,尝试从expanded_input构建 if key == 'memory' and (value is None or value == '' or value == '{{output}}'): if memory_data: result[key] = memory_data logger.info(f"[rjb] Transform节点 {node_id} mapping中的memory为空,使用构建的memory对象") elif 'memory' in expanded_input: result[key] = expanded_input['memory'] else: result[key] = value # 确保记忆字段被保留(即使mapping覆盖了它们) for field in memory_fields: if field in expanded_input and field not in result: result[field] = expanded_input[field] # 如果memory字段是dict,也要检查其中的字段 if 'memory' in expanded_input and isinstance(expanded_input['memory'], dict): if 'memory' not in result: result['memory'] = expanded_input['memory'].copy() else: # 合并memory字段 if isinstance(result['memory'], dict): result['memory'].update(expanded_input['memory']) logger.info(f"[rjb] Transform节点 {node_id} merge模式,结果keys: {list(result.keys())}") if 'memory' in result and isinstance(result['memory'], dict): if 'conversation_history' in result['memory']: logger.info(f"[rjb] memory.conversation_history: {len(result['memory']['conversation_history'])} 条") elif 'conversation_history' in result: logger.info(f"[rjb] conversation_history: {len(result['conversation_history'])} 条") else: # 使用处理后的mapping进行转换(使用展开后的数据) result = data_transformer.transform_data( input_data=expanded_input, mapping=processed_mapping, filter_rules=filter_rules, compute_rules=compute_rules, mode=mode ) exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) logger.error(f"[rjb] Transform节点执行失败: {str(e)}", exc_info=True) return { 'output': None, 'status': 'failed', 'error': f'数据转换失败: {str(e)}' } elif node_type == 'loop' or node_type == 'foreach': # 循环节点:对数组进行循环处理 node_data = node.get('data', {}) items_path = node_data.get('items_path', 'items') # 数组数据路径 item_variable = node_data.get('item_variable', 'item') # 循环变量名 # 从输入数据中获取数组 items = self._get_nested_value(input_data, items_path) if not isinstance(items, list): if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, ValueError(f"路径 {items_path} 的值不是数组"), duration) return { 'output': None, 'status': 'failed', 'error': f'路径 {items_path} 的值不是数组,当前类型: {type(items).__name__}' } if self.logger: self.logger.info(f"循环节点开始处理 {len(items)} 个元素", node_id=node_id, node_type=node_type, data={"items_count": len(items)}) # 执行循环:对每个元素执行循环体 loop_results = [] for index, item in enumerate(items): if self.logger: self.logger.info(f"循环迭代 {index + 1}/{len(items)}", node_id=node_id, node_type=node_type, data={"index": index, "item": item}) # 准备循环体的输入数据 loop_input = { **input_data, # 保留原始输入数据 item_variable: item, # 当前循环项 f'{item_variable}_index': index, # 索引 f'{item_variable}_total': len(items) # 总数 } # 执行循环体(获取循环节点的子节点) loop_body_result = await self._execute_loop_body( node_id, loop_input, index ) if loop_body_result.get('status') == 'success': loop_results.append(loop_body_result.get('output', item)) else: # 如果循环体执行失败,可以选择继续或停止 error_handling = node_data.get('error_handling', 'continue') if error_handling == 'stop': if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, Exception(f"循环体执行失败,停止循环: {loop_body_result.get('error')}"), duration) return { 'output': None, 'status': 'failed', 'error': f'循环体执行失败: {loop_body_result.get("error")}', 'completed_items': index, 'results': loop_results } else: # continue: 继续执行,记录错误 if self.logger: self.logger.warn(f"循环迭代 {index + 1} 失败,继续执行", node_id=node_id, node_type=node_type, data={"error": loop_body_result.get('error')}) loop_results.append(None) exec_result = { 'output': loop_results, 'status': 'success', 'items_processed': len(items), 'results_count': len(loop_results) } if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, {'results_count': len(loop_results)}, duration) return exec_result elif node_type == 'http' or node_type == 'request': # HTTP请求节点:发送HTTP请求 node_data = node.get('data', {}) url = node_data.get('url', '') method = node_data.get('method', 'GET').upper() headers = node_data.get('headers', {}) params = node_data.get('params', {}) body = node_data.get('body', {}) timeout = node_data.get('timeout', 30) # 如果URL、headers、params、body中包含变量,从input_data中替换 import re def replace_variables(text: str, data: Dict[str, Any]) -> str: """替换字符串中的变量占位符""" if not isinstance(text, str): return text # 支持 {key} 或 ${key} 格式 pattern = r'\{([^}]+)\}|\$\{([^}]+)\}' def replacer(match): key = match.group(1) or match.group(2) value = self._get_nested_value(data, key) return str(value) if value is not None else match.group(0) return re.sub(pattern, replacer, text) # 替换URL中的变量 if url: url = replace_variables(url, input_data) # 替换headers中的变量 if isinstance(headers, dict): headers = {k: replace_variables(str(v), input_data) for k, v in headers.items()} elif isinstance(headers, str): try: headers = json.loads(replace_variables(headers, input_data)) except: headers = {} # 替换params中的变量 if isinstance(params, dict): params = {k: replace_variables(str(v), input_data) if isinstance(v, str) else v for k, v in params.items()} elif isinstance(params, str): try: params = json.loads(replace_variables(params, input_data)) except: params = {} # 替换body中的变量 if isinstance(body, dict): # 递归替换字典中的变量 def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]: result = {} for k, v in d.items(): new_k = replace_variables(k, data) if isinstance(v, dict): result[new_k] = replace_dict_vars(v, data) elif isinstance(v, str): result[new_k] = replace_variables(v, data) else: result[new_k] = v return result body = replace_dict_vars(body, input_data) elif isinstance(body, str): body = replace_variables(body, input_data) try: body = json.loads(body) except: pass try: import httpx async with httpx.AsyncClient(timeout=timeout) as client: if method == 'GET': response = await client.get(url, params=params, headers=headers) elif method == 'POST': response = await client.post(url, json=body, params=params, headers=headers) elif method == 'PUT': response = await client.put(url, json=body, params=params, headers=headers) elif method == 'DELETE': response = await client.delete(url, params=params, headers=headers) elif method == 'PATCH': response = await client.patch(url, json=body, params=params, headers=headers) else: raise ValueError(f"不支持的HTTP方法: {method}") # 尝试解析JSON响应 try: response_data = response.json() except: response_data = response.text result = { 'output': { 'status_code': response.status_code, 'headers': dict(response.headers), 'data': response_data }, 'status': 'success' } if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result.get('output'), duration) return result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'HTTP请求失败: {str(e)}' } elif node_type == 'database' or node_type == 'db': # 数据库操作节点:执行数据库操作 node_data = node.get('data', {}) data_source_id = node_data.get('data_source_id') operation = node_data.get('operation', 'query') # query/insert/update/delete sql = node_data.get('sql', '') table = node_data.get('table', '') data = node_data.get('data', {}) where = node_data.get('where', {}) # 如果SQL中包含变量,从input_data中替换 if sql and isinstance(sql, str): import re def replace_sql_vars(text: str, data: Dict[str, Any]) -> str: pattern = r'\{([^}]+)\}|\$\{([^}]+)\}' def replacer(match): key = match.group(1) or match.group(2) value = self._get_nested_value(data, key) if value is None: return match.group(0) # 如果是字符串,需要转义SQL注入 if isinstance(value, str): # 简单转义,实际应该使用参数化查询 escaped_value = value.replace("'", "''") return f"'{escaped_value}'" return str(value) return re.sub(pattern, replacer, text) sql = replace_sql_vars(sql, input_data) try: # 从数据库加载数据源配置 if not self.db: raise ValueError("数据库会话未提供,无法执行数据库操作") from app.models.data_source import DataSource from app.services.data_source_connector import create_connector data_source = self.db.query(DataSource).filter( DataSource.id == data_source_id ).first() if not data_source: raise ValueError(f"数据源不存在: {data_source_id}") connector = create_connector(data_source.type, data_source.config) if operation == 'query': # 查询操作 if not sql: raise ValueError("查询操作需要提供SQL语句") query_params = {'query': sql} result_data = connector.query(query_params) result = {'output': result_data, 'status': 'success'} elif operation == 'insert': # 插入操作 if not table: raise ValueError("插入操作需要提供表名") # 构建INSERT SQL columns = ', '.join(data.keys()) # 处理字符串值,转义单引号 def escape_value(v): if isinstance(v, str): escaped = v.replace("'", "''") return f"'{escaped}'" return str(v) values = ', '.join([escape_value(v) for v in data.values()]) insert_sql = f"INSERT INTO {table} ({columns}) VALUES ({values})" query_params = {'query': insert_sql} result_data = connector.query(query_params) result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'} elif operation == 'update': # 更新操作 if not table or not where: raise ValueError("更新操作需要提供表名和WHERE条件") set_clause = ', '.join([f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in data.items()]) where_clause = ' AND '.join([f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in where.items()]) update_sql = f"UPDATE {table} SET {set_clause} WHERE {where_clause}" query_params = {'query': update_sql} result_data = connector.query(query_params) result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'} elif operation == 'delete': # 删除操作 if not table or not where: raise ValueError("删除操作需要提供表名和WHERE条件") # 处理字符串值,转义单引号 def escape_sql_value(k, v): if isinstance(v, str): escaped = v.replace("'", "''") return f"{k} = '{escaped}'" return f"{k} = {v}" where_clause = ' AND '.join([escape_sql_value(k, v) for k, v in where.items()]) delete_sql = f"DELETE FROM {table} WHERE {where_clause}" query_params = {'query': delete_sql} result_data = connector.query(query_params) result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'} else: raise ValueError(f"不支持的数据库操作: {operation}") if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result.get('output'), duration) return result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'数据库操作失败: {str(e)}' } elif node_type == 'file' or node_type == 'file_operation': # 文件操作节点:文件读取、写入、上传、下载 node_data = node.get('data', {}) operation = node_data.get('operation', 'read') # read/write/upload/download file_path = node_data.get('file_path', '') content = node_data.get('content', '') encoding = node_data.get('encoding', 'utf-8') # 替换文件路径和内容中的变量 import re def replace_variables(text: str, data: Dict[str, Any]) -> str: """替换字符串中的变量占位符""" if not isinstance(text, str): return text pattern = r'\{([^}]+)\}|\$\{([^}]+)\}' def replacer(match): key = match.group(1) or match.group(2) value = self._get_nested_value(data, key) return str(value) if value is not None else match.group(0) return re.sub(pattern, replacer, text) if file_path: file_path = replace_variables(file_path, input_data) if isinstance(content, str): content = replace_variables(content, input_data) try: import os import json import base64 from pathlib import Path if operation == 'read': # 读取文件 if not file_path: raise ValueError("读取操作需要提供文件路径") if not os.path.exists(file_path): raise FileNotFoundError(f"文件不存在: {file_path}") # 根据文件扩展名决定读取方式 file_ext = Path(file_path).suffix.lower() if file_ext == '.json': with open(file_path, 'r', encoding=encoding) as f: data = json.load(f) elif file_ext in ['.txt', '.md', '.log']: with open(file_path, 'r', encoding=encoding) as f: data = f.read() else: # 二进制文件,返回base64编码 with open(file_path, 'rb') as f: data = base64.b64encode(f.read()).decode('utf-8') result = {'output': data, 'status': 'success'} elif operation == 'write': # 写入文件 if not file_path: raise ValueError("写入操作需要提供文件路径") # 确保目录存在 os.makedirs(os.path.dirname(file_path) if os.path.dirname(file_path) else '.', exist_ok=True) # 如果content是字典或列表,转换为JSON if isinstance(content, (dict, list)): content = json.dumps(content, ensure_ascii=False, indent=2) # 根据文件扩展名决定写入方式 file_ext = Path(file_path).suffix.lower() if file_ext == '.json': with open(file_path, 'w', encoding=encoding) as f: json.dump(json.loads(content) if isinstance(content, str) else content, f, ensure_ascii=False, indent=2) else: with open(file_path, 'w', encoding=encoding) as f: f.write(str(content)) result = {'output': {'file_path': file_path, 'message': '文件写入成功'}, 'status': 'success'} elif operation == 'upload': # 文件上传(从base64或URL上传) upload_type = node_data.get('upload_type', 'base64') # base64/url target_path = node_data.get('target_path', '') if upload_type == 'base64': # 从输入数据中获取base64编码的文件内容 file_data = input_data.get('file_data') or input_data.get('content') if not file_data: raise ValueError("上传操作需要提供file_data或content字段") # 解码base64 if isinstance(file_data, str): file_bytes = base64.b64decode(file_data) else: file_bytes = file_data # 写入目标路径 if not target_path: raise ValueError("上传操作需要提供target_path") os.makedirs(os.path.dirname(target_path) if os.path.dirname(target_path) else '.', exist_ok=True) with open(target_path, 'wb') as f: f.write(file_bytes) result = {'output': {'file_path': target_path, 'message': '文件上传成功'}, 'status': 'success'} else: # URL上传(下载后保存) import httpx url = node_data.get('url', '') if not url: raise ValueError("URL上传需要提供url") async with httpx.AsyncClient() as client: response = await client.get(url) response.raise_for_status() if not target_path: # 从URL提取文件名 target_path = os.path.basename(url) or 'downloaded_file' os.makedirs(os.path.dirname(target_path) if os.path.dirname(target_path) else '.', exist_ok=True) with open(target_path, 'wb') as f: f.write(response.content) result = {'output': {'file_path': target_path, 'message': '文件下载并保存成功'}, 'status': 'success'} elif operation == 'download': # 文件下载(返回base64编码或文件URL) download_format = node_data.get('download_format', 'base64') # base64/url if not file_path: raise ValueError("下载操作需要提供文件路径") if not os.path.exists(file_path): raise FileNotFoundError(f"文件不存在: {file_path}") if download_format == 'base64': # 返回base64编码 with open(file_path, 'rb') as f: file_bytes = f.read() file_base64 = base64.b64encode(file_bytes).decode('utf-8') result = {'output': {'file_name': os.path.basename(file_path), 'content': file_base64, 'format': 'base64'}, 'status': 'success'} else: # 返回文件路径(实际应用中可能需要生成临时URL) result = {'output': {'file_path': file_path, 'format': 'path'}, 'status': 'success'} else: raise ValueError(f"不支持的文件操作: {operation}") if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result.get('output'), duration) return result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'文件操作失败: {str(e)}' } elif node_type == 'webhook': # Webhook节点:发送Webhook请求到外部系统 node_data = node.get('data', {}) url = node_data.get('url', '') method = node_data.get('method', 'POST').upper() headers = node_data.get('headers', {}) body = node_data.get('body', {}) timeout = node_data.get('timeout', 30) # 如果URL、headers、body中包含变量,从input_data中替换 import re def replace_variables(text: str, data: Dict[str, Any]) -> str: """替换字符串中的变量占位符""" if not isinstance(text, str): return text pattern = r'\{([^}]+)\}|\$\{([^}]+)\}' def replacer(match): key = match.group(1) or match.group(2) value = self._get_nested_value(data, key) return str(value) if value is not None else match.group(0) return re.sub(pattern, replacer, text) # 替换URL中的变量 if url: url = replace_variables(url, input_data) # 替换headers中的变量 if isinstance(headers, dict): headers = {k: replace_variables(str(v), input_data) for k, v in headers.items()} elif isinstance(headers, str): try: headers = json.loads(replace_variables(headers, input_data)) except: headers = {} # 替换body中的变量 if isinstance(body, dict): # 递归替换字典中的变量 def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]: result = {} for k, v in d.items(): new_k = replace_variables(k, data) if isinstance(v, dict): result[new_k] = replace_dict_vars(v, data) elif isinstance(v, str): result[new_k] = replace_variables(v, data) else: result[new_k] = v return result body = replace_dict_vars(body, input_data) elif isinstance(body, str): body = replace_variables(body, input_data) try: body = json.loads(body) except: pass # 如果没有配置body,默认使用input_data作为body if not body: body = input_data try: import httpx async with httpx.AsyncClient(timeout=timeout) as client: if method == 'GET': response = await client.get(url, headers=headers) elif method == 'POST': response = await client.post(url, json=body, headers=headers) elif method == 'PUT': response = await client.put(url, json=body, headers=headers) elif method == 'PATCH': response = await client.patch(url, json=body, headers=headers) else: raise ValueError(f"Webhook不支持HTTP方法: {method}") # 尝试解析JSON响应 try: response_data = response.json() except: response_data = response.text result = { 'output': { 'status_code': response.status_code, 'headers': dict(response.headers), 'data': response_data }, 'status': 'success' } if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result.get('output'), duration) return result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'Webhook请求失败: {str(e)}' } elif node_type == 'schedule' or node_type == 'delay' or node_type == 'timer': # 定时任务节点:延迟执行或定时执行 node_data = node.get('data', {}) delay_type = node_data.get('delay_type', 'fixed') # fixed: 固定延迟, cron: cron表达式 delay_value = node_data.get('delay_value', 0) # 延迟值(秒) delay_unit = node_data.get('delay_unit', 'seconds') # seconds, minutes, hours # 计算实际延迟时间(毫秒) if delay_unit == 'seconds': delay_ms = int(delay_value * 1000) elif delay_unit == 'minutes': delay_ms = int(delay_value * 60 * 1000) elif delay_unit == 'hours': delay_ms = int(delay_value * 60 * 60 * 1000) else: delay_ms = int(delay_value * 1000) # 如果延迟时间大于0,则等待 if delay_ms > 0: if self.logger: self.logger.info( f"定时任务节点等待 {delay_value} {delay_unit}", node_id=node_id, node_type=node_type, data={'delay_ms': delay_ms, 'delay_value': delay_value, 'delay_unit': delay_unit} ) await asyncio.sleep(delay_ms / 1000.0) # 返回输入数据(定时节点只是延迟,不改变数据) result = {'output': input_data, 'status': 'success', 'delay_ms': delay_ms} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result.get('output'), duration) return result elif node_type == 'email' or node_type == 'mail': # 邮件节点:发送邮件通知 node_data = node.get('data', {}) smtp_host = node_data.get('smtp_host', '') smtp_port = node_data.get('smtp_port', 587) smtp_user = node_data.get('smtp_user', '') smtp_password = node_data.get('smtp_password', '') use_tls = node_data.get('use_tls', True) from_email = node_data.get('from_email', '') to_email = node_data.get('to_email', '') cc_email = node_data.get('cc_email', '') bcc_email = node_data.get('bcc_email', '') subject = node_data.get('subject', '') body = node_data.get('body', '') body_type = node_data.get('body_type', 'text') # text/html attachments = node_data.get('attachments', []) # 附件列表 # 替换变量 import re def replace_variables(text: str, data: Dict[str, Any]) -> str: """替换字符串中的变量占位符""" if not isinstance(text, str): return text pattern = r'\{([^}]+)\}|\$\{([^}]+)\}' def replacer(match): key = match.group(1) or match.group(2) value = self._get_nested_value(data, key) return str(value) if value is not None else match.group(0) return re.sub(pattern, replacer, text) # 替换所有配置中的变量 smtp_host = replace_variables(smtp_host, input_data) smtp_user = replace_variables(smtp_user, input_data) smtp_password = replace_variables(smtp_password, input_data) from_email = replace_variables(from_email, input_data) to_email = replace_variables(to_email, input_data) cc_email = replace_variables(cc_email, input_data) bcc_email = replace_variables(bcc_email, input_data) subject = replace_variables(subject, input_data) body = replace_variables(body, input_data) # 验证必需参数 if not smtp_host: raise ValueError("邮件节点需要配置SMTP服务器地址") if not from_email: raise ValueError("邮件节点需要配置发件人邮箱") if not to_email: raise ValueError("邮件节点需要配置收件人邮箱") if not subject: raise ValueError("邮件节点需要配置邮件主题") try: import aiosmtplib from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart from email.mime.base import MIMEBase from email import encoders import base64 import os # 创建邮件消息 msg = MIMEMultipart('alternative') msg['From'] = from_email msg['To'] = to_email if cc_email: msg['Cc'] = cc_email msg['Subject'] = subject # 添加邮件正文 if body_type == 'html': msg.attach(MIMEText(body, 'html', 'utf-8')) else: msg.attach(MIMEText(body, 'plain', 'utf-8')) # 处理附件 for attachment in attachments: if isinstance(attachment, dict): file_path = attachment.get('file_path', '') file_name = attachment.get('file_name', '') file_content = attachment.get('file_content', '') # base64编码的内容 # 替换变量 file_path = replace_variables(file_path, input_data) file_name = replace_variables(file_name, input_data) if file_path and os.path.exists(file_path): # 从文件路径读取 with open(file_path, 'rb') as f: file_data = f.read() if not file_name: file_name = os.path.basename(file_path) elif file_content: # 从base64内容读取 file_data = base64.b64decode(file_content) if not file_name: file_name = 'attachment' else: continue # 添加附件 part = MIMEBase('application', 'octet-stream') part.set_payload(file_data) encoders.encode_base64(part) part.add_header( 'Content-Disposition', f'attachment; filename= {file_name}' ) msg.attach(part) # 发送邮件 recipients = [to_email] if cc_email: recipients.extend([email.strip() for email in cc_email.split(',')]) if bcc_email: recipients.extend([email.strip() for email in bcc_email.split(',')]) async with aiosmtplib.SMTP(hostname=smtp_host, port=smtp_port) as smtp: if use_tls: await smtp.starttls() if smtp_user and smtp_password: await smtp.login(smtp_user, smtp_password) await smtp.send_message(msg, recipients=recipients) result = { 'output': { 'message': '邮件发送成功', 'from': from_email, 'to': to_email, 'subject': subject, 'recipients_count': len(recipients) }, 'status': 'success' } if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result.get('output'), duration) return result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'邮件发送失败: {str(e)}' } elif node_type == 'message_queue' or node_type == 'mq' or node_type == 'rabbitmq' or node_type == 'kafka': # 消息队列节点:发送消息到RabbitMQ或Kafka node_data = node.get('data', {}) queue_type = node_data.get('queue_type', 'rabbitmq') # rabbitmq/kafka # 替换变量 import re def replace_variables(text: str, data: Dict[str, Any]) -> str: """替换字符串中的变量占位符""" if not isinstance(text, str): return text pattern = r'\{([^}]+)\}|\$\{([^}]+)\}' def replacer(match): key = match.group(1) or match.group(2) value = self._get_nested_value(data, key) return str(value) if value is not None else match.group(0) return re.sub(pattern, replacer, text) try: if queue_type == 'rabbitmq': # RabbitMQ实现 import aio_pika import json # 获取RabbitMQ配置 host = replace_variables(node_data.get('host', 'localhost'), input_data) port = node_data.get('port', 5672) username = replace_variables(node_data.get('username', 'guest'), input_data) password = replace_variables(node_data.get('password', 'guest'), input_data) exchange = replace_variables(node_data.get('exchange', ''), input_data) routing_key = replace_variables(node_data.get('routing_key', ''), input_data) queue_name = replace_variables(node_data.get('queue_name', ''), input_data) message = node_data.get('message', input_data) # 如果message是字符串,尝试替换变量 if isinstance(message, str): message = replace_variables(message, input_data) try: message = json.loads(message) except: pass elif isinstance(message, dict): # 递归替换字典中的变量 def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]: result = {} for k, v in d.items(): new_k = replace_variables(k, data) if isinstance(v, dict): result[new_k] = replace_dict_vars(v, data) elif isinstance(v, str): result[new_k] = replace_variables(v, data) else: result[new_k] = v return result message = replace_dict_vars(message, input_data) # 如果没有配置message,使用input_data if not message: message = input_data # 连接RabbitMQ connection_url = f"amqp://{username}:{password}@{host}:{port}/" connection = await aio_pika.connect_robust(connection_url) channel = await connection.channel() # 发送消息 message_body = json.dumps(message, ensure_ascii=False).encode('utf-8') if exchange: # 使用exchange和routing_key await channel.default_exchange.publish( aio_pika.Message(message_body), routing_key=routing_key or queue_name ) elif queue_name: # 直接发送到队列 queue = await channel.declare_queue(queue_name, durable=True) await channel.default_exchange.publish( aio_pika.Message(message_body), routing_key=queue_name ) else: raise ValueError("RabbitMQ节点需要配置exchange或queue_name") await connection.close() result = { 'output': { 'message': '消息已发送到RabbitMQ', 'queue_type': 'rabbitmq', 'exchange': exchange, 'routing_key': routing_key or queue_name, 'queue_name': queue_name, 'message_size': len(message_body) }, 'status': 'success' } elif queue_type == 'kafka': # Kafka实现 from kafka import KafkaProducer import json # 获取Kafka配置 bootstrap_servers = replace_variables(node_data.get('bootstrap_servers', 'localhost:9092'), input_data) topic = replace_variables(node_data.get('topic', ''), input_data) message = node_data.get('message', input_data) # 如果message是字符串,尝试替换变量 if isinstance(message, str): message = replace_variables(message, input_data) try: message = json.loads(message) except: pass elif isinstance(message, dict): # 递归替换字典中的变量 def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]: result = {} for k, v in d.items(): new_k = replace_variables(k, data) if isinstance(v, dict): result[new_k] = replace_dict_vars(v, data) elif isinstance(v, str): result[new_k] = replace_variables(v, data) else: result[new_k] = v return result message = replace_dict_vars(message, input_data) # 如果没有配置message,使用input_data if not message: message = input_data if not topic: raise ValueError("Kafka节点需要配置topic") # 创建Kafka生产者(注意:kafka-python是同步的,需要在线程池中运行) from concurrent.futures import ThreadPoolExecutor def send_kafka_message(): producer = KafkaProducer( bootstrap_servers=bootstrap_servers.split(','), value_serializer=lambda v: json.dumps(v, ensure_ascii=False).encode('utf-8') ) future = producer.send(topic, message) record_metadata = future.get(timeout=10) producer.close() return record_metadata # 在线程池中执行同步操作 loop = asyncio.get_event_loop() with ThreadPoolExecutor() as executor: record_metadata = await loop.run_in_executor(executor, send_kafka_message) result = { 'output': { 'message': '消息已发送到Kafka', 'queue_type': 'kafka', 'topic': topic, 'partition': record_metadata.partition, 'offset': record_metadata.offset }, 'status': 'success' } else: raise ValueError(f"不支持的消息队列类型: {queue_type}") if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result.get('output'), duration) return result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'消息队列发送失败: {str(e)}' } elif node_type == 'switch': # Switch节点:多分支路由 logger.info(f"[rjb] 执行Switch节点: node_id={node_id}, node_type={node_type}, input_data keys={list(input_data.keys()) if isinstance(input_data, dict) else 'not_dict'}") node_data = node.get('data', {}) field = node_data.get('field', '') cases = node_data.get('cases', {}) default_case = node_data.get('default', 'default') logger.info(f"[rjb] Switch节点配置: field={field}, cases={cases}, default={default_case}") # 处理输入数据:尝试解析JSON字符串(递归处理所有字段) def parse_json_recursively(data, merge_parsed=True): """ 递归解析JSON字符串 Args: data: 要处理的数据 merge_parsed: 是否将解析后的字典内容合并到父级(用于方便字段提取) """ import json if isinstance(data, str): # 如果是字符串,尝试解析为JSON try: parsed = json.loads(data) # 如果解析成功,递归处理解析后的数据 if isinstance(parsed, (dict, list)): return parse_json_recursively(parsed, merge_parsed) return parsed except: # 不是JSON,返回原字符串 return data elif isinstance(data, dict): # 如果是字典,递归处理每个值 result = {} for key, value in data.items(): parsed_value = parse_json_recursively(value, merge_parsed=False) result[key] = parsed_value # 如果merge_parsed为True且解析后的值是字典,将其内容合并到当前层级(方便字段提取) if merge_parsed and isinstance(parsed_value, dict): # 合并时避免覆盖已有的键 for k, v in parsed_value.items(): if k not in result: result[k] = v return result elif isinstance(data, list): # 如果是列表,递归处理每个元素 return [parse_json_recursively(item, merge_parsed) for item in data] else: # 其他类型,直接返回 return data processed_input = parse_json_recursively(input_data, merge_parsed=True) # 从处理后的输入数据中获取字段值 field_value = self._get_nested_value(processed_input, field) field_value_str = str(field_value) if field_value is not None else '' # 查找匹配的case matched_case = default_case if field_value_str in cases: matched_case = cases[field_value_str] elif field_value in cases: matched_case = cases[field_value] # 记录详细的匹配信息(同时输出到控制台和数据库) match_info = { 'field': field, 'field_value': field_value, 'field_value_str': field_value_str, 'matched_case': matched_case, 'processed_input_keys': list(processed_input.keys()) if isinstance(processed_input, dict) else 'not_dict', 'cases_keys': list(cases.keys()) } logger.info(f"[rjb] Switch节点匹配: node_id={node_id}, {match_info}") if self.logger: self.logger.info( f"Switch节点匹配: field={field}, field_value={field_value}, matched_case={matched_case}", node_id=node_id, node_type=node_type, data=match_info ) exec_result = { 'output': processed_input, 'status': 'success', 'branch': matched_case, 'matched_value': field_value } if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, {'branch': matched_case, 'value': field_value}, duration) return exec_result elif node_type == 'merge': # Merge节点:合并多个分支的数据流 node_data = node.get('data', {}) mode = node_data.get('mode', 'merge_all') # merge_all, merge_first, merge_last strategy = node_data.get('strategy', 'array') # array, object, concat # 获取所有上游节点的输出(通过input_data中的特殊字段) # 如果input_data包含多个分支的数据,合并它们 merged_data = {} if strategy == 'array': # 数组策略:将所有输入数据作为数组元素 if isinstance(input_data, list): merged_data = input_data elif isinstance(input_data, dict): # 如果包含多个分支数据,提取为数组 branch_data = [] for key, value in input_data.items(): if not key.startswith('_'): branch_data.append(value) merged_data = branch_data if branch_data else [input_data] else: merged_data = [input_data] elif strategy == 'object': # 对象策略:合并所有字段 if isinstance(input_data, dict): merged_data = input_data.copy() else: merged_data = {'data': input_data} elif strategy == 'concat': # 连接策略:将所有数据连接为字符串 if isinstance(input_data, list): merged_data = '\n'.join(str(item) for item in input_data) elif isinstance(input_data, dict): merged_data = '\n'.join(f"{k}: {v}" for k, v in input_data.items() if not k.startswith('_')) else: merged_data = str(input_data) exec_result = {'output': merged_data, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, merged_data, duration) return exec_result elif node_type == 'wait': # Wait节点:等待条件满足 node_data = node.get('data', {}) wait_type = node_data.get('wait_type', 'condition') # condition, time, event condition = node_data.get('condition', '') timeout = node_data.get('timeout', 300) # 默认5分钟 poll_interval = node_data.get('poll_interval', 5) # 默认5秒 if wait_type == 'condition': # 等待条件满足 start_wait = time.time() while time.time() - start_wait < timeout: try: result = condition_parser.evaluate_condition(condition, input_data) if result: exec_result = {'output': input_data, 'status': 'success', 'waited': time.time() - start_wait} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, exec_result, duration) return exec_result except Exception as e: logger.warning(f"Wait节点条件评估失败: {str(e)}") await asyncio.sleep(poll_interval) # 超时 exec_result = { 'output': input_data, 'status': 'failed', 'error': f'等待条件超时: {timeout}秒' } if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, Exception("等待超时"), duration) return exec_result elif wait_type == 'time': # 等待固定时间 wait_seconds = node_data.get('wait_seconds', 0) await asyncio.sleep(wait_seconds) exec_result = {'output': input_data, 'status': 'success', 'waited': wait_seconds} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, exec_result, duration) return exec_result else: # 其他类型暂不支持 exec_result = {'output': input_data, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, exec_result, duration) return exec_result elif node_type == 'json': # JSON处理节点 node_data = node.get('data', {}) operation = node_data.get('operation', 'parse') # parse, stringify, extract, validate path = node_data.get('path', '') schema = node_data.get('schema', {}) try: if operation == 'parse': # 解析JSON字符串 if isinstance(input_data, str): result = json_module.loads(input_data) elif isinstance(input_data, dict) and 'data' in input_data: # 如果包含data字段,尝试解析 if isinstance(input_data['data'], str): result = json_module.loads(input_data['data']) else: result = input_data['data'] else: result = input_data elif operation == 'stringify': # 转换为JSON字符串 result = json_module.dumps(input_data, ensure_ascii=False, indent=2) elif operation == 'extract': # 使用JSONPath提取数据(简化实现) if path and isinstance(input_data, dict): # 简单的路径提取,支持 $.key 格式 path = path.replace('$.', '').replace('$', '') keys = path.split('.') result = input_data for key in keys: if key.endswith('[*]'): # 数组提取 array_key = key[:-3] if isinstance(result, dict) and array_key in result: result = result[array_key] elif isinstance(result, dict) and key in result: result = result[key] else: result = None break else: result = input_data elif operation == 'validate': # JSON Schema验证(简化实现) # 这里只做基本验证,完整实现需要使用jsonschema库 if schema: # 简单类型检查 if 'type' in schema: expected_type = schema['type'] actual_type = type(input_data).__name__ if expected_type == 'object' and actual_type != 'dict': raise ValueError(f"期望类型 {expected_type},实际类型 {actual_type}") elif expected_type == 'array' and actual_type != 'list': raise ValueError(f"期望类型 {expected_type},实际类型 {actual_type}") result = input_data else: result = input_data exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'JSON处理失败: {str(e)}' } elif node_type == 'text': # 文本处理节点 node_data = node.get('data', {}) operation = node_data.get('operation', 'split') # split, join, extract, replace, format delimiter = node_data.get('delimiter', '\n') regex = node_data.get('regex', '') template = node_data.get('template', '') try: # 获取输入文本 input_text = input_data if isinstance(input_data, dict): # 尝试从字典中提取文本 for key in ['text', 'content', 'message', 'input', 'output']: if key in input_data and isinstance(input_data[key], str): input_text = input_data[key] break if isinstance(input_text, dict): input_text = str(input_text) elif not isinstance(input_text, str): input_text = str(input_text) if operation == 'split': # 拆分文本 result = input_text.split(delimiter) elif operation == 'join': # 合并文本(需要输入是数组) if isinstance(input_data, list): result = delimiter.join(str(item) for item in input_data) else: result = input_text elif operation == 'extract': # 使用正则表达式提取 if regex: import re matches = re.findall(regex, input_text) result = matches if len(matches) > 1 else (matches[0] if matches else '') else: result = input_text elif operation == 'replace': # 替换文本 old_text = node_data.get('old_text', '') new_text = node_data.get('new_text', '') if regex: import re result = re.sub(regex, new_text, input_text) else: result = input_text.replace(old_text, new_text) elif operation == 'format': # 格式化文本(使用模板) if template: # 支持 {key} 格式的变量替换 result = template if isinstance(input_data, dict): for key, value in input_data.items(): result = result.replace(f'{{{key}}}', str(value)) else: result = result.replace('{value}', str(input_data)) else: result = input_text else: result = input_text exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'文本处理失败: {str(e)}' } elif node_type == 'cache': # 缓存节点 node_data = node.get('data', {}) operation = node_data.get('operation', 'get') # get, set, delete, clear key = node_data.get('key', '') ttl = node_data.get('ttl', 3600) # 默认1小时 # 默认优先使用redis(如果配置了),否则memory backend = node_data.get('backend') or ('redis' if getattr(settings, 'REDIS_URL', None) else 'memory') # redis, memory default_value = node_data.get('default_value', '{}') value_template = node_data.get('value', '') # 使用Redis作为持久化缓存(如果可用),否则使用内存缓存 # 注意:内存缓存在单次执行会话内有效,跨执行不会保留 use_redis = False redis_client = None # 默认尝试使用Redis(如果配置了),除非明确指定使用memory if backend != 'memory': try: from app.core.redis_client import get_redis_client redis_client = get_redis_client() if redis_client: use_redis = True logger.info(f"[rjb] Cache节点 {node_id} 使用Redis缓存") except Exception as e: logger.warning(f"Redis不可用: {str(e)},使用内存缓存") # 内存缓存(单次执行会话内有效) if not hasattr(self, '_cache_store'): self._cache_store = {} self._cache_timestamps = {} try: # 替换key中的变量 if isinstance(input_data, dict): # 首先处理 {{variable}} 格式 import re double_brace_vars = re.findall(r'\{\{(\w+)\}\}', key) for var_name in double_brace_vars: if var_name in input_data: key = key.replace(f'{{{{{var_name}}}}}', str(input_data[var_name])) else: # 如果变量不存在,使用默认值 if var_name == 'user_id': # 尝试从输入数据中提取user_id,如果没有则使用"default" user_id = input_data.get('user_id') or input_data.get('USER_ID') or 'default' key = key.replace(f'{{{{{var_name}}}}}', str(user_id)) else: key = key.replace(f'{{{{{var_name}}}}}', 'default') # 然后处理 {key} 格式 for k, v in input_data.items(): key = key.replace(f'{{{k}}}', str(v)) # 如果key中还有未替换的变量,使用默认值 if '{' in key: key = key.replace('{user_id}', 'default').replace('{{user_id}}', 'default') # 清理其他未替换的变量 key = re.sub(r'\{[^}]+\}', 'default', key) logger.info(f"[rjb] Cache节点 {node_id} 处理后的key: {key}") if operation == 'get': # 获取缓存 result = None cache_hit = False if use_redis and redis_client: # 从Redis获取 try: cached_data = redis_client.get(key) if cached_data: result = json_module.loads(cached_data) cache_hit = True except Exception as e: logger.warning(f"从Redis获取缓存失败: {str(e)}") if result is None: # 从内存缓存获取 if key in self._cache_store: # 检查是否过期 if key in self._cache_timestamps: if time.time() - self._cache_timestamps[key] > ttl: # 过期,删除 del self._cache_store[key] del self._cache_timestamps[key] else: result = self._cache_store[key] cache_hit = True else: result = self._cache_store[key] cache_hit = True # 如果缓存未命中,使用default_value if result is None: try: if isinstance(default_value, str): result = json_module.loads(default_value) if default_value else {} else: result = default_value except: result = {} cache_hit = False logger.info(f"[rjb] Cache节点 {node_id} cache miss,使用default_value: {result}") # 合并输入数据和缓存结果 output = input_data.copy() if isinstance(input_data, dict) else {} if isinstance(result, dict): output.update(result) else: output['memory'] = result exec_result = {'output': output, 'status': 'success', 'cache_hit': cache_hit, 'memory': result} elif operation == 'set': # 设置缓存 # 处理value模板 if value_template: # 处理模板语法 {{variable}} import re value_str = value_template # 替换 {{variable}} 格式的变量 # 注意:只替换 memory.* 路径的变量,user_input、output、timestamp 等变量在Python表达式执行阶段处理 template_vars = re.findall(r'\{\{(\w+(?:\.\w+)*)\}\}', value_str) for var_path in template_vars: # 跳过 user_input、output、timestamp 等变量,这些在Python表达式执行阶段处理 if var_path in ['user_input', 'output', 'timestamp']: continue # 支持嵌套路径,如 memory.conversation_history var_parts = var_path.split('.') var_value = input_data try: for part in var_parts: if isinstance(var_value, dict) and part in var_value: var_value = var_value[part] else: var_value = None break if var_value is not None: # 替换模板变量 replacement = json_module.dumps(var_value, ensure_ascii=False) if isinstance(var_value, (dict, list)) else str(var_value) value_str = value_str.replace(f'{{{{{var_path}}}}}', replacement) else: # 变量不存在,根据路径使用合适的默认值 if 'conversation_history' in var_path: value_str = value_str.replace(f'{{{{{var_path}}}}}', '[]') elif 'user_profile' in var_path or 'context' in var_path: value_str = value_str.replace(f'{{{{{var_path}}}}}', '{}') else: # 对于其他变量,保留原样,让Python表达式执行阶段处理 pass except Exception as e: logger.warning(f"处理模板变量 {var_path} 失败: {str(e)}") # 替换 {key} 格式的变量(但不要替换 {{variable}} 格式的) for k, v in input_data.items(): placeholder = f'{{{k}}}' # 确保不是 {{variable}} 格式 if placeholder in value_str and f'{{{{{k}}}}}' not in value_str: replacement = json_module.dumps(v, ensure_ascii=False) if isinstance(v, (dict, list)) else str(v) value_str = value_str.replace(placeholder, replacement) # 解析处理后的value(可能是JSON字符串或Python表达式) try: # 尝试作为JSON解析 value = json_module.loads(value_str) except: # 如果不是有效的JSON,尝试作为Python表达式执行(安全限制) try: # 准备安全的环境变量 from datetime import datetime memory = input_data.get('memory', {}) if not isinstance(memory, dict): memory = {} # 确保memory中有必要的字段 if 'conversation_history' not in memory: memory['conversation_history'] = [] if 'user_profile' not in memory: memory['user_profile'] = {} if 'context' not in memory: memory['context'] = {} # 获取user_input:优先从query或USER_INPUT获取 user_input = input_data.get('query') or input_data.get('USER_INPUT') or input_data.get('user_input') or '' # 获取output:从right字段获取,如果是dict则提取right子字段 output = input_data.get('right', '') if isinstance(output, dict): output = output.get('right', '') or output.get('content', '') or str(output) if not output: output = '' timestamp = datetime.now().isoformat() # 在Python表达式执行前,替换 {{user_input}}、{{output}}、{{timestamp}} # 注意:模板中已经有引号了,所以需要转义字符串中的特殊字符,然后直接插入 # 使用json.dumps来正确转义,但去掉外层的引号(因为模板中已经有引号了) user_input_escaped = json_module.dumps(user_input, ensure_ascii=False)[1:-1] # 去掉首尾引号 output_escaped = json_module.dumps(output, ensure_ascii=False)[1:-1] timestamp_escaped = json_module.dumps(timestamp, ensure_ascii=False)[1:-1] value_str = value_str.replace('{{user_input}}', user_input_escaped) value_str = value_str.replace('{{output}}', output_escaped) value_str = value_str.replace('{{timestamp}}', timestamp_escaped) # 只允许基本的字典和列表操作 safe_dict = { 'memory': memory, 'user_input': user_input, 'output': output, 'timestamp': timestamp } logger.info(f"[rjb] Cache节点 {node_id} 执行value模板") logger.info(f"[rjb] value_str前300字符: {value_str[:300]}") logger.info(f"[rjb] user_input: {user_input[:50]}, output: {str(output)[:50]}, timestamp: {timestamp}") value = eval(value_str, {"__builtins__": {}}, safe_dict) logger.info(f"[rjb] Cache节点 {node_id} value模板执行成功,类型: {type(value)}") if isinstance(value, dict): logger.info(f"[rjb] keys: {list(value.keys())}") if 'conversation_history' in value: logger.info(f"[rjb] conversation_history: {len(value['conversation_history'])} 条") if value['conversation_history']: logger.info(f"[rjb] 第一条: {value['conversation_history'][0]}") except Exception as e: logger.error(f"Cache节点 {node_id} value模板执行失败: {str(e)}") logger.error(f"value_str: {value_str[:500]}") logger.error(f"safe_dict: {safe_dict}") import traceback logger.error(f"traceback: {traceback.format_exc()}") # 如果都失败,使用原始输入数据 value = input_data else: # 没有value模板,使用输入数据 value = input_data if isinstance(input_data, dict) and 'value' in input_data: value = input_data['value'] # 存储到缓存 if use_redis and redis_client: try: redis_client.setex(key, ttl, json_module.dumps(value, ensure_ascii=False)) logger.info(f"[rjb] Cache节点 {node_id} 已存储到Redis: key={key}") except Exception as e: logger.warning(f"存储到Redis失败: {str(e)}") # 同时存储到内存缓存 self._cache_store[key] = value self._cache_timestamps[key] = time.time() logger.info(f"[rjb] Cache节点 {node_id} 已存储: key={key}, value类型={type(value)}") exec_result = {'output': input_data, 'status': 'success', 'cached_value': value} elif operation == 'delete': # 删除缓存 if key in self._cache_store: del self._cache_store[key] if key in self._cache_timestamps: del self._cache_timestamps[key] exec_result = {'output': input_data, 'status': 'success'} elif operation == 'clear': # 清空缓存 self._cache_store.clear() self._cache_timestamps.clear() exec_result = {'output': input_data, 'status': 'success'} else: exec_result = {'output': input_data, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, exec_result.get('output'), duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'缓存操作失败: {str(e)}' } elif node_type == 'vector_db': # 向量数据库节点:向量存储、相似度搜索、RAG检索 node_data = node.get('data', {}) operation = node_data.get('operation', 'search') # search, upsert, delete collection = node_data.get('collection', 'default') query_vector = node_data.get('query_vector', '') top_k = node_data.get('top_k', 5) # 简化的内存向量存储实现(实际生产环境应使用ChromaDB、Pinecone等) if not hasattr(self, '_vector_store'): self._vector_store = {} try: if operation == 'search': # 向量相似度搜索(简化实现:使用余弦相似度) if collection not in self._vector_store: self._vector_store[collection] = [] # 获取查询向量 if isinstance(query_vector, str): # 尝试从input_data中获取 query_vec = self._get_nested_value(input_data, query_vector.replace('{', '').replace('}', '')) else: query_vec = query_vector if not isinstance(query_vec, list): # 如果输入数据包含embedding字段,使用它 if isinstance(input_data, dict) and 'embedding' in input_data: query_vec = input_data['embedding'] else: raise ValueError("无法获取查询向量") # 计算相似度并排序 results = [] for item in self._vector_store[collection]: if 'vector' in item: # 计算余弦相似度 import math vec1 = query_vec vec2 = item['vector'] if len(vec1) != len(vec2): continue dot_product = sum(a * b for a, b in zip(vec1, vec2)) magnitude1 = math.sqrt(sum(a * a for a in vec1)) magnitude2 = math.sqrt(sum(a * a for a in vec2)) if magnitude1 == 0 or magnitude2 == 0: similarity = 0 else: similarity = dot_product / (magnitude1 * magnitude2) results.append({ 'id': item.get('id'), 'text': item.get('text', ''), 'metadata': item.get('metadata', {}), 'similarity': similarity }) # 按相似度排序并返回top_k results.sort(key=lambda x: x['similarity'], reverse=True) result = results[:top_k] elif operation == 'upsert': # 插入或更新向量 if collection not in self._vector_store: self._vector_store[collection] = [] # 从输入数据中提取向量和文本 vector = input_data.get('embedding') or input_data.get('vector') text = input_data.get('text') or input_data.get('content', '') metadata = input_data.get('metadata', {}) doc_id = input_data.get('id') or f"doc_{len(self._vector_store[collection])}" # 查找是否已存在 existing_index = None for i, item in enumerate(self._vector_store[collection]): if item.get('id') == doc_id: existing_index = i break doc_item = { 'id': doc_id, 'vector': vector, 'text': text, 'metadata': metadata } if existing_index is not None: self._vector_store[collection][existing_index] = doc_item else: self._vector_store[collection].append(doc_item) result = {'id': doc_id, 'status': 'upserted'} elif operation == 'delete': # 删除向量 if collection in self._vector_store: doc_id = node_data.get('doc_id') or input_data.get('id') if doc_id: self._vector_store[collection] = [ item for item in self._vector_store[collection] if item.get('id') != doc_id ] result = {'id': doc_id, 'status': 'deleted'} else: # 删除整个集合 del self._vector_store[collection] result = {'collection': collection, 'status': 'deleted'} else: result = {'status': 'not_found'} else: result = input_data exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'向量数据库操作失败: {str(e)}' } elif node_type == 'log': # 日志节点:记录日志、调试输出、性能监控 node_data = node.get('data', {}) level = node_data.get('level', 'info') # debug, info, warning, error message = node_data.get('message', '') include_data = node_data.get('include_data', True) try: # 格式化消息 if message: # 替换变量 if isinstance(input_data, dict): for key, value in input_data.items(): message = message.replace(f'{{{key}}}', str(value)) # 构建日志内容 log_data = { 'message': message or '节点执行', 'node_id': node_id, 'node_type': node_type, 'timestamp': time.time() } if include_data: log_data['data'] = input_data # 记录日志 log_message = f"[{node_id}] {log_data['message']}" if include_data: log_message += f" | 数据: {json_module.dumps(input_data, ensure_ascii=False)[:200]}" if level == 'debug': logger.debug(log_message) elif level == 'info': logger.info(log_message) elif level == 'warning': logger.warning(log_message) elif level == 'error': logger.error(log_message) else: logger.info(log_message) # 如果使用执行日志记录器,也记录 if self.logger: self.logger.info(log_data['message'], node_id=node_id, node_type=node_type, data=input_data if include_data else None) exec_result = {'output': input_data, 'status': 'success', 'log': log_data} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, exec_result.get('output'), duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': input_data, 'status': 'failed', 'error': f'日志记录失败: {str(e)}' } elif node_type == 'error_handler': # 错误处理节点:捕获错误、错误重试、错误通知 # 注意:这个节点需要特殊处理,因为它应该包装其他节点的执行 # 这里我们实现一个简化版本,主要用于错误重试和通知 node_data = node.get('data', {}) retry_count = node_data.get('retry_count', 3) retry_delay = node_data.get('retry_delay', 1000) # 毫秒 on_error = node_data.get('on_error', 'notify') # notify, retry, stop error_handler_workflow = node_data.get('error_handler_workflow', '') # 这个节点通常用于包装其他节点,但在这里我们只处理输入数据中的错误 try: # 检查输入数据中是否有错误 if isinstance(input_data, dict) and input_data.get('status') == 'failed': error = input_data.get('error', '未知错误') if on_error == 'retry' and retry_count > 0: # 重试逻辑(这里简化处理,实际应该重新执行前一个节点) logger.warning(f"错误处理节点检测到错误,将重试: {error}") # 注意:实际重试需要重新执行前一个节点,这里只记录 exec_result = { 'output': input_data, 'status': 'retry', 'retry_count': retry_count, 'error': error } elif on_error == 'notify': # 通知错误(记录日志) logger.error(f"错误处理节点捕获错误: {error}") exec_result = { 'output': input_data, 'status': 'error_handled', 'error': error, 'notified': True } else: # 停止执行 exec_result = { 'output': input_data, 'status': 'failed', 'error': error, 'stopped': True } else: # 没有错误,正常通过 exec_result = {'output': input_data, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, exec_result.get('output'), duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': input_data, 'status': 'failed', 'error': f'错误处理失败: {str(e)}' } elif node_type == 'csv': # CSV处理节点:CSV解析、生成、转换 node_data = node.get('data', {}) operation = node_data.get('operation', 'parse') # parse, generate, convert delimiter = node_data.get('delimiter', ',') headers = node_data.get('headers', True) encoding = node_data.get('encoding', 'utf-8') try: import csv import io if operation == 'parse': # 解析CSV csv_text = input_data if isinstance(input_data, dict): # 尝试从字典中提取CSV文本 for key in ['csv', 'data', 'content', 'text']: if key in input_data and isinstance(input_data[key], str): csv_text = input_data[key] break if not isinstance(csv_text, str): csv_text = str(csv_text) # 解析CSV csv_reader = csv.DictReader(io.StringIO(csv_text), delimiter=delimiter) if headers else csv.reader(io.StringIO(csv_text), delimiter=delimiter) if headers: result = list(csv_reader) else: # 没有表头,返回数组的数组 result = list(csv_reader) elif operation == 'generate': # 生成CSV data = input_data if isinstance(input_data, dict) and 'data' in input_data: data = input_data['data'] if not isinstance(data, list): data = [data] output = io.StringIO() if data and isinstance(data[0], dict): # 字典列表,使用DictWriter fieldnames = data[0].keys() writer = csv.DictWriter(output, fieldnames=fieldnames, delimiter=delimiter) if headers: writer.writeheader() writer.writerows(data) else: # 数组列表,使用writer writer = csv.writer(output, delimiter=delimiter) if headers and data and isinstance(data[0], list): # 假设第一行是表头 writer.writerow(data[0]) writer.writerows(data[1:]) else: writer.writerows(data) result = output.getvalue() elif operation == 'convert': # 转换CSV格式(改变分隔符等) csv_text = input_data if isinstance(input_data, dict): for key in ['csv', 'data', 'content']: if key in input_data and isinstance(input_data[key], str): csv_text = input_data[key] break if not isinstance(csv_text, str): csv_text = str(csv_text) # 读取并重新写入 reader = csv.reader(io.StringIO(csv_text)) output = io.StringIO() writer = csv.writer(output, delimiter=delimiter) writer.writerows(reader) result = output.getvalue() else: result = input_data exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'CSV处理失败: {str(e)}' } elif node_type == 'object_storage': # 对象存储节点:文件上传、下载、删除、列表 node_data = node.get('data', {}) provider = node_data.get('provider', 's3') # oss, s3, cos operation = node_data.get('operation', 'upload') bucket = node_data.get('bucket', '') key = node_data.get('key', '') file_data = node_data.get('file', '') try: # 简化实现:实际生产环境需要使用boto3(AWS S3)、oss2(阿里云OSS)等 # 这里提供一个接口框架,实际使用时需要安装相应的SDK if operation == 'upload': # 上传文件 if not file_data: # 从input_data中获取文件数据 if isinstance(input_data, dict): file_data = input_data.get('file') or input_data.get('data') or input_data.get('content') else: file_data = input_data # 替换key中的变量 if isinstance(input_data, dict): for k, v in input_data.items(): key = key.replace(f'{{{k}}}', str(v)) # 这里只是模拟上传,实际需要调用相应的SDK logger.info(f"对象存储上传: provider={provider}, bucket={bucket}, key={key}") result = { 'provider': provider, 'bucket': bucket, 'key': key, 'status': 'uploaded', 'url': f"{provider}://{bucket}/{key}" # 模拟URL } elif operation == 'download': # 下载文件 if isinstance(input_data, dict): for k, v in input_data.items(): key = key.replace(f'{{{k}}}', str(v)) logger.info(f"对象存储下载: provider={provider}, bucket={bucket}, key={key}") # 这里只是模拟下载,实际需要调用相应的SDK result = { 'provider': provider, 'bucket': bucket, 'key': key, 'status': 'downloaded', 'data': '模拟文件内容' # 实际应该是文件内容 } elif operation == 'delete': # 删除文件 if isinstance(input_data, dict): for k, v in input_data.items(): key = key.replace(f'{{{k}}}', str(v)) logger.info(f"对象存储删除: provider={provider}, bucket={bucket}, key={key}") result = { 'provider': provider, 'bucket': bucket, 'key': key, 'status': 'deleted' } elif operation == 'list': # 列出文件 prefix = node_data.get('prefix', '') logger.info(f"对象存储列表: provider={provider}, bucket={bucket}, prefix={prefix}") result = { 'provider': provider, 'bucket': bucket, 'prefix': prefix, 'files': [] # 实际应该是文件列表 } else: result = input_data exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'对象存储操作失败: {str(e)}。注意:实际使用需要安装相应的SDK(如boto3、oss2等)' } elif node_type == 'slack': # Slack节点:发送消息、创建频道、获取消息 node_data = node.get('data', {}) operation = node_data.get('operation', 'send_message') token = node_data.get('token', '') channel = node_data.get('channel', '') message = node_data.get('message', '') attachments = node_data.get('attachments', []) try: import httpx # 替换消息中的变量 if isinstance(input_data, dict): for key, value in input_data.items(): message = message.replace(f'{{{key}}}', str(value)) channel = channel.replace(f'{{{key}}}', str(value)) if operation == 'send_message': # 发送消息 url = 'https://slack.com/api/chat.postMessage' headers = { 'Authorization': f'Bearer {token}', 'Content-Type': 'application/json' } payload = { 'channel': channel, 'text': message } if attachments: payload['attachments'] = attachments # 注意:实际使用时需要安装httpx库,这里提供接口框架 # async with httpx.AsyncClient() as client: # response = await client.post(url, headers=headers, json=payload) # result = response.json() # 模拟响应 logger.info(f"Slack发送消息: channel={channel}, message={message[:50]}") result = { 'ok': True, 'channel': channel, 'ts': str(time.time()), 'message': {'text': message} } elif operation == 'create_channel': # 创建频道 url = 'https://slack.com/api/conversations.create' headers = { 'Authorization': f'Bearer {token}', 'Content-Type': 'application/json' } payload = {'name': channel} logger.info(f"Slack创建频道: channel={channel}") result = {'ok': True, 'channel': {'name': channel, 'id': f'C{int(time.time())}'}} elif operation == 'get_messages': # 获取消息 url = f'https://slack.com/api/conversations.history' headers = { 'Authorization': f'Bearer {token}', 'Content-Type': 'application/json' } params = {'channel': channel} logger.info(f"Slack获取消息: channel={channel}") result = {'ok': True, 'messages': []} else: result = input_data exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'Slack操作失败: {str(e)}。注意:需要配置有效的Slack Token' } elif node_type == 'dingtalk' or node_type == 'dingding': # 钉钉节点:发送消息、创建群组、获取消息 node_data = node.get('data', {}) operation = node_data.get('operation', 'send_message') webhook_url = node_data.get('webhook_url', '') access_token = node_data.get('access_token', '') chat_id = node_data.get('chat_id', '') message = node_data.get('message', '') try: import httpx # 替换消息中的变量 if isinstance(input_data, dict): for key, value in input_data.items(): message = message.replace(f'{{{key}}}', str(value)) chat_id = chat_id.replace(f'{{{key}}}', str(value)) if operation == 'send_message': # 发送消息(通过Webhook或API) if webhook_url: # 使用Webhook payload = { 'msgtype': 'text', 'text': {'content': message} } # async with httpx.AsyncClient() as client: # response = await client.post(webhook_url, json=payload) # result = response.json() logger.info(f"钉钉发送消息(Webhook): message={message[:50]}") result = {'errcode': 0, 'errmsg': 'ok'} else: # 使用API url = f'https://oapi.dingtalk.com/chat/send' headers = { 'Content-Type': 'application/json' } payload = { 'access_token': access_token, 'chatid': chat_id, 'msg': { 'msgtype': 'text', 'text': {'content': message} } } logger.info(f"钉钉发送消息(API): chat_id={chat_id}, message={message[:50]}") result = {'errcode': 0, 'errmsg': 'ok'} elif operation == 'create_group': # 创建群组 url = 'https://oapi.dingtalk.com/chat/create' headers = {'Content-Type': 'application/json'} payload = { 'access_token': access_token, 'name': chat_id, 'owner': node_data.get('owner', '') } logger.info(f"钉钉创建群组: name={chat_id}") result = {'errcode': 0, 'errmsg': 'ok', 'chatid': f'chat_{int(time.time())}'} else: result = input_data exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'钉钉操作失败: {str(e)}。注意:需要配置有效的Webhook URL或Access Token' } elif node_type == 'wechat_work' or node_type == 'wecom': # 企业微信节点:发送消息、创建群组、获取消息 node_data = node.get('data', {}) operation = node_data.get('operation', 'send_message') corp_id = node_data.get('corp_id', '') corp_secret = node_data.get('corp_secret', '') agent_id = node_data.get('agent_id', '') chat_id = node_data.get('chat_id', '') message = node_data.get('message', '') try: import httpx # 替换消息中的变量 if isinstance(input_data, dict): for key, value in input_data.items(): message = message.replace(f'{{{key}}}', str(value)) chat_id = chat_id.replace(f'{{{key}}}', str(value)) if operation == 'send_message': # 先获取access_token token_url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken' token_params = { 'corpid': corp_id, 'corpsecret': corp_secret } # async with httpx.AsyncClient() as client: # token_response = await client.get(token_url, params=token_params) # token_data = token_response.json() # access_token = token_data.get('access_token') # 模拟获取token access_token = 'mock_token' # 发送消息 url = f'https://qyapi.weixin.qq.com/cgi-bin/message/send' params = {'access_token': access_token} payload = { 'touser': chat_id or '@all', 'msgtype': 'text', 'agentid': agent_id, 'text': {'content': message} } logger.info(f"企业微信发送消息: chat_id={chat_id}, message={message[:50]}") result = {'errcode': 0, 'errmsg': 'ok'} elif operation == 'create_group': # 创建群组 url = 'https://qyapi.weixin.qq.com/cgi-bin/appchat/create' params = {'access_token': access_token} payload = { 'name': chat_id, 'owner': node_data.get('owner', ''), 'userlist': node_data.get('userlist', []) } logger.info(f"企业微信创建群组: name={chat_id}") result = {'errcode': 0, 'errmsg': 'ok', 'chatid': f'chat_{int(time.time())}'} else: result = input_data exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'企业微信操作失败: {str(e)}。注意:需要配置有效的Corp ID和Secret' } elif node_type == 'sms': # 短信节点(SMS):发送短信、批量发送、短信模板 node_data = node.get('data', {}) provider = node_data.get('provider', 'aliyun') # aliyun, tencent, twilio operation = node_data.get('operation', 'send') phone = node_data.get('phone', '') template = node_data.get('template', '') sign = node_data.get('sign', '') access_key = node_data.get('access_key', '') access_secret = node_data.get('access_secret', '') try: # 替换模板中的变量 if isinstance(input_data, dict): for key, value in input_data.items(): template = template.replace(f'{{{key}}}', str(value)) phone = phone.replace(f'{{{key}}}', str(value)) if operation == 'send': # 发送短信 if provider == 'aliyun': # 阿里云短信(需要安装alibabacloud-dysmsapi20170525) logger.info(f"阿里云短信发送: phone={phone}, template={template[:50]}") result = { 'provider': 'aliyun', 'phone': phone, 'status': 'sent', 'message_id': f'sms_{int(time.time())}' } elif provider == 'tencent': # 腾讯云短信(需要安装tencentcloud-sdk-python) logger.info(f"腾讯云短信发送: phone={phone}, template={template[:50]}") result = { 'provider': 'tencent', 'phone': phone, 'status': 'sent', 'message_id': f'sms_{int(time.time())}' } elif provider == 'twilio': # Twilio短信(需要安装twilio) logger.info(f"Twilio短信发送: phone={phone}, template={template[:50]}") result = { 'provider': 'twilio', 'phone': phone, 'status': 'sent', 'message_id': f'sms_{int(time.time())}' } else: result = {'error': f'不支持的短信提供商: {provider}'} elif operation == 'batch_send': # 批量发送 phones = node_data.get('phones', []) if isinstance(phones, str): phones = [p.strip() for p in phones.split(',')] logger.info(f"批量发送短信: phones={len(phones)}, provider={provider}") result = { 'provider': provider, 'phones': phones, 'status': 'sent', 'count': len(phones) } else: result = input_data exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'短信发送失败: {str(e)}。注意:需要安装相应的SDK(如alibabacloud-dysmsapi20170525、tencentcloud-sdk-python、twilio)' } elif node_type == 'pdf': # PDF处理节点:PDF解析、生成、合并、拆分 node_data = node.get('data', {}) operation = node_data.get('operation', 'extract_text') # extract_text, generate, merge, split pages = node_data.get('pages', '') template = node_data.get('template', '') try: # 注意:需要安装PyPDF2或pdfplumber库 # pip install PyPDF2 pdfplumber if operation == 'extract_text': # 提取文本 pdf_data = input_data if isinstance(input_data, dict): pdf_data = input_data.get('pdf') or input_data.get('data') or input_data.get('file') # 这里只是接口框架,实际需要: # from PyPDF2 import PdfReader # reader = PdfReader(io.BytesIO(pdf_data)) # text = "" # for page in reader.pages: # text += page.extract_text() logger.info(f"PDF提取文本: pages={pages}") result = { 'text': 'PDF文本提取结果(需要安装PyPDF2或pdfplumber)', 'pages': pages or 'all' } elif operation == 'generate': # 生成PDF content = input_data if isinstance(input_data, dict): content = input_data.get('content') or input_data.get('text') or input_data.get('data') # 这里只是接口框架,实际需要: # from reportlab.pdfgen import canvas # 或使用其他PDF生成库 logger.info(f"PDF生成: template={template}") result = { 'pdf': 'PDF生成结果(需要安装reportlab或其他PDF生成库)', 'template': template } elif operation == 'merge': # 合并PDF pdfs = input_data if isinstance(input_data, dict): pdfs = input_data.get('pdfs') or input_data.get('files') if not isinstance(pdfs, list): pdfs = [pdfs] # 这里只是接口框架,实际需要: # from PyPDF2 import PdfMerger # merger = PdfMerger() # for pdf in pdfs: # merger.append(pdf) # result_pdf = merger.write() logger.info(f"PDF合并: count={len(pdfs)}") result = { 'merged_pdf': '合并后的PDF(需要安装PyPDF2)', 'count': len(pdfs) } elif operation == 'split': # 拆分PDF pdf_data = input_data if isinstance(input_data, dict): pdf_data = input_data.get('pdf') or input_data.get('file') # 这里只是接口框架,实际需要: # from PyPDF2 import PdfReader, PdfWriter # reader = PdfReader(pdf_data) # writer = PdfWriter() # for page_num in range(start_page, end_page): # writer.add_page(reader.pages[page_num]) logger.info(f"PDF拆分: pages={pages}") result = { 'split_pdfs': ['拆分后的PDF列表(需要安装PyPDF2)'], 'pages': pages } else: result = input_data exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'PDF处理失败: {str(e)}。注意:需要安装PyPDF2或pdfplumber库(pip install PyPDF2 pdfplumber)' } elif node_type == 'image': # 图像处理节点:图像缩放、裁剪、格式转换、OCR识别 node_data = node.get('data', {}) operation = node_data.get('operation', 'resize') # resize, crop, convert, ocr width = node_data.get('width', 800) height = node_data.get('height', 600) format_type = node_data.get('format', 'png') try: # 注意:需要安装Pillow库 # pip install Pillow # OCR需要安装pytesseract和tesseract-ocr # pip install pytesseract image_data = input_data if isinstance(input_data, dict): image_data = input_data.get('image') or input_data.get('data') or input_data.get('file') if operation == 'resize': # 缩放图像 # 这里只是接口框架,实际需要: # from PIL import Image # import io # img = Image.open(io.BytesIO(image_data)) # img_resized = img.resize((width, height)) # output = io.BytesIO() # img_resized.save(output, format=format_type.upper()) # result = output.getvalue() logger.info(f"图像缩放: {width}x{height}, format={format_type}") result = { 'image': '缩放后的图像数据(需要安装Pillow)', 'width': width, 'height': height, 'format': format_type } elif operation == 'crop': # 裁剪图像 x = node_data.get('x', 0) y = node_data.get('y', 0) crop_width = node_data.get('crop_width', width) crop_height = node_data.get('crop_height', height) # 这里只是接口框架,实际需要: # from PIL import Image # img = Image.open(io.BytesIO(image_data)) # img_cropped = img.crop((x, y, x + crop_width, y + crop_height)) logger.info(f"图像裁剪: ({x}, {y}, {crop_width}, {crop_height})") result = { 'image': '裁剪后的图像数据(需要安装Pillow)', 'crop_box': (x, y, crop_width, crop_height) } elif operation == 'convert': # 格式转换 target_format = node_data.get('target_format', format_type) # 这里只是接口框架,实际需要: # from PIL import Image # img = Image.open(io.BytesIO(image_data)) # output = io.BytesIO() # img.save(output, format=target_format.upper()) # result = output.getvalue() logger.info(f"图像格式转换: {format_type} -> {target_format}") result = { 'image': f'转换后的图像数据(需要安装Pillow)', 'format': target_format } elif operation == 'ocr': # OCR识别 # 这里只是接口框架,实际需要: # from PIL import Image # import pytesseract # img = Image.open(io.BytesIO(image_data)) # text = pytesseract.image_to_string(img, lang='chi_sim+eng') logger.info(f"OCR识别") result = { 'text': 'OCR识别结果(需要安装pytesseract和tesseract-ocr)', 'confidence': 0.95 } else: result = input_data exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'图像处理失败: {str(e)}。注意:需要安装Pillow库(pip install Pillow),OCR需要pytesseract和tesseract-ocr' } elif node_type == 'excel': # Excel处理节点:Excel读取、写入、格式转换、公式计算 node_data = node.get('data', {}) operation = node_data.get('operation', 'read') # read, write, convert, formula sheet = node_data.get('sheet', 'Sheet1') range_str = node_data.get('range', '') format_type = node_data.get('format', 'xlsx') # xlsx, xls, csv try: # 注意:需要安装openpyxl或pandas库 # pip install openpyxl pandas if operation == 'read': # 读取Excel excel_data = input_data if isinstance(input_data, dict): excel_data = input_data.get('excel') or input_data.get('file') or input_data.get('data') # 这里只是接口框架,实际需要: # import pandas as pd # df = pd.read_excel(io.BytesIO(excel_data), sheet_name=sheet) # if range_str: # # 解析范围,如 "A1:C10" # df = df.loc[range_start:range_end] # result = df.to_dict('records') logger.info(f"Excel读取: sheet={sheet}, range={range_str}") result = { 'data': [{'列1': '值1', '列2': '值2'}], 'sheet': sheet, 'range': range_str } elif operation == 'write': # 写入Excel data = input_data if isinstance(input_data, dict): data = input_data.get('data') or input_data.get('rows') if not isinstance(data, list): data = [data] # 这里只是接口框架,实际需要: # import pandas as pd # df = pd.DataFrame(data) # output = io.BytesIO() # df.to_excel(output, sheet_name=sheet, index=False) # result = output.getvalue() logger.info(f"Excel写入: sheet={sheet}, rows={len(data)}") result = { 'excel': '生成的Excel数据(需要安装openpyxl或pandas)', 'sheet': sheet, 'rows': len(data) } elif operation == 'convert': # 格式转换 target_format = node_data.get('target_format', 'csv') excel_data = input_data if isinstance(input_data, dict): excel_data = input_data.get('excel') or input_data.get('file') # 这里只是接口框架,实际需要: # import pandas as pd # df = pd.read_excel(io.BytesIO(excel_data)) # if target_format == 'csv': # result = df.to_csv(index=False) # elif target_format == 'json': # result = df.to_json(orient='records') logger.info(f"Excel格式转换: {format_type} -> {target_format}") result = { 'data': '转换后的数据(需要安装pandas)', 'format': target_format } elif operation == 'formula': # 公式计算 formula = node_data.get('formula', '') data = input_data if isinstance(input_data, dict): data = input_data.get('data') # 这里只是接口框架,实际需要: # import pandas as pd # df = pd.DataFrame(data) # # 使用eval或更安全的方式计算公式 # result = df.eval(formula) logger.info(f"Excel公式计算: formula={formula}") result = { 'result': '公式计算结果(需要安装pandas)', 'formula': formula } else: result = input_data exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'Excel处理失败: {str(e)}。注意:需要安装openpyxl或pandas库(pip install openpyxl pandas)' } elif node_type == 'subworkflow': # 子工作流节点:调用其他工作流 node_data = node.get('data', {}) workflow_id = node_data.get('workflow_id', '') input_mapping = node_data.get('input_mapping', {}) try: # 将当前输入根据映射转换为子工作流输入 sub_input = {} if isinstance(input_mapping, dict): for k, v in input_mapping.items(): if isinstance(v, str) and isinstance(input_data, dict): sub_input[k] = input_data.get(v) or input_data.get(v.strip('{}')) or v else: sub_input[k] = v else: sub_input = input_data # 实际调用子工作流的执行,这里简化为回传映射后的输入 # TODO: 集成 WorkflowEngine 执行指定 workflow_id result = { 'workflow_id': workflow_id, 'input': sub_input, 'status': 'success', 'note': '子工作流执行框架占位,需集成实际调用' } exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'子工作流执行失败: {str(e)}' } elif node_type == 'code': # 代码执行节点:支持简单的Python/JavaScript片段执行(注意安全) node_data = node.get('data', {}) language = node_data.get('language', 'python') code = node_data.get('code', '') timeout = node_data.get('timeout', 30) try: if language.lower() == 'python': # 受限执行环境 local_vars = {'input_data': input_data, 'result': None} exec(code, {'__builtins__': {}}, local_vars) # 注意:生产环境需更严格沙箱 result = local_vars.get('result', local_vars.get('output', input_data)) elif language.lower() == 'javascript': # JS 执行需要外部运行时,这里仅占位 result = { 'status': 'not_implemented', 'message': 'JavaScript执行需集成运行时' } else: result = {'status': 'failed', 'error': f'不支持的语言: {language}'} exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'代码执行失败: {str(e)}' } elif node_type == 'oauth': # OAuth 节点:获取/刷新 Token node_data = node.get('data', {}) provider = node_data.get('provider', 'google') client_id = node_data.get('client_id', '') client_secret = node_data.get('client_secret', '') scopes = node_data.get('scopes', []) try: # 简化占位实现,返回模拟 token token_data = { 'access_token': f'mock_access_token_{provider}', 'expires_in': 3600, 'token_type': 'Bearer', 'scope': scopes } exec_result = {'output': token_data, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, token_data, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'OAuth处理失败: {str(e)}' } elif node_type == 'validator': # 数据验证节点:基于简化的schema检查 node_data = node.get('data', {}) schema = node_data.get('schema', {}) on_error = node_data.get('on_error', 'reject') # reject, continue, transform try: # 简单类型检查 if 'type' in schema: expected_type = schema['type'] actual_type = type(input_data).__name__ if expected_type == 'object' and not isinstance(input_data, dict): raise ValueError(f'期望类型object,实际类型{actual_type}') if expected_type == 'array' and not isinstance(input_data, list): raise ValueError(f'期望类型array,实际类型{actual_type}') result = input_data exec_result = {'output': result, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if on_error == 'continue': return {'output': input_data, 'status': 'success', 'warning': str(e)} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'数据验证失败: {str(e)}' } elif node_type == 'batch': # 批处理节点:数据分批处理 node_data = node.get('data', {}) batch_size = node_data.get('batch_size', 100) mode = node_data.get('mode', 'split') # split, group, aggregate wait_for_completion = node_data.get('wait_for_completion', True) try: data_list = input_data if isinstance(input_data, list) else [input_data] if mode == 'split': batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)] result = batches elif mode == 'group': batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)] result = batches elif mode == 'aggregate': result = { 'count': len(data_list), 'samples': data_list[:min(3, len(data_list))] } else: result = data_list exec_result = {'output': result, 'status': 'success', 'wait': wait_for_completion} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, result, duration) return exec_result except Exception as e: if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': f'批处理失败: {str(e)}' } elif node_type == 'output' or node_type == 'end': # 输出节点:返回最终结果 # 读取节点配置中的输出格式设置 node_data = node.get('data', {}) output_format = node_data.get('output_format', 'text') # 默认纯文本 logger.debug(f"[rjb] End节点处理: node_id={node_id}, output_format={output_format}, input_data={input_data}, input_data type={type(input_data)}") final_output = input_data # 如果配置为JSON格式,直接返回原始数据(或格式化的JSON) if output_format == 'json': # 如果是字典,直接返回JSON格式 if isinstance(input_data, dict): final_output = json_module.dumps(input_data, ensure_ascii=False, indent=2) elif isinstance(input_data, str): # 尝试解析为JSON,如果成功则格式化,否则直接返回 try: parsed = json_module.loads(input_data) final_output = json_module.dumps(parsed, ensure_ascii=False, indent=2) except: final_output = input_data else: final_output = json_module.dumps({'output': input_data}, ensure_ascii=False, indent=2) else: # 默认纯文本格式:递归解包,提取实际的文本内容 if isinstance(input_data, dict): # 优先提取 'output' 字段(LLM节点的标准输出格式) if 'output' in input_data and isinstance(input_data['output'], str): final_output = input_data['output'] # 如果只有一个 key 且是 'input',提取其值 elif len(input_data) == 1 and 'input' in input_data: final_output = input_data['input'] # 如果包含 'solution' 字段,提取其值 elif 'solution' in input_data and isinstance(input_data['solution'], str): final_output = input_data['solution'] # 如果input_data是字符串类型的字典(JSON字符串),尝试解析 elif isinstance(input_data, str): try: parsed = json_module.loads(input_data) if isinstance(parsed, dict) and 'output' in parsed: final_output = parsed['output'] elif isinstance(parsed, str): final_output = parsed except: final_output = input_data logger.debug(f"[rjb] End节点提取第一层: final_output={final_output}, type={type(final_output)}") # 如果提取的值仍然是字典且只有一个 'input' key,继续提取 if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output: final_output = final_output['input'] logger.debug(f"[rjb] End节点提取第二层: final_output={final_output}, type={type(final_output)}") # 确保最终输出是字符串(对于人机交互场景) # 如果是字典,尝试转换为字符串;如果是其他类型,也转换为字符串 if not isinstance(final_output, str): if isinstance(final_output, dict): # 如果是字典,尝试提取文本内容或转换为JSON字符串 # 优先查找常见的文本字段 if 'text' in final_output: final_output = str(final_output['text']) elif 'content' in final_output: final_output = str(final_output['content']) elif 'message' in final_output: final_output = str(final_output['message']) elif 'response' in final_output: final_output = str(final_output['response']) elif len(final_output) == 1: # 如果只有一个key,直接使用其值 final_output = str(list(final_output.values())[0]) else: # 否则转换为纯文本(不是JSON) # 尝试提取所有文本字段并组合,但排除系统字段和用户查询字段 text_parts = [] exclude_keys = {'status', 'error', 'timestamp', 'node_id', 'execution_time', 'query', 'USER_INPUT', 'user_input', 'user_query'} # 优先使用input字段(LLM的实际输出) if 'input' in final_output and isinstance(final_output['input'], str): final_output = final_output['input'] else: for key, value in final_output.items(): if key in exclude_keys: continue if isinstance(value, str) and value.strip(): # 如果值本身已经包含 "key: " 格式,直接使用 if value.strip().startswith(f"{key}:"): text_parts.append(value.strip()) else: text_parts.append(value.strip()) elif isinstance(value, (int, float, bool)): text_parts.append(f"{key}: {value}") if text_parts: final_output = "\n".join(text_parts) else: final_output = str(final_output) else: final_output = str(final_output) # 清理输出文本:移除常见的字段前缀(如 "input: ", "query: " 等) if isinstance(final_output, str): import re # 移除行首的 "input: ", "query: ", "output: " 等前缀 lines = final_output.split('\n') cleaned_lines = [] for line in lines: # 匹配行首的 "字段名: " 格式并移除 # 但保留内容本身 line = re.sub(r'^(input|query|output|result|response|message|content|text):\s*', '', line, flags=re.IGNORECASE) if line.strip(): # 只保留非空行 cleaned_lines.append(line) # 如果清理后还有内容,使用清理后的版本 if cleaned_lines: final_output = '\n'.join(cleaned_lines) # 如果清理后为空,使用原始输出(避免丢失所有内容) elif final_output.strip(): # 如果原始输出不为空,但清理后为空,说明可能格式特殊,尝试更宽松的清理 # 只移除明显的 "input: " 和 "query: " 前缀,保留其他内容 final_output = re.sub(r'^(input|query):\s*', '', final_output, flags=re.IGNORECASE | re.MULTILINE) if not final_output.strip(): final_output = str(input_data) # 如果还是空,使用原始输入 logger.debug(f"[rjb] End节点最终输出: output_format={output_format}, final_output={final_output[:100] if isinstance(final_output, str) else type(final_output)}") result = {'output': final_output, 'status': 'success'} if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_complete(node_id, node_type, final_output, duration) return result else: # 未知节点类型 logger.warning(f"[rjb] 未知节点类型: node_id={node_id}, node_type={node_type}, node keys={list(node.keys())}") return { 'output': input_data, 'status': 'success', 'message': f'节点类型 {node_type} 暂未实现' } except Exception as e: logger.error(f"节点执行失败: {node_id} ({node_type}) - {str(e)}", exc_info=True) if self.logger: duration = int((time.time() - start_time) * 1000) self.logger.log_node_error(node_id, node_type, e, duration) return { 'output': None, 'status': 'failed', 'error': str(e), 'node_id': node_id, 'node_type': node_type } async def execute(self, input_data: Dict[str, Any]) -> Dict[str, Any]: """ 执行完整工作流 Args: input_data: 初始输入数据 Returns: 执行结果 """ # 记录工作流开始执行 if self.logger: self.logger.info("工作流开始执行", data={"input": input_data}) # 初始化节点输出 self.node_outputs = {} active_edges = self.edges.copy() # 活跃的边列表 executed_nodes = set() # 已执行的节点 # 按拓扑顺序执行节点(动态构建执行图) results = {} while True: # 构建当前活跃的执行图 execution_order = self.build_execution_graph(active_edges) logger.debug(f"[rjb] 当前执行图: {execution_order}, 活跃边数: {len(active_edges)}, 已执行节点: {executed_nodes}") # 找到下一个要执行的节点(未执行且入度为0) next_node_id = None for node_id in execution_order: if node_id not in executed_nodes: # 检查所有前置节点是否已执行 can_execute = True incoming_edges = [e for e in active_edges if e['target'] == node_id] if not incoming_edges: # 没有入边,可能是起始节点或孤立节点 if node_id not in [n['id'] for n in self.nodes.values() if n.get('type') == 'start']: # 不是起始节点,但有入边被过滤了,不应该执行 logger.debug(f"[rjb] 节点 {node_id} 没有入边,跳过执行") continue for edge in incoming_edges: if edge['source'] not in executed_nodes: can_execute = False logger.debug(f"[rjb] 节点 {node_id} 的前置节点 {edge['source']} 未执行,不能执行") break if can_execute: next_node_id = node_id logger.info(f"[rjb] 选择执行节点: {next_node_id}, 类型: {self.nodes[next_node_id].get('type')}, 入边数: {len(incoming_edges)}") break if not next_node_id: break # 没有更多节点可执行 node = self.nodes[next_node_id] executed_nodes.add(next_node_id) # 调试:检查节点数据结构 if node.get('type') == 'llm': logger.debug(f"[rjb] 执行LLM节点: node_id={next_node_id}, node keys={list(node.keys())}, data keys={list(node.get('data', {}).keys()) if node.get('data') else []}") # 获取节点输入(使用活跃的边) node_input = self.get_node_input(next_node_id, self.node_outputs, active_edges) # 如果是起始节点,使用初始输入 if node.get('type') == 'start' and not node_input: node_input = input_data logger.info(f"[rjb] Start节点使用初始输入: node_id={next_node_id}, node_input={node_input}") # 调试:记录节点输入数据 if node.get('type') == 'llm': logger.info(f"[rjb] LLM节点输入: node_id={next_node_id}, node_input={node_input}, node_outputs keys={list(self.node_outputs.keys())}") if 'start-1' in self.node_outputs: logger.info(f"[rjb] Start节点输出内容: {self.node_outputs['start-1']}") # 执行节点 result = await self.execute_node(node, node_input) results[next_node_id] = result # 保存节点输出 if result.get('status') == 'success': output_value = result.get('output', {}) self.node_outputs[next_node_id] = output_value if node.get('type') == 'start': logger.info(f"[rjb] Start节点输出已保存: node_id={next_node_id}, output={output_value}, output_type={type(output_value)}") # 如果是条件节点或Switch节点,根据分支结果过滤边 if node.get('type') == 'condition': branch = result.get('branch', 'false') logger.info(f"[rjb] 条件节点分支过滤: node_id={next_node_id}, branch={branch}") # 移除不符合条件的边 # 只保留:1) 不是从条件节点出发的边,或 2) 从条件节点出发且sourceHandle匹配分支的边 edges_to_remove = [] edges_to_keep = [] for edge in active_edges: if edge['source'] == next_node_id: # 这是从条件节点出发的边 edge_handle = edge.get('sourceHandle') if edge_handle == branch: # sourceHandle匹配分支,保留 edges_to_keep.append(edge) logger.info(f"[rjb] 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})") else: # sourceHandle不匹配或为None,移除 edges_to_remove.append(edge) logger.info(f"[rjb] 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch})") else: # 不是从条件节点出发的边,保留 edges_to_keep.append(edge) active_edges = edges_to_keep elif node.get('type') == 'switch': branch = result.get('branch', 'default') logger.info(f"[rjb] Switch节点分支过滤: node_id={next_node_id}, branch={branch}") # 记录过滤前的边信息 edges_before = [e for e in active_edges if e['source'] == next_node_id] logger.info(f"[rjb] Switch节点过滤前: 从节点出发的边有{len(edges_before)}条") for edge in edges_before: logger.info(f"[rjb] 边 {edge.get('id')}: sourceHandle={edge.get('sourceHandle')}, target={edge.get('target')}") # 移除不匹配的边 edges_to_keep = [] edges_removed_count = 0 removed_source_nodes = set() # 记录被移除边的源节点 for edge in active_edges: if edge['source'] == next_node_id: # 这是从Switch节点出发的边 edge_handle = edge.get('sourceHandle') if edge_handle == branch: # sourceHandle匹配分支,保留 edges_to_keep.append(edge) logger.info(f"[rjb] ✅ 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})") else: # sourceHandle不匹配,移除 edges_removed_count += 1 target_id = edge.get('target') removed_source_nodes.add(target_id) # 记录目标节点(这些节点将不再可达) logger.info(f"[rjb] ❌ 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch}, target={target_id})") else: # 不是从Switch节点出发的边,保留 edges_to_keep.append(edge) # 重要:移除那些指向被过滤节点的边(这些边来自被过滤的LLM节点) # 例如:如果llm-question被过滤了,那么llm-question → merge-response的边也应该被移除 additional_removed = 0 for edge in list(edges_to_keep): # 使用list副本,因为我们要修改原列表 if edge['source'] in removed_source_nodes: # 这条边来自被过滤的节点,也应该被移除 edges_to_keep.remove(edge) additional_removed += 1 logger.info(f"[rjb] ❌ 移除来自被过滤节点的边: {edge.get('id')} ({edge.get('source')} → {edge.get('target')})") edges_removed_count += additional_removed active_edges = edges_to_keep filter_info = { 'branch': branch, 'edges_before': len(edges_before), 'edges_kept': len([e for e in edges_to_keep if e['source'] == next_node_id]), 'edges_removed': edges_removed_count } logger.info(f"[rjb] Switch节点过滤后: 保留{len(active_edges)}条边(其中从Switch节点出发的{filter_info['edges_kept']}条),移除{edges_removed_count}条边") # 记录过滤后的活跃边 remaining_switch_edges = [e for e in active_edges if e['source'] == next_node_id] logger.info(f"[rjb] Switch节点过滤后剩余的边: {[e.get('id') + '->' + e.get('target') for e in remaining_switch_edges]}") # 重要:找出那些不再可达的节点(这些节点只通过被移除的边连接) removed_targets = set() for edge in edges_before: if edge not in edges_to_keep: target_id = edge.get('target') removed_targets.add(target_id) logger.info(f"[rjb] ❌ 节点 {target_id} 的边已被移除,该节点将不会被执行") # 关键修复:立即重新构建执行图,确保不再可达的节点不在执行图中 # 这样在下次循环时,这些节点就不会被选择执行 logger.info(f"[rjb] Switch节点过滤后,重新构建执行图(排除 {len(removed_targets)} 个不再可达的节点)") # 同时记录到数据库 if self.logger: self.logger.info( f"Switch节点分支过滤: branch={branch}, 保留{filter_info['edges_kept']}条边,移除{edges_removed_count}条边", node_id=next_node_id, node_type='switch', data=filter_info ) # 如果是循环节点,跳过循环体的节点(循环体已在节点内部执行) if node.get('type') in ['loop', 'foreach']: # 标记循环体的节点为已执行(简化处理) for edge in active_edges[:]: # 使用切片复制列表 if edge.get('source') == next_node_id: target_id = edge.get('target') if target_id in self.nodes: # 检查是否是循环结束节点 target_node = self.nodes[target_id] if target_node.get('type') not in ['loop_end', 'end']: # 标记为已执行(循环体已在循环节点内部执行) executed_nodes.add(target_id) # 继续查找循环体内的节点 self._mark_loop_body_executed(target_id, executed_nodes, active_edges) else: # 执行失败,停止工作流 error_msg = result.get('error', '未知错误') node_type = node.get('type', 'unknown') logger.error(f"工作流执行失败 - 节点: {next_node_id} ({node_type}), 错误: {error_msg}") raise WorkflowExecutionError( detail=error_msg, node_id=next_node_id ) # 返回最终结果(最后一个执行的节点的输出) if executed_nodes: # 找到最后一个节点(没有出边的节点) last_node_id = None for node_id in executed_nodes: has_outgoing = any(edge['source'] == node_id for edge in active_edges) if not has_outgoing: last_node_id = node_id break if not last_node_id: # 如果没有找到,使用最后一个执行的节点 last_node_id = list(executed_nodes)[-1] # 获取最终结果 final_output = self.node_outputs.get(last_node_id) # 如果最终输出是字典且只有一个 'input' key,提取其值 # 这样可以确保最终结果不是重复包装的格式 if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output: final_output = final_output['input'] # 如果提取的值仍然是字典且只有一个 'input' key,继续提取 if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output: final_output = final_output['input'] # 确保最终结果是字符串(对于人机交互场景) # 如果是字典,尝试转换为字符串;如果是其他类型,也转换为字符串 if not isinstance(final_output, str): if isinstance(final_output, dict): # 如果是字典,尝试提取文本内容或转换为JSON字符串 # 优先查找常见的文本字段 if 'text' in final_output: final_output = str(final_output['text']) elif 'content' in final_output: final_output = str(final_output['content']) elif 'message' in final_output: final_output = str(final_output['message']) elif 'response' in final_output: final_output = str(final_output['response']) elif len(final_output) == 1: # 如果只有一个key,直接使用其值 final_output = str(list(final_output.values())[0]) else: # 否则转换为JSON字符串 import json as json_module final_output = json_module.dumps(final_output, ensure_ascii=False) else: final_output = str(final_output) final_result = { 'status': 'completed', 'result': final_output, 'node_results': results } # 记录工作流执行完成 if self.logger: self.logger.info("工作流执行完成", data={"result": final_result.get('result')}) return final_result if self.logger: self.logger.warn("工作流执行完成,但没有执行任何节点") return {'status': 'completed', 'result': None}