170 lines
5.6 KiB
Python
170 lines
5.6 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
工作流数据流转测试脚本
|
|||
|
|
用于诊断"答非所问"问题
|
|||
|
|
"""
|
|||
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
# 添加项目路径
|
|||
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'backend'))
|
|||
|
|
|
|||
|
|
from app.services.workflow_engine import WorkflowEngine
|
|||
|
|
from app.core.database import SessionLocal
|
|||
|
|
|
|||
|
|
|
|||
|
|
def print_section(title):
|
|||
|
|
"""打印分隔线"""
|
|||
|
|
print("\n" + "=" * 80)
|
|||
|
|
print(f" {title}")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def print_data(label, data, indent=0):
|
|||
|
|
"""格式化打印数据"""
|
|||
|
|
prefix = " " * indent
|
|||
|
|
print(f"{prefix}{label}:")
|
|||
|
|
if isinstance(data, dict):
|
|||
|
|
print(f"{prefix} {json.dumps(data, ensure_ascii=False, indent=2)}")
|
|||
|
|
else:
|
|||
|
|
print(f"{prefix} {data}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def test_workflow_data_flow():
|
|||
|
|
"""测试工作流数据流转"""
|
|||
|
|
print_section("工作流数据流转测试")
|
|||
|
|
|
|||
|
|
# 模拟一个简单的工作流
|
|||
|
|
workflow_data = {
|
|||
|
|
"nodes": [
|
|||
|
|
{
|
|||
|
|
"id": "start-1",
|
|||
|
|
"type": "start",
|
|||
|
|
"position": {"x": 100, "y": 100},
|
|||
|
|
"data": {
|
|||
|
|
"label": "开始"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"id": "llm-1",
|
|||
|
|
"type": "llm",
|
|||
|
|
"position": {"x": 300, "y": 100},
|
|||
|
|
"data": {
|
|||
|
|
"label": "LLM处理",
|
|||
|
|
"provider": "deepseek",
|
|||
|
|
"model": "deepseek-chat",
|
|||
|
|
"prompt": "请处理用户请求。",
|
|||
|
|
"temperature": 0.5,
|
|||
|
|
"max_tokens": 1500
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"id": "end-1",
|
|||
|
|
"type": "end",
|
|||
|
|
"position": {"x": 500, "y": 100},
|
|||
|
|
"data": {
|
|||
|
|
"label": "结束",
|
|||
|
|
"output_format": "text"
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"edges": [
|
|||
|
|
{"id": "e1", "source": "start-1", "target": "llm-1"},
|
|||
|
|
{"id": "e2", "source": "llm-1", "target": "end-1"}
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 模拟前端发送的输入数据
|
|||
|
|
input_data = {
|
|||
|
|
"query": "苹果英语怎么讲?",
|
|||
|
|
"USER_INPUT": "苹果英语怎么讲?"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
print_section("1. 初始输入数据")
|
|||
|
|
print_data("input_data", input_data)
|
|||
|
|
|
|||
|
|
# 创建引擎(不使用logger,避免数据库依赖)
|
|||
|
|
engine = WorkflowEngine("test-workflow", workflow_data)
|
|||
|
|
|
|||
|
|
# 重写get_node_input方法,添加详细日志
|
|||
|
|
original_get_node_input = engine.get_node_input
|
|||
|
|
|
|||
|
|
def logged_get_node_input(node_id, node_outputs, active_edges=None):
|
|||
|
|
print_section(f"获取节点输入: {node_id}")
|
|||
|
|
print_data("node_outputs", node_outputs)
|
|||
|
|
result = original_get_node_input(node_id, node_outputs, active_edges)
|
|||
|
|
print_data(f"返回的input_data (for {node_id})", result)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
engine.get_node_input = logged_get_node_input
|
|||
|
|
|
|||
|
|
# 重写execute_node方法,添加详细日志
|
|||
|
|
original_execute_node = engine.execute_node
|
|||
|
|
|
|||
|
|
async def logged_execute_node(node, input_data):
|
|||
|
|
node_id = node.get('id')
|
|||
|
|
node_type = node.get('type')
|
|||
|
|
|
|||
|
|
print_section(f"执行节点: {node_id} ({node_type})")
|
|||
|
|
print_data("节点配置", node.get('data', {}))
|
|||
|
|
print_data("输入数据", input_data)
|
|||
|
|
|
|||
|
|
result = await original_execute_node(node, input_data)
|
|||
|
|
|
|||
|
|
print_data("执行结果", result)
|
|||
|
|
|
|||
|
|
# 如果是LLM节点,特别关注prompt和输出
|
|||
|
|
if node_type == 'llm':
|
|||
|
|
print_section(f"LLM节点详细分析: {node_id}")
|
|||
|
|
node_data = node.get('data', {})
|
|||
|
|
prompt = node_data.get('prompt', '')
|
|||
|
|
print_data("原始prompt", prompt)
|
|||
|
|
print_data("输入数据", input_data)
|
|||
|
|
|
|||
|
|
# 模拟prompt格式化逻辑
|
|||
|
|
if isinstance(input_data, dict):
|
|||
|
|
# 检查是否有嵌套的input字段
|
|||
|
|
nested_input = input_data.get('input')
|
|||
|
|
if isinstance(nested_input, dict):
|
|||
|
|
print("⚠️ 发现嵌套的input字段!")
|
|||
|
|
print_data("嵌套input内容", nested_input)
|
|||
|
|
# 尝试提取user_query
|
|||
|
|
user_query = None
|
|||
|
|
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
|
|||
|
|
if key in nested_input:
|
|||
|
|
user_query = nested_input[key]
|
|||
|
|
print(f"✅ 从嵌套input中提取到user_query: {key} = {user_query}")
|
|||
|
|
break
|
|||
|
|
else:
|
|||
|
|
# 从顶层提取
|
|||
|
|
user_query = None
|
|||
|
|
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
|
|||
|
|
if key in input_data:
|
|||
|
|
value = input_data[key]
|
|||
|
|
if isinstance(value, str):
|
|||
|
|
user_query = value
|
|||
|
|
print(f"✅ 从顶层提取到user_query: {key} = {user_query}")
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
engine.execute_node = logged_execute_node
|
|||
|
|
|
|||
|
|
# 执行工作流
|
|||
|
|
print_section("开始执行工作流")
|
|||
|
|
try:
|
|||
|
|
result = await engine.execute(input_data)
|
|||
|
|
print_section("工作流执行完成")
|
|||
|
|
print_data("最终结果", result)
|
|||
|
|
except Exception as e:
|
|||
|
|
print_section("执行出错")
|
|||
|
|
print(f"错误: {str(e)}")
|
|||
|
|
import traceback
|
|||
|
|
traceback.print_exc()
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
asyncio.run(test_workflow_data_flow())
|