Files
aiagent/backend/app/services/workflow_engine.py
2026-01-22 09:59:02 +08:00

4217 lines
229 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工作流执行引擎
"""
from typing import Dict, Any, List, Optional
import asyncio
from collections import defaultdict, deque
import json
import logging
import re
import time
from app.services.llm_service import llm_service
from app.services.condition_parser import condition_parser
from app.services.data_transformer import data_transformer
from app.core.exceptions import WorkflowExecutionError
from app.core.database import SessionLocal
from app.models.agent import Agent
from app.core.config import settings
logger = logging.getLogger(__name__)
class WorkflowEngine:
"""工作流执行引擎"""
def __init__(self, workflow_id: str, workflow_data: Dict[str, Any], logger=None, db=None):
"""
初始化工作流引擎
Args:
workflow_id: 工作流ID
workflow_data: 工作流数据包含nodes和edges
logger: 执行日志记录器(可选)
db: 数据库会话可选用于Agent节点加载Agent配置
"""
self.workflow_id = workflow_id
self.nodes = {node['id']: node for node in workflow_data.get('nodes', [])}
self.edges = workflow_data.get('edges', [])
self.execution_graph = None
self.node_outputs = {}
self.logger = logger
self.db = db
def build_execution_graph(self, active_edges: Optional[List[Dict[str, Any]]] = None) -> List[str]:
"""
构建执行图DAG并返回拓扑排序结果
Args:
active_edges: 活跃的边列表(用于条件分支过滤)
Returns:
拓扑排序后的节点ID列表
"""
# 使用活跃的边,如果没有提供则使用所有边
edges_to_use = active_edges if active_edges is not None else self.edges
# 构建邻接表和入度表
graph = defaultdict(list)
in_degree = defaultdict(int)
# 初始化所有节点的入度
for node_id in self.nodes.keys():
in_degree[node_id] = 0
# 构建图
for edge in edges_to_use:
source = edge['source']
target = edge['target']
graph[source].append(target)
in_degree[target] += 1
# 拓扑排序Kahn算法
queue = deque()
result = []
# 找到所有入度为0的节点起始节点
for node_id in self.nodes.keys():
if in_degree[node_id] == 0:
queue.append(node_id)
while queue:
node_id = queue.popleft()
result.append(node_id)
# 处理该节点的所有出边
for neighbor in graph[node_id]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
# 检查是否有环(只检查可达节点)
reachable_nodes = set(result)
if len(reachable_nodes) < len(self.nodes):
# 有些节点不可达,这是正常的(条件分支)
pass
self.execution_graph = result
return result
def get_node_input(self, node_id: str, node_outputs: Dict[str, Any], active_edges: Optional[List[Dict[str, Any]]] = None) -> Dict[str, Any]:
"""
获取节点的输入数据
Args:
node_id: 节点ID
node_outputs: 所有节点的输出数据
active_edges: 活跃的边列表(用于条件分支过滤)
Returns:
节点的输入数据
"""
# 使用活跃的边,如果没有提供则使用所有边
edges_to_use = active_edges if active_edges is not None else self.edges
# 找到所有指向该节点的边
input_data = {}
for edge in edges_to_use:
if edge['target'] == node_id:
source_id = edge['source']
source_output = node_outputs.get(source_id, {})
logger.info(f"[rjb] 获取节点输入: target={node_id}, source={source_id}, source_output={source_output}, source_output_type={type(source_output)}, sourceHandle={edge.get('sourceHandle')}")
# 如果有sourceHandle使用它作为key
if 'sourceHandle' in edge and edge['sourceHandle']:
input_data[edge['sourceHandle']] = source_output
# 重要即使有sourceHandle也要保留记忆相关字段conversation_history、user_profile、context
# 这些字段应该始终传递到下游节点
if isinstance(source_output, dict):
memory_fields = ['conversation_history', 'user_profile', 'context', 'memory']
for field in memory_fields:
if field in source_output:
input_data[field] = source_output[field]
logger.info(f"[rjb] 保留记忆字段 {field} 到节点 {node_id} 的输入")
else:
# 否则合并所有输入
if isinstance(source_output, dict):
# 如果source_output包含output字段展开它
if 'output' in source_output and isinstance(source_output['output'], dict):
# 将output中的内容展开到顶层
input_data.update(source_output['output'])
# 保留其他字段如status
for key, value in source_output.items():
if key != 'output':
input_data[key] = value
else:
# 直接展开source_output的内容
input_data.update(source_output)
logger.info(f"[rjb] 展开source_output后: input_data={input_data}")
else:
# 如果source_output不是字典包装到input字段
input_data['input'] = source_output
logger.info(f"[rjb] source_output不是字典包装到input字段: input_data={input_data}")
# 重要对于LLM节点和cache节点如果输入中没有memory字段尝试从所有已执行的节点中查找并合并记忆字段
# 这样可以确保即使上游节点没有传递记忆信息,这些节点也能访问到记忆
node_type = None
node = self.nodes.get(node_id)
if node:
node_type = node.get('type')
# 对于LLM节点和cache节点特别是cache-update需要memory字段
if node_type in ['llm', 'cache'] and 'memory' not in input_data:
# 从所有已执行的节点中查找memory字段
for executed_node_id, node_output in self.node_outputs.items():
if isinstance(node_output, dict):
# 检查是否有memory字段
if 'memory' in node_output:
input_data['memory'] = node_output['memory']
logger.info(f"[rjb] 为{node_type}节点 {node_id} 从节点 {executed_node_id} 获取memory字段")
break
# 或者检查是否有conversation_history等记忆字段
elif 'conversation_history' in node_output:
# 构建memory对象
memory = {}
for field in ['conversation_history', 'user_profile', 'context']:
if field in node_output:
memory[field] = node_output[field]
if memory:
input_data['memory'] = memory
logger.info(f"[rjb] 为{node_type}节点 {node_id} 从节点 {executed_node_id} 构建memory对象: {list(memory.keys())}")
break
# 如果input_data中没有query字段尝试从所有已执行的节点中查找特别是start节点
if 'query' not in input_data:
# 优先查找start节点
for node_id_key in ['start-1', 'start']:
if node_id_key in node_outputs:
node_output = node_outputs[node_id_key]
if isinstance(node_output, dict):
# 检查顶层字段因为node_outputs存储的是output字段的内容
if 'query' in node_output:
input_data['query'] = node_output['query']
logger.debug(f"[rjb] 从节点 {node_id_key} 中获取query: {input_data['query']}")
break
# 检查output字段兼容性
elif 'output' in node_output and isinstance(node_output['output'], dict):
if 'query' in node_output['output']:
input_data['query'] = node_output['output']['query']
logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取query: {input_data['query']}")
break
# 如果还没找到,遍历所有节点
if 'query' not in input_data:
for node_id_key, node_output in node_outputs.items():
if isinstance(node_output, dict):
# 检查顶层字段
if 'query' in node_output:
input_data['query'] = node_output['query']
logger.debug(f"[rjb] 从节点 {node_id_key} 中获取query: {input_data['query']}")
break
# 检查output字段兼容性
elif 'output' in node_output and isinstance(node_output['output'], dict):
if 'query' in node_output['output']:
input_data['query'] = node_output['output']['query']
logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取query: {input_data['query']}")
break
# 如果input_data中没有requirement_analysis字段尝试从所有已执行的节点中查找
if 'requirement_analysis' not in input_data:
# 优先查找requirement-analysis节点
for node_id_key in ['llm-requirement-analysis', 'requirement-analysis']:
if node_id_key in node_outputs:
node_output = node_outputs[node_id_key]
if isinstance(node_output, dict):
# 检查顶层字段因为node_outputs存储的是output字段的内容
if 'requirement_analysis' in node_output:
input_data['requirement_analysis'] = node_output['requirement_analysis']
logger.debug(f"[rjb] 从节点 {node_id_key} 中获取requirement_analysis")
break
# 检查output字段兼容性
elif 'output' in node_output and isinstance(node_output['output'], dict):
if 'requirement_analysis' in node_output['output']:
input_data['requirement_analysis'] = node_output['output']['requirement_analysis']
logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取requirement_analysis")
break
# 如果还没找到,遍历所有节点
if 'requirement_analysis' not in input_data:
for node_id_key, node_output in node_outputs.items():
if isinstance(node_output, dict):
# 检查顶层字段
if 'requirement_analysis' in node_output:
input_data['requirement_analysis'] = node_output['requirement_analysis']
logger.debug(f"[rjb] 从节点 {node_id_key} 中获取requirement_analysis")
break
# 检查output字段兼容性
elif 'output' in node_output and isinstance(node_output['output'], dict):
if 'requirement_analysis' in node_output['output']:
input_data['requirement_analysis'] = node_output['output']['requirement_analysis']
logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取requirement_analysis")
break
logger.debug(f"[rjb] 节点输入结果: node_id={node_id}, input_data={input_data}")
return input_data
def _get_nested_value(self, data: Dict[str, Any], path: str) -> Any:
"""
从嵌套字典中获取值(支持点号路径和数组索引)
Args:
data: 数据字典
path: 路径,如 "user.name""items[0].price"
Returns:
路径对应的值
"""
if not path:
return data
parts = path.split('.')
result = data
for part in parts:
if '[' in part and ']' in part:
# 处理数组索引,如 "items[0]"
key = part[:part.index('[')]
index_str = part[part.index('[') + 1:part.index(']')]
if isinstance(result, dict):
result = result.get(key)
elif isinstance(result, list):
try:
result = result[int(index_str)]
except (ValueError, IndexError):
return None
else:
return None
if result is None:
return None
else:
# 普通键访问
if isinstance(result, dict):
result = result.get(part)
else:
return None
if result is None:
return None
return result
async def _execute_loop_body(self, loop_node_id: str, loop_input: Dict[str, Any], iteration_index: int) -> Dict[str, Any]:
"""
执行循环体
Args:
loop_node_id: 循环节点ID
loop_input: 循环体的输入数据
iteration_index: 当前迭代索引
Returns:
循环体的执行结果
"""
# 找到循环节点的直接子节点(循环体开始节点)
loop_body_start_nodes = []
for edge in self.edges:
if edge.get('source') == loop_node_id:
target_id = edge.get('target')
if target_id and target_id in self.nodes:
loop_body_start_nodes.append(target_id)
if not loop_body_start_nodes:
# 如果没有子节点,直接返回输入数据
return {'output': loop_input, 'status': 'success'}
# 执行循环体:从循环体开始节点执行到循环结束节点或没有更多节点
# 简化处理:只执行第一个子节点链
executed_in_loop = set()
loop_results = {}
current_node_id = loop_body_start_nodes[0] # 简化:只执行第一个子节点链
# 执行循环体内的节点(简化版本:只执行直接连接的子节点)
max_iterations = 100 # 防止无限循环
iteration = 0
while current_node_id and iteration < max_iterations:
iteration += 1
if current_node_id in executed_in_loop:
break # 避免循环体内部循环
if current_node_id not in self.nodes:
break
node = self.nodes[current_node_id]
executed_in_loop.add(current_node_id)
# 如果是循环结束节点,停止执行
if node.get('type') == 'loop_end' or node.get('type') == 'end':
break
# 执行节点
result = await self.execute_node(node, loop_input)
loop_results[current_node_id] = result
if result.get('status') != 'success':
return result
# 更新输入数据为当前节点的输出
if result.get('output'):
if isinstance(result.get('output'), dict):
loop_input = {**loop_input, **result.get('output')}
else:
loop_input = {**loop_input, 'result': result.get('output')}
# 找到下一个节点(简化:只找第一个子节点)
next_node_id = None
for edge in self.edges:
if edge.get('source') == current_node_id:
target_id = edge.get('target')
if target_id and target_id in self.nodes and target_id not in executed_in_loop:
# 跳过循环节点本身
if target_id != loop_node_id:
next_node_id = target_id
break
current_node_id = next_node_id
# 返回最后一个节点的输出
if loop_results:
last_result = list(loop_results.values())[-1]
return last_result
return {'output': loop_input, 'status': 'success'}
def _mark_loop_body_executed(self, node_id: str, executed_nodes: set, active_edges: List[Dict[str, Any]]):
"""
递归标记循环体内的节点为已执行
Args:
node_id: 当前节点ID
executed_nodes: 已执行节点集合
active_edges: 活跃的边列表
"""
if node_id in executed_nodes:
return
executed_nodes.add(node_id)
# 查找所有子节点
for edge in active_edges:
if edge.get('source') == node_id:
target_id = edge.get('target')
if target_id in self.nodes:
target_node = self.nodes[target_id]
# 如果是循环结束节点,停止递归
if target_node.get('type') in ['loop_end', 'end']:
continue
# 递归标记子节点
self._mark_loop_body_executed(target_id, executed_nodes, active_edges)
async def execute_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
"""
执行单个节点
Args:
node: 节点配置
input_data: 输入数据
Returns:
节点执行结果
"""
# 确保可以访问全局的 json 模块
import json as json_module
node_type = node.get('type', 'unknown')
node_id = node.get('id')
start_time = time.time()
# 记录节点开始执行
if self.logger:
self.logger.log_node_start(node_id, node_type, input_data)
try:
if node_type == 'start':
# 起始节点:返回输入数据
logger.debug(f"[rjb] 开始节点执行: node_id={node_id}, input_data={input_data}")
result = {'output': input_data, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
logger.debug(f"[rjb] 开始节点输出: node_id={node_id}, output={result.get('output')}")
return result
elif node_type == 'input':
# 输入节点:处理输入数据
result = {'output': input_data, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
elif node_type == 'llm' or node_type == 'template':
# LLM节点调用AI模型
node_data = node.get('data', {})
logger.debug(f"[rjb] LLM节点执行: node_id={node_id}, input_data={input_data}, input_data type={type(input_data)}")
logger.debug(f"[rjb] LLM节点数据: node_id={node_id}, node_data keys={list(node_data.keys())}, api_key={'已配置' if node_data.get('api_key') else '未配置'}")
prompt = node_data.get('prompt', '')
# 如果prompt为空使用默认提示词
if not prompt:
prompt = "请处理以下输入数据:\n{input}"
# 格式化prompt替换变量
try:
# 将input_data转换为字符串用于格式化
if isinstance(input_data, dict):
# 支持两种格式的变量:{key} 和 {{key}}
formatted_prompt = prompt
has_unfilled_variables = False
has_any_placeholder = False
import re
# 检查是否有任何占位符
has_any_placeholder = bool(re.search(r'\{\{?\w+\}?\}', prompt))
# 首先处理 {{variable}} 和 {{variable.path}} 格式(模板节点常用)
# 支持嵌套路径,如 {{memory.conversation_history}}
double_brace_vars = re.findall(r'\{\{([^}]+)\}\}', prompt)
for var_path in double_brace_vars:
# 尝试从input_data中获取值支持嵌套路径
value = self._get_nested_value(input_data, var_path)
# 如果变量未找到,尝试常见的别名映射
if value is None:
# user_input 可以映射到 query、input、USER_INPUT 等字段
if var_path == 'user_input':
for alias in ['query', 'input', 'USER_INPUT', 'user_input', 'text', 'message', 'content']:
value = input_data.get(alias)
if value is not None:
# 如果值是字典,尝试从中提取字符串值
if isinstance(value, dict):
for sub_key in ['query', 'input', 'text', 'message', 'content']:
if sub_key in value:
value = value[sub_key]
break
break
# output 可以映射到 right 字段LLM节点的输出通常存储在right字段中
elif var_path == 'output':
# 尝试从right字段中提取
right_value = input_data.get('right')
logger.info(f"[rjb] LLM节点查找output变量: right_value类型={type(right_value)}, right_value={str(right_value)[:100] if right_value else None}")
if right_value is not None:
# 如果right是字符串直接使用
if isinstance(right_value, str):
value = right_value
logger.info(f"[rjb] LLM节点从right字段字符串提取output: {value[:100]}")
# 如果right是字典尝试递归查找字符串值
elif isinstance(right_value, dict):
# 尝试从right.right.right...中提取处理嵌套的right字段
current = right_value
depth = 0
while isinstance(current, dict) and depth < 10:
if 'right' in current:
current = current['right']
depth += 1
if isinstance(current, str):
value = current
logger.info(f"[rjb] LLM节点从right字段嵌套{depth}提取output: {value[:100]}")
break
else:
# 如果没有right字段尝试其他可能的字段
for key in ['content', 'text', 'message', 'output']:
if key in current and isinstance(current[key], str):
value = current[key]
logger.info(f"[rjb] LLM节点从right字段中找到{key}字段: {value[:100]}")
break
if value is not None:
break
break
if value is None:
logger.warning(f"[rjb] LLM节点无法从right字段中提取outputright结构: {str(right_value)[:200]}")
if value is not None:
# 替换 {{variable}} 或 {{variable.path}} 为实际值
# 特殊处理如果是memory.conversation_history格式化为易读的对话格式
if var_path == 'memory.conversation_history' and isinstance(value, list):
# 将对话历史格式化为易读的文本格式
formatted_history = []
for msg in value:
role = msg.get('role', 'unknown')
content = msg.get('content', '')
if role == 'user':
formatted_history.append(f"用户:{content}")
elif role == 'assistant':
formatted_history.append(f"助手:{content}")
else:
formatted_history.append(f"{role}{content}")
replacement = '\n'.join(formatted_history) if formatted_history else '(暂无对话历史)'
else:
# 其他情况使用JSON格式
replacement = json_module.dumps(value, ensure_ascii=False) if isinstance(value, (dict, list)) else str(value)
formatted_prompt = formatted_prompt.replace(f'{{{{{var_path}}}}}', replacement)
# 对于conversation_history显示完整内容以便调试
if var_path == 'memory.conversation_history':
logger.info(f"[rjb] LLM节点替换变量: {var_path} = {replacement[:500] if len(replacement) > 500 else replacement}")
else:
logger.info(f"[rjb] LLM节点替换变量: {var_path} = {str(replacement)[:200]}")
else:
has_unfilled_variables = True
logger.warning(f"[rjb] LLM节点变量未找到: {var_path}, input_data keys: {list(input_data.keys()) if isinstance(input_data, dict) else 'not dict'}")
# 然后处理 {key} 格式
for key, value in input_data.items():
placeholder = f'{{{key}}}'
if placeholder in formatted_prompt:
formatted_prompt = formatted_prompt.replace(
placeholder,
json_module.dumps(value, ensure_ascii=False) if isinstance(value, (dict, list)) else str(value)
)
# 如果还有{input}占位符替换为整个input_data
if '{input}' in formatted_prompt:
formatted_prompt = formatted_prompt.replace(
'{input}',
json_module.dumps(input_data, ensure_ascii=False)
)
# 提取用户的实际查询内容(优先提取)
user_query = None
logger.info(f"[rjb] 开始提取user_query: input_data={input_data}, input_data_type={type(input_data)}")
if isinstance(input_data, dict):
# 首先检查是否有嵌套的input字段
nested_input = input_data.get('input')
logger.info(f"[rjb] 检查嵌套input: nested_input={nested_input}, nested_input_type={type(nested_input) if nested_input else None}")
if isinstance(nested_input, dict):
# 从嵌套的input中提取
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
if key in nested_input:
user_query = nested_input[key]
logger.info(f"[rjb] 从嵌套input中提取到user_query: key={key}, user_query={user_query}")
break
# 如果还没有,从顶层提取
if not user_query:
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
if key in input_data:
value = input_data[key]
logger.info(f"[rjb] 从顶层提取: key={key}, value={value}, value_type={type(value)}")
# 如果值是字符串,直接使用
if isinstance(value, str):
user_query = value
logger.info(f"[rjb] 提取到字符串user_query: {user_query}")
break
# 如果值是字典,尝试从中提取
elif isinstance(value, dict):
for sub_key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
if sub_key in value:
user_query = value[sub_key]
logger.info(f"[rjb] 从字典值中提取到user_query: sub_key={sub_key}, user_query={user_query}")
break
if user_query:
break
# 如果还是没有使用整个input_data但排除系统字段
if not user_query:
filtered_data = {k: v for k, v in input_data.items() if not k.startswith('_')}
logger.info(f"[rjb] 使用filtered_data: filtered_data={filtered_data}")
if filtered_data:
# 如果只有一个字段且是字符串,直接使用
if len(filtered_data) == 1:
single_value = list(filtered_data.values())[0]
if isinstance(single_value, str):
user_query = single_value
logger.info(f"[rjb] 从单个字符串字段提取到user_query: {user_query}")
elif isinstance(single_value, dict):
# 从字典中提取第一个字符串值
for v in single_value.values():
if isinstance(v, str):
user_query = v
logger.info(f"[rjb] 从字典的单个字段中提取到user_query: {user_query}")
break
if not user_query:
user_query = json_module.dumps(filtered_data, ensure_ascii=False) if len(filtered_data) > 1 else str(list(filtered_data.values())[0])
logger.info(f"[rjb] 使用JSON或字符串转换: user_query={user_query}")
logger.info(f"[rjb] 最终提取的user_query: {user_query}")
# 如果prompt中没有占位符或者仍有未填充的变量将用户输入附加到prompt
is_generic_instruction = False # 初始化变量
if not has_any_placeholder:
# 如果prompt中没有占位符将用户输入作为主要内容
if user_query:
# 判断是否是通用指令:简短且不包含具体任务描述
prompt_stripped = prompt.strip()
is_generic_instruction = (
len(prompt_stripped) < 30 or # 简短提示词
prompt_stripped in [
"请处理用户请求。", "请处理用户请求",
"请处理以下输入数据:", "请处理以下输入数据",
"请处理输入。", "请处理输入",
"处理用户请求", "处理请求",
"请回答用户问题", "请回答用户问题。",
"请帮助用户", "请帮助用户。"
] or
# 检查是否只包含通用指令关键词
(len(prompt_stripped) < 50 and any(keyword in prompt_stripped for keyword in [
"请处理", "处理", "请回答", "回答", "请帮助", "帮助", "请执行", "执行"
]) and not any(specific in prompt_stripped for specific in [
"翻译", "生成", "分析", "总结", "提取", "转换", "计算"
]))
)
if is_generic_instruction:
# 如果是通用指令直接使用用户输入作为prompt
formatted_prompt = str(user_query)
logger.info(f"[rjb] 检测到通用指令直接使用用户输入作为prompt: {user_query[:50] if user_query else 'None'}")
else:
# 否则将用户输入附加到prompt
formatted_prompt = f"{formatted_prompt}\n\n{user_query}"
logger.info(f"[rjb] 非通用指令将用户输入附加到prompt")
else:
# 如果没有提取到用户查询附加整个input_data
formatted_prompt = f"{formatted_prompt}\n\n{json_module.dumps(input_data, ensure_ascii=False)}"
elif has_unfilled_variables or re.search(r'\{\{(\w+)\}\}', formatted_prompt):
# 如果有占位符但未填充,附加用户需求说明
if user_query:
formatted_prompt = f"{formatted_prompt}\n\n用户需求:{user_query}\n\n请根据以上用户需求,忽略未填充的变量占位符(如{{{{variable}}}}),直接基于用户需求来完成任务。"
logger.info(f"[rjb] LLM节点prompt格式化: node_id={node_id}, original_prompt='{prompt[:50] if len(prompt) > 50 else prompt}', has_any_placeholder={has_any_placeholder}, user_query={user_query}, is_generic_instruction={is_generic_instruction}, final_prompt前200字符='{formatted_prompt[:200] if len(formatted_prompt) > 200 else formatted_prompt}'")
prompt = formatted_prompt
else:
# 如果input_data不是dict直接转换为字符串
if '{input}' in prompt:
prompt = prompt.replace('{input}', str(input_data))
else:
prompt = f"{prompt}\n\n输入:{str(input_data)}"
except Exception as e:
# 格式化失败使用原始prompt和input_data
logger.warning(f"[rjb] Prompt格式化失败: {str(e)}")
try:
prompt = f"{prompt}\n\n输入数据:\n{json_module.dumps(input_data, ensure_ascii=False)}"
except:
prompt = f"{prompt}\n\n输入数据:{str(input_data)}"
# 获取LLM配置
provider = node_data.get('provider', 'openai')
model = node_data.get('model', 'gpt-3.5-turbo')
# 确保temperature是浮点数节点模板中可能是字符串
temperature_raw = node_data.get('temperature', 0.7)
if isinstance(temperature_raw, str):
try:
temperature = float(temperature_raw)
except (ValueError, TypeError):
temperature = 0.7
else:
temperature = float(temperature_raw) if temperature_raw is not None else 0.7
# 确保max_tokens是整数节点模板中可能是字符串
max_tokens_raw = node_data.get('max_tokens')
if max_tokens_raw is not None:
if isinstance(max_tokens_raw, str):
try:
max_tokens = int(max_tokens_raw)
except (ValueError, TypeError):
max_tokens = None
else:
max_tokens = int(max_tokens_raw) if max_tokens_raw is not None else None
else:
max_tokens = None
# 不传递 api_key 和 base_url让 LLM 服务使用系统默认配置(与节点测试保持一致)
api_key = None
base_url = None
# 记录实际发送给LLM的prompt
logger.info(f"[rjb] 准备调用LLM: node_id={node_id}, provider={provider}, model={model}, prompt前200字符='{prompt[:200] if len(prompt) > 200 else prompt}'")
# 调用LLM服务
try:
if self.logger:
logger.debug(f"[rjb] LLM节点配置: provider={provider}, model={model}, 使用系统默认API Key配置")
self.logger.info(f"调用LLM服务: {provider}/{model}", node_id=node_id, node_type=node_type)
result = await llm_service.call_llm(
prompt=prompt,
provider=provider,
model=model,
temperature=temperature,
max_tokens=max_tokens
# 不传递 api_key 和 base_url使用系统默认配置
)
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
# LLM调用失败返回错误
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'LLM调用失败: {str(e)}'
}
elif node_type == 'condition':
# 条件节点:判断分支
condition = node.get('data', {}).get('condition', '')
if not condition:
# 如果没有条件表达式默认返回False
return {
'output': False,
'status': 'success',
'branch': 'false'
}
# 使用条件解析器评估表达式
try:
result = condition_parser.evaluate_condition(condition, input_data)
exec_result = {
'output': result,
'status': 'success',
'branch': 'true' if result else 'false'
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, {'result': result, 'branch': exec_result['branch']}, duration)
return exec_result
except Exception as e:
# 条件评估失败
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': False,
'status': 'failed',
'error': f'条件评估失败: {str(e)}',
'branch': 'false'
}
elif node_type == 'data' or node_type == 'transform':
# 数据转换节点
node_data = node.get('data', {})
mapping = node_data.get('mapping', {})
filter_rules = node_data.get('filter_rules', [])
compute_rules = node_data.get('compute_rules', {})
mode = node_data.get('mode', 'mapping')
try:
# 处理mapping中的{{variable}}格式从input_data中提取值
# 首先如果input_data包含output字段需要展开它
expanded_input = input_data.copy()
if 'output' in input_data and isinstance(input_data['output'], dict):
# 将output中的内容展开到顶层但保留output字段
expanded_input.update(input_data['output'])
processed_mapping = {}
import re
for target_key, source_expr in mapping.items():
if isinstance(source_expr, str):
# 支持{{variable}}格式
double_brace_vars = re.findall(r'\{\{(\w+)\}\}', source_expr)
if double_brace_vars:
# 从expanded_input中获取变量值
var_value = None
for var_name in double_brace_vars:
# 尝试从expanded_input中获取支持嵌套路径
var_value = self._get_nested_value(expanded_input, var_name)
if var_value is not None:
break
if var_value is not None:
# 如果只有一个变量,直接使用值;否则替换表达式
if len(double_brace_vars) == 1:
processed_mapping[target_key] = var_value
else:
# 多个变量,替换表达式
processed_expr = source_expr
for var_name in double_brace_vars:
var_val = self._get_nested_value(expanded_input, var_name)
if var_val is not None:
replacement = json_module.dumps(var_val, ensure_ascii=False) if isinstance(var_val, (dict, list)) else str(var_val)
processed_expr = processed_expr.replace(f'{{{{{var_name}}}}}', replacement)
processed_mapping[target_key] = processed_expr
else:
# 变量不存在,保持原表达式
processed_mapping[target_key] = source_expr
else:
# 不是{{variable}}格式,直接使用
processed_mapping[target_key] = source_expr
else:
# 不是字符串,直接使用
processed_mapping[target_key] = source_expr
# 如果mode是merge需要合并所有输入数据
if mode == 'merge':
# 合并所有上游节点的输出(使用展开后的数据)
result = expanded_input.copy()
# 重要如果输入数据中包含conversation_history、user_profile、context等记忆字段先构建memory对象
memory_fields = ['conversation_history', 'user_profile', 'context']
memory_data = {}
for field in memory_fields:
if field in expanded_input:
memory_data[field] = expanded_input[field]
# 如果构建了memory对象添加到result中
if memory_data:
result['memory'] = memory_data
logger.info(f"[rjb] Transform节点 {node_id} 构建memory对象: {list(memory_data.keys())}")
# 添加mapping的结果mapping可能会覆盖memory字段
for key, value in processed_mapping.items():
# 如果mapping中的value是None或空字符串且key是memory尝试从expanded_input构建
if key == 'memory' and (value is None or value == '' or value == '{{output}}'):
if memory_data:
result[key] = memory_data
logger.info(f"[rjb] Transform节点 {node_id} mapping中的memory为空使用构建的memory对象")
elif 'memory' in expanded_input:
result[key] = expanded_input['memory']
else:
result[key] = value
# 确保记忆字段被保留即使mapping覆盖了它们
for field in memory_fields:
if field in expanded_input and field not in result:
result[field] = expanded_input[field]
# 如果memory字段是dict也要检查其中的字段
if 'memory' in expanded_input and isinstance(expanded_input['memory'], dict):
if 'memory' not in result:
result['memory'] = expanded_input['memory'].copy()
else:
# 合并memory字段
if isinstance(result['memory'], dict):
result['memory'].update(expanded_input['memory'])
logger.info(f"[rjb] Transform节点 {node_id} merge模式结果keys: {list(result.keys())}")
if 'memory' in result and isinstance(result['memory'], dict):
if 'conversation_history' in result['memory']:
logger.info(f"[rjb] memory.conversation_history: {len(result['memory']['conversation_history'])}")
elif 'conversation_history' in result:
logger.info(f"[rjb] conversation_history: {len(result['conversation_history'])}")
else:
# 使用处理后的mapping进行转换使用展开后的数据
result = data_transformer.transform_data(
input_data=expanded_input,
mapping=processed_mapping,
filter_rules=filter_rules,
compute_rules=compute_rules,
mode=mode
)
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
logger.error(f"[rjb] Transform节点执行失败: {str(e)}", exc_info=True)
return {
'output': None,
'status': 'failed',
'error': f'数据转换失败: {str(e)}'
}
elif node_type == 'loop' or node_type == 'foreach':
# 循环节点:对数组进行循环处理
node_data = node.get('data', {})
items_path = node_data.get('items_path', 'items') # 数组数据路径
item_variable = node_data.get('item_variable', 'item') # 循环变量名
# 从输入数据中获取数组
items = self._get_nested_value(input_data, items_path)
if not isinstance(items, list):
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type,
ValueError(f"路径 {items_path} 的值不是数组"), duration)
return {
'output': None,
'status': 'failed',
'error': f'路径 {items_path} 的值不是数组,当前类型: {type(items).__name__}'
}
if self.logger:
self.logger.info(f"循环节点开始处理 {len(items)} 个元素",
node_id=node_id, node_type=node_type,
data={"items_count": len(items)})
# 执行循环:对每个元素执行循环体
loop_results = []
for index, item in enumerate(items):
if self.logger:
self.logger.info(f"循环迭代 {index + 1}/{len(items)}",
node_id=node_id, node_type=node_type,
data={"index": index, "item": item})
# 准备循环体的输入数据
loop_input = {
**input_data, # 保留原始输入数据
item_variable: item, # 当前循环项
f'{item_variable}_index': index, # 索引
f'{item_variable}_total': len(items) # 总数
}
# 执行循环体(获取循环节点的子节点)
loop_body_result = await self._execute_loop_body(
node_id, loop_input, index
)
if loop_body_result.get('status') == 'success':
loop_results.append(loop_body_result.get('output', item))
else:
# 如果循环体执行失败,可以选择继续或停止
error_handling = node_data.get('error_handling', 'continue')
if error_handling == 'stop':
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type,
Exception(f"循环体执行失败,停止循环: {loop_body_result.get('error')}"), duration)
return {
'output': None,
'status': 'failed',
'error': f'循环体执行失败: {loop_body_result.get("error")}',
'completed_items': index,
'results': loop_results
}
else:
# continue: 继续执行,记录错误
if self.logger:
self.logger.warn(f"循环迭代 {index + 1} 失败,继续执行",
node_id=node_id, node_type=node_type,
data={"error": loop_body_result.get('error')})
loop_results.append(None)
exec_result = {
'output': loop_results,
'status': 'success',
'items_processed': len(items),
'results_count': len(loop_results)
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type,
{'results_count': len(loop_results)}, duration)
return exec_result
elif node_type == 'http' or node_type == 'request':
# HTTP请求节点发送HTTP请求
node_data = node.get('data', {})
url = node_data.get('url', '')
method = node_data.get('method', 'GET').upper()
headers = node_data.get('headers', {})
params = node_data.get('params', {})
body = node_data.get('body', {})
timeout = node_data.get('timeout', 30)
# 如果URL、headers、params、body中包含变量从input_data中替换
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
# 支持 {key} 或 ${key} 格式
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
# 替换URL中的变量
if url:
url = replace_variables(url, input_data)
# 替换headers中的变量
if isinstance(headers, dict):
headers = {k: replace_variables(str(v), input_data) for k, v in headers.items()}
elif isinstance(headers, str):
try:
headers = json.loads(replace_variables(headers, input_data))
except:
headers = {}
# 替换params中的变量
if isinstance(params, dict):
params = {k: replace_variables(str(v), input_data) if isinstance(v, str) else v
for k, v in params.items()}
elif isinstance(params, str):
try:
params = json.loads(replace_variables(params, input_data))
except:
params = {}
# 替换body中的变量
if isinstance(body, dict):
# 递归替换字典中的变量
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
new_k = replace_variables(k, data)
if isinstance(v, dict):
result[new_k] = replace_dict_vars(v, data)
elif isinstance(v, str):
result[new_k] = replace_variables(v, data)
else:
result[new_k] = v
return result
body = replace_dict_vars(body, input_data)
elif isinstance(body, str):
body = replace_variables(body, input_data)
try:
body = json.loads(body)
except:
pass
try:
import httpx
async with httpx.AsyncClient(timeout=timeout) as client:
if method == 'GET':
response = await client.get(url, params=params, headers=headers)
elif method == 'POST':
response = await client.post(url, json=body, params=params, headers=headers)
elif method == 'PUT':
response = await client.put(url, json=body, params=params, headers=headers)
elif method == 'DELETE':
response = await client.delete(url, params=params, headers=headers)
elif method == 'PATCH':
response = await client.patch(url, json=body, params=params, headers=headers)
else:
raise ValueError(f"不支持的HTTP方法: {method}")
# 尝试解析JSON响应
try:
response_data = response.json()
except:
response_data = response.text
result = {
'output': {
'status_code': response.status_code,
'headers': dict(response.headers),
'data': response_data
},
'status': 'success'
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'HTTP请求失败: {str(e)}'
}
elif node_type == 'database' or node_type == 'db':
# 数据库操作节点:执行数据库操作
node_data = node.get('data', {})
data_source_id = node_data.get('data_source_id')
operation = node_data.get('operation', 'query') # query/insert/update/delete
sql = node_data.get('sql', '')
table = node_data.get('table', '')
data = node_data.get('data', {})
where = node_data.get('where', {})
# 如果SQL中包含变量从input_data中替换
if sql and isinstance(sql, str):
import re
def replace_sql_vars(text: str, data: Dict[str, Any]) -> str:
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
if value is None:
return match.group(0)
# 如果是字符串需要转义SQL注入
if isinstance(value, str):
# 简单转义,实际应该使用参数化查询
escaped_value = value.replace("'", "''")
return f"'{escaped_value}'"
return str(value)
return re.sub(pattern, replacer, text)
sql = replace_sql_vars(sql, input_data)
try:
# 从数据库加载数据源配置
if not self.db:
raise ValueError("数据库会话未提供,无法执行数据库操作")
from app.models.data_source import DataSource
from app.services.data_source_connector import create_connector
data_source = self.db.query(DataSource).filter(
DataSource.id == data_source_id
).first()
if not data_source:
raise ValueError(f"数据源不存在: {data_source_id}")
connector = create_connector(data_source.type, data_source.config)
if operation == 'query':
# 查询操作
if not sql:
raise ValueError("查询操作需要提供SQL语句")
query_params = {'query': sql}
result_data = connector.query(query_params)
result = {'output': result_data, 'status': 'success'}
elif operation == 'insert':
# 插入操作
if not table:
raise ValueError("插入操作需要提供表名")
# 构建INSERT SQL
columns = ', '.join(data.keys())
# 处理字符串值,转义单引号
def escape_value(v):
if isinstance(v, str):
escaped = v.replace("'", "''")
return f"'{escaped}'"
return str(v)
values = ', '.join([escape_value(v) for v in data.values()])
insert_sql = f"INSERT INTO {table} ({columns}) VALUES ({values})"
query_params = {'query': insert_sql}
result_data = connector.query(query_params)
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
elif operation == 'update':
# 更新操作
if not table or not where:
raise ValueError("更新操作需要提供表名和WHERE条件")
set_clause = ', '.join([f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in data.items()])
where_clause = ' AND '.join([f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in where.items()])
update_sql = f"UPDATE {table} SET {set_clause} WHERE {where_clause}"
query_params = {'query': update_sql}
result_data = connector.query(query_params)
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
elif operation == 'delete':
# 删除操作
if not table or not where:
raise ValueError("删除操作需要提供表名和WHERE条件")
# 处理字符串值,转义单引号
def escape_sql_value(k, v):
if isinstance(v, str):
escaped = v.replace("'", "''")
return f"{k} = '{escaped}'"
return f"{k} = {v}"
where_clause = ' AND '.join([escape_sql_value(k, v) for k, v in where.items()])
delete_sql = f"DELETE FROM {table} WHERE {where_clause}"
query_params = {'query': delete_sql}
result_data = connector.query(query_params)
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
else:
raise ValueError(f"不支持的数据库操作: {operation}")
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'数据库操作失败: {str(e)}'
}
elif node_type == 'file' or node_type == 'file_operation':
# 文件操作节点:文件读取、写入、上传、下载
node_data = node.get('data', {})
operation = node_data.get('operation', 'read') # read/write/upload/download
file_path = node_data.get('file_path', '')
content = node_data.get('content', '')
encoding = node_data.get('encoding', 'utf-8')
# 替换文件路径和内容中的变量
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
if file_path:
file_path = replace_variables(file_path, input_data)
if isinstance(content, str):
content = replace_variables(content, input_data)
try:
import os
import json
import base64
from pathlib import Path
if operation == 'read':
# 读取文件
if not file_path:
raise ValueError("读取操作需要提供文件路径")
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
# 根据文件扩展名决定读取方式
file_ext = Path(file_path).suffix.lower()
if file_ext == '.json':
with open(file_path, 'r', encoding=encoding) as f:
data = json.load(f)
elif file_ext in ['.txt', '.md', '.log']:
with open(file_path, 'r', encoding=encoding) as f:
data = f.read()
else:
# 二进制文件返回base64编码
with open(file_path, 'rb') as f:
data = base64.b64encode(f.read()).decode('utf-8')
result = {'output': data, 'status': 'success'}
elif operation == 'write':
# 写入文件
if not file_path:
raise ValueError("写入操作需要提供文件路径")
# 确保目录存在
os.makedirs(os.path.dirname(file_path) if os.path.dirname(file_path) else '.', exist_ok=True)
# 如果content是字典或列表转换为JSON
if isinstance(content, (dict, list)):
content = json.dumps(content, ensure_ascii=False, indent=2)
# 根据文件扩展名决定写入方式
file_ext = Path(file_path).suffix.lower()
if file_ext == '.json':
with open(file_path, 'w', encoding=encoding) as f:
json.dump(json.loads(content) if isinstance(content, str) else content, f, ensure_ascii=False, indent=2)
else:
with open(file_path, 'w', encoding=encoding) as f:
f.write(str(content))
result = {'output': {'file_path': file_path, 'message': '文件写入成功'}, 'status': 'success'}
elif operation == 'upload':
# 文件上传从base64或URL上传
upload_type = node_data.get('upload_type', 'base64') # base64/url
target_path = node_data.get('target_path', '')
if upload_type == 'base64':
# 从输入数据中获取base64编码的文件内容
file_data = input_data.get('file_data') or input_data.get('content')
if not file_data:
raise ValueError("上传操作需要提供file_data或content字段")
# 解码base64
if isinstance(file_data, str):
file_bytes = base64.b64decode(file_data)
else:
file_bytes = file_data
# 写入目标路径
if not target_path:
raise ValueError("上传操作需要提供target_path")
os.makedirs(os.path.dirname(target_path) if os.path.dirname(target_path) else '.', exist_ok=True)
with open(target_path, 'wb') as f:
f.write(file_bytes)
result = {'output': {'file_path': target_path, 'message': '文件上传成功'}, 'status': 'success'}
else:
# URL上传下载后保存
import httpx
url = node_data.get('url', '')
if not url:
raise ValueError("URL上传需要提供url")
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
if not target_path:
# 从URL提取文件名
target_path = os.path.basename(url) or 'downloaded_file'
os.makedirs(os.path.dirname(target_path) if os.path.dirname(target_path) else '.', exist_ok=True)
with open(target_path, 'wb') as f:
f.write(response.content)
result = {'output': {'file_path': target_path, 'message': '文件下载并保存成功'}, 'status': 'success'}
elif operation == 'download':
# 文件下载返回base64编码或文件URL
download_format = node_data.get('download_format', 'base64') # base64/url
if not file_path:
raise ValueError("下载操作需要提供文件路径")
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
if download_format == 'base64':
# 返回base64编码
with open(file_path, 'rb') as f:
file_bytes = f.read()
file_base64 = base64.b64encode(file_bytes).decode('utf-8')
result = {'output': {'file_name': os.path.basename(file_path), 'content': file_base64, 'format': 'base64'}, 'status': 'success'}
else:
# 返回文件路径实际应用中可能需要生成临时URL
result = {'output': {'file_path': file_path, 'format': 'path'}, 'status': 'success'}
else:
raise ValueError(f"不支持的文件操作: {operation}")
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'文件操作失败: {str(e)}'
}
elif node_type == 'webhook':
# Webhook节点发送Webhook请求到外部系统
node_data = node.get('data', {})
url = node_data.get('url', '')
method = node_data.get('method', 'POST').upper()
headers = node_data.get('headers', {})
body = node_data.get('body', {})
timeout = node_data.get('timeout', 30)
# 如果URL、headers、body中包含变量从input_data中替换
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
# 替换URL中的变量
if url:
url = replace_variables(url, input_data)
# 替换headers中的变量
if isinstance(headers, dict):
headers = {k: replace_variables(str(v), input_data) for k, v in headers.items()}
elif isinstance(headers, str):
try:
headers = json.loads(replace_variables(headers, input_data))
except:
headers = {}
# 替换body中的变量
if isinstance(body, dict):
# 递归替换字典中的变量
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
new_k = replace_variables(k, data)
if isinstance(v, dict):
result[new_k] = replace_dict_vars(v, data)
elif isinstance(v, str):
result[new_k] = replace_variables(v, data)
else:
result[new_k] = v
return result
body = replace_dict_vars(body, input_data)
elif isinstance(body, str):
body = replace_variables(body, input_data)
try:
body = json.loads(body)
except:
pass
# 如果没有配置body默认使用input_data作为body
if not body:
body = input_data
try:
import httpx
async with httpx.AsyncClient(timeout=timeout) as client:
if method == 'GET':
response = await client.get(url, headers=headers)
elif method == 'POST':
response = await client.post(url, json=body, headers=headers)
elif method == 'PUT':
response = await client.put(url, json=body, headers=headers)
elif method == 'PATCH':
response = await client.patch(url, json=body, headers=headers)
else:
raise ValueError(f"Webhook不支持HTTP方法: {method}")
# 尝试解析JSON响应
try:
response_data = response.json()
except:
response_data = response.text
result = {
'output': {
'status_code': response.status_code,
'headers': dict(response.headers),
'data': response_data
},
'status': 'success'
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'Webhook请求失败: {str(e)}'
}
elif node_type == 'schedule' or node_type == 'delay' or node_type == 'timer':
# 定时任务节点:延迟执行或定时执行
node_data = node.get('data', {})
delay_type = node_data.get('delay_type', 'fixed') # fixed: 固定延迟, cron: cron表达式
delay_value = node_data.get('delay_value', 0) # 延迟值(秒)
delay_unit = node_data.get('delay_unit', 'seconds') # seconds, minutes, hours
# 计算实际延迟时间(毫秒)
if delay_unit == 'seconds':
delay_ms = int(delay_value * 1000)
elif delay_unit == 'minutes':
delay_ms = int(delay_value * 60 * 1000)
elif delay_unit == 'hours':
delay_ms = int(delay_value * 60 * 60 * 1000)
else:
delay_ms = int(delay_value * 1000)
# 如果延迟时间大于0则等待
if delay_ms > 0:
if self.logger:
self.logger.info(
f"定时任务节点等待 {delay_value} {delay_unit}",
node_id=node_id,
node_type=node_type,
data={'delay_ms': delay_ms, 'delay_value': delay_value, 'delay_unit': delay_unit}
)
await asyncio.sleep(delay_ms / 1000.0)
# 返回输入数据(定时节点只是延迟,不改变数据)
result = {'output': input_data, 'status': 'success', 'delay_ms': delay_ms}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
elif node_type == 'email' or node_type == 'mail':
# 邮件节点:发送邮件通知
node_data = node.get('data', {})
smtp_host = node_data.get('smtp_host', '')
smtp_port = node_data.get('smtp_port', 587)
smtp_user = node_data.get('smtp_user', '')
smtp_password = node_data.get('smtp_password', '')
use_tls = node_data.get('use_tls', True)
from_email = node_data.get('from_email', '')
to_email = node_data.get('to_email', '')
cc_email = node_data.get('cc_email', '')
bcc_email = node_data.get('bcc_email', '')
subject = node_data.get('subject', '')
body = node_data.get('body', '')
body_type = node_data.get('body_type', 'text') # text/html
attachments = node_data.get('attachments', []) # 附件列表
# 替换变量
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
# 替换所有配置中的变量
smtp_host = replace_variables(smtp_host, input_data)
smtp_user = replace_variables(smtp_user, input_data)
smtp_password = replace_variables(smtp_password, input_data)
from_email = replace_variables(from_email, input_data)
to_email = replace_variables(to_email, input_data)
cc_email = replace_variables(cc_email, input_data)
bcc_email = replace_variables(bcc_email, input_data)
subject = replace_variables(subject, input_data)
body = replace_variables(body, input_data)
# 验证必需参数
if not smtp_host:
raise ValueError("邮件节点需要配置SMTP服务器地址")
if not from_email:
raise ValueError("邮件节点需要配置发件人邮箱")
if not to_email:
raise ValueError("邮件节点需要配置收件人邮箱")
if not subject:
raise ValueError("邮件节点需要配置邮件主题")
try:
import aiosmtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.mime.base import MIMEBase
from email import encoders
import base64
import os
# 创建邮件消息
msg = MIMEMultipart('alternative')
msg['From'] = from_email
msg['To'] = to_email
if cc_email:
msg['Cc'] = cc_email
msg['Subject'] = subject
# 添加邮件正文
if body_type == 'html':
msg.attach(MIMEText(body, 'html', 'utf-8'))
else:
msg.attach(MIMEText(body, 'plain', 'utf-8'))
# 处理附件
for attachment in attachments:
if isinstance(attachment, dict):
file_path = attachment.get('file_path', '')
file_name = attachment.get('file_name', '')
file_content = attachment.get('file_content', '') # base64编码的内容
# 替换变量
file_path = replace_variables(file_path, input_data)
file_name = replace_variables(file_name, input_data)
if file_path and os.path.exists(file_path):
# 从文件路径读取
with open(file_path, 'rb') as f:
file_data = f.read()
if not file_name:
file_name = os.path.basename(file_path)
elif file_content:
# 从base64内容读取
file_data = base64.b64decode(file_content)
if not file_name:
file_name = 'attachment'
else:
continue
# 添加附件
part = MIMEBase('application', 'octet-stream')
part.set_payload(file_data)
encoders.encode_base64(part)
part.add_header(
'Content-Disposition',
f'attachment; filename= {file_name}'
)
msg.attach(part)
# 发送邮件
recipients = [to_email]
if cc_email:
recipients.extend([email.strip() for email in cc_email.split(',')])
if bcc_email:
recipients.extend([email.strip() for email in bcc_email.split(',')])
async with aiosmtplib.SMTP(hostname=smtp_host, port=smtp_port) as smtp:
if use_tls:
await smtp.starttls()
if smtp_user and smtp_password:
await smtp.login(smtp_user, smtp_password)
await smtp.send_message(msg, recipients=recipients)
result = {
'output': {
'message': '邮件发送成功',
'from': from_email,
'to': to_email,
'subject': subject,
'recipients_count': len(recipients)
},
'status': 'success'
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'邮件发送失败: {str(e)}'
}
elif node_type == 'message_queue' or node_type == 'mq' or node_type == 'rabbitmq' or node_type == 'kafka':
# 消息队列节点发送消息到RabbitMQ或Kafka
node_data = node.get('data', {})
queue_type = node_data.get('queue_type', 'rabbitmq') # rabbitmq/kafka
# 替换变量
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
try:
if queue_type == 'rabbitmq':
# RabbitMQ实现
import aio_pika
import json
# 获取RabbitMQ配置
host = replace_variables(node_data.get('host', 'localhost'), input_data)
port = node_data.get('port', 5672)
username = replace_variables(node_data.get('username', 'guest'), input_data)
password = replace_variables(node_data.get('password', 'guest'), input_data)
exchange = replace_variables(node_data.get('exchange', ''), input_data)
routing_key = replace_variables(node_data.get('routing_key', ''), input_data)
queue_name = replace_variables(node_data.get('queue_name', ''), input_data)
message = node_data.get('message', input_data)
# 如果message是字符串尝试替换变量
if isinstance(message, str):
message = replace_variables(message, input_data)
try:
message = json.loads(message)
except:
pass
elif isinstance(message, dict):
# 递归替换字典中的变量
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
new_k = replace_variables(k, data)
if isinstance(v, dict):
result[new_k] = replace_dict_vars(v, data)
elif isinstance(v, str):
result[new_k] = replace_variables(v, data)
else:
result[new_k] = v
return result
message = replace_dict_vars(message, input_data)
# 如果没有配置message使用input_data
if not message:
message = input_data
# 连接RabbitMQ
connection_url = f"amqp://{username}:{password}@{host}:{port}/"
connection = await aio_pika.connect_robust(connection_url)
channel = await connection.channel()
# 发送消息
message_body = json.dumps(message, ensure_ascii=False).encode('utf-8')
if exchange:
# 使用exchange和routing_key
await channel.default_exchange.publish(
aio_pika.Message(message_body),
routing_key=routing_key or queue_name
)
elif queue_name:
# 直接发送到队列
queue = await channel.declare_queue(queue_name, durable=True)
await channel.default_exchange.publish(
aio_pika.Message(message_body),
routing_key=queue_name
)
else:
raise ValueError("RabbitMQ节点需要配置exchange或queue_name")
await connection.close()
result = {
'output': {
'message': '消息已发送到RabbitMQ',
'queue_type': 'rabbitmq',
'exchange': exchange,
'routing_key': routing_key or queue_name,
'queue_name': queue_name,
'message_size': len(message_body)
},
'status': 'success'
}
elif queue_type == 'kafka':
# Kafka实现
from kafka import KafkaProducer
import json
# 获取Kafka配置
bootstrap_servers = replace_variables(node_data.get('bootstrap_servers', 'localhost:9092'), input_data)
topic = replace_variables(node_data.get('topic', ''), input_data)
message = node_data.get('message', input_data)
# 如果message是字符串尝试替换变量
if isinstance(message, str):
message = replace_variables(message, input_data)
try:
message = json.loads(message)
except:
pass
elif isinstance(message, dict):
# 递归替换字典中的变量
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
new_k = replace_variables(k, data)
if isinstance(v, dict):
result[new_k] = replace_dict_vars(v, data)
elif isinstance(v, str):
result[new_k] = replace_variables(v, data)
else:
result[new_k] = v
return result
message = replace_dict_vars(message, input_data)
# 如果没有配置message使用input_data
if not message:
message = input_data
if not topic:
raise ValueError("Kafka节点需要配置topic")
# 创建Kafka生产者注意kafka-python是同步的需要在线程池中运行
from concurrent.futures import ThreadPoolExecutor
def send_kafka_message():
producer = KafkaProducer(
bootstrap_servers=bootstrap_servers.split(','),
value_serializer=lambda v: json.dumps(v, ensure_ascii=False).encode('utf-8')
)
future = producer.send(topic, message)
record_metadata = future.get(timeout=10)
producer.close()
return record_metadata
# 在线程池中执行同步操作
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
record_metadata = await loop.run_in_executor(executor, send_kafka_message)
result = {
'output': {
'message': '消息已发送到Kafka',
'queue_type': 'kafka',
'topic': topic,
'partition': record_metadata.partition,
'offset': record_metadata.offset
},
'status': 'success'
}
else:
raise ValueError(f"不支持的消息队列类型: {queue_type}")
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'消息队列发送失败: {str(e)}'
}
elif node_type == 'switch':
# Switch节点多分支路由
logger.info(f"[rjb] 执行Switch节点: node_id={node_id}, node_type={node_type}, input_data keys={list(input_data.keys()) if isinstance(input_data, dict) else 'not_dict'}")
node_data = node.get('data', {})
field = node_data.get('field', '')
cases = node_data.get('cases', {})
default_case = node_data.get('default', 'default')
logger.info(f"[rjb] Switch节点配置: field={field}, cases={cases}, default={default_case}")
# 处理输入数据尝试解析JSON字符串递归处理所有字段
def parse_json_recursively(data, merge_parsed=True):
"""
递归解析JSON字符串
Args:
data: 要处理的数据
merge_parsed: 是否将解析后的字典内容合并到父级(用于方便字段提取)
"""
import json
if isinstance(data, str):
# 如果是字符串尝试解析为JSON
try:
parsed = json.loads(data)
# 如果解析成功,递归处理解析后的数据
if isinstance(parsed, (dict, list)):
return parse_json_recursively(parsed, merge_parsed)
return parsed
except:
# 不是JSON返回原字符串
return data
elif isinstance(data, dict):
# 如果是字典,递归处理每个值
result = {}
for key, value in data.items():
parsed_value = parse_json_recursively(value, merge_parsed=False)
result[key] = parsed_value
# 如果merge_parsed为True且解析后的值是字典将其内容合并到当前层级方便字段提取
if merge_parsed and isinstance(parsed_value, dict):
# 合并时避免覆盖已有的键
for k, v in parsed_value.items():
if k not in result:
result[k] = v
return result
elif isinstance(data, list):
# 如果是列表,递归处理每个元素
return [parse_json_recursively(item, merge_parsed) for item in data]
else:
# 其他类型,直接返回
return data
processed_input = parse_json_recursively(input_data, merge_parsed=True)
# 从处理后的输入数据中获取字段值
field_value = self._get_nested_value(processed_input, field)
field_value_str = str(field_value) if field_value is not None else ''
# 查找匹配的case
matched_case = default_case
if field_value_str in cases:
matched_case = cases[field_value_str]
elif field_value in cases:
matched_case = cases[field_value]
# 记录详细的匹配信息(同时输出到控制台和数据库)
match_info = {
'field': field,
'field_value': field_value,
'field_value_str': field_value_str,
'matched_case': matched_case,
'processed_input_keys': list(processed_input.keys()) if isinstance(processed_input, dict) else 'not_dict',
'cases_keys': list(cases.keys())
}
logger.info(f"[rjb] Switch节点匹配: node_id={node_id}, {match_info}")
if self.logger:
self.logger.info(
f"Switch节点匹配: field={field}, field_value={field_value}, matched_case={matched_case}",
node_id=node_id,
node_type=node_type,
data=match_info
)
exec_result = {
'output': processed_input,
'status': 'success',
'branch': matched_case,
'matched_value': field_value
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, {'branch': matched_case, 'value': field_value}, duration)
return exec_result
elif node_type == 'merge':
# Merge节点合并多个分支的数据流
node_data = node.get('data', {})
mode = node_data.get('mode', 'merge_all') # merge_all, merge_first, merge_last
strategy = node_data.get('strategy', 'array') # array, object, concat
# 获取所有上游节点的输出通过input_data中的特殊字段
# 如果input_data包含多个分支的数据合并它们
merged_data = {}
if strategy == 'array':
# 数组策略:将所有输入数据作为数组元素
if isinstance(input_data, list):
merged_data = input_data
elif isinstance(input_data, dict):
# 如果包含多个分支数据,提取为数组
branch_data = []
for key, value in input_data.items():
if not key.startswith('_'):
branch_data.append(value)
merged_data = branch_data if branch_data else [input_data]
else:
merged_data = [input_data]
elif strategy == 'object':
# 对象策略:合并所有字段
if isinstance(input_data, dict):
merged_data = input_data.copy()
else:
merged_data = {'data': input_data}
elif strategy == 'concat':
# 连接策略:将所有数据连接为字符串
if isinstance(input_data, list):
merged_data = '\n'.join(str(item) for item in input_data)
elif isinstance(input_data, dict):
merged_data = '\n'.join(f"{k}: {v}" for k, v in input_data.items() if not k.startswith('_'))
else:
merged_data = str(input_data)
exec_result = {'output': merged_data, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, merged_data, duration)
return exec_result
elif node_type == 'wait':
# Wait节点等待条件满足
node_data = node.get('data', {})
wait_type = node_data.get('wait_type', 'condition') # condition, time, event
condition = node_data.get('condition', '')
timeout = node_data.get('timeout', 300) # 默认5分钟
poll_interval = node_data.get('poll_interval', 5) # 默认5秒
if wait_type == 'condition':
# 等待条件满足
start_wait = time.time()
while time.time() - start_wait < timeout:
try:
result = condition_parser.evaluate_condition(condition, input_data)
if result:
exec_result = {'output': input_data, 'status': 'success', 'waited': time.time() - start_wait}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, exec_result, duration)
return exec_result
except Exception as e:
logger.warning(f"Wait节点条件评估失败: {str(e)}")
await asyncio.sleep(poll_interval)
# 超时
exec_result = {
'output': input_data,
'status': 'failed',
'error': f'等待条件超时: {timeout}'
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, Exception("等待超时"), duration)
return exec_result
elif wait_type == 'time':
# 等待固定时间
wait_seconds = node_data.get('wait_seconds', 0)
await asyncio.sleep(wait_seconds)
exec_result = {'output': input_data, 'status': 'success', 'waited': wait_seconds}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, exec_result, duration)
return exec_result
else:
# 其他类型暂不支持
exec_result = {'output': input_data, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, exec_result, duration)
return exec_result
elif node_type == 'json':
# JSON处理节点
node_data = node.get('data', {})
operation = node_data.get('operation', 'parse') # parse, stringify, extract, validate
path = node_data.get('path', '')
schema = node_data.get('schema', {})
try:
if operation == 'parse':
# 解析JSON字符串
if isinstance(input_data, str):
result = json_module.loads(input_data)
elif isinstance(input_data, dict) and 'data' in input_data:
# 如果包含data字段尝试解析
if isinstance(input_data['data'], str):
result = json_module.loads(input_data['data'])
else:
result = input_data['data']
else:
result = input_data
elif operation == 'stringify':
# 转换为JSON字符串
result = json_module.dumps(input_data, ensure_ascii=False, indent=2)
elif operation == 'extract':
# 使用JSONPath提取数据简化实现
if path and isinstance(input_data, dict):
# 简单的路径提取,支持 $.key 格式
path = path.replace('$.', '').replace('$', '')
keys = path.split('.')
result = input_data
for key in keys:
if key.endswith('[*]'):
# 数组提取
array_key = key[:-3]
if isinstance(result, dict) and array_key in result:
result = result[array_key]
elif isinstance(result, dict) and key in result:
result = result[key]
else:
result = None
break
else:
result = input_data
elif operation == 'validate':
# JSON Schema验证简化实现
# 这里只做基本验证完整实现需要使用jsonschema库
if schema:
# 简单类型检查
if 'type' in schema:
expected_type = schema['type']
actual_type = type(input_data).__name__
if expected_type == 'object' and actual_type != 'dict':
raise ValueError(f"期望类型 {expected_type},实际类型 {actual_type}")
elif expected_type == 'array' and actual_type != 'list':
raise ValueError(f"期望类型 {expected_type},实际类型 {actual_type}")
result = input_data
else:
result = input_data
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'JSON处理失败: {str(e)}'
}
elif node_type == 'text':
# 文本处理节点
node_data = node.get('data', {})
operation = node_data.get('operation', 'split') # split, join, extract, replace, format
delimiter = node_data.get('delimiter', '\n')
regex = node_data.get('regex', '')
template = node_data.get('template', '')
try:
# 获取输入文本
input_text = input_data
if isinstance(input_data, dict):
# 尝试从字典中提取文本
for key in ['text', 'content', 'message', 'input', 'output']:
if key in input_data and isinstance(input_data[key], str):
input_text = input_data[key]
break
if isinstance(input_text, dict):
input_text = str(input_text)
elif not isinstance(input_text, str):
input_text = str(input_text)
if operation == 'split':
# 拆分文本
result = input_text.split(delimiter)
elif operation == 'join':
# 合并文本(需要输入是数组)
if isinstance(input_data, list):
result = delimiter.join(str(item) for item in input_data)
else:
result = input_text
elif operation == 'extract':
# 使用正则表达式提取
if regex:
import re
matches = re.findall(regex, input_text)
result = matches if len(matches) > 1 else (matches[0] if matches else '')
else:
result = input_text
elif operation == 'replace':
# 替换文本
old_text = node_data.get('old_text', '')
new_text = node_data.get('new_text', '')
if regex:
import re
result = re.sub(regex, new_text, input_text)
else:
result = input_text.replace(old_text, new_text)
elif operation == 'format':
# 格式化文本(使用模板)
if template:
# 支持 {key} 格式的变量替换
result = template
if isinstance(input_data, dict):
for key, value in input_data.items():
result = result.replace(f'{{{key}}}', str(value))
else:
result = result.replace('{value}', str(input_data))
else:
result = input_text
else:
result = input_text
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'文本处理失败: {str(e)}'
}
elif node_type == 'cache':
# 缓存节点
node_data = node.get('data', {})
operation = node_data.get('operation', 'get') # get, set, delete, clear
key = node_data.get('key', '')
ttl = node_data.get('ttl', 3600) # 默认1小时
# 默认优先使用redis如果配置了否则memory
backend = node_data.get('backend') or ('redis' if getattr(settings, 'REDIS_URL', None) else 'memory') # redis, memory
default_value = node_data.get('default_value', '{}')
value_template = node_data.get('value', '')
# 使用Redis作为持久化缓存如果可用否则使用内存缓存
# 注意:内存缓存在单次执行会话内有效,跨执行不会保留
use_redis = False
redis_client = None
# 默认尝试使用Redis如果配置了除非明确指定使用memory
if backend != 'memory':
try:
from app.core.redis_client import get_redis_client
redis_client = get_redis_client()
if redis_client:
use_redis = True
logger.info(f"[rjb] Cache节点 {node_id} 使用Redis缓存")
except Exception as e:
logger.warning(f"Redis不可用: {str(e)},使用内存缓存")
# 内存缓存(单次执行会话内有效)
if not hasattr(self, '_cache_store'):
self._cache_store = {}
self._cache_timestamps = {}
try:
# 替换key中的变量
if isinstance(input_data, dict):
# 首先处理 {{variable}} 格式
import re
double_brace_vars = re.findall(r'\{\{(\w+)\}\}', key)
for var_name in double_brace_vars:
if var_name in input_data:
key = key.replace(f'{{{{{var_name}}}}}', str(input_data[var_name]))
else:
# 如果变量不存在,使用默认值
if var_name == 'user_id':
# 尝试从输入数据中提取user_id如果没有则使用"default"
user_id = input_data.get('user_id') or input_data.get('USER_ID') or 'default'
key = key.replace(f'{{{{{var_name}}}}}', str(user_id))
else:
key = key.replace(f'{{{{{var_name}}}}}', 'default')
# 然后处理 {key} 格式
for k, v in input_data.items():
key = key.replace(f'{{{k}}}', str(v))
# 如果key中还有未替换的变量使用默认值
if '{' in key:
key = key.replace('{user_id}', 'default').replace('{{user_id}}', 'default')
# 清理其他未替换的变量
key = re.sub(r'\{[^}]+\}', 'default', key)
logger.info(f"[rjb] Cache节点 {node_id} 处理后的key: {key}")
if operation == 'get':
# 获取缓存
result = None
cache_hit = False
if use_redis and redis_client:
# 从Redis获取
try:
cached_data = redis_client.get(key)
if cached_data:
result = json_module.loads(cached_data)
cache_hit = True
except Exception as e:
logger.warning(f"从Redis获取缓存失败: {str(e)}")
if result is None:
# 从内存缓存获取
if key in self._cache_store:
# 检查是否过期
if key in self._cache_timestamps:
if time.time() - self._cache_timestamps[key] > ttl:
# 过期,删除
del self._cache_store[key]
del self._cache_timestamps[key]
else:
result = self._cache_store[key]
cache_hit = True
else:
result = self._cache_store[key]
cache_hit = True
# 如果缓存未命中使用default_value
if result is None:
try:
if isinstance(default_value, str):
result = json_module.loads(default_value) if default_value else {}
else:
result = default_value
except:
result = {}
cache_hit = False
logger.info(f"[rjb] Cache节点 {node_id} cache miss使用default_value: {result}")
# 合并输入数据和缓存结果
output = input_data.copy() if isinstance(input_data, dict) else {}
if isinstance(result, dict):
output.update(result)
else:
output['memory'] = result
exec_result = {'output': output, 'status': 'success', 'cache_hit': cache_hit, 'memory': result}
elif operation == 'set':
# 设置缓存
# 处理value模板
if value_template:
# 处理模板语法 {{variable}}
import re
value_str = value_template
# 替换 {{variable}} 格式的变量
# 注意:只替换 memory.* 路径的变量user_input、output、timestamp 等变量在Python表达式执行阶段处理
template_vars = re.findall(r'\{\{(\w+(?:\.\w+)*)\}\}', value_str)
for var_path in template_vars:
# 跳过 user_input、output、timestamp 等变量这些在Python表达式执行阶段处理
if var_path in ['user_input', 'output', 'timestamp']:
continue
# 支持嵌套路径,如 memory.conversation_history
var_parts = var_path.split('.')
var_value = input_data
try:
for part in var_parts:
if isinstance(var_value, dict) and part in var_value:
var_value = var_value[part]
else:
var_value = None
break
if var_value is not None:
# 替换模板变量
replacement = json_module.dumps(var_value, ensure_ascii=False) if isinstance(var_value, (dict, list)) else str(var_value)
value_str = value_str.replace(f'{{{{{var_path}}}}}', replacement)
else:
# 变量不存在,根据路径使用合适的默认值
if 'conversation_history' in var_path:
value_str = value_str.replace(f'{{{{{var_path}}}}}', '[]')
elif 'user_profile' in var_path or 'context' in var_path:
value_str = value_str.replace(f'{{{{{var_path}}}}}', '{}')
else:
# 对于其他变量保留原样让Python表达式执行阶段处理
pass
except Exception as e:
logger.warning(f"处理模板变量 {var_path} 失败: {str(e)}")
# 替换 {key} 格式的变量(但不要替换 {{variable}} 格式的)
for k, v in input_data.items():
placeholder = f'{{{k}}}'
# 确保不是 {{variable}} 格式
if placeholder in value_str and f'{{{{{k}}}}}' not in value_str:
replacement = json_module.dumps(v, ensure_ascii=False) if isinstance(v, (dict, list)) else str(v)
value_str = value_str.replace(placeholder, replacement)
# 解析处理后的value可能是JSON字符串或Python表达式
try:
# 尝试作为JSON解析
value = json_module.loads(value_str)
except:
# 如果不是有效的JSON尝试作为Python表达式执行安全限制
try:
# 准备安全的环境变量
from datetime import datetime
memory = input_data.get('memory', {})
if not isinstance(memory, dict):
memory = {}
# 确保memory中有必要的字段
if 'conversation_history' not in memory:
memory['conversation_history'] = []
if 'user_profile' not in memory:
memory['user_profile'] = {}
if 'context' not in memory:
memory['context'] = {}
# 获取user_input优先从query或USER_INPUT获取
user_input = input_data.get('query') or input_data.get('USER_INPUT') or input_data.get('user_input') or ''
# 获取output从right字段获取如果是dict则提取right子字段
output = input_data.get('right', '')
if isinstance(output, dict):
output = output.get('right', '') or output.get('content', '') or str(output)
if not output:
output = ''
timestamp = datetime.now().isoformat()
# 在Python表达式执行前替换 {{user_input}}、{{output}}、{{timestamp}}
# 注意:模板中已经有引号了,所以需要转义字符串中的特殊字符,然后直接插入
# 使用json.dumps来正确转义但去掉外层的引号因为模板中已经有引号了
user_input_escaped = json_module.dumps(user_input, ensure_ascii=False)[1:-1] # 去掉首尾引号
output_escaped = json_module.dumps(output, ensure_ascii=False)[1:-1]
timestamp_escaped = json_module.dumps(timestamp, ensure_ascii=False)[1:-1]
value_str = value_str.replace('{{user_input}}', user_input_escaped)
value_str = value_str.replace('{{output}}', output_escaped)
value_str = value_str.replace('{{timestamp}}', timestamp_escaped)
# 只允许基本的字典和列表操作
safe_dict = {
'memory': memory,
'user_input': user_input,
'output': output,
'timestamp': timestamp
}
logger.info(f"[rjb] Cache节点 {node_id} 执行value模板")
logger.info(f"[rjb] value_str前300字符: {value_str[:300]}")
logger.info(f"[rjb] user_input: {user_input[:50]}, output: {str(output)[:50]}, timestamp: {timestamp}")
value = eval(value_str, {"__builtins__": {}}, safe_dict)
logger.info(f"[rjb] Cache节点 {node_id} value模板执行成功类型: {type(value)}")
if isinstance(value, dict):
logger.info(f"[rjb] keys: {list(value.keys())}")
if 'conversation_history' in value:
logger.info(f"[rjb] conversation_history: {len(value['conversation_history'])}")
if value['conversation_history']:
logger.info(f"[rjb] 第一条: {value['conversation_history'][0]}")
except Exception as e:
logger.error(f"Cache节点 {node_id} value模板执行失败: {str(e)}")
logger.error(f"value_str: {value_str[:500]}")
logger.error(f"safe_dict: {safe_dict}")
import traceback
logger.error(f"traceback: {traceback.format_exc()}")
# 如果都失败,使用原始输入数据
value = input_data
else:
# 没有value模板使用输入数据
value = input_data
if isinstance(input_data, dict) and 'value' in input_data:
value = input_data['value']
# 存储到缓存
if use_redis and redis_client:
try:
redis_client.setex(key, ttl, json_module.dumps(value, ensure_ascii=False))
logger.info(f"[rjb] Cache节点 {node_id} 已存储到Redis: key={key}")
except Exception as e:
logger.warning(f"存储到Redis失败: {str(e)}")
# 同时存储到内存缓存
self._cache_store[key] = value
self._cache_timestamps[key] = time.time()
logger.info(f"[rjb] Cache节点 {node_id} 已存储: key={key}, value类型={type(value)}")
exec_result = {'output': input_data, 'status': 'success', 'cached_value': value}
elif operation == 'delete':
# 删除缓存
if key in self._cache_store:
del self._cache_store[key]
if key in self._cache_timestamps:
del self._cache_timestamps[key]
exec_result = {'output': input_data, 'status': 'success'}
elif operation == 'clear':
# 清空缓存
self._cache_store.clear()
self._cache_timestamps.clear()
exec_result = {'output': input_data, 'status': 'success'}
else:
exec_result = {'output': input_data, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, exec_result.get('output'), duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'缓存操作失败: {str(e)}'
}
elif node_type == 'vector_db':
# 向量数据库节点向量存储、相似度搜索、RAG检索
node_data = node.get('data', {})
operation = node_data.get('operation', 'search') # search, upsert, delete
collection = node_data.get('collection', 'default')
query_vector = node_data.get('query_vector', '')
top_k = node_data.get('top_k', 5)
# 简化的内存向量存储实现实际生产环境应使用ChromaDB、Pinecone等
if not hasattr(self, '_vector_store'):
self._vector_store = {}
try:
if operation == 'search':
# 向量相似度搜索(简化实现:使用余弦相似度)
if collection not in self._vector_store:
self._vector_store[collection] = []
# 获取查询向量
if isinstance(query_vector, str):
# 尝试从input_data中获取
query_vec = self._get_nested_value(input_data, query_vector.replace('{', '').replace('}', ''))
else:
query_vec = query_vector
if not isinstance(query_vec, list):
# 如果输入数据包含embedding字段使用它
if isinstance(input_data, dict) and 'embedding' in input_data:
query_vec = input_data['embedding']
else:
raise ValueError("无法获取查询向量")
# 计算相似度并排序
results = []
for item in self._vector_store[collection]:
if 'vector' in item:
# 计算余弦相似度
import math
vec1 = query_vec
vec2 = item['vector']
if len(vec1) != len(vec2):
continue
dot_product = sum(a * b for a, b in zip(vec1, vec2))
magnitude1 = math.sqrt(sum(a * a for a in vec1))
magnitude2 = math.sqrt(sum(a * a for a in vec2))
if magnitude1 == 0 or magnitude2 == 0:
similarity = 0
else:
similarity = dot_product / (magnitude1 * magnitude2)
results.append({
'id': item.get('id'),
'text': item.get('text', ''),
'metadata': item.get('metadata', {}),
'similarity': similarity
})
# 按相似度排序并返回top_k
results.sort(key=lambda x: x['similarity'], reverse=True)
result = results[:top_k]
elif operation == 'upsert':
# 插入或更新向量
if collection not in self._vector_store:
self._vector_store[collection] = []
# 从输入数据中提取向量和文本
vector = input_data.get('embedding') or input_data.get('vector')
text = input_data.get('text') or input_data.get('content', '')
metadata = input_data.get('metadata', {})
doc_id = input_data.get('id') or f"doc_{len(self._vector_store[collection])}"
# 查找是否已存在
existing_index = None
for i, item in enumerate(self._vector_store[collection]):
if item.get('id') == doc_id:
existing_index = i
break
doc_item = {
'id': doc_id,
'vector': vector,
'text': text,
'metadata': metadata
}
if existing_index is not None:
self._vector_store[collection][existing_index] = doc_item
else:
self._vector_store[collection].append(doc_item)
result = {'id': doc_id, 'status': 'upserted'}
elif operation == 'delete':
# 删除向量
if collection in self._vector_store:
doc_id = node_data.get('doc_id') or input_data.get('id')
if doc_id:
self._vector_store[collection] = [
item for item in self._vector_store[collection]
if item.get('id') != doc_id
]
result = {'id': doc_id, 'status': 'deleted'}
else:
# 删除整个集合
del self._vector_store[collection]
result = {'collection': collection, 'status': 'deleted'}
else:
result = {'status': 'not_found'}
else:
result = input_data
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'向量数据库操作失败: {str(e)}'
}
elif node_type == 'log':
# 日志节点:记录日志、调试输出、性能监控
node_data = node.get('data', {})
level = node_data.get('level', 'info') # debug, info, warning, error
message = node_data.get('message', '')
include_data = node_data.get('include_data', True)
try:
# 格式化消息
if message:
# 替换变量
if isinstance(input_data, dict):
for key, value in input_data.items():
message = message.replace(f'{{{key}}}', str(value))
# 构建日志内容
log_data = {
'message': message or '节点执行',
'node_id': node_id,
'node_type': node_type,
'timestamp': time.time()
}
if include_data:
log_data['data'] = input_data
# 记录日志
log_message = f"[{node_id}] {log_data['message']}"
if include_data:
log_message += f" | 数据: {json_module.dumps(input_data, ensure_ascii=False)[:200]}"
if level == 'debug':
logger.debug(log_message)
elif level == 'info':
logger.info(log_message)
elif level == 'warning':
logger.warning(log_message)
elif level == 'error':
logger.error(log_message)
else:
logger.info(log_message)
# 如果使用执行日志记录器,也记录
if self.logger:
self.logger.info(log_data['message'], node_id=node_id, node_type=node_type, data=input_data if include_data else None)
exec_result = {'output': input_data, 'status': 'success', 'log': log_data}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, exec_result.get('output'), duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': input_data,
'status': 'failed',
'error': f'日志记录失败: {str(e)}'
}
elif node_type == 'error_handler':
# 错误处理节点:捕获错误、错误重试、错误通知
# 注意:这个节点需要特殊处理,因为它应该包装其他节点的执行
# 这里我们实现一个简化版本,主要用于错误重试和通知
node_data = node.get('data', {})
retry_count = node_data.get('retry_count', 3)
retry_delay = node_data.get('retry_delay', 1000) # 毫秒
on_error = node_data.get('on_error', 'notify') # notify, retry, stop
error_handler_workflow = node_data.get('error_handler_workflow', '')
# 这个节点通常用于包装其他节点,但在这里我们只处理输入数据中的错误
try:
# 检查输入数据中是否有错误
if isinstance(input_data, dict) and input_data.get('status') == 'failed':
error = input_data.get('error', '未知错误')
if on_error == 'retry' and retry_count > 0:
# 重试逻辑(这里简化处理,实际应该重新执行前一个节点)
logger.warning(f"错误处理节点检测到错误,将重试: {error}")
# 注意:实际重试需要重新执行前一个节点,这里只记录
exec_result = {
'output': input_data,
'status': 'retry',
'retry_count': retry_count,
'error': error
}
elif on_error == 'notify':
# 通知错误(记录日志)
logger.error(f"错误处理节点捕获错误: {error}")
exec_result = {
'output': input_data,
'status': 'error_handled',
'error': error,
'notified': True
}
else:
# 停止执行
exec_result = {
'output': input_data,
'status': 'failed',
'error': error,
'stopped': True
}
else:
# 没有错误,正常通过
exec_result = {'output': input_data, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, exec_result.get('output'), duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': input_data,
'status': 'failed',
'error': f'错误处理失败: {str(e)}'
}
elif node_type == 'csv':
# CSV处理节点CSV解析、生成、转换
node_data = node.get('data', {})
operation = node_data.get('operation', 'parse') # parse, generate, convert
delimiter = node_data.get('delimiter', ',')
headers = node_data.get('headers', True)
encoding = node_data.get('encoding', 'utf-8')
try:
import csv
import io
if operation == 'parse':
# 解析CSV
csv_text = input_data
if isinstance(input_data, dict):
# 尝试从字典中提取CSV文本
for key in ['csv', 'data', 'content', 'text']:
if key in input_data and isinstance(input_data[key], str):
csv_text = input_data[key]
break
if not isinstance(csv_text, str):
csv_text = str(csv_text)
# 解析CSV
csv_reader = csv.DictReader(io.StringIO(csv_text), delimiter=delimiter) if headers else csv.reader(io.StringIO(csv_text), delimiter=delimiter)
if headers:
result = list(csv_reader)
else:
# 没有表头,返回数组的数组
result = list(csv_reader)
elif operation == 'generate':
# 生成CSV
data = input_data
if isinstance(input_data, dict) and 'data' in input_data:
data = input_data['data']
if not isinstance(data, list):
data = [data]
output = io.StringIO()
if data and isinstance(data[0], dict):
# 字典列表使用DictWriter
fieldnames = data[0].keys()
writer = csv.DictWriter(output, fieldnames=fieldnames, delimiter=delimiter)
if headers:
writer.writeheader()
writer.writerows(data)
else:
# 数组列表使用writer
writer = csv.writer(output, delimiter=delimiter)
if headers and data and isinstance(data[0], list):
# 假设第一行是表头
writer.writerow(data[0])
writer.writerows(data[1:])
else:
writer.writerows(data)
result = output.getvalue()
elif operation == 'convert':
# 转换CSV格式改变分隔符等
csv_text = input_data
if isinstance(input_data, dict):
for key in ['csv', 'data', 'content']:
if key in input_data and isinstance(input_data[key], str):
csv_text = input_data[key]
break
if not isinstance(csv_text, str):
csv_text = str(csv_text)
# 读取并重新写入
reader = csv.reader(io.StringIO(csv_text))
output = io.StringIO()
writer = csv.writer(output, delimiter=delimiter)
writer.writerows(reader)
result = output.getvalue()
else:
result = input_data
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'CSV处理失败: {str(e)}'
}
elif node_type == 'object_storage':
# 对象存储节点:文件上传、下载、删除、列表
node_data = node.get('data', {})
provider = node_data.get('provider', 's3') # oss, s3, cos
operation = node_data.get('operation', 'upload')
bucket = node_data.get('bucket', '')
key = node_data.get('key', '')
file_data = node_data.get('file', '')
try:
# 简化实现实际生产环境需要使用boto3AWS S3、oss2阿里云OSS
# 这里提供一个接口框架实际使用时需要安装相应的SDK
if operation == 'upload':
# 上传文件
if not file_data:
# 从input_data中获取文件数据
if isinstance(input_data, dict):
file_data = input_data.get('file') or input_data.get('data') or input_data.get('content')
else:
file_data = input_data
# 替换key中的变量
if isinstance(input_data, dict):
for k, v in input_data.items():
key = key.replace(f'{{{k}}}', str(v))
# 这里只是模拟上传实际需要调用相应的SDK
logger.info(f"对象存储上传: provider={provider}, bucket={bucket}, key={key}")
result = {
'provider': provider,
'bucket': bucket,
'key': key,
'status': 'uploaded',
'url': f"{provider}://{bucket}/{key}" # 模拟URL
}
elif operation == 'download':
# 下载文件
if isinstance(input_data, dict):
for k, v in input_data.items():
key = key.replace(f'{{{k}}}', str(v))
logger.info(f"对象存储下载: provider={provider}, bucket={bucket}, key={key}")
# 这里只是模拟下载实际需要调用相应的SDK
result = {
'provider': provider,
'bucket': bucket,
'key': key,
'status': 'downloaded',
'data': '模拟文件内容' # 实际应该是文件内容
}
elif operation == 'delete':
# 删除文件
if isinstance(input_data, dict):
for k, v in input_data.items():
key = key.replace(f'{{{k}}}', str(v))
logger.info(f"对象存储删除: provider={provider}, bucket={bucket}, key={key}")
result = {
'provider': provider,
'bucket': bucket,
'key': key,
'status': 'deleted'
}
elif operation == 'list':
# 列出文件
prefix = node_data.get('prefix', '')
logger.info(f"对象存储列表: provider={provider}, bucket={bucket}, prefix={prefix}")
result = {
'provider': provider,
'bucket': bucket,
'prefix': prefix,
'files': [] # 实际应该是文件列表
}
else:
result = input_data
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'对象存储操作失败: {str(e)}。注意实际使用需要安装相应的SDK如boto3、oss2等'
}
elif node_type == 'slack':
# Slack节点发送消息、创建频道、获取消息
node_data = node.get('data', {})
operation = node_data.get('operation', 'send_message')
token = node_data.get('token', '')
channel = node_data.get('channel', '')
message = node_data.get('message', '')
attachments = node_data.get('attachments', [])
try:
import httpx
# 替换消息中的变量
if isinstance(input_data, dict):
for key, value in input_data.items():
message = message.replace(f'{{{key}}}', str(value))
channel = channel.replace(f'{{{key}}}', str(value))
if operation == 'send_message':
# 发送消息
url = 'https://slack.com/api/chat.postMessage'
headers = {
'Authorization': f'Bearer {token}',
'Content-Type': 'application/json'
}
payload = {
'channel': channel,
'text': message
}
if attachments:
payload['attachments'] = attachments
# 注意实际使用时需要安装httpx库这里提供接口框架
# async with httpx.AsyncClient() as client:
# response = await client.post(url, headers=headers, json=payload)
# result = response.json()
# 模拟响应
logger.info(f"Slack发送消息: channel={channel}, message={message[:50]}")
result = {
'ok': True,
'channel': channel,
'ts': str(time.time()),
'message': {'text': message}
}
elif operation == 'create_channel':
# 创建频道
url = 'https://slack.com/api/conversations.create'
headers = {
'Authorization': f'Bearer {token}',
'Content-Type': 'application/json'
}
payload = {'name': channel}
logger.info(f"Slack创建频道: channel={channel}")
result = {'ok': True, 'channel': {'name': channel, 'id': f'C{int(time.time())}'}}
elif operation == 'get_messages':
# 获取消息
url = f'https://slack.com/api/conversations.history'
headers = {
'Authorization': f'Bearer {token}',
'Content-Type': 'application/json'
}
params = {'channel': channel}
logger.info(f"Slack获取消息: channel={channel}")
result = {'ok': True, 'messages': []}
else:
result = input_data
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'Slack操作失败: {str(e)}。注意需要配置有效的Slack Token'
}
elif node_type == 'dingtalk' or node_type == 'dingding':
# 钉钉节点:发送消息、创建群组、获取消息
node_data = node.get('data', {})
operation = node_data.get('operation', 'send_message')
webhook_url = node_data.get('webhook_url', '')
access_token = node_data.get('access_token', '')
chat_id = node_data.get('chat_id', '')
message = node_data.get('message', '')
try:
import httpx
# 替换消息中的变量
if isinstance(input_data, dict):
for key, value in input_data.items():
message = message.replace(f'{{{key}}}', str(value))
chat_id = chat_id.replace(f'{{{key}}}', str(value))
if operation == 'send_message':
# 发送消息通过Webhook或API
if webhook_url:
# 使用Webhook
payload = {
'msgtype': 'text',
'text': {'content': message}
}
# async with httpx.AsyncClient() as client:
# response = await client.post(webhook_url, json=payload)
# result = response.json()
logger.info(f"钉钉发送消息Webhook: message={message[:50]}")
result = {'errcode': 0, 'errmsg': 'ok'}
else:
# 使用API
url = f'https://oapi.dingtalk.com/chat/send'
headers = {
'Content-Type': 'application/json'
}
payload = {
'access_token': access_token,
'chatid': chat_id,
'msg': {
'msgtype': 'text',
'text': {'content': message}
}
}
logger.info(f"钉钉发送消息API: chat_id={chat_id}, message={message[:50]}")
result = {'errcode': 0, 'errmsg': 'ok'}
elif operation == 'create_group':
# 创建群组
url = 'https://oapi.dingtalk.com/chat/create'
headers = {'Content-Type': 'application/json'}
payload = {
'access_token': access_token,
'name': chat_id,
'owner': node_data.get('owner', '')
}
logger.info(f"钉钉创建群组: name={chat_id}")
result = {'errcode': 0, 'errmsg': 'ok', 'chatid': f'chat_{int(time.time())}'}
else:
result = input_data
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'钉钉操作失败: {str(e)}。注意需要配置有效的Webhook URL或Access Token'
}
elif node_type == 'wechat_work' or node_type == 'wecom':
# 企业微信节点:发送消息、创建群组、获取消息
node_data = node.get('data', {})
operation = node_data.get('operation', 'send_message')
corp_id = node_data.get('corp_id', '')
corp_secret = node_data.get('corp_secret', '')
agent_id = node_data.get('agent_id', '')
chat_id = node_data.get('chat_id', '')
message = node_data.get('message', '')
try:
import httpx
# 替换消息中的变量
if isinstance(input_data, dict):
for key, value in input_data.items():
message = message.replace(f'{{{key}}}', str(value))
chat_id = chat_id.replace(f'{{{key}}}', str(value))
if operation == 'send_message':
# 先获取access_token
token_url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken'
token_params = {
'corpid': corp_id,
'corpsecret': corp_secret
}
# async with httpx.AsyncClient() as client:
# token_response = await client.get(token_url, params=token_params)
# token_data = token_response.json()
# access_token = token_data.get('access_token')
# 模拟获取token
access_token = 'mock_token'
# 发送消息
url = f'https://qyapi.weixin.qq.com/cgi-bin/message/send'
params = {'access_token': access_token}
payload = {
'touser': chat_id or '@all',
'msgtype': 'text',
'agentid': agent_id,
'text': {'content': message}
}
logger.info(f"企业微信发送消息: chat_id={chat_id}, message={message[:50]}")
result = {'errcode': 0, 'errmsg': 'ok'}
elif operation == 'create_group':
# 创建群组
url = 'https://qyapi.weixin.qq.com/cgi-bin/appchat/create'
params = {'access_token': access_token}
payload = {
'name': chat_id,
'owner': node_data.get('owner', ''),
'userlist': node_data.get('userlist', [])
}
logger.info(f"企业微信创建群组: name={chat_id}")
result = {'errcode': 0, 'errmsg': 'ok', 'chatid': f'chat_{int(time.time())}'}
else:
result = input_data
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'企业微信操作失败: {str(e)}。注意需要配置有效的Corp ID和Secret'
}
elif node_type == 'sms':
# 短信节点SMS发送短信、批量发送、短信模板
node_data = node.get('data', {})
provider = node_data.get('provider', 'aliyun') # aliyun, tencent, twilio
operation = node_data.get('operation', 'send')
phone = node_data.get('phone', '')
template = node_data.get('template', '')
sign = node_data.get('sign', '')
access_key = node_data.get('access_key', '')
access_secret = node_data.get('access_secret', '')
try:
# 替换模板中的变量
if isinstance(input_data, dict):
for key, value in input_data.items():
template = template.replace(f'{{{key}}}', str(value))
phone = phone.replace(f'{{{key}}}', str(value))
if operation == 'send':
# 发送短信
if provider == 'aliyun':
# 阿里云短信需要安装alibabacloud-dysmsapi20170525
logger.info(f"阿里云短信发送: phone={phone}, template={template[:50]}")
result = {
'provider': 'aliyun',
'phone': phone,
'status': 'sent',
'message_id': f'sms_{int(time.time())}'
}
elif provider == 'tencent':
# 腾讯云短信需要安装tencentcloud-sdk-python
logger.info(f"腾讯云短信发送: phone={phone}, template={template[:50]}")
result = {
'provider': 'tencent',
'phone': phone,
'status': 'sent',
'message_id': f'sms_{int(time.time())}'
}
elif provider == 'twilio':
# Twilio短信需要安装twilio
logger.info(f"Twilio短信发送: phone={phone}, template={template[:50]}")
result = {
'provider': 'twilio',
'phone': phone,
'status': 'sent',
'message_id': f'sms_{int(time.time())}'
}
else:
result = {'error': f'不支持的短信提供商: {provider}'}
elif operation == 'batch_send':
# 批量发送
phones = node_data.get('phones', [])
if isinstance(phones, str):
phones = [p.strip() for p in phones.split(',')]
logger.info(f"批量发送短信: phones={len(phones)}, provider={provider}")
result = {
'provider': provider,
'phones': phones,
'status': 'sent',
'count': len(phones)
}
else:
result = input_data
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'短信发送失败: {str(e)}。注意需要安装相应的SDK如alibabacloud-dysmsapi20170525、tencentcloud-sdk-python、twilio'
}
elif node_type == 'pdf':
# PDF处理节点PDF解析、生成、合并、拆分
node_data = node.get('data', {})
operation = node_data.get('operation', 'extract_text') # extract_text, generate, merge, split
pages = node_data.get('pages', '')
template = node_data.get('template', '')
try:
# 注意需要安装PyPDF2或pdfplumber库
# pip install PyPDF2 pdfplumber
if operation == 'extract_text':
# 提取文本
pdf_data = input_data
if isinstance(input_data, dict):
pdf_data = input_data.get('pdf') or input_data.get('data') or input_data.get('file')
# 这里只是接口框架,实际需要:
# from PyPDF2 import PdfReader
# reader = PdfReader(io.BytesIO(pdf_data))
# text = ""
# for page in reader.pages:
# text += page.extract_text()
logger.info(f"PDF提取文本: pages={pages}")
result = {
'text': 'PDF文本提取结果需要安装PyPDF2或pdfplumber',
'pages': pages or 'all'
}
elif operation == 'generate':
# 生成PDF
content = input_data
if isinstance(input_data, dict):
content = input_data.get('content') or input_data.get('text') or input_data.get('data')
# 这里只是接口框架,实际需要:
# from reportlab.pdfgen import canvas
# 或使用其他PDF生成库
logger.info(f"PDF生成: template={template}")
result = {
'pdf': 'PDF生成结果需要安装reportlab或其他PDF生成库',
'template': template
}
elif operation == 'merge':
# 合并PDF
pdfs = input_data
if isinstance(input_data, dict):
pdfs = input_data.get('pdfs') or input_data.get('files')
if not isinstance(pdfs, list):
pdfs = [pdfs]
# 这里只是接口框架,实际需要:
# from PyPDF2 import PdfMerger
# merger = PdfMerger()
# for pdf in pdfs:
# merger.append(pdf)
# result_pdf = merger.write()
logger.info(f"PDF合并: count={len(pdfs)}")
result = {
'merged_pdf': '合并后的PDF需要安装PyPDF2',
'count': len(pdfs)
}
elif operation == 'split':
# 拆分PDF
pdf_data = input_data
if isinstance(input_data, dict):
pdf_data = input_data.get('pdf') or input_data.get('file')
# 这里只是接口框架,实际需要:
# from PyPDF2 import PdfReader, PdfWriter
# reader = PdfReader(pdf_data)
# writer = PdfWriter()
# for page_num in range(start_page, end_page):
# writer.add_page(reader.pages[page_num])
logger.info(f"PDF拆分: pages={pages}")
result = {
'split_pdfs': ['拆分后的PDF列表需要安装PyPDF2'],
'pages': pages
}
else:
result = input_data
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'PDF处理失败: {str(e)}。注意需要安装PyPDF2或pdfplumber库pip install PyPDF2 pdfplumber'
}
elif node_type == 'image':
# 图像处理节点图像缩放、裁剪、格式转换、OCR识别
node_data = node.get('data', {})
operation = node_data.get('operation', 'resize') # resize, crop, convert, ocr
width = node_data.get('width', 800)
height = node_data.get('height', 600)
format_type = node_data.get('format', 'png')
try:
# 注意需要安装Pillow库
# pip install Pillow
# OCR需要安装pytesseract和tesseract-ocr
# pip install pytesseract
image_data = input_data
if isinstance(input_data, dict):
image_data = input_data.get('image') or input_data.get('data') or input_data.get('file')
if operation == 'resize':
# 缩放图像
# 这里只是接口框架,实际需要:
# from PIL import Image
# import io
# img = Image.open(io.BytesIO(image_data))
# img_resized = img.resize((width, height))
# output = io.BytesIO()
# img_resized.save(output, format=format_type.upper())
# result = output.getvalue()
logger.info(f"图像缩放: {width}x{height}, format={format_type}")
result = {
'image': '缩放后的图像数据需要安装Pillow',
'width': width,
'height': height,
'format': format_type
}
elif operation == 'crop':
# 裁剪图像
x = node_data.get('x', 0)
y = node_data.get('y', 0)
crop_width = node_data.get('crop_width', width)
crop_height = node_data.get('crop_height', height)
# 这里只是接口框架,实际需要:
# from PIL import Image
# img = Image.open(io.BytesIO(image_data))
# img_cropped = img.crop((x, y, x + crop_width, y + crop_height))
logger.info(f"图像裁剪: ({x}, {y}, {crop_width}, {crop_height})")
result = {
'image': '裁剪后的图像数据需要安装Pillow',
'crop_box': (x, y, crop_width, crop_height)
}
elif operation == 'convert':
# 格式转换
target_format = node_data.get('target_format', format_type)
# 这里只是接口框架,实际需要:
# from PIL import Image
# img = Image.open(io.BytesIO(image_data))
# output = io.BytesIO()
# img.save(output, format=target_format.upper())
# result = output.getvalue()
logger.info(f"图像格式转换: {format_type} -> {target_format}")
result = {
'image': f'转换后的图像数据需要安装Pillow',
'format': target_format
}
elif operation == 'ocr':
# OCR识别
# 这里只是接口框架,实际需要:
# from PIL import Image
# import pytesseract
# img = Image.open(io.BytesIO(image_data))
# text = pytesseract.image_to_string(img, lang='chi_sim+eng')
logger.info(f"OCR识别")
result = {
'text': 'OCR识别结果需要安装pytesseract和tesseract-ocr',
'confidence': 0.95
}
else:
result = input_data
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'图像处理失败: {str(e)}。注意需要安装Pillow库pip install PillowOCR需要pytesseract和tesseract-ocr'
}
elif node_type == 'excel':
# Excel处理节点Excel读取、写入、格式转换、公式计算
node_data = node.get('data', {})
operation = node_data.get('operation', 'read') # read, write, convert, formula
sheet = node_data.get('sheet', 'Sheet1')
range_str = node_data.get('range', '')
format_type = node_data.get('format', 'xlsx') # xlsx, xls, csv
try:
# 注意需要安装openpyxl或pandas库
# pip install openpyxl pandas
if operation == 'read':
# 读取Excel
excel_data = input_data
if isinstance(input_data, dict):
excel_data = input_data.get('excel') or input_data.get('file') or input_data.get('data')
# 这里只是接口框架,实际需要:
# import pandas as pd
# df = pd.read_excel(io.BytesIO(excel_data), sheet_name=sheet)
# if range_str:
# # 解析范围,如 "A1:C10"
# df = df.loc[range_start:range_end]
# result = df.to_dict('records')
logger.info(f"Excel读取: sheet={sheet}, range={range_str}")
result = {
'data': [{'列1': '值1', '列2': '值2'}],
'sheet': sheet,
'range': range_str
}
elif operation == 'write':
# 写入Excel
data = input_data
if isinstance(input_data, dict):
data = input_data.get('data') or input_data.get('rows')
if not isinstance(data, list):
data = [data]
# 这里只是接口框架,实际需要:
# import pandas as pd
# df = pd.DataFrame(data)
# output = io.BytesIO()
# df.to_excel(output, sheet_name=sheet, index=False)
# result = output.getvalue()
logger.info(f"Excel写入: sheet={sheet}, rows={len(data)}")
result = {
'excel': '生成的Excel数据需要安装openpyxl或pandas',
'sheet': sheet,
'rows': len(data)
}
elif operation == 'convert':
# 格式转换
target_format = node_data.get('target_format', 'csv')
excel_data = input_data
if isinstance(input_data, dict):
excel_data = input_data.get('excel') or input_data.get('file')
# 这里只是接口框架,实际需要:
# import pandas as pd
# df = pd.read_excel(io.BytesIO(excel_data))
# if target_format == 'csv':
# result = df.to_csv(index=False)
# elif target_format == 'json':
# result = df.to_json(orient='records')
logger.info(f"Excel格式转换: {format_type} -> {target_format}")
result = {
'data': '转换后的数据需要安装pandas',
'format': target_format
}
elif operation == 'formula':
# 公式计算
formula = node_data.get('formula', '')
data = input_data
if isinstance(input_data, dict):
data = input_data.get('data')
# 这里只是接口框架,实际需要:
# import pandas as pd
# df = pd.DataFrame(data)
# # 使用eval或更安全的方式计算公式
# result = df.eval(formula)
logger.info(f"Excel公式计算: formula={formula}")
result = {
'result': '公式计算结果需要安装pandas',
'formula': formula
}
else:
result = input_data
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'Excel处理失败: {str(e)}。注意需要安装openpyxl或pandas库pip install openpyxl pandas'
}
elif node_type == 'subworkflow':
# 子工作流节点:调用其他工作流
node_data = node.get('data', {})
workflow_id = node_data.get('workflow_id', '')
input_mapping = node_data.get('input_mapping', {})
try:
# 将当前输入根据映射转换为子工作流输入
sub_input = {}
if isinstance(input_mapping, dict):
for k, v in input_mapping.items():
if isinstance(v, str) and isinstance(input_data, dict):
sub_input[k] = input_data.get(v) or input_data.get(v.strip('{}')) or v
else:
sub_input[k] = v
else:
sub_input = input_data
# 实际调用子工作流的执行,这里简化为回传映射后的输入
# TODO: 集成 WorkflowEngine 执行指定 workflow_id
result = {
'workflow_id': workflow_id,
'input': sub_input,
'status': 'success',
'note': '子工作流执行框架占位,需集成实际调用'
}
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'子工作流执行失败: {str(e)}'
}
elif node_type == 'code':
# 代码执行节点支持简单的Python/JavaScript片段执行注意安全
node_data = node.get('data', {})
language = node_data.get('language', 'python')
code = node_data.get('code', '')
timeout = node_data.get('timeout', 30)
try:
if language.lower() == 'python':
# 受限执行环境
local_vars = {'input_data': input_data, 'result': None}
exec(code, {'__builtins__': {}}, local_vars) # 注意:生产环境需更严格沙箱
result = local_vars.get('result', local_vars.get('output', input_data))
elif language.lower() == 'javascript':
# JS 执行需要外部运行时,这里仅占位
result = {
'status': 'not_implemented',
'message': 'JavaScript执行需集成运行时'
}
else:
result = {'status': 'failed', 'error': f'不支持的语言: {language}'}
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'代码执行失败: {str(e)}'
}
elif node_type == 'oauth':
# OAuth 节点:获取/刷新 Token
node_data = node.get('data', {})
provider = node_data.get('provider', 'google')
client_id = node_data.get('client_id', '')
client_secret = node_data.get('client_secret', '')
scopes = node_data.get('scopes', [])
try:
# 简化占位实现,返回模拟 token
token_data = {
'access_token': f'mock_access_token_{provider}',
'expires_in': 3600,
'token_type': 'Bearer',
'scope': scopes
}
exec_result = {'output': token_data, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, token_data, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'OAuth处理失败: {str(e)}'
}
elif node_type == 'validator':
# 数据验证节点基于简化的schema检查
node_data = node.get('data', {})
schema = node_data.get('schema', {})
on_error = node_data.get('on_error', 'reject') # reject, continue, transform
try:
# 简单类型检查
if 'type' in schema:
expected_type = schema['type']
actual_type = type(input_data).__name__
if expected_type == 'object' and not isinstance(input_data, dict):
raise ValueError(f'期望类型object实际类型{actual_type}')
if expected_type == 'array' and not isinstance(input_data, list):
raise ValueError(f'期望类型array实际类型{actual_type}')
result = input_data
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if on_error == 'continue':
return {'output': input_data, 'status': 'success', 'warning': str(e)}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'数据验证失败: {str(e)}'
}
elif node_type == 'batch':
# 批处理节点:数据分批处理
node_data = node.get('data', {})
batch_size = node_data.get('batch_size', 100)
mode = node_data.get('mode', 'split') # split, group, aggregate
wait_for_completion = node_data.get('wait_for_completion', True)
try:
data_list = input_data if isinstance(input_data, list) else [input_data]
if mode == 'split':
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
result = batches
elif mode == 'group':
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
result = batches
elif mode == 'aggregate':
result = {
'count': len(data_list),
'samples': data_list[:min(3, len(data_list))]
}
else:
result = data_list
exec_result = {'output': result, 'status': 'success', 'wait': wait_for_completion}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'批处理失败: {str(e)}'
}
elif node_type == 'output' or node_type == 'end':
# 输出节点:返回最终结果
# 读取节点配置中的输出格式设置
node_data = node.get('data', {})
output_format = node_data.get('output_format', 'text') # 默认纯文本
logger.debug(f"[rjb] End节点处理: node_id={node_id}, output_format={output_format}, input_data={input_data}, input_data type={type(input_data)}")
final_output = input_data
# 如果配置为JSON格式直接返回原始数据或格式化的JSON
if output_format == 'json':
# 如果是字典直接返回JSON格式
if isinstance(input_data, dict):
final_output = json_module.dumps(input_data, ensure_ascii=False, indent=2)
elif isinstance(input_data, str):
# 尝试解析为JSON如果成功则格式化否则直接返回
try:
parsed = json_module.loads(input_data)
final_output = json_module.dumps(parsed, ensure_ascii=False, indent=2)
except:
final_output = input_data
else:
final_output = json_module.dumps({'output': input_data}, ensure_ascii=False, indent=2)
else:
# 默认纯文本格式:递归解包,提取实际的文本内容
if isinstance(input_data, dict):
# 优先提取 'output' 字段LLM节点的标准输出格式
if 'output' in input_data and isinstance(input_data['output'], str):
final_output = input_data['output']
# 如果只有一个 key 且是 'input',提取其值
elif len(input_data) == 1 and 'input' in input_data:
final_output = input_data['input']
# 如果包含 'solution' 字段,提取其值
elif 'solution' in input_data and isinstance(input_data['solution'], str):
final_output = input_data['solution']
# 如果input_data是字符串类型的字典JSON字符串尝试解析
elif isinstance(input_data, str):
try:
parsed = json_module.loads(input_data)
if isinstance(parsed, dict) and 'output' in parsed:
final_output = parsed['output']
elif isinstance(parsed, str):
final_output = parsed
except:
final_output = input_data
logger.debug(f"[rjb] End节点提取第一层: final_output={final_output}, type={type(final_output)}")
# 如果提取的值仍然是字典且只有一个 'input' key继续提取
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
final_output = final_output['input']
logger.debug(f"[rjb] End节点提取第二层: final_output={final_output}, type={type(final_output)}")
# 确保最终输出是字符串(对于人机交互场景)
# 如果是字典,尝试转换为字符串;如果是其他类型,也转换为字符串
if not isinstance(final_output, str):
if isinstance(final_output, dict):
# 如果是字典尝试提取文本内容或转换为JSON字符串
# 优先查找常见的文本字段
if 'text' in final_output:
final_output = str(final_output['text'])
elif 'content' in final_output:
final_output = str(final_output['content'])
elif 'message' in final_output:
final_output = str(final_output['message'])
elif 'response' in final_output:
final_output = str(final_output['response'])
elif len(final_output) == 1:
# 如果只有一个key直接使用其值
final_output = str(list(final_output.values())[0])
else:
# 否则转换为纯文本不是JSON
# 尝试提取所有文本字段并组合,但排除系统字段和用户查询字段
text_parts = []
exclude_keys = {'status', 'error', 'timestamp', 'node_id', 'execution_time', 'query', 'USER_INPUT', 'user_input', 'user_query'}
# 优先使用input字段LLM的实际输出
if 'input' in final_output and isinstance(final_output['input'], str):
final_output = final_output['input']
else:
for key, value in final_output.items():
if key in exclude_keys:
continue
if isinstance(value, str) and value.strip():
# 如果值本身已经包含 "key: " 格式,直接使用
if value.strip().startswith(f"{key}:"):
text_parts.append(value.strip())
else:
text_parts.append(value.strip())
elif isinstance(value, (int, float, bool)):
text_parts.append(f"{key}: {value}")
if text_parts:
final_output = "\n".join(text_parts)
else:
final_output = str(final_output)
else:
final_output = str(final_output)
# 清理输出文本:移除常见的字段前缀(如 "input: ", "query: " 等)
if isinstance(final_output, str):
import re
# 移除行首的 "input: ", "query: ", "output: " 等前缀
lines = final_output.split('\n')
cleaned_lines = []
for line in lines:
# 匹配行首的 "字段名: " 格式并移除
# 但保留内容本身
line = re.sub(r'^(input|query|output|result|response|message|content|text):\s*', '', line, flags=re.IGNORECASE)
if line.strip(): # 只保留非空行
cleaned_lines.append(line)
# 如果清理后还有内容,使用清理后的版本
if cleaned_lines:
final_output = '\n'.join(cleaned_lines)
# 如果清理后为空,使用原始输出(避免丢失所有内容)
elif final_output.strip():
# 如果原始输出不为空,但清理后为空,说明可能格式特殊,尝试更宽松的清理
# 只移除明显的 "input: " 和 "query: " 前缀,保留其他内容
final_output = re.sub(r'^(input|query):\s*', '', final_output, flags=re.IGNORECASE | re.MULTILINE)
if not final_output.strip():
final_output = str(input_data) # 如果还是空,使用原始输入
logger.debug(f"[rjb] End节点最终输出: output_format={output_format}, final_output={final_output[:100] if isinstance(final_output, str) else type(final_output)}")
result = {'output': final_output, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, final_output, duration)
return result
else:
# 未知节点类型
logger.warning(f"[rjb] 未知节点类型: node_id={node_id}, node_type={node_type}, node keys={list(node.keys())}")
return {
'output': input_data,
'status': 'success',
'message': f'节点类型 {node_type} 暂未实现'
}
except Exception as e:
logger.error(f"节点执行失败: {node_id} ({node_type}) - {str(e)}", exc_info=True)
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': str(e),
'node_id': node_id,
'node_type': node_type
}
async def execute(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
"""
执行完整工作流
Args:
input_data: 初始输入数据
Returns:
执行结果
"""
# 记录工作流开始执行
if self.logger:
self.logger.info("工作流开始执行", data={"input": input_data})
# 初始化节点输出
self.node_outputs = {}
active_edges = self.edges.copy() # 活跃的边列表
executed_nodes = set() # 已执行的节点
# 按拓扑顺序执行节点(动态构建执行图)
results = {}
while True:
# 构建当前活跃的执行图
execution_order = self.build_execution_graph(active_edges)
logger.debug(f"[rjb] 当前执行图: {execution_order}, 活跃边数: {len(active_edges)}, 已执行节点: {executed_nodes}")
# 找到下一个要执行的节点未执行且入度为0
next_node_id = None
for node_id in execution_order:
if node_id not in executed_nodes:
# 检查所有前置节点是否已执行
can_execute = True
incoming_edges = [e for e in active_edges if e['target'] == node_id]
if not incoming_edges:
# 没有入边,可能是起始节点或孤立节点
if node_id not in [n['id'] for n in self.nodes.values() if n.get('type') == 'start']:
# 不是起始节点,但有入边被过滤了,不应该执行
logger.debug(f"[rjb] 节点 {node_id} 没有入边,跳过执行")
continue
for edge in incoming_edges:
if edge['source'] not in executed_nodes:
can_execute = False
logger.debug(f"[rjb] 节点 {node_id} 的前置节点 {edge['source']} 未执行,不能执行")
break
if can_execute:
next_node_id = node_id
logger.info(f"[rjb] 选择执行节点: {next_node_id}, 类型: {self.nodes[next_node_id].get('type')}, 入边数: {len(incoming_edges)}")
break
if not next_node_id:
break # 没有更多节点可执行
node = self.nodes[next_node_id]
executed_nodes.add(next_node_id)
# 调试:检查节点数据结构
if node.get('type') == 'llm':
logger.debug(f"[rjb] 执行LLM节点: node_id={next_node_id}, node keys={list(node.keys())}, data keys={list(node.get('data', {}).keys()) if node.get('data') else []}")
# 获取节点输入(使用活跃的边)
node_input = self.get_node_input(next_node_id, self.node_outputs, active_edges)
# 如果是起始节点,使用初始输入
if node.get('type') == 'start' and not node_input:
node_input = input_data
logger.info(f"[rjb] Start节点使用初始输入: node_id={next_node_id}, node_input={node_input}")
# 调试:记录节点输入数据
if node.get('type') == 'llm':
logger.info(f"[rjb] LLM节点输入: node_id={next_node_id}, node_input={node_input}, node_outputs keys={list(self.node_outputs.keys())}")
if 'start-1' in self.node_outputs:
logger.info(f"[rjb] Start节点输出内容: {self.node_outputs['start-1']}")
# 执行节点
result = await self.execute_node(node, node_input)
results[next_node_id] = result
# 保存节点输出
if result.get('status') == 'success':
output_value = result.get('output', {})
self.node_outputs[next_node_id] = output_value
if node.get('type') == 'start':
logger.info(f"[rjb] Start节点输出已保存: node_id={next_node_id}, output={output_value}, output_type={type(output_value)}")
# 如果是条件节点或Switch节点根据分支结果过滤边
if node.get('type') == 'condition':
branch = result.get('branch', 'false')
logger.info(f"[rjb] 条件节点分支过滤: node_id={next_node_id}, branch={branch}")
# 移除不符合条件的边
# 只保留1) 不是从条件节点出发的边,或 2) 从条件节点出发且sourceHandle匹配分支的边
edges_to_remove = []
edges_to_keep = []
for edge in active_edges:
if edge['source'] == next_node_id:
# 这是从条件节点出发的边
edge_handle = edge.get('sourceHandle')
if edge_handle == branch:
# sourceHandle匹配分支保留
edges_to_keep.append(edge)
logger.info(f"[rjb] 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})")
else:
# sourceHandle不匹配或为None移除
edges_to_remove.append(edge)
logger.info(f"[rjb] 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch})")
else:
# 不是从条件节点出发的边,保留
edges_to_keep.append(edge)
active_edges = edges_to_keep
elif node.get('type') == 'switch':
branch = result.get('branch', 'default')
logger.info(f"[rjb] Switch节点分支过滤: node_id={next_node_id}, branch={branch}")
# 记录过滤前的边信息
edges_before = [e for e in active_edges if e['source'] == next_node_id]
logger.info(f"[rjb] Switch节点过滤前: 从节点出发的边有{len(edges_before)}")
for edge in edges_before:
logger.info(f"[rjb] 边 {edge.get('id')}: sourceHandle={edge.get('sourceHandle')}, target={edge.get('target')}")
# 移除不匹配的边
edges_to_keep = []
edges_removed_count = 0
removed_source_nodes = set() # 记录被移除边的源节点
for edge in active_edges:
if edge['source'] == next_node_id:
# 这是从Switch节点出发的边
edge_handle = edge.get('sourceHandle')
if edge_handle == branch:
# sourceHandle匹配分支保留
edges_to_keep.append(edge)
logger.info(f"[rjb] ✅ 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})")
else:
# sourceHandle不匹配移除
edges_removed_count += 1
target_id = edge.get('target')
removed_source_nodes.add(target_id) # 记录目标节点(这些节点将不再可达)
logger.info(f"[rjb] ❌ 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch}, target={target_id})")
else:
# 不是从Switch节点出发的边保留
edges_to_keep.append(edge)
# 重要移除那些指向被过滤节点的边这些边来自被过滤的LLM节点
# 例如如果llm-question被过滤了那么llm-question → merge-response的边也应该被移除
additional_removed = 0
for edge in list(edges_to_keep): # 使用list副本因为我们要修改原列表
if edge['source'] in removed_source_nodes:
# 这条边来自被过滤的节点,也应该被移除
edges_to_keep.remove(edge)
additional_removed += 1
logger.info(f"[rjb] ❌ 移除来自被过滤节点的边: {edge.get('id')} ({edge.get('source')}{edge.get('target')})")
edges_removed_count += additional_removed
active_edges = edges_to_keep
filter_info = {
'branch': branch,
'edges_before': len(edges_before),
'edges_kept': len([e for e in edges_to_keep if e['source'] == next_node_id]),
'edges_removed': edges_removed_count
}
logger.info(f"[rjb] Switch节点过滤后: 保留{len(active_edges)}条边其中从Switch节点出发的{filter_info['edges_kept']}条),移除{edges_removed_count}条边")
# 记录过滤后的活跃边
remaining_switch_edges = [e for e in active_edges if e['source'] == next_node_id]
logger.info(f"[rjb] Switch节点过滤后剩余的边: {[e.get('id') + '->' + e.get('target') for e in remaining_switch_edges]}")
# 重要:找出那些不再可达的节点(这些节点只通过被移除的边连接)
removed_targets = set()
for edge in edges_before:
if edge not in edges_to_keep:
target_id = edge.get('target')
removed_targets.add(target_id)
logger.info(f"[rjb] ❌ 节点 {target_id} 的边已被移除,该节点将不会被执行")
# 关键修复:立即重新构建执行图,确保不再可达的节点不在执行图中
# 这样在下次循环时,这些节点就不会被选择执行
logger.info(f"[rjb] Switch节点过滤后重新构建执行图排除 {len(removed_targets)} 个不再可达的节点)")
# 同时记录到数据库
if self.logger:
self.logger.info(
f"Switch节点分支过滤: branch={branch}, 保留{filter_info['edges_kept']}条边,移除{edges_removed_count}条边",
node_id=next_node_id,
node_type='switch',
data=filter_info
)
# 如果是循环节点,跳过循环体的节点(循环体已在节点内部执行)
if node.get('type') in ['loop', 'foreach']:
# 标记循环体的节点为已执行(简化处理)
for edge in active_edges[:]: # 使用切片复制列表
if edge.get('source') == next_node_id:
target_id = edge.get('target')
if target_id in self.nodes:
# 检查是否是循环结束节点
target_node = self.nodes[target_id]
if target_node.get('type') not in ['loop_end', 'end']:
# 标记为已执行(循环体已在循环节点内部执行)
executed_nodes.add(target_id)
# 继续查找循环体内的节点
self._mark_loop_body_executed(target_id, executed_nodes, active_edges)
else:
# 执行失败,停止工作流
error_msg = result.get('error', '未知错误')
node_type = node.get('type', 'unknown')
logger.error(f"工作流执行失败 - 节点: {next_node_id} ({node_type}), 错误: {error_msg}")
raise WorkflowExecutionError(
detail=error_msg,
node_id=next_node_id
)
# 返回最终结果(最后一个执行的节点的输出)
if executed_nodes:
# 找到最后一个节点(没有出边的节点)
last_node_id = None
for node_id in executed_nodes:
has_outgoing = any(edge['source'] == node_id for edge in active_edges)
if not has_outgoing:
last_node_id = node_id
break
if not last_node_id:
# 如果没有找到,使用最后一个执行的节点
last_node_id = list(executed_nodes)[-1]
# 获取最终结果
final_output = self.node_outputs.get(last_node_id)
# 如果最终输出是字典且只有一个 'input' key提取其值
# 这样可以确保最终结果不是重复包装的格式
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
final_output = final_output['input']
# 如果提取的值仍然是字典且只有一个 'input' key继续提取
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
final_output = final_output['input']
# 确保最终结果是字符串(对于人机交互场景)
# 如果是字典,尝试转换为字符串;如果是其他类型,也转换为字符串
if not isinstance(final_output, str):
if isinstance(final_output, dict):
# 如果是字典尝试提取文本内容或转换为JSON字符串
# 优先查找常见的文本字段
if 'text' in final_output:
final_output = str(final_output['text'])
elif 'content' in final_output:
final_output = str(final_output['content'])
elif 'message' in final_output:
final_output = str(final_output['message'])
elif 'response' in final_output:
final_output = str(final_output['response'])
elif len(final_output) == 1:
# 如果只有一个key直接使用其值
final_output = str(list(final_output.values())[0])
else:
# 否则转换为JSON字符串
import json as json_module
final_output = json_module.dumps(final_output, ensure_ascii=False)
else:
final_output = str(final_output)
final_result = {
'status': 'completed',
'result': final_output,
'node_results': results
}
# 记录工作流执行完成
if self.logger:
self.logger.info("工作流执行完成", data={"result": final_result.get('result')})
return final_result
if self.logger:
self.logger.warn("工作流执行完成,但没有执行任何节点")
return {'status': 'completed', 'result': None}