#!/usr/bin/env python3 """ 工作流测试工具 支持通过Agent名称和用户输入来测试工作流执行 """ import requests import json import time import sys import argparse # API基础URL BASE_URL = "http://localhost:8037" def print_section(title): """打印分隔线""" print("\n" + "=" * 80) print(f" {title}") print("=" * 80) def print_info(message): """打印信息""" print(f"ℹ️ {message}") def print_success(message): """打印成功信息""" print(f"✅ {message}") def print_error(message): """打印错误信息""" print(f"❌ {message}") def print_warning(message): """打印警告信息""" print(f"⚠️ {message}") def login(username="admin", password="123456"): """ 用户登录 Returns: tuple: (success: bool, token: str or None, headers: dict or None) """ print_section("1. 用户登录") login_data = { "username": username, "password": password } try: response = requests.post(f"{BASE_URL}/api/v1/auth/login", data=login_data) if response.status_code != 200: print_error(f"登录失败: {response.status_code}") print(f"响应: {response.text}") return False, None, None token = response.json().get("access_token") if not token: print_error("登录失败: 未获取到token") return False, None, None print_success(f"登录成功 (用户: {username})") headers = {"Authorization": f"Bearer {token}"} return True, token, headers except requests.exceptions.ConnectionError: print_error("无法连接到后端服务,请确保后端服务正在运行") print_info(f"后端服务地址: {BASE_URL}") return False, None, None except Exception as e: print_error(f"登录异常: {str(e)}") return False, None, None def find_agent_by_name(agent_name, headers): """ 通过名称查找Agent Args: agent_name: Agent名称 headers: 请求头(包含token) Returns: tuple: (success: bool, agent: dict or None) """ print_section("2. 查找Agent") print_info(f"搜索Agent: {agent_name}") try: # 搜索Agent response = requests.get( f"{BASE_URL}/api/v1/agents", headers=headers, params={"search": agent_name, "limit": 100} ) if response.status_code != 200: print_error(f"获取Agent列表失败: {response.status_code}") print(f"响应: {response.text}") return False, None agents = response.json() # 精确匹配名称 exact_match = None for agent in agents: if agent.get("name") == agent_name: exact_match = agent break if exact_match: agent_id = exact_match["id"] agent_status = exact_match.get("status", "unknown") print_success(f"找到Agent: {agent_name} (ID: {agent_id}, 状态: {agent_status})") # 检查状态 if agent_status not in ["published", "running"]: print_warning(f"Agent状态为 '{agent_status}',可能无法执行") print_info("只有 'published' 或 'running' 状态的Agent可以执行") return True, exact_match # 如果没有精确匹配,显示相似的结果 if agents: print_warning(f"未找到名称为 '{agent_name}' 的Agent") print_info("找到以下相似的Agent:") for agent in agents[:5]: # 只显示前5个 print(f" - {agent.get('name')} (ID: {agent.get('id')}, 状态: {agent.get('status')})") else: print_error(f"未找到任何Agent") print_info("请检查Agent名称是否正确,或先创建一个Agent") return False, None except Exception as e: print_error(f"查找Agent异常: {str(e)}") return False, None def execute_agent(agent_id, user_input, headers): """ 执行Agent工作流 Args: agent_id: Agent ID user_input: 用户输入内容 headers: 请求头 Returns: tuple: (success: bool, execution_id: str or None) """ print_section("3. 执行Agent工作流") print_info(f"用户输入: {user_input}") input_data = { "query": user_input, "USER_INPUT": user_input } execution_data = { "agent_id": agent_id, "input_data": input_data } try: response = requests.post( f"{BASE_URL}/api/v1/executions", headers=headers, json=execution_data ) if response.status_code != 201: print_error(f"创建执行任务失败: {response.status_code}") print(f"响应: {response.text}") return False, None execution = response.json() execution_id = execution["id"] status = execution.get("status") print_success(f"执行任务已创建") print_info(f"执行ID: {execution_id}") print_info(f"状态: {status}") return True, execution_id except Exception as e: print_error(f"创建执行任务异常: {str(e)}") return False, None def wait_for_completion(execution_id, headers, max_wait_time=300, poll_interval=2): """ 等待执行完成 Args: execution_id: 执行ID headers: 请求头 max_wait_time: 最大等待时间(秒) poll_interval: 轮询间隔(秒) Returns: tuple: (success: bool, status: str or None) """ print_section("4. 等待执行完成") print_info(f"最大等待时间: {max_wait_time}秒") print_info(f"轮询间隔: {poll_interval}秒") start_time = time.time() last_node = None while True: elapsed_time = time.time() - start_time if elapsed_time > max_wait_time: print_error(f"执行超时(超过{max_wait_time}秒)") return False, "timeout" try: # 获取执行状态 status_response = requests.get( f"{BASE_URL}/api/v1/executions/{execution_id}/status", headers=headers ) if status_response.status_code == 200: status = status_response.json() current_status = status.get("status") progress = status.get("progress", 0) current_node = status.get("current_node") # 显示当前执行的节点 if current_node: node_id = current_node.get("node_id", "unknown") node_name = current_node.get("node_name", "unknown") if node_id != last_node: print_info(f"当前节点: {node_id} ({node_name})") last_node = node_id # 显示进度 elapsed_str = f"{int(elapsed_time)}秒" print(f"⏳ 执行中... 状态: {current_status}, 进度: {progress}%, 耗时: {elapsed_str}", end="\r") if current_status == "completed": print() # 换行 print_success("执行完成!") return True, "completed" elif current_status == "failed": print() # 换行 print_error("执行失败") error = status.get("error", "未知错误") print_error(f"错误信息: {error}") return False, "failed" time.sleep(poll_interval) except KeyboardInterrupt: print() # 换行 print_warning("用户中断执行") return False, "interrupted" except Exception as e: print_error(f"获取执行状态异常: {str(e)}") time.sleep(poll_interval) def get_execution_result(execution_id, headers): """ 获取执行结果 Args: execution_id: 执行ID headers: 请求头 Returns: tuple: (success: bool, result: dict or None) """ print_section("5. 获取执行结果") try: response = requests.get( f"{BASE_URL}/api/v1/executions/{execution_id}", headers=headers ) if response.status_code != 200: print_error(f"获取执行结果失败: {response.status_code}") print(f"响应: {response.text}") return False, None execution = response.json() status = execution.get("status") output_data = execution.get("output_data") execution_time = execution.get("execution_time") print_info(f"执行状态: {status}") if execution_time: print_info(f"执行时间: {execution_time}ms ({execution_time/1000:.2f}秒)") print() print("=" * 80) print("输出结果:") print("=" * 80) if output_data: if isinstance(output_data, dict): # 如果 result 字段是字符串,尝试解析它(类似JSON节点的parse操作) if "result" in output_data and isinstance(output_data["result"], str): try: # 尝试使用 ast.literal_eval 解析Python字典字符串 import ast parsed_result = ast.literal_eval(output_data["result"]) output_data = parsed_result except: # 如果解析失败,尝试作为JSON解析 try: parsed_result = json.loads(output_data["result"]) output_data = parsed_result except: pass # 使用类似JSON节点的extract操作来提取文本 def json_extract(data, path): """类似JSON节点的extract操作,使用路径提取数据""" if not path or not isinstance(data, dict): return None # 支持 $.right.right.right 格式的路径 path = path.replace('$.', '').replace('$', '') keys = path.split('.') result = data for key in keys: if isinstance(result, dict) and key in result: result = result[key] else: return None return result # 尝试使用路径提取:递归查找 right 字段直到找到字符串 def extract_text_by_path(data, depth=0, max_depth=10): """递归提取嵌套在right字段中的文本""" if depth > max_depth: return None if isinstance(data, str): # 如果是字符串且不是JSON格式,返回它 if len(data) > 10 and not data.strip().startswith('{') and not data.strip().startswith('['): return data return None if isinstance(data, dict): # 优先查找 right 字段 if "right" in data: right_value = data["right"] # 如果 right 的值是字符串,直接返回 if isinstance(right_value, str) and len(right_value) > 10: return right_value # 否则递归查找 result = extract_text_by_path(right_value, depth + 1, max_depth) if result: return result # 查找其他常见的输出字段 for key in ["output", "text", "content"]: if key in data: result = extract_text_by_path(data[key], depth + 1, max_depth) if result: return result return None return None # 优先检查 result 字段(JSON节点提取后的结果) if "result" in output_data and isinstance(output_data["result"], str): text_output = output_data["result"] else: # 先尝试使用路径提取(类似JSON节点的extract操作) # 尝试多个可能的路径 paths_to_try = [ "right.right.right", # 最常见的嵌套路径 "right.right", "right", "output", "text", "content" ] text_output = None for path in paths_to_try: extracted = json_extract(output_data, f"$.{path}") if extracted and isinstance(extracted, str) and len(extracted) > 10: text_output = extracted break # 如果路径提取失败,使用递归提取 if not text_output: text_output = extract_text_by_path(output_data) if text_output and isinstance(text_output, str): print(text_output) print() print_info(f"回答长度: {len(text_output)} 字符") else: # 如果无法提取,显示格式化的JSON print(json.dumps(output_data, ensure_ascii=False, indent=2)) else: print(output_data) else: print("(无输出数据)") print("=" * 80) return True, execution except Exception as e: print_error(f"获取执行结果异常: {str(e)}") return False, None def main(): """主函数""" parser = argparse.ArgumentParser( description="工作流测试工具 - 通过Agent名称和用户输入测试工作流执行", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 示例: # 使用默认参数(交互式输入) python3 test_workflow_tool.py # 指定Agent名称和用户输入 python3 test_workflow_tool.py -a "智能需求分析与解决方案生成器" -i "生成一个导出androidlog的脚本" # 指定用户名和密码 python3 test_workflow_tool.py -u admin -p 123456 -a "Agent名称" -i "用户输入" """ ) parser.add_argument( "-a", "--agent-name", type=str, help="Agent名称(如果不指定,将交互式输入)" ) parser.add_argument( "-i", "--input", type=str, help="用户输入内容(如果不指定,将交互式输入)" ) parser.add_argument( "-u", "--username", type=str, default="admin", help="登录用户名(默认: admin)" ) parser.add_argument( "-p", "--password", type=str, default="123456", help="登录密码(默认: 123456)" ) parser.add_argument( "--max-wait", type=int, default=300, help="最大等待时间(秒,默认: 300)" ) parser.add_argument( "--poll-interval", type=float, default=2.0, help="轮询间隔(秒,默认: 2.0)" ) args = parser.parse_args() # 打印标题 print("=" * 80) print(" 工作流测试工具") print("=" * 80) # 1. 登录 success, token, headers = login(args.username, args.password) if not success: sys.exit(1) # 2. 获取Agent名称 agent_name = args.agent_name if not agent_name: agent_name = input("\n请输入Agent名称: ").strip() if not agent_name: print_error("Agent名称不能为空") sys.exit(1) # 3. 查找Agent success, agent = find_agent_by_name(agent_name, headers) if not success or not agent: sys.exit(1) agent_id = agent["id"] # 4. 获取用户输入 user_input = args.input if not user_input: user_input = input("\n请输入用户输入内容: ").strip() if not user_input: print_error("用户输入不能为空") sys.exit(1) # 5. 执行Agent success, execution_id = execute_agent(agent_id, user_input, headers) if not success: sys.exit(1) # 6. 等待执行完成 success, status = wait_for_completion( execution_id, headers, max_wait_time=args.max_wait, poll_interval=args.poll_interval ) if not success: if status == "timeout": print_warning("执行超时,但可能仍在后台运行") print_info(f"执行ID: {execution_id}") print_info("可以通过API查询执行状态") sys.exit(1) # 7. 获取执行结果 success, result = get_execution_result(execution_id, headers) if not success: sys.exit(1) # 完成 print_section("测试完成") print_success("工作流测试成功完成!") print_info(f"执行ID: {execution_id}") print_info(f"Agent: {agent_name}") if __name__ == "__main__": main()