Files
aiagent/test_workflow_data_flow.py
2026-01-20 09:40:16 +08:00

170 lines
5.6 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
工作流数据流转测试脚本
用于诊断"答非所问"问题
"""
import asyncio
import json
import sys
import os
# 添加项目路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'backend'))
from app.services.workflow_engine import WorkflowEngine
from app.core.database import SessionLocal
def print_section(title):
"""打印分隔线"""
print("\n" + "=" * 80)
print(f" {title}")
print("=" * 80)
def print_data(label, data, indent=0):
"""格式化打印数据"""
prefix = " " * indent
print(f"{prefix}{label}:")
if isinstance(data, dict):
print(f"{prefix} {json.dumps(data, ensure_ascii=False, indent=2)}")
else:
print(f"{prefix} {data}")
async def test_workflow_data_flow():
"""测试工作流数据流转"""
print_section("工作流数据流转测试")
# 模拟一个简单的工作流
workflow_data = {
"nodes": [
{
"id": "start-1",
"type": "start",
"position": {"x": 100, "y": 100},
"data": {
"label": "开始"
}
},
{
"id": "llm-1",
"type": "llm",
"position": {"x": 300, "y": 100},
"data": {
"label": "LLM处理",
"provider": "deepseek",
"model": "deepseek-chat",
"prompt": "请处理用户请求。",
"temperature": 0.5,
"max_tokens": 1500
}
},
{
"id": "end-1",
"type": "end",
"position": {"x": 500, "y": 100},
"data": {
"label": "结束",
"output_format": "text"
}
}
],
"edges": [
{"id": "e1", "source": "start-1", "target": "llm-1"},
{"id": "e2", "source": "llm-1", "target": "end-1"}
]
}
# 模拟前端发送的输入数据
input_data = {
"query": "苹果英语怎么讲?",
"USER_INPUT": "苹果英语怎么讲?"
}
print_section("1. 初始输入数据")
print_data("input_data", input_data)
# 创建引擎不使用logger避免数据库依赖
engine = WorkflowEngine("test-workflow", workflow_data)
# 重写get_node_input方法添加详细日志
original_get_node_input = engine.get_node_input
def logged_get_node_input(node_id, node_outputs, active_edges=None):
print_section(f"获取节点输入: {node_id}")
print_data("node_outputs", node_outputs)
result = original_get_node_input(node_id, node_outputs, active_edges)
print_data(f"返回的input_data (for {node_id})", result)
return result
engine.get_node_input = logged_get_node_input
# 重写execute_node方法添加详细日志
original_execute_node = engine.execute_node
async def logged_execute_node(node, input_data):
node_id = node.get('id')
node_type = node.get('type')
print_section(f"执行节点: {node_id} ({node_type})")
print_data("节点配置", node.get('data', {}))
print_data("输入数据", input_data)
result = await original_execute_node(node, input_data)
print_data("执行结果", result)
# 如果是LLM节点特别关注prompt和输出
if node_type == 'llm':
print_section(f"LLM节点详细分析: {node_id}")
node_data = node.get('data', {})
prompt = node_data.get('prompt', '')
print_data("原始prompt", prompt)
print_data("输入数据", input_data)
# 模拟prompt格式化逻辑
if isinstance(input_data, dict):
# 检查是否有嵌套的input字段
nested_input = input_data.get('input')
if isinstance(nested_input, dict):
print("⚠️ 发现嵌套的input字段")
print_data("嵌套input内容", nested_input)
# 尝试提取user_query
user_query = None
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
if key in nested_input:
user_query = nested_input[key]
print(f"✅ 从嵌套input中提取到user_query: {key} = {user_query}")
break
else:
# 从顶层提取
user_query = None
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
if key in input_data:
value = input_data[key]
if isinstance(value, str):
user_query = value
print(f"✅ 从顶层提取到user_query: {key} = {user_query}")
break
return result
engine.execute_node = logged_execute_node
# 执行工作流
print_section("开始执行工作流")
try:
result = await engine.execute(input_data)
print_section("工作流执行完成")
print_data("最终结果", result)
except Exception as e:
print_section("执行出错")
print(f"错误: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(test_workflow_data_flow())