170 lines
5.6 KiB
Python
Executable File
170 lines
5.6 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
工作流数据流转测试脚本
|
||
用于诊断"答非所问"问题
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import sys
|
||
import os
|
||
|
||
# 添加项目路径
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'backend'))
|
||
|
||
from app.services.workflow_engine import WorkflowEngine
|
||
from app.core.database import SessionLocal
|
||
|
||
|
||
def print_section(title):
|
||
"""打印分隔线"""
|
||
print("\n" + "=" * 80)
|
||
print(f" {title}")
|
||
print("=" * 80)
|
||
|
||
|
||
def print_data(label, data, indent=0):
|
||
"""格式化打印数据"""
|
||
prefix = " " * indent
|
||
print(f"{prefix}{label}:")
|
||
if isinstance(data, dict):
|
||
print(f"{prefix} {json.dumps(data, ensure_ascii=False, indent=2)}")
|
||
else:
|
||
print(f"{prefix} {data}")
|
||
|
||
|
||
async def test_workflow_data_flow():
|
||
"""测试工作流数据流转"""
|
||
print_section("工作流数据流转测试")
|
||
|
||
# 模拟一个简单的工作流
|
||
workflow_data = {
|
||
"nodes": [
|
||
{
|
||
"id": "start-1",
|
||
"type": "start",
|
||
"position": {"x": 100, "y": 100},
|
||
"data": {
|
||
"label": "开始"
|
||
}
|
||
},
|
||
{
|
||
"id": "llm-1",
|
||
"type": "llm",
|
||
"position": {"x": 300, "y": 100},
|
||
"data": {
|
||
"label": "LLM处理",
|
||
"provider": "deepseek",
|
||
"model": "deepseek-chat",
|
||
"prompt": "请处理用户请求。",
|
||
"temperature": 0.5,
|
||
"max_tokens": 1500
|
||
}
|
||
},
|
||
{
|
||
"id": "end-1",
|
||
"type": "end",
|
||
"position": {"x": 500, "y": 100},
|
||
"data": {
|
||
"label": "结束",
|
||
"output_format": "text"
|
||
}
|
||
}
|
||
],
|
||
"edges": [
|
||
{"id": "e1", "source": "start-1", "target": "llm-1"},
|
||
{"id": "e2", "source": "llm-1", "target": "end-1"}
|
||
]
|
||
}
|
||
|
||
# 模拟前端发送的输入数据
|
||
input_data = {
|
||
"query": "苹果英语怎么讲?",
|
||
"USER_INPUT": "苹果英语怎么讲?"
|
||
}
|
||
|
||
print_section("1. 初始输入数据")
|
||
print_data("input_data", input_data)
|
||
|
||
# 创建引擎(不使用logger,避免数据库依赖)
|
||
engine = WorkflowEngine("test-workflow", workflow_data)
|
||
|
||
# 重写get_node_input方法,添加详细日志
|
||
original_get_node_input = engine.get_node_input
|
||
|
||
def logged_get_node_input(node_id, node_outputs, active_edges=None):
|
||
print_section(f"获取节点输入: {node_id}")
|
||
print_data("node_outputs", node_outputs)
|
||
result = original_get_node_input(node_id, node_outputs, active_edges)
|
||
print_data(f"返回的input_data (for {node_id})", result)
|
||
return result
|
||
|
||
engine.get_node_input = logged_get_node_input
|
||
|
||
# 重写execute_node方法,添加详细日志
|
||
original_execute_node = engine.execute_node
|
||
|
||
async def logged_execute_node(node, input_data):
|
||
node_id = node.get('id')
|
||
node_type = node.get('type')
|
||
|
||
print_section(f"执行节点: {node_id} ({node_type})")
|
||
print_data("节点配置", node.get('data', {}))
|
||
print_data("输入数据", input_data)
|
||
|
||
result = await original_execute_node(node, input_data)
|
||
|
||
print_data("执行结果", result)
|
||
|
||
# 如果是LLM节点,特别关注prompt和输出
|
||
if node_type == 'llm':
|
||
print_section(f"LLM节点详细分析: {node_id}")
|
||
node_data = node.get('data', {})
|
||
prompt = node_data.get('prompt', '')
|
||
print_data("原始prompt", prompt)
|
||
print_data("输入数据", input_data)
|
||
|
||
# 模拟prompt格式化逻辑
|
||
if isinstance(input_data, dict):
|
||
# 检查是否有嵌套的input字段
|
||
nested_input = input_data.get('input')
|
||
if isinstance(nested_input, dict):
|
||
print("⚠️ 发现嵌套的input字段!")
|
||
print_data("嵌套input内容", nested_input)
|
||
# 尝试提取user_query
|
||
user_query = None
|
||
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
|
||
if key in nested_input:
|
||
user_query = nested_input[key]
|
||
print(f"✅ 从嵌套input中提取到user_query: {key} = {user_query}")
|
||
break
|
||
else:
|
||
# 从顶层提取
|
||
user_query = None
|
||
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
|
||
if key in input_data:
|
||
value = input_data[key]
|
||
if isinstance(value, str):
|
||
user_query = value
|
||
print(f"✅ 从顶层提取到user_query: {key} = {user_query}")
|
||
break
|
||
|
||
return result
|
||
|
||
engine.execute_node = logged_execute_node
|
||
|
||
# 执行工作流
|
||
print_section("开始执行工作流")
|
||
try:
|
||
result = await engine.execute(input_data)
|
||
print_section("工作流执行完成")
|
||
print_data("最终结果", result)
|
||
except Exception as e:
|
||
print_section("执行出错")
|
||
print(f"错误: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(test_workflow_data_flow())
|