处理agent答非所问的问题
This commit is contained in:
169
test_workflow_data_flow.py
Executable file
169
test_workflow_data_flow.py
Executable file
@@ -0,0 +1,169 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
工作流数据流转测试脚本
|
||||
用于诊断"答非所问"问题
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'backend'))
|
||||
|
||||
from app.services.workflow_engine import WorkflowEngine
|
||||
from app.core.database import SessionLocal
|
||||
|
||||
|
||||
def print_section(title):
|
||||
"""打印分隔线"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f" {title}")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def print_data(label, data, indent=0):
|
||||
"""格式化打印数据"""
|
||||
prefix = " " * indent
|
||||
print(f"{prefix}{label}:")
|
||||
if isinstance(data, dict):
|
||||
print(f"{prefix} {json.dumps(data, ensure_ascii=False, indent=2)}")
|
||||
else:
|
||||
print(f"{prefix} {data}")
|
||||
|
||||
|
||||
async def test_workflow_data_flow():
|
||||
"""测试工作流数据流转"""
|
||||
print_section("工作流数据流转测试")
|
||||
|
||||
# 模拟一个简单的工作流
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start-1",
|
||||
"type": "start",
|
||||
"position": {"x": 100, "y": 100},
|
||||
"data": {
|
||||
"label": "开始"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "llm-1",
|
||||
"type": "llm",
|
||||
"position": {"x": 300, "y": 100},
|
||||
"data": {
|
||||
"label": "LLM处理",
|
||||
"provider": "deepseek",
|
||||
"model": "deepseek-chat",
|
||||
"prompt": "请处理用户请求。",
|
||||
"temperature": 0.5,
|
||||
"max_tokens": 1500
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "end-1",
|
||||
"type": "end",
|
||||
"position": {"x": 500, "y": 100},
|
||||
"data": {
|
||||
"label": "结束",
|
||||
"output_format": "text"
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "source": "start-1", "target": "llm-1"},
|
||||
{"id": "e2", "source": "llm-1", "target": "end-1"}
|
||||
]
|
||||
}
|
||||
|
||||
# 模拟前端发送的输入数据
|
||||
input_data = {
|
||||
"query": "苹果英语怎么讲?",
|
||||
"USER_INPUT": "苹果英语怎么讲?"
|
||||
}
|
||||
|
||||
print_section("1. 初始输入数据")
|
||||
print_data("input_data", input_data)
|
||||
|
||||
# 创建引擎(不使用logger,避免数据库依赖)
|
||||
engine = WorkflowEngine("test-workflow", workflow_data)
|
||||
|
||||
# 重写get_node_input方法,添加详细日志
|
||||
original_get_node_input = engine.get_node_input
|
||||
|
||||
def logged_get_node_input(node_id, node_outputs, active_edges=None):
|
||||
print_section(f"获取节点输入: {node_id}")
|
||||
print_data("node_outputs", node_outputs)
|
||||
result = original_get_node_input(node_id, node_outputs, active_edges)
|
||||
print_data(f"返回的input_data (for {node_id})", result)
|
||||
return result
|
||||
|
||||
engine.get_node_input = logged_get_node_input
|
||||
|
||||
# 重写execute_node方法,添加详细日志
|
||||
original_execute_node = engine.execute_node
|
||||
|
||||
async def logged_execute_node(node, input_data):
|
||||
node_id = node.get('id')
|
||||
node_type = node.get('type')
|
||||
|
||||
print_section(f"执行节点: {node_id} ({node_type})")
|
||||
print_data("节点配置", node.get('data', {}))
|
||||
print_data("输入数据", input_data)
|
||||
|
||||
result = await original_execute_node(node, input_data)
|
||||
|
||||
print_data("执行结果", result)
|
||||
|
||||
# 如果是LLM节点,特别关注prompt和输出
|
||||
if node_type == 'llm':
|
||||
print_section(f"LLM节点详细分析: {node_id}")
|
||||
node_data = node.get('data', {})
|
||||
prompt = node_data.get('prompt', '')
|
||||
print_data("原始prompt", prompt)
|
||||
print_data("输入数据", input_data)
|
||||
|
||||
# 模拟prompt格式化逻辑
|
||||
if isinstance(input_data, dict):
|
||||
# 检查是否有嵌套的input字段
|
||||
nested_input = input_data.get('input')
|
||||
if isinstance(nested_input, dict):
|
||||
print("⚠️ 发现嵌套的input字段!")
|
||||
print_data("嵌套input内容", nested_input)
|
||||
# 尝试提取user_query
|
||||
user_query = None
|
||||
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
|
||||
if key in nested_input:
|
||||
user_query = nested_input[key]
|
||||
print(f"✅ 从嵌套input中提取到user_query: {key} = {user_query}")
|
||||
break
|
||||
else:
|
||||
# 从顶层提取
|
||||
user_query = None
|
||||
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
|
||||
if key in input_data:
|
||||
value = input_data[key]
|
||||
if isinstance(value, str):
|
||||
user_query = value
|
||||
print(f"✅ 从顶层提取到user_query: {key} = {user_query}")
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
engine.execute_node = logged_execute_node
|
||||
|
||||
# 执行工作流
|
||||
print_section("开始执行工作流")
|
||||
try:
|
||||
result = await engine.execute(input_data)
|
||||
print_section("工作流执行完成")
|
||||
print_data("最终结果", result)
|
||||
except Exception as e:
|
||||
print_section("执行出错")
|
||||
print(f"错误: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_workflow_data_flow())
|
||||
Reference in New Issue
Block a user