Files
aiagent/backend/app/services/workflow_engine.py
2026-01-19 00:09:36 +08:00

1667 lines
82 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工作流执行引擎
"""
from typing import Dict, Any, List, Optional
import asyncio
from collections import defaultdict, deque
import json
import logging
from app.services.llm_service import llm_service
from app.services.condition_parser import condition_parser
from app.services.data_transformer import data_transformer
from app.core.exceptions import WorkflowExecutionError
from app.core.database import SessionLocal
from app.models.agent import Agent
logger = logging.getLogger(__name__)
class WorkflowEngine:
"""工作流执行引擎"""
def __init__(self, workflow_id: str, workflow_data: Dict[str, Any], logger=None, db=None):
"""
初始化工作流引擎
Args:
workflow_id: 工作流ID
workflow_data: 工作流数据包含nodes和edges
logger: 执行日志记录器(可选)
db: 数据库会话可选用于Agent节点加载Agent配置
"""
self.workflow_id = workflow_id
self.nodes = {node['id']: node for node in workflow_data.get('nodes', [])}
self.edges = workflow_data.get('edges', [])
self.execution_graph = None
self.node_outputs = {}
self.logger = logger
self.db = db
def build_execution_graph(self, active_edges: Optional[List[Dict[str, Any]]] = None) -> List[str]:
"""
构建执行图DAG并返回拓扑排序结果
Args:
active_edges: 活跃的边列表(用于条件分支过滤)
Returns:
拓扑排序后的节点ID列表
"""
# 使用活跃的边,如果没有提供则使用所有边
edges_to_use = active_edges if active_edges is not None else self.edges
# 构建邻接表和入度表
graph = defaultdict(list)
in_degree = defaultdict(int)
# 初始化所有节点的入度
for node_id in self.nodes.keys():
in_degree[node_id] = 0
# 构建图
for edge in edges_to_use:
source = edge['source']
target = edge['target']
graph[source].append(target)
in_degree[target] += 1
# 拓扑排序Kahn算法
queue = deque()
result = []
# 找到所有入度为0的节点起始节点
for node_id in self.nodes.keys():
if in_degree[node_id] == 0:
queue.append(node_id)
while queue:
node_id = queue.popleft()
result.append(node_id)
# 处理该节点的所有出边
for neighbor in graph[node_id]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
# 检查是否有环(只检查可达节点)
reachable_nodes = set(result)
if len(reachable_nodes) < len(self.nodes):
# 有些节点不可达,这是正常的(条件分支)
pass
self.execution_graph = result
return result
def get_node_input(self, node_id: str, node_outputs: Dict[str, Any], active_edges: Optional[List[Dict[str, Any]]] = None) -> Dict[str, Any]:
"""
获取节点的输入数据
Args:
node_id: 节点ID
node_outputs: 所有节点的输出数据
active_edges: 活跃的边列表(用于条件分支过滤)
Returns:
节点的输入数据
"""
# 使用活跃的边,如果没有提供则使用所有边
edges_to_use = active_edges if active_edges is not None else self.edges
# 找到所有指向该节点的边
input_data = {}
for edge in edges_to_use:
if edge['target'] == node_id:
source_id = edge['source']
source_output = node_outputs.get(source_id, {})
logger.debug(f"[rjb] 获取节点输入: target={node_id}, source={source_id}, source_output={source_output}, sourceHandle={edge.get('sourceHandle')}")
# 如果有sourceHandle使用它作为key
if 'sourceHandle' in edge and edge['sourceHandle']:
input_data[edge['sourceHandle']] = source_output
else:
# 否则合并所有输入
if isinstance(source_output, dict):
input_data.update(source_output)
else:
input_data['input'] = source_output
logger.debug(f"[rjb] 节点输入结果: node_id={node_id}, input_data={input_data}")
return input_data
def _get_nested_value(self, data: Dict[str, Any], path: str) -> Any:
"""
从嵌套字典中获取值(支持点号路径和数组索引)
Args:
data: 数据字典
path: 路径,如 "user.name""items[0].price"
Returns:
路径对应的值
"""
if not path:
return data
parts = path.split('.')
result = data
for part in parts:
if '[' in part and ']' in part:
# 处理数组索引,如 "items[0]"
key = part[:part.index('[')]
index_str = part[part.index('[') + 1:part.index(']')]
if isinstance(result, dict):
result = result.get(key)
elif isinstance(result, list):
try:
result = result[int(index_str)]
except (ValueError, IndexError):
return None
else:
return None
if result is None:
return None
else:
# 普通键访问
if isinstance(result, dict):
result = result.get(part)
else:
return None
if result is None:
return None
return result
async def _execute_loop_body(self, loop_node_id: str, loop_input: Dict[str, Any], iteration_index: int) -> Dict[str, Any]:
"""
执行循环体
Args:
loop_node_id: 循环节点ID
loop_input: 循环体的输入数据
iteration_index: 当前迭代索引
Returns:
循环体的执行结果
"""
# 找到循环节点的直接子节点(循环体开始节点)
loop_body_start_nodes = []
for edge in self.edges:
if edge.get('source') == loop_node_id:
target_id = edge.get('target')
if target_id and target_id in self.nodes:
loop_body_start_nodes.append(target_id)
if not loop_body_start_nodes:
# 如果没有子节点,直接返回输入数据
return {'output': loop_input, 'status': 'success'}
# 执行循环体:从循环体开始节点执行到循环结束节点或没有更多节点
# 简化处理:只执行第一个子节点链
executed_in_loop = set()
loop_results = {}
current_node_id = loop_body_start_nodes[0] # 简化:只执行第一个子节点链
# 执行循环体内的节点(简化版本:只执行直接连接的子节点)
max_iterations = 100 # 防止无限循环
iteration = 0
while current_node_id and iteration < max_iterations:
iteration += 1
if current_node_id in executed_in_loop:
break # 避免循环体内部循环
if current_node_id not in self.nodes:
break
node = self.nodes[current_node_id]
executed_in_loop.add(current_node_id)
# 如果是循环结束节点,停止执行
if node.get('type') == 'loop_end' or node.get('type') == 'end':
break
# 执行节点
result = await self.execute_node(node, loop_input)
loop_results[current_node_id] = result
if result.get('status') != 'success':
return result
# 更新输入数据为当前节点的输出
if result.get('output'):
if isinstance(result.get('output'), dict):
loop_input = {**loop_input, **result.get('output')}
else:
loop_input = {**loop_input, 'result': result.get('output')}
# 找到下一个节点(简化:只找第一个子节点)
next_node_id = None
for edge in self.edges:
if edge.get('source') == current_node_id:
target_id = edge.get('target')
if target_id and target_id in self.nodes and target_id not in executed_in_loop:
# 跳过循环节点本身
if target_id != loop_node_id:
next_node_id = target_id
break
current_node_id = next_node_id
# 返回最后一个节点的输出
if loop_results:
last_result = list(loop_results.values())[-1]
return last_result
return {'output': loop_input, 'status': 'success'}
def _mark_loop_body_executed(self, node_id: str, executed_nodes: set, active_edges: List[Dict[str, Any]]):
"""
递归标记循环体内的节点为已执行
Args:
node_id: 当前节点ID
executed_nodes: 已执行节点集合
active_edges: 活跃的边列表
"""
if node_id in executed_nodes:
return
executed_nodes.add(node_id)
# 查找所有子节点
for edge in active_edges:
if edge.get('source') == node_id:
target_id = edge.get('target')
if target_id in self.nodes:
target_node = self.nodes[target_id]
# 如果是循环结束节点,停止递归
if target_node.get('type') in ['loop_end', 'end']:
continue
# 递归标记子节点
self._mark_loop_body_executed(target_id, executed_nodes, active_edges)
async def execute_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
"""
执行单个节点
Args:
node: 节点配置
input_data: 输入数据
Returns:
节点执行结果
"""
# 确保可以访问全局的 json 模块
import json as json_module
node_type = node.get('type', 'unknown')
node_id = node.get('id')
import time
start_time = time.time()
# 记录节点开始执行
if self.logger:
self.logger.log_node_start(node_id, node_type, input_data)
try:
if node_type == 'start':
# 起始节点:返回输入数据
logger.debug(f"[rjb] 开始节点执行: node_id={node_id}, input_data={input_data}")
result = {'output': input_data, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
logger.debug(f"[rjb] 开始节点输出: node_id={node_id}, output={result.get('output')}")
return result
elif node_type == 'input':
# 输入节点:处理输入数据
result = {'output': input_data, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
elif node_type == 'llm' or node_type == 'template':
# LLM节点调用AI模型
node_data = node.get('data', {})
logger.debug(f"[rjb] LLM节点执行: node_id={node_id}, input_data={input_data}, input_data type={type(input_data)}")
logger.debug(f"[rjb] LLM节点数据: node_id={node_id}, node_data keys={list(node_data.keys())}, api_key={'已配置' if node_data.get('api_key') else '未配置'}")
prompt = node_data.get('prompt', '')
# 如果prompt为空使用默认提示词
if not prompt:
prompt = "请处理以下输入数据:\n{input}"
# 格式化prompt替换变量
try:
# 将input_data转换为字符串用于格式化
if isinstance(input_data, dict):
# 如果prompt中包含变量尝试格式化
if '{' in prompt and '}' in prompt:
# 尝试格式化所有input_data中的键
formatted_prompt = prompt
for key, value in input_data.items():
placeholder = f'{{{key}}}'
if placeholder in formatted_prompt:
formatted_prompt = formatted_prompt.replace(
placeholder,
json_module.dumps(value, ensure_ascii=False) if isinstance(value, (dict, list)) else str(value)
)
# 如果还有{input}占位符替换为整个input_data
if '{input}' in formatted_prompt:
formatted_prompt = formatted_prompt.replace(
'{input}',
json_module.dumps(input_data, ensure_ascii=False)
)
prompt = formatted_prompt
else:
# 如果没有占位符将input_data作为JSON附加到prompt
prompt = f"{prompt}\n\n输入数据:\n{json_module.dumps(input_data, ensure_ascii=False)}"
else:
# 如果input_data不是dict直接转换为字符串
if '{input}' in prompt:
prompt = prompt.replace('{input}', str(input_data))
else:
prompt = f"{prompt}\n\n输入:{str(input_data)}"
except Exception as e:
# 格式化失败使用原始prompt和input_data
try:
prompt = f"{prompt}\n\n输入数据:\n{json_module.dumps(input_data, ensure_ascii=False)}"
except:
prompt = f"{prompt}\n\n输入数据:{str(input_data)}"
# 获取LLM配置
provider = node_data.get('provider', 'openai')
model = node_data.get('model', 'gpt-3.5-turbo')
temperature = node_data.get('temperature', 0.7)
max_tokens = node_data.get('max_tokens')
# 不传递 api_key 和 base_url让 LLM 服务使用系统默认配置(与节点测试保持一致)
api_key = None
base_url = None
# 调用LLM服务
try:
if self.logger:
logger.debug(f"[rjb] LLM节点配置: provider={provider}, model={model}, 使用系统默认API Key配置")
self.logger.info(f"调用LLM服务: {provider}/{model}", node_id=node_id, node_type=node_type)
result = await llm_service.call_llm(
prompt=prompt,
provider=provider,
model=model,
temperature=temperature,
max_tokens=max_tokens
# 不传递 api_key 和 base_url使用系统默认配置
)
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
# LLM调用失败返回错误
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'LLM调用失败: {str(e)}'
}
elif node_type == 'condition':
# 条件节点:判断分支
condition = node.get('data', {}).get('condition', '')
if not condition:
# 如果没有条件表达式默认返回False
return {
'output': False,
'status': 'success',
'branch': 'false'
}
# 使用条件解析器评估表达式
try:
result = condition_parser.evaluate_condition(condition, input_data)
exec_result = {
'output': result,
'status': 'success',
'branch': 'true' if result else 'false'
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, {'result': result, 'branch': exec_result['branch']}, duration)
return exec_result
except Exception as e:
# 条件评估失败
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': False,
'status': 'failed',
'error': f'条件评估失败: {str(e)}',
'branch': 'false'
}
elif node_type == 'data' or node_type == 'transform':
# 数据转换节点
node_data = node.get('data', {})
mapping = node_data.get('mapping', {})
filter_rules = node_data.get('filter_rules', [])
compute_rules = node_data.get('compute_rules', {})
mode = node_data.get('mode', 'mapping')
try:
result = data_transformer.transform_data(
input_data=input_data,
mapping=mapping,
filter_rules=filter_rules,
compute_rules=compute_rules,
mode=mode
)
exec_result = {'output': result, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result, duration)
return exec_result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'数据转换失败: {str(e)}'
}
elif node_type == 'loop' or node_type == 'foreach':
# 循环节点:对数组进行循环处理
node_data = node.get('data', {})
items_path = node_data.get('items_path', 'items') # 数组数据路径
item_variable = node_data.get('item_variable', 'item') # 循环变量名
# 从输入数据中获取数组
items = self._get_nested_value(input_data, items_path)
if not isinstance(items, list):
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type,
ValueError(f"路径 {items_path} 的值不是数组"), duration)
return {
'output': None,
'status': 'failed',
'error': f'路径 {items_path} 的值不是数组,当前类型: {type(items).__name__}'
}
if self.logger:
self.logger.info(f"循环节点开始处理 {len(items)} 个元素",
node_id=node_id, node_type=node_type,
data={"items_count": len(items)})
# 执行循环:对每个元素执行循环体
loop_results = []
for index, item in enumerate(items):
if self.logger:
self.logger.info(f"循环迭代 {index + 1}/{len(items)}",
node_id=node_id, node_type=node_type,
data={"index": index, "item": item})
# 准备循环体的输入数据
loop_input = {
**input_data, # 保留原始输入数据
item_variable: item, # 当前循环项
f'{item_variable}_index': index, # 索引
f'{item_variable}_total': len(items) # 总数
}
# 执行循环体(获取循环节点的子节点)
loop_body_result = await self._execute_loop_body(
node_id, loop_input, index
)
if loop_body_result.get('status') == 'success':
loop_results.append(loop_body_result.get('output', item))
else:
# 如果循环体执行失败,可以选择继续或停止
error_handling = node_data.get('error_handling', 'continue')
if error_handling == 'stop':
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type,
Exception(f"循环体执行失败,停止循环: {loop_body_result.get('error')}"), duration)
return {
'output': None,
'status': 'failed',
'error': f'循环体执行失败: {loop_body_result.get("error")}',
'completed_items': index,
'results': loop_results
}
else:
# continue: 继续执行,记录错误
if self.logger:
self.logger.warn(f"循环迭代 {index + 1} 失败,继续执行",
node_id=node_id, node_type=node_type,
data={"error": loop_body_result.get('error')})
loop_results.append(None)
exec_result = {
'output': loop_results,
'status': 'success',
'items_processed': len(items),
'results_count': len(loop_results)
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type,
{'results_count': len(loop_results)}, duration)
return exec_result
elif node_type == 'http' or node_type == 'request':
# HTTP请求节点发送HTTP请求
node_data = node.get('data', {})
url = node_data.get('url', '')
method = node_data.get('method', 'GET').upper()
headers = node_data.get('headers', {})
params = node_data.get('params', {})
body = node_data.get('body', {})
timeout = node_data.get('timeout', 30)
# 如果URL、headers、params、body中包含变量从input_data中替换
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
# 支持 {key} 或 ${key} 格式
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
# 替换URL中的变量
if url:
url = replace_variables(url, input_data)
# 替换headers中的变量
if isinstance(headers, dict):
headers = {k: replace_variables(str(v), input_data) for k, v in headers.items()}
elif isinstance(headers, str):
try:
headers = json.loads(replace_variables(headers, input_data))
except:
headers = {}
# 替换params中的变量
if isinstance(params, dict):
params = {k: replace_variables(str(v), input_data) if isinstance(v, str) else v
for k, v in params.items()}
elif isinstance(params, str):
try:
params = json.loads(replace_variables(params, input_data))
except:
params = {}
# 替换body中的变量
if isinstance(body, dict):
# 递归替换字典中的变量
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
new_k = replace_variables(k, data)
if isinstance(v, dict):
result[new_k] = replace_dict_vars(v, data)
elif isinstance(v, str):
result[new_k] = replace_variables(v, data)
else:
result[new_k] = v
return result
body = replace_dict_vars(body, input_data)
elif isinstance(body, str):
body = replace_variables(body, input_data)
try:
body = json.loads(body)
except:
pass
try:
import httpx
async with httpx.AsyncClient(timeout=timeout) as client:
if method == 'GET':
response = await client.get(url, params=params, headers=headers)
elif method == 'POST':
response = await client.post(url, json=body, params=params, headers=headers)
elif method == 'PUT':
response = await client.put(url, json=body, params=params, headers=headers)
elif method == 'DELETE':
response = await client.delete(url, params=params, headers=headers)
elif method == 'PATCH':
response = await client.patch(url, json=body, params=params, headers=headers)
else:
raise ValueError(f"不支持的HTTP方法: {method}")
# 尝试解析JSON响应
try:
response_data = response.json()
except:
response_data = response.text
result = {
'output': {
'status_code': response.status_code,
'headers': dict(response.headers),
'data': response_data
},
'status': 'success'
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'HTTP请求失败: {str(e)}'
}
elif node_type == 'database' or node_type == 'db':
# 数据库操作节点:执行数据库操作
node_data = node.get('data', {})
data_source_id = node_data.get('data_source_id')
operation = node_data.get('operation', 'query') # query/insert/update/delete
sql = node_data.get('sql', '')
table = node_data.get('table', '')
data = node_data.get('data', {})
where = node_data.get('where', {})
# 如果SQL中包含变量从input_data中替换
if sql and isinstance(sql, str):
import re
def replace_sql_vars(text: str, data: Dict[str, Any]) -> str:
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
if value is None:
return match.group(0)
# 如果是字符串需要转义SQL注入
if isinstance(value, str):
# 简单转义,实际应该使用参数化查询
escaped_value = value.replace("'", "''")
return f"'{escaped_value}'"
return str(value)
return re.sub(pattern, replacer, text)
sql = replace_sql_vars(sql, input_data)
try:
# 从数据库加载数据源配置
if not self.db:
raise ValueError("数据库会话未提供,无法执行数据库操作")
from app.models.data_source import DataSource
from app.services.data_source_connector import create_connector
data_source = self.db.query(DataSource).filter(
DataSource.id == data_source_id
).first()
if not data_source:
raise ValueError(f"数据源不存在: {data_source_id}")
connector = create_connector(data_source.type, data_source.config)
if operation == 'query':
# 查询操作
if not sql:
raise ValueError("查询操作需要提供SQL语句")
query_params = {'query': sql}
result_data = connector.query(query_params)
result = {'output': result_data, 'status': 'success'}
elif operation == 'insert':
# 插入操作
if not table:
raise ValueError("插入操作需要提供表名")
# 构建INSERT SQL
columns = ', '.join(data.keys())
# 处理字符串值,转义单引号
def escape_value(v):
if isinstance(v, str):
escaped = v.replace("'", "''")
return f"'{escaped}'"
return str(v)
values = ', '.join([escape_value(v) for v in data.values()])
insert_sql = f"INSERT INTO {table} ({columns}) VALUES ({values})"
query_params = {'query': insert_sql}
result_data = connector.query(query_params)
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
elif operation == 'update':
# 更新操作
if not table or not where:
raise ValueError("更新操作需要提供表名和WHERE条件")
set_clause = ', '.join([f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in data.items()])
where_clause = ' AND '.join([f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in where.items()])
update_sql = f"UPDATE {table} SET {set_clause} WHERE {where_clause}"
query_params = {'query': update_sql}
result_data = connector.query(query_params)
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
elif operation == 'delete':
# 删除操作
if not table or not where:
raise ValueError("删除操作需要提供表名和WHERE条件")
# 处理字符串值,转义单引号
def escape_sql_value(k, v):
if isinstance(v, str):
escaped = v.replace("'", "''")
return f"{k} = '{escaped}'"
return f"{k} = {v}"
where_clause = ' AND '.join([escape_sql_value(k, v) for k, v in where.items()])
delete_sql = f"DELETE FROM {table} WHERE {where_clause}"
query_params = {'query': delete_sql}
result_data = connector.query(query_params)
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
else:
raise ValueError(f"不支持的数据库操作: {operation}")
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'数据库操作失败: {str(e)}'
}
elif node_type == 'file' or node_type == 'file_operation':
# 文件操作节点:文件读取、写入、上传、下载
node_data = node.get('data', {})
operation = node_data.get('operation', 'read') # read/write/upload/download
file_path = node_data.get('file_path', '')
content = node_data.get('content', '')
encoding = node_data.get('encoding', 'utf-8')
# 替换文件路径和内容中的变量
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
if file_path:
file_path = replace_variables(file_path, input_data)
if isinstance(content, str):
content = replace_variables(content, input_data)
try:
import os
import json
import base64
from pathlib import Path
if operation == 'read':
# 读取文件
if not file_path:
raise ValueError("读取操作需要提供文件路径")
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
# 根据文件扩展名决定读取方式
file_ext = Path(file_path).suffix.lower()
if file_ext == '.json':
with open(file_path, 'r', encoding=encoding) as f:
data = json.load(f)
elif file_ext in ['.txt', '.md', '.log']:
with open(file_path, 'r', encoding=encoding) as f:
data = f.read()
else:
# 二进制文件返回base64编码
with open(file_path, 'rb') as f:
data = base64.b64encode(f.read()).decode('utf-8')
result = {'output': data, 'status': 'success'}
elif operation == 'write':
# 写入文件
if not file_path:
raise ValueError("写入操作需要提供文件路径")
# 确保目录存在
os.makedirs(os.path.dirname(file_path) if os.path.dirname(file_path) else '.', exist_ok=True)
# 如果content是字典或列表转换为JSON
if isinstance(content, (dict, list)):
content = json.dumps(content, ensure_ascii=False, indent=2)
# 根据文件扩展名决定写入方式
file_ext = Path(file_path).suffix.lower()
if file_ext == '.json':
with open(file_path, 'w', encoding=encoding) as f:
json.dump(json.loads(content) if isinstance(content, str) else content, f, ensure_ascii=False, indent=2)
else:
with open(file_path, 'w', encoding=encoding) as f:
f.write(str(content))
result = {'output': {'file_path': file_path, 'message': '文件写入成功'}, 'status': 'success'}
elif operation == 'upload':
# 文件上传从base64或URL上传
upload_type = node_data.get('upload_type', 'base64') # base64/url
target_path = node_data.get('target_path', '')
if upload_type == 'base64':
# 从输入数据中获取base64编码的文件内容
file_data = input_data.get('file_data') or input_data.get('content')
if not file_data:
raise ValueError("上传操作需要提供file_data或content字段")
# 解码base64
if isinstance(file_data, str):
file_bytes = base64.b64decode(file_data)
else:
file_bytes = file_data
# 写入目标路径
if not target_path:
raise ValueError("上传操作需要提供target_path")
os.makedirs(os.path.dirname(target_path) if os.path.dirname(target_path) else '.', exist_ok=True)
with open(target_path, 'wb') as f:
f.write(file_bytes)
result = {'output': {'file_path': target_path, 'message': '文件上传成功'}, 'status': 'success'}
else:
# URL上传下载后保存
import httpx
url = node_data.get('url', '')
if not url:
raise ValueError("URL上传需要提供url")
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
if not target_path:
# 从URL提取文件名
target_path = os.path.basename(url) or 'downloaded_file'
os.makedirs(os.path.dirname(target_path) if os.path.dirname(target_path) else '.', exist_ok=True)
with open(target_path, 'wb') as f:
f.write(response.content)
result = {'output': {'file_path': target_path, 'message': '文件下载并保存成功'}, 'status': 'success'}
elif operation == 'download':
# 文件下载返回base64编码或文件URL
download_format = node_data.get('download_format', 'base64') # base64/url
if not file_path:
raise ValueError("下载操作需要提供文件路径")
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
if download_format == 'base64':
# 返回base64编码
with open(file_path, 'rb') as f:
file_bytes = f.read()
file_base64 = base64.b64encode(file_bytes).decode('utf-8')
result = {'output': {'file_name': os.path.basename(file_path), 'content': file_base64, 'format': 'base64'}, 'status': 'success'}
else:
# 返回文件路径实际应用中可能需要生成临时URL
result = {'output': {'file_path': file_path, 'format': 'path'}, 'status': 'success'}
else:
raise ValueError(f"不支持的文件操作: {operation}")
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'文件操作失败: {str(e)}'
}
elif node_type == 'webhook':
# Webhook节点发送Webhook请求到外部系统
node_data = node.get('data', {})
url = node_data.get('url', '')
method = node_data.get('method', 'POST').upper()
headers = node_data.get('headers', {})
body = node_data.get('body', {})
timeout = node_data.get('timeout', 30)
# 如果URL、headers、body中包含变量从input_data中替换
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
# 替换URL中的变量
if url:
url = replace_variables(url, input_data)
# 替换headers中的变量
if isinstance(headers, dict):
headers = {k: replace_variables(str(v), input_data) for k, v in headers.items()}
elif isinstance(headers, str):
try:
headers = json.loads(replace_variables(headers, input_data))
except:
headers = {}
# 替换body中的变量
if isinstance(body, dict):
# 递归替换字典中的变量
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
new_k = replace_variables(k, data)
if isinstance(v, dict):
result[new_k] = replace_dict_vars(v, data)
elif isinstance(v, str):
result[new_k] = replace_variables(v, data)
else:
result[new_k] = v
return result
body = replace_dict_vars(body, input_data)
elif isinstance(body, str):
body = replace_variables(body, input_data)
try:
body = json.loads(body)
except:
pass
# 如果没有配置body默认使用input_data作为body
if not body:
body = input_data
try:
import httpx
async with httpx.AsyncClient(timeout=timeout) as client:
if method == 'GET':
response = await client.get(url, headers=headers)
elif method == 'POST':
response = await client.post(url, json=body, headers=headers)
elif method == 'PUT':
response = await client.put(url, json=body, headers=headers)
elif method == 'PATCH':
response = await client.patch(url, json=body, headers=headers)
else:
raise ValueError(f"Webhook不支持HTTP方法: {method}")
# 尝试解析JSON响应
try:
response_data = response.json()
except:
response_data = response.text
result = {
'output': {
'status_code': response.status_code,
'headers': dict(response.headers),
'data': response_data
},
'status': 'success'
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'Webhook请求失败: {str(e)}'
}
elif node_type == 'schedule' or node_type == 'delay' or node_type == 'timer':
# 定时任务节点:延迟执行或定时执行
node_data = node.get('data', {})
delay_type = node_data.get('delay_type', 'fixed') # fixed: 固定延迟, cron: cron表达式
delay_value = node_data.get('delay_value', 0) # 延迟值(秒)
delay_unit = node_data.get('delay_unit', 'seconds') # seconds, minutes, hours
# 计算实际延迟时间(毫秒)
if delay_unit == 'seconds':
delay_ms = int(delay_value * 1000)
elif delay_unit == 'minutes':
delay_ms = int(delay_value * 60 * 1000)
elif delay_unit == 'hours':
delay_ms = int(delay_value * 60 * 60 * 1000)
else:
delay_ms = int(delay_value * 1000)
# 如果延迟时间大于0则等待
if delay_ms > 0:
if self.logger:
self.logger.info(
f"定时任务节点等待 {delay_value} {delay_unit}",
node_id=node_id,
node_type=node_type,
data={'delay_ms': delay_ms, 'delay_value': delay_value, 'delay_unit': delay_unit}
)
await asyncio.sleep(delay_ms / 1000.0)
# 返回输入数据(定时节点只是延迟,不改变数据)
result = {'output': input_data, 'status': 'success', 'delay_ms': delay_ms}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
elif node_type == 'email' or node_type == 'mail':
# 邮件节点:发送邮件通知
node_data = node.get('data', {})
smtp_host = node_data.get('smtp_host', '')
smtp_port = node_data.get('smtp_port', 587)
smtp_user = node_data.get('smtp_user', '')
smtp_password = node_data.get('smtp_password', '')
use_tls = node_data.get('use_tls', True)
from_email = node_data.get('from_email', '')
to_email = node_data.get('to_email', '')
cc_email = node_data.get('cc_email', '')
bcc_email = node_data.get('bcc_email', '')
subject = node_data.get('subject', '')
body = node_data.get('body', '')
body_type = node_data.get('body_type', 'text') # text/html
attachments = node_data.get('attachments', []) # 附件列表
# 替换变量
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
# 替换所有配置中的变量
smtp_host = replace_variables(smtp_host, input_data)
smtp_user = replace_variables(smtp_user, input_data)
smtp_password = replace_variables(smtp_password, input_data)
from_email = replace_variables(from_email, input_data)
to_email = replace_variables(to_email, input_data)
cc_email = replace_variables(cc_email, input_data)
bcc_email = replace_variables(bcc_email, input_data)
subject = replace_variables(subject, input_data)
body = replace_variables(body, input_data)
# 验证必需参数
if not smtp_host:
raise ValueError("邮件节点需要配置SMTP服务器地址")
if not from_email:
raise ValueError("邮件节点需要配置发件人邮箱")
if not to_email:
raise ValueError("邮件节点需要配置收件人邮箱")
if not subject:
raise ValueError("邮件节点需要配置邮件主题")
try:
import aiosmtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.mime.base import MIMEBase
from email import encoders
import base64
import os
# 创建邮件消息
msg = MIMEMultipart('alternative')
msg['From'] = from_email
msg['To'] = to_email
if cc_email:
msg['Cc'] = cc_email
msg['Subject'] = subject
# 添加邮件正文
if body_type == 'html':
msg.attach(MIMEText(body, 'html', 'utf-8'))
else:
msg.attach(MIMEText(body, 'plain', 'utf-8'))
# 处理附件
for attachment in attachments:
if isinstance(attachment, dict):
file_path = attachment.get('file_path', '')
file_name = attachment.get('file_name', '')
file_content = attachment.get('file_content', '') # base64编码的内容
# 替换变量
file_path = replace_variables(file_path, input_data)
file_name = replace_variables(file_name, input_data)
if file_path and os.path.exists(file_path):
# 从文件路径读取
with open(file_path, 'rb') as f:
file_data = f.read()
if not file_name:
file_name = os.path.basename(file_path)
elif file_content:
# 从base64内容读取
file_data = base64.b64decode(file_content)
if not file_name:
file_name = 'attachment'
else:
continue
# 添加附件
part = MIMEBase('application', 'octet-stream')
part.set_payload(file_data)
encoders.encode_base64(part)
part.add_header(
'Content-Disposition',
f'attachment; filename= {file_name}'
)
msg.attach(part)
# 发送邮件
recipients = [to_email]
if cc_email:
recipients.extend([email.strip() for email in cc_email.split(',')])
if bcc_email:
recipients.extend([email.strip() for email in bcc_email.split(',')])
async with aiosmtplib.SMTP(hostname=smtp_host, port=smtp_port) as smtp:
if use_tls:
await smtp.starttls()
if smtp_user and smtp_password:
await smtp.login(smtp_user, smtp_password)
await smtp.send_message(msg, recipients=recipients)
result = {
'output': {
'message': '邮件发送成功',
'from': from_email,
'to': to_email,
'subject': subject,
'recipients_count': len(recipients)
},
'status': 'success'
}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'邮件发送失败: {str(e)}'
}
elif node_type == 'message_queue' or node_type == 'mq' or node_type == 'rabbitmq' or node_type == 'kafka':
# 消息队列节点发送消息到RabbitMQ或Kafka
node_data = node.get('data', {})
queue_type = node_data.get('queue_type', 'rabbitmq') # rabbitmq/kafka
# 替换变量
import re
def replace_variables(text: str, data: Dict[str, Any]) -> str:
"""替换字符串中的变量占位符"""
if not isinstance(text, str):
return text
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
def replacer(match):
key = match.group(1) or match.group(2)
value = self._get_nested_value(data, key)
return str(value) if value is not None else match.group(0)
return re.sub(pattern, replacer, text)
try:
if queue_type == 'rabbitmq':
# RabbitMQ实现
import aio_pika
import json
# 获取RabbitMQ配置
host = replace_variables(node_data.get('host', 'localhost'), input_data)
port = node_data.get('port', 5672)
username = replace_variables(node_data.get('username', 'guest'), input_data)
password = replace_variables(node_data.get('password', 'guest'), input_data)
exchange = replace_variables(node_data.get('exchange', ''), input_data)
routing_key = replace_variables(node_data.get('routing_key', ''), input_data)
queue_name = replace_variables(node_data.get('queue_name', ''), input_data)
message = node_data.get('message', input_data)
# 如果message是字符串尝试替换变量
if isinstance(message, str):
message = replace_variables(message, input_data)
try:
message = json.loads(message)
except:
pass
elif isinstance(message, dict):
# 递归替换字典中的变量
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
new_k = replace_variables(k, data)
if isinstance(v, dict):
result[new_k] = replace_dict_vars(v, data)
elif isinstance(v, str):
result[new_k] = replace_variables(v, data)
else:
result[new_k] = v
return result
message = replace_dict_vars(message, input_data)
# 如果没有配置message使用input_data
if not message:
message = input_data
# 连接RabbitMQ
connection_url = f"amqp://{username}:{password}@{host}:{port}/"
connection = await aio_pika.connect_robust(connection_url)
channel = await connection.channel()
# 发送消息
message_body = json.dumps(message, ensure_ascii=False).encode('utf-8')
if exchange:
# 使用exchange和routing_key
await channel.default_exchange.publish(
aio_pika.Message(message_body),
routing_key=routing_key or queue_name
)
elif queue_name:
# 直接发送到队列
queue = await channel.declare_queue(queue_name, durable=True)
await channel.default_exchange.publish(
aio_pika.Message(message_body),
routing_key=queue_name
)
else:
raise ValueError("RabbitMQ节点需要配置exchange或queue_name")
await connection.close()
result = {
'output': {
'message': '消息已发送到RabbitMQ',
'queue_type': 'rabbitmq',
'exchange': exchange,
'routing_key': routing_key or queue_name,
'queue_name': queue_name,
'message_size': len(message_body)
},
'status': 'success'
}
elif queue_type == 'kafka':
# Kafka实现
from kafka import KafkaProducer
import json
# 获取Kafka配置
bootstrap_servers = replace_variables(node_data.get('bootstrap_servers', 'localhost:9092'), input_data)
topic = replace_variables(node_data.get('topic', ''), input_data)
message = node_data.get('message', input_data)
# 如果message是字符串尝试替换变量
if isinstance(message, str):
message = replace_variables(message, input_data)
try:
message = json.loads(message)
except:
pass
elif isinstance(message, dict):
# 递归替换字典中的变量
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
new_k = replace_variables(k, data)
if isinstance(v, dict):
result[new_k] = replace_dict_vars(v, data)
elif isinstance(v, str):
result[new_k] = replace_variables(v, data)
else:
result[new_k] = v
return result
message = replace_dict_vars(message, input_data)
# 如果没有配置message使用input_data
if not message:
message = input_data
if not topic:
raise ValueError("Kafka节点需要配置topic")
# 创建Kafka生产者注意kafka-python是同步的需要在线程池中运行
import asyncio
from concurrent.futures import ThreadPoolExecutor
def send_kafka_message():
producer = KafkaProducer(
bootstrap_servers=bootstrap_servers.split(','),
value_serializer=lambda v: json.dumps(v, ensure_ascii=False).encode('utf-8')
)
future = producer.send(topic, message)
record_metadata = future.get(timeout=10)
producer.close()
return record_metadata
# 在线程池中执行同步操作
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
record_metadata = await loop.run_in_executor(executor, send_kafka_message)
result = {
'output': {
'message': '消息已发送到Kafka',
'queue_type': 'kafka',
'topic': topic,
'partition': record_metadata.partition,
'offset': record_metadata.offset
},
'status': 'success'
}
else:
raise ValueError(f"不支持的消息队列类型: {queue_type}")
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
return result
except Exception as e:
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': f'消息队列发送失败: {str(e)}'
}
elif node_type == 'output' or node_type == 'end':
# 输出节点:返回最终结果
# 对于人机交互场景End节点应该返回纯文本字符串而不是JSON
logger.debug(f"[rjb] End节点处理: node_id={node_id}, input_data={input_data}, input_data type={type(input_data)}")
final_output = input_data
# 递归解包,提取实际的文本内容
if isinstance(input_data, dict):
# 如果只有一个 key 且是 'input',提取其值
if len(input_data) == 1 and 'input' in input_data:
final_output = input_data['input']
logger.debug(f"[rjb] End节点提取第一层: final_output={final_output}, type={type(final_output)}")
# 如果提取的值仍然是字典且只有一个 'input' key继续提取
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
final_output = final_output['input']
logger.debug(f"[rjb] End节点提取第二层: final_output={final_output}, type={type(final_output)}")
# 确保最终输出是字符串(对于人机交互场景)
# 如果是字典,尝试转换为字符串;如果是其他类型,也转换为字符串
if not isinstance(final_output, str):
if isinstance(final_output, dict):
# 如果是字典尝试提取文本内容或转换为JSON字符串
# 优先查找常见的文本字段
if 'text' in final_output:
final_output = str(final_output['text'])
elif 'content' in final_output:
final_output = str(final_output['content'])
elif 'message' in final_output:
final_output = str(final_output['message'])
elif 'response' in final_output:
final_output = str(final_output['response'])
elif len(final_output) == 1:
# 如果只有一个key直接使用其值
final_output = str(list(final_output.values())[0])
else:
# 否则转换为JSON字符串
final_output = json_module.dumps(final_output, ensure_ascii=False)
else:
final_output = str(final_output)
logger.debug(f"[rjb] End节点最终输出: final_output={final_output}, type={type(final_output)}")
result = {'output': final_output, 'status': 'success'}
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_complete(node_id, node_type, final_output, duration)
return result
else:
# 未知节点类型
return {
'output': input_data,
'status': 'success',
'message': f'节点类型 {node_type} 暂未实现'
}
except Exception as e:
logger.error(f"节点执行失败: {node_id} ({node_type}) - {str(e)}", exc_info=True)
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)
return {
'output': None,
'status': 'failed',
'error': str(e),
'node_id': node_id,
'node_type': node_type
}
async def execute(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
"""
执行完整工作流
Args:
input_data: 初始输入数据
Returns:
执行结果
"""
# 记录工作流开始执行
if self.logger:
self.logger.info("工作流开始执行", data={"input": input_data})
# 初始化节点输出
self.node_outputs = {}
active_edges = self.edges.copy() # 活跃的边列表
executed_nodes = set() # 已执行的节点
# 按拓扑顺序执行节点(动态构建执行图)
results = {}
while True:
# 构建当前活跃的执行图
execution_order = self.build_execution_graph(active_edges)
# 找到下一个要执行的节点未执行且入度为0
next_node_id = None
for node_id in execution_order:
if node_id not in executed_nodes:
# 检查所有前置节点是否已执行
can_execute = True
for edge in active_edges:
if edge['target'] == node_id:
if edge['source'] not in executed_nodes:
can_execute = False
break
if can_execute:
next_node_id = node_id
break
if not next_node_id:
break # 没有更多节点可执行
node = self.nodes[next_node_id]
executed_nodes.add(next_node_id)
# 调试:检查节点数据结构
if node.get('type') == 'llm':
logger.debug(f"[rjb] 执行LLM节点: node_id={next_node_id}, node keys={list(node.keys())}, data keys={list(node.get('data', {}).keys()) if node.get('data') else []}")
# 获取节点输入(使用活跃的边)
node_input = self.get_node_input(next_node_id, self.node_outputs, active_edges)
# 如果是起始节点,使用初始输入
if node.get('type') == 'start' and not node_input:
node_input = input_data
# 调试:记录节点输入数据
if node.get('type') == 'llm':
logger.debug(f"[rjb] LLM节点输入: node_id={next_node_id}, node_input={node_input}, node_outputs keys={list(self.node_outputs.keys())}")
# 执行节点
result = await self.execute_node(node, node_input)
results[next_node_id] = result
# 保存节点输出
if result.get('status') == 'success':
self.node_outputs[next_node_id] = result.get('output', {})
# 如果是条件节点,根据分支结果过滤边
if node.get('type') == 'condition':
branch = result.get('branch', 'false')
# 移除不符合条件的边
active_edges = [
edge for edge in active_edges
if not (edge['source'] == next_node_id and edge.get('sourceHandle') != branch)
]
# 如果是循环节点,跳过循环体的节点(循环体已在节点内部执行)
if node.get('type') in ['loop', 'foreach']:
# 标记循环体的节点为已执行(简化处理)
for edge in active_edges[:]: # 使用切片复制列表
if edge.get('source') == next_node_id:
target_id = edge.get('target')
if target_id in self.nodes:
# 检查是否是循环结束节点
target_node = self.nodes[target_id]
if target_node.get('type') not in ['loop_end', 'end']:
# 标记为已执行(循环体已在循环节点内部执行)
executed_nodes.add(target_id)
# 继续查找循环体内的节点
self._mark_loop_body_executed(target_id, executed_nodes, active_edges)
else:
# 执行失败,停止工作流
error_msg = result.get('error', '未知错误')
node_type = node.get('type', 'unknown')
logger.error(f"工作流执行失败 - 节点: {next_node_id} ({node_type}), 错误: {error_msg}")
raise WorkflowExecutionError(
detail=error_msg,
node_id=next_node_id
)
# 返回最终结果(最后一个执行的节点的输出)
if executed_nodes:
# 找到最后一个节点(没有出边的节点)
last_node_id = None
for node_id in executed_nodes:
has_outgoing = any(edge['source'] == node_id for edge in active_edges)
if not has_outgoing:
last_node_id = node_id
break
if not last_node_id:
# 如果没有找到,使用最后一个执行的节点
last_node_id = list(executed_nodes)[-1]
# 获取最终结果
final_output = self.node_outputs.get(last_node_id)
# 如果最终输出是字典且只有一个 'input' key提取其值
# 这样可以确保最终结果不是重复包装的格式
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
final_output = final_output['input']
# 如果提取的值仍然是字典且只有一个 'input' key继续提取
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
final_output = final_output['input']
# 确保最终结果是字符串(对于人机交互场景)
# 如果是字典,尝试转换为字符串;如果是其他类型,也转换为字符串
if not isinstance(final_output, str):
if isinstance(final_output, dict):
# 如果是字典尝试提取文本内容或转换为JSON字符串
# 优先查找常见的文本字段
if 'text' in final_output:
final_output = str(final_output['text'])
elif 'content' in final_output:
final_output = str(final_output['content'])
elif 'message' in final_output:
final_output = str(final_output['message'])
elif 'response' in final_output:
final_output = str(final_output['response'])
elif len(final_output) == 1:
# 如果只有一个key直接使用其值
final_output = str(list(final_output.values())[0])
else:
# 否则转换为JSON字符串
final_output = json_module.dumps(final_output, ensure_ascii=False)
else:
final_output = str(final_output)
final_result = {
'status': 'completed',
'result': final_output,
'node_results': results
}
# 记录工作流执行完成
if self.logger:
self.logger.info("工作流执行完成", data={"result": final_result.get('result')})
return final_result
if self.logger:
self.logger.warn("工作流执行完成,但没有执行任何节点")
return {'status': 'completed', 'result': None}