Files
aiagent/test_workflow_data_flow.py

170 lines
5.6 KiB
Python
Raw Normal View History

2026-01-20 09:40:16 +08:00
#!/usr/bin/env python3
"""
工作流数据流转测试脚本
用于诊断"答非所问"问题
"""
import asyncio
import json
import sys
import os
# 添加项目路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'backend'))
from app.services.workflow_engine import WorkflowEngine
from app.core.database import SessionLocal
def print_section(title):
"""打印分隔线"""
print("\n" + "=" * 80)
print(f" {title}")
print("=" * 80)
def print_data(label, data, indent=0):
"""格式化打印数据"""
prefix = " " * indent
print(f"{prefix}{label}:")
if isinstance(data, dict):
print(f"{prefix} {json.dumps(data, ensure_ascii=False, indent=2)}")
else:
print(f"{prefix} {data}")
async def test_workflow_data_flow():
"""测试工作流数据流转"""
print_section("工作流数据流转测试")
# 模拟一个简单的工作流
workflow_data = {
"nodes": [
{
"id": "start-1",
"type": "start",
"position": {"x": 100, "y": 100},
"data": {
"label": "开始"
}
},
{
"id": "llm-1",
"type": "llm",
"position": {"x": 300, "y": 100},
"data": {
"label": "LLM处理",
"provider": "deepseek",
"model": "deepseek-chat",
"prompt": "请处理用户请求。",
"temperature": 0.5,
"max_tokens": 1500
}
},
{
"id": "end-1",
"type": "end",
"position": {"x": 500, "y": 100},
"data": {
"label": "结束",
"output_format": "text"
}
}
],
"edges": [
{"id": "e1", "source": "start-1", "target": "llm-1"},
{"id": "e2", "source": "llm-1", "target": "end-1"}
]
}
# 模拟前端发送的输入数据
input_data = {
"query": "苹果英语怎么讲?",
"USER_INPUT": "苹果英语怎么讲?"
}
print_section("1. 初始输入数据")
print_data("input_data", input_data)
# 创建引擎不使用logger避免数据库依赖
engine = WorkflowEngine("test-workflow", workflow_data)
# 重写get_node_input方法添加详细日志
original_get_node_input = engine.get_node_input
def logged_get_node_input(node_id, node_outputs, active_edges=None):
print_section(f"获取节点输入: {node_id}")
print_data("node_outputs", node_outputs)
result = original_get_node_input(node_id, node_outputs, active_edges)
print_data(f"返回的input_data (for {node_id})", result)
return result
engine.get_node_input = logged_get_node_input
# 重写execute_node方法添加详细日志
original_execute_node = engine.execute_node
async def logged_execute_node(node, input_data):
node_id = node.get('id')
node_type = node.get('type')
print_section(f"执行节点: {node_id} ({node_type})")
print_data("节点配置", node.get('data', {}))
print_data("输入数据", input_data)
result = await original_execute_node(node, input_data)
print_data("执行结果", result)
# 如果是LLM节点特别关注prompt和输出
if node_type == 'llm':
print_section(f"LLM节点详细分析: {node_id}")
node_data = node.get('data', {})
prompt = node_data.get('prompt', '')
print_data("原始prompt", prompt)
print_data("输入数据", input_data)
# 模拟prompt格式化逻辑
if isinstance(input_data, dict):
# 检查是否有嵌套的input字段
nested_input = input_data.get('input')
if isinstance(nested_input, dict):
print("⚠️ 发现嵌套的input字段")
print_data("嵌套input内容", nested_input)
# 尝试提取user_query
user_query = None
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
if key in nested_input:
user_query = nested_input[key]
print(f"✅ 从嵌套input中提取到user_query: {key} = {user_query}")
break
else:
# 从顶层提取
user_query = None
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
if key in input_data:
value = input_data[key]
if isinstance(value, str):
user_query = value
print(f"✅ 从顶层提取到user_query: {key} = {user_query}")
break
return result
engine.execute_node = logged_execute_node
# 执行工作流
print_section("开始执行工作流")
try:
result = await engine.execute(input_data)
print_section("工作流执行完成")
print_data("最终结果", result)
except Exception as e:
print_section("执行出错")
print(f"错误: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(test_workflow_data_flow())