处理agent答非所问的问题
This commit is contained in:
@@ -116,7 +116,7 @@ class WorkflowEngine:
|
||||
if edge['target'] == node_id:
|
||||
source_id = edge['source']
|
||||
source_output = node_outputs.get(source_id, {})
|
||||
logger.debug(f"[rjb] 获取节点输入: target={node_id}, source={source_id}, source_output={source_output}, sourceHandle={edge.get('sourceHandle')}")
|
||||
logger.info(f"[rjb] 获取节点输入: target={node_id}, source={source_id}, source_output={source_output}, source_output_type={type(source_output)}, sourceHandle={edge.get('sourceHandle')}")
|
||||
|
||||
# 如果有sourceHandle,使用它作为key
|
||||
if 'sourceHandle' in edge and edge['sourceHandle']:
|
||||
@@ -133,9 +133,13 @@ class WorkflowEngine:
|
||||
if key != 'output':
|
||||
input_data[key] = value
|
||||
else:
|
||||
# 直接展开source_output的内容
|
||||
input_data.update(source_output)
|
||||
logger.info(f"[rjb] 展开source_output后: input_data={input_data}")
|
||||
else:
|
||||
# 如果source_output不是字典,包装到input字段
|
||||
input_data['input'] = source_output
|
||||
logger.info(f"[rjb] source_output不是字典,包装到input字段: input_data={input_data}")
|
||||
|
||||
# 如果input_data中没有query字段,尝试从所有已执行的节点中查找(特别是start节点)
|
||||
if 'query' not in input_data:
|
||||
@@ -427,9 +431,13 @@ class WorkflowEngine:
|
||||
# 支持两种格式的变量:{key} 和 {{key}}
|
||||
formatted_prompt = prompt
|
||||
has_unfilled_variables = False
|
||||
has_any_placeholder = False
|
||||
|
||||
import re
|
||||
# 检查是否有任何占位符
|
||||
has_any_placeholder = bool(re.search(r'\{\{?\w+\}?\}', prompt))
|
||||
|
||||
# 首先处理 {{variable}} 格式(模板节点常用)
|
||||
import re
|
||||
double_brace_vars = re.findall(r'\{\{(\w+)\}\}', prompt)
|
||||
for var_name in double_brace_vars:
|
||||
if var_name in input_data:
|
||||
@@ -456,17 +464,108 @@ class WorkflowEngine:
|
||||
json_module.dumps(input_data, ensure_ascii=False)
|
||||
)
|
||||
|
||||
# 如果仍有未填充的变量({{variable}}格式),将用户输入作为上下文附加
|
||||
if has_unfilled_variables or re.search(r'\{\{(\w+)\}\}', formatted_prompt):
|
||||
# 提取用户的实际查询内容
|
||||
user_query = input_data.get('query', input_data.get('input', input_data.get('text', '')))
|
||||
if not user_query and isinstance(input_data, dict):
|
||||
# 如果没有明确的query字段,尝试从整个input_data中提取文本内容
|
||||
user_query = json_module.dumps(input_data, ensure_ascii=False)
|
||||
# 提取用户的实际查询内容(优先提取)
|
||||
user_query = None
|
||||
logger.info(f"[rjb] 开始提取user_query: input_data={input_data}, input_data_type={type(input_data)}")
|
||||
if isinstance(input_data, dict):
|
||||
# 首先检查是否有嵌套的input字段
|
||||
nested_input = input_data.get('input')
|
||||
logger.info(f"[rjb] 检查嵌套input: nested_input={nested_input}, nested_input_type={type(nested_input) if nested_input else None}")
|
||||
if isinstance(nested_input, dict):
|
||||
# 从嵌套的input中提取
|
||||
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
|
||||
if key in nested_input:
|
||||
user_query = nested_input[key]
|
||||
logger.info(f"[rjb] 从嵌套input中提取到user_query: key={key}, user_query={user_query}")
|
||||
break
|
||||
|
||||
# 如果还没有,从顶层提取
|
||||
if not user_query:
|
||||
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
|
||||
if key in input_data:
|
||||
value = input_data[key]
|
||||
logger.info(f"[rjb] 从顶层提取: key={key}, value={value}, value_type={type(value)}")
|
||||
# 如果值是字符串,直接使用
|
||||
if isinstance(value, str):
|
||||
user_query = value
|
||||
logger.info(f"[rjb] 提取到字符串user_query: {user_query}")
|
||||
break
|
||||
# 如果值是字典,尝试从中提取
|
||||
elif isinstance(value, dict):
|
||||
for sub_key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
|
||||
if sub_key in value:
|
||||
user_query = value[sub_key]
|
||||
logger.info(f"[rjb] 从字典值中提取到user_query: sub_key={sub_key}, user_query={user_query}")
|
||||
break
|
||||
if user_query:
|
||||
break
|
||||
|
||||
# 如果还是没有,使用整个input_data(但排除系统字段)
|
||||
if not user_query:
|
||||
filtered_data = {k: v for k, v in input_data.items() if not k.startswith('_')}
|
||||
logger.info(f"[rjb] 使用filtered_data: filtered_data={filtered_data}")
|
||||
if filtered_data:
|
||||
# 如果只有一个字段且是字符串,直接使用
|
||||
if len(filtered_data) == 1:
|
||||
single_value = list(filtered_data.values())[0]
|
||||
if isinstance(single_value, str):
|
||||
user_query = single_value
|
||||
logger.info(f"[rjb] 从单个字符串字段提取到user_query: {user_query}")
|
||||
elif isinstance(single_value, dict):
|
||||
# 从字典中提取第一个字符串值
|
||||
for v in single_value.values():
|
||||
if isinstance(v, str):
|
||||
user_query = v
|
||||
logger.info(f"[rjb] 从字典的单个字段中提取到user_query: {user_query}")
|
||||
break
|
||||
if not user_query:
|
||||
user_query = json_module.dumps(filtered_data, ensure_ascii=False) if len(filtered_data) > 1 else str(list(filtered_data.values())[0])
|
||||
logger.info(f"[rjb] 使用JSON或字符串转换: user_query={user_query}")
|
||||
|
||||
logger.info(f"[rjb] 最终提取的user_query: {user_query}")
|
||||
|
||||
# 如果prompt中没有占位符,或者仍有未填充的变量,将用户输入附加到prompt
|
||||
is_generic_instruction = False # 初始化变量
|
||||
if not has_any_placeholder:
|
||||
# 如果prompt中没有占位符,将用户输入作为主要内容
|
||||
if user_query:
|
||||
# 判断是否是通用指令:简短且不包含具体任务描述
|
||||
prompt_stripped = prompt.strip()
|
||||
is_generic_instruction = (
|
||||
len(prompt_stripped) < 30 or # 简短提示词
|
||||
prompt_stripped in [
|
||||
"请处理用户请求。", "请处理用户请求",
|
||||
"请处理以下输入数据:", "请处理以下输入数据",
|
||||
"请处理输入。", "请处理输入",
|
||||
"处理用户请求", "处理请求",
|
||||
"请回答用户问题", "请回答用户问题。",
|
||||
"请帮助用户", "请帮助用户。"
|
||||
] or
|
||||
# 检查是否只包含通用指令关键词
|
||||
(len(prompt_stripped) < 50 and any(keyword in prompt_stripped for keyword in [
|
||||
"请处理", "处理", "请回答", "回答", "请帮助", "帮助", "请执行", "执行"
|
||||
]) and not any(specific in prompt_stripped for specific in [
|
||||
"翻译", "生成", "分析", "总结", "提取", "转换", "计算"
|
||||
]))
|
||||
)
|
||||
|
||||
if is_generic_instruction:
|
||||
# 如果是通用指令,直接使用用户输入作为prompt
|
||||
formatted_prompt = str(user_query)
|
||||
logger.info(f"[rjb] 检测到通用指令,直接使用用户输入作为prompt: {user_query[:50] if user_query else 'None'}")
|
||||
else:
|
||||
# 否则,将用户输入附加到prompt
|
||||
formatted_prompt = f"{formatted_prompt}\n\n{user_query}"
|
||||
logger.info(f"[rjb] 非通用指令,将用户输入附加到prompt")
|
||||
else:
|
||||
# 如果没有提取到用户查询,附加整个input_data
|
||||
formatted_prompt = f"{formatted_prompt}\n\n{json_module.dumps(input_data, ensure_ascii=False)}"
|
||||
elif has_unfilled_variables or re.search(r'\{\{(\w+)\}\}', formatted_prompt):
|
||||
# 如果有占位符但未填充,附加用户需求说明
|
||||
if user_query:
|
||||
formatted_prompt = f"{formatted_prompt}\n\n用户需求:{user_query}\n\n请根据以上用户需求,忽略未填充的变量占位符(如{{{{variable}}}}),直接基于用户需求来完成任务。"
|
||||
|
||||
logger.info(f"[rjb] LLM节点prompt格式化: node_id={node_id}, original_prompt='{prompt[:50] if len(prompt) > 50 else prompt}', has_any_placeholder={has_any_placeholder}, user_query={user_query}, is_generic_instruction={is_generic_instruction}, final_prompt前200字符='{formatted_prompt[:200] if len(formatted_prompt) > 200 else formatted_prompt}'")
|
||||
prompt = formatted_prompt
|
||||
else:
|
||||
# 如果input_data不是dict,直接转换为字符串
|
||||
@@ -510,6 +609,9 @@ class WorkflowEngine:
|
||||
api_key = None
|
||||
base_url = None
|
||||
|
||||
# 记录实际发送给LLM的prompt
|
||||
logger.info(f"[rjb] 准备调用LLM: node_id={node_id}, provider={provider}, model={model}, prompt前200字符='{prompt[:200] if len(prompt) > 200 else prompt}'")
|
||||
|
||||
# 调用LLM服务
|
||||
try:
|
||||
if self.logger:
|
||||
@@ -1677,24 +1779,28 @@ class WorkflowEngine:
|
||||
final_output = str(list(final_output.values())[0])
|
||||
else:
|
||||
# 否则转换为纯文本(不是JSON)
|
||||
# 尝试提取所有文本字段并组合,但排除系统字段
|
||||
# 尝试提取所有文本字段并组合,但排除系统字段和用户查询字段
|
||||
text_parts = []
|
||||
exclude_keys = {'status', 'error', 'timestamp', 'node_id', 'execution_time'}
|
||||
for key, value in final_output.items():
|
||||
if key in exclude_keys:
|
||||
continue
|
||||
if isinstance(value, str) and value.strip():
|
||||
# 如果值本身已经包含 "key: " 格式,直接使用
|
||||
if value.strip().startswith(f"{key}:"):
|
||||
text_parts.append(value.strip())
|
||||
else:
|
||||
text_parts.append(value.strip())
|
||||
elif isinstance(value, (int, float, bool)):
|
||||
text_parts.append(f"{key}: {value}")
|
||||
if text_parts:
|
||||
final_output = "\n".join(text_parts)
|
||||
exclude_keys = {'status', 'error', 'timestamp', 'node_id', 'execution_time', 'query', 'USER_INPUT', 'user_input', 'user_query'}
|
||||
# 优先使用input字段(LLM的实际输出)
|
||||
if 'input' in final_output and isinstance(final_output['input'], str):
|
||||
final_output = final_output['input']
|
||||
else:
|
||||
final_output = str(final_output)
|
||||
for key, value in final_output.items():
|
||||
if key in exclude_keys:
|
||||
continue
|
||||
if isinstance(value, str) and value.strip():
|
||||
# 如果值本身已经包含 "key: " 格式,直接使用
|
||||
if value.strip().startswith(f"{key}:"):
|
||||
text_parts.append(value.strip())
|
||||
else:
|
||||
text_parts.append(value.strip())
|
||||
elif isinstance(value, (int, float, bool)):
|
||||
text_parts.append(f"{key}: {value}")
|
||||
if text_parts:
|
||||
final_output = "\n".join(text_parts)
|
||||
else:
|
||||
final_output = str(final_output)
|
||||
else:
|
||||
final_output = str(final_output)
|
||||
|
||||
@@ -1808,10 +1914,13 @@ class WorkflowEngine:
|
||||
# 如果是起始节点,使用初始输入
|
||||
if node.get('type') == 'start' and not node_input:
|
||||
node_input = input_data
|
||||
logger.info(f"[rjb] Start节点使用初始输入: node_id={next_node_id}, node_input={node_input}")
|
||||
|
||||
# 调试:记录节点输入数据
|
||||
if node.get('type') == 'llm':
|
||||
logger.debug(f"[rjb] LLM节点输入: node_id={next_node_id}, node_input={node_input}, node_outputs keys={list(self.node_outputs.keys())}")
|
||||
logger.info(f"[rjb] LLM节点输入: node_id={next_node_id}, node_input={node_input}, node_outputs keys={list(self.node_outputs.keys())}")
|
||||
if 'start-1' in self.node_outputs:
|
||||
logger.info(f"[rjb] Start节点输出内容: {self.node_outputs['start-1']}")
|
||||
|
||||
# 执行节点
|
||||
result = await self.execute_node(node, node_input)
|
||||
@@ -1819,7 +1928,10 @@ class WorkflowEngine:
|
||||
|
||||
# 保存节点输出
|
||||
if result.get('status') == 'success':
|
||||
self.node_outputs[next_node_id] = result.get('output', {})
|
||||
output_value = result.get('output', {})
|
||||
self.node_outputs[next_node_id] = output_value
|
||||
if node.get('type') == 'start':
|
||||
logger.info(f"[rjb] Start节点输出已保存: node_id={next_node_id}, output={output_value}, output_type={type(output_value)}")
|
||||
|
||||
# 如果是条件节点,根据分支结果过滤边
|
||||
if node.get('type') == 'condition':
|
||||
@@ -1898,6 +2010,7 @@ class WorkflowEngine:
|
||||
final_output = str(list(final_output.values())[0])
|
||||
else:
|
||||
# 否则转换为JSON字符串
|
||||
import json as json_module
|
||||
final_output = json_module.dumps(final_output, ensure_ascii=False)
|
||||
else:
|
||||
final_output = str(final_output)
|
||||
|
||||
Reference in New Issue
Block a user