Files
aiagent/test_workflow_tool.py
2026-01-23 09:49:45 +08:00

530 lines
18 KiB
Python
Executable File
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
工作流测试工具
支持通过Agent名称和用户输入来测试工作流执行
"""
import requests
import json
import time
import sys
import argparse
# API基础URL
BASE_URL = "http://localhost:8037"
def print_section(title):
"""打印分隔线"""
print("\n" + "=" * 80)
print(f" {title}")
print("=" * 80)
def print_info(message):
"""打印信息"""
print(f" {message}")
def print_success(message):
"""打印成功信息"""
print(f"{message}")
def print_error(message):
"""打印错误信息"""
print(f"{message}")
def print_warning(message):
"""打印警告信息"""
print(f"⚠️ {message}")
def login(username="admin", password="123456"):
"""
用户登录
Returns:
tuple: (success: bool, token: str or None, headers: dict or None)
"""
print_section("1. 用户登录")
login_data = {
"username": username,
"password": password
}
try:
response = requests.post(f"{BASE_URL}/api/v1/auth/login", data=login_data)
if response.status_code != 200:
print_error(f"登录失败: {response.status_code}")
print(f"响应: {response.text}")
return False, None, None
token = response.json().get("access_token")
if not token:
print_error("登录失败: 未获取到token")
return False, None, None
print_success(f"登录成功 (用户: {username})")
headers = {"Authorization": f"Bearer {token}"}
return True, token, headers
except requests.exceptions.ConnectionError:
print_error("无法连接到后端服务,请确保后端服务正在运行")
print_info(f"后端服务地址: {BASE_URL}")
return False, None, None
except Exception as e:
print_error(f"登录异常: {str(e)}")
return False, None, None
def find_agent_by_name(agent_name, headers):
"""
通过名称查找Agent
Args:
agent_name: Agent名称
headers: 请求头包含token
Returns:
tuple: (success: bool, agent: dict or None)
"""
print_section("2. 查找Agent")
print_info(f"搜索Agent: {agent_name}")
try:
# 搜索Agent
response = requests.get(
f"{BASE_URL}/api/v1/agents",
headers=headers,
params={"search": agent_name, "limit": 100}
)
if response.status_code != 200:
print_error(f"获取Agent列表失败: {response.status_code}")
print(f"响应: {response.text}")
return False, None
agents = response.json()
# 精确匹配名称
exact_match = None
for agent in agents:
if agent.get("name") == agent_name:
exact_match = agent
break
if exact_match:
agent_id = exact_match["id"]
agent_status = exact_match.get("status", "unknown")
print_success(f"找到Agent: {agent_name} (ID: {agent_id}, 状态: {agent_status})")
# 检查状态
if agent_status not in ["published", "running"]:
print_warning(f"Agent状态为 '{agent_status}',可能无法执行")
print_info("只有 'published''running' 状态的Agent可以执行")
return True, exact_match
# 如果没有精确匹配,显示相似的结果
if agents:
print_warning(f"未找到名称为 '{agent_name}' 的Agent")
print_info("找到以下相似的Agent:")
for agent in agents[:5]: # 只显示前5个
print(f" - {agent.get('name')} (ID: {agent.get('id')}, 状态: {agent.get('status')})")
else:
print_error(f"未找到任何Agent")
print_info("请检查Agent名称是否正确或先创建一个Agent")
return False, None
except Exception as e:
print_error(f"查找Agent异常: {str(e)}")
return False, None
def execute_agent(agent_id, user_input, headers):
"""
执行Agent工作流
Args:
agent_id: Agent ID
user_input: 用户输入内容
headers: 请求头
Returns:
tuple: (success: bool, execution_id: str or None)
"""
print_section("3. 执行Agent工作流")
print_info(f"用户输入: {user_input}")
input_data = {
"query": user_input,
"USER_INPUT": user_input
}
execution_data = {
"agent_id": agent_id,
"input_data": input_data
}
try:
response = requests.post(
f"{BASE_URL}/api/v1/executions",
headers=headers,
json=execution_data
)
if response.status_code != 201:
print_error(f"创建执行任务失败: {response.status_code}")
print(f"响应: {response.text}")
return False, None
execution = response.json()
execution_id = execution["id"]
status = execution.get("status")
print_success(f"执行任务已创建")
print_info(f"执行ID: {execution_id}")
print_info(f"状态: {status}")
return True, execution_id
except Exception as e:
print_error(f"创建执行任务异常: {str(e)}")
return False, None
def wait_for_completion(execution_id, headers, max_wait_time=300, poll_interval=2):
"""
等待执行完成
Args:
execution_id: 执行ID
headers: 请求头
max_wait_time: 最大等待时间(秒)
poll_interval: 轮询间隔(秒)
Returns:
tuple: (success: bool, status: str or None)
"""
print_section("4. 等待执行完成")
print_info(f"最大等待时间: {max_wait_time}")
print_info(f"轮询间隔: {poll_interval}")
start_time = time.time()
last_node = None
while True:
elapsed_time = time.time() - start_time
if elapsed_time > max_wait_time:
print_error(f"执行超时(超过{max_wait_time}秒)")
return False, "timeout"
try:
# 获取执行状态
status_response = requests.get(
f"{BASE_URL}/api/v1/executions/{execution_id}/status",
headers=headers
)
if status_response.status_code == 200:
status = status_response.json()
current_status = status.get("status")
progress = status.get("progress", 0)
current_node = status.get("current_node")
# 显示当前执行的节点
if current_node:
node_id = current_node.get("node_id", "unknown")
node_name = current_node.get("node_name", "unknown")
if node_id != last_node:
print_info(f"当前节点: {node_id} ({node_name})")
last_node = node_id
# 显示进度
elapsed_str = f"{int(elapsed_time)}"
print(f"⏳ 执行中... 状态: {current_status}, 进度: {progress}%, 耗时: {elapsed_str}", end="\r")
if current_status == "completed":
print() # 换行
print_success("执行完成!")
return True, "completed"
elif current_status == "failed":
print() # 换行
print_error("执行失败")
error = status.get("error", "未知错误")
print_error(f"错误信息: {error}")
return False, "failed"
time.sleep(poll_interval)
except KeyboardInterrupt:
print() # 换行
print_warning("用户中断执行")
return False, "interrupted"
except Exception as e:
print_error(f"获取执行状态异常: {str(e)}")
time.sleep(poll_interval)
def get_execution_result(execution_id, headers):
"""
获取执行结果
Args:
execution_id: 执行ID
headers: 请求头
Returns:
tuple: (success: bool, result: dict or None)
"""
print_section("5. 获取执行结果")
try:
response = requests.get(
f"{BASE_URL}/api/v1/executions/{execution_id}",
headers=headers
)
if response.status_code != 200:
print_error(f"获取执行结果失败: {response.status_code}")
print(f"响应: {response.text}")
return False, None
execution = response.json()
status = execution.get("status")
output_data = execution.get("output_data")
execution_time = execution.get("execution_time")
print_info(f"执行状态: {status}")
if execution_time:
print_info(f"执行时间: {execution_time}ms ({execution_time/1000:.2f}秒)")
print()
print("=" * 80)
print("输出结果:")
print("=" * 80)
if output_data:
if isinstance(output_data, dict):
# 如果 result 字段是字符串尝试解析它类似JSON节点的parse操作
if "result" in output_data and isinstance(output_data["result"], str):
try:
# 尝试使用 ast.literal_eval 解析Python字典字符串
import ast
parsed_result = ast.literal_eval(output_data["result"])
output_data = parsed_result
except:
# 如果解析失败尝试作为JSON解析
try:
parsed_result = json.loads(output_data["result"])
output_data = parsed_result
except:
pass
# 使用类似JSON节点的extract操作来提取文本
def json_extract(data, path):
"""类似JSON节点的extract操作使用路径提取数据"""
if not path or not isinstance(data, dict):
return None
# 支持 $.right.right.right 格式的路径
path = path.replace('$.', '').replace('$', '')
keys = path.split('.')
result = data
for key in keys:
if isinstance(result, dict) and key in result:
result = result[key]
else:
return None
return result
# 尝试使用路径提取:递归查找 right 字段直到找到字符串
def extract_text_by_path(data, depth=0, max_depth=10):
"""递归提取嵌套在right字段中的文本"""
if depth > max_depth:
return None
if isinstance(data, str):
# 如果是字符串且不是JSON格式返回它
if len(data) > 10 and not data.strip().startswith('{') and not data.strip().startswith('['):
return data
return None
if isinstance(data, dict):
# 优先查找 right 字段
if "right" in data:
right_value = data["right"]
# 如果 right 的值是字符串,直接返回
if isinstance(right_value, str) and len(right_value) > 10:
return right_value
# 否则递归查找
result = extract_text_by_path(right_value, depth + 1, max_depth)
if result:
return result
# 查找其他常见的输出字段
for key in ["output", "text", "content"]:
if key in data:
result = extract_text_by_path(data[key], depth + 1, max_depth)
if result:
return result
return None
return None
# 优先检查 result 字段JSON节点提取后的结果
if "result" in output_data and isinstance(output_data["result"], str):
text_output = output_data["result"]
else:
# 先尝试使用路径提取类似JSON节点的extract操作
# 尝试多个可能的路径
paths_to_try = [
"right.right.right", # 最常见的嵌套路径
"right.right",
"right",
"output",
"text",
"content"
]
text_output = None
for path in paths_to_try:
extracted = json_extract(output_data, f"$.{path}")
if extracted and isinstance(extracted, str) and len(extracted) > 10:
text_output = extracted
break
# 如果路径提取失败,使用递归提取
if not text_output:
text_output = extract_text_by_path(output_data)
if text_output and isinstance(text_output, str):
print(text_output)
print()
print_info(f"回答长度: {len(text_output)} 字符")
else:
# 如果无法提取显示格式化的JSON
print(json.dumps(output_data, ensure_ascii=False, indent=2))
else:
print(output_data)
else:
print("(无输出数据)")
print("=" * 80)
return True, execution
except Exception as e:
print_error(f"获取执行结果异常: {str(e)}")
return False, None
def main():
"""主函数"""
parser = argparse.ArgumentParser(
description="工作流测试工具 - 通过Agent名称和用户输入测试工作流执行",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 使用默认参数(交互式输入)
python3 test_workflow_tool.py
# 指定Agent名称和用户输入
python3 test_workflow_tool.py -a "智能需求分析与解决方案生成器" -i "生成一个导出androidlog的脚本"
# 指定用户名和密码
python3 test_workflow_tool.py -u admin -p 123456 -a "Agent名称" -i "用户输入"
"""
)
parser.add_argument(
"-a", "--agent-name",
type=str,
help="Agent名称如果不指定将交互式输入"
)
parser.add_argument(
"-i", "--input",
type=str,
help="用户输入内容(如果不指定,将交互式输入)"
)
parser.add_argument(
"-u", "--username",
type=str,
default="admin",
help="登录用户名(默认: admin"
)
parser.add_argument(
"-p", "--password",
type=str,
default="123456",
help="登录密码(默认: 123456"
)
parser.add_argument(
"--max-wait",
type=int,
default=300,
help="最大等待时间(秒,默认: 300"
)
parser.add_argument(
"--poll-interval",
type=float,
default=2.0,
help="轮询间隔(秒,默认: 2.0"
)
args = parser.parse_args()
# 打印标题
print("=" * 80)
print(" 工作流测试工具")
print("=" * 80)
# 1. 登录
success, token, headers = login(args.username, args.password)
if not success:
sys.exit(1)
# 2. 获取Agent名称
agent_name = args.agent_name
if not agent_name:
agent_name = input("\n请输入Agent名称: ").strip()
if not agent_name:
print_error("Agent名称不能为空")
sys.exit(1)
# 3. 查找Agent
success, agent = find_agent_by_name(agent_name, headers)
if not success or not agent:
sys.exit(1)
agent_id = agent["id"]
# 4. 获取用户输入
user_input = args.input
if not user_input:
user_input = input("\n请输入用户输入内容: ").strip()
if not user_input:
print_error("用户输入不能为空")
sys.exit(1)
# 5. 执行Agent
success, execution_id = execute_agent(agent_id, user_input, headers)
if not success:
sys.exit(1)
# 6. 等待执行完成
success, status = wait_for_completion(
execution_id,
headers,
max_wait_time=args.max_wait,
poll_interval=args.poll_interval
)
if not success:
if status == "timeout":
print_warning("执行超时,但可能仍在后台运行")
print_info(f"执行ID: {execution_id}")
print_info("可以通过API查询执行状态")
sys.exit(1)
# 7. 获取执行结果
success, result = get_execution_result(execution_id, headers)
if not success:
sys.exit(1)
# 完成
print_section("测试完成")
print_success("工作流测试成功完成!")
print_info(f"执行ID: {execution_id}")
print_info(f"Agent: {agent_name}")
if __name__ == "__main__":
main()