#!/usr/bin/env python3 """ 读取 agent_test_cases.json(或 --cases 指定),批量执行 Agent 并做简单断言。 规范见:(红头)agent测试用例文档.md """ from __future__ import annotations import argparse import json import os import sys import time from typing import Any, Dict, List, Optional, Tuple import requests DEFAULT_CASES_FILE = "agent_test_cases.json" def _ensure_utf8_stdio() -> None: if sys.platform != "win32": return for name in ("stdout", "stderr"): stream = getattr(sys, name, None) if stream is not None and hasattr(stream, "reconfigure"): try: stream.reconfigure(encoding="utf-8", errors="replace") except Exception: pass _ensure_utf8_stdio() def _login(base_url: str, username: str, password: str, timeout: float) -> Optional[Dict[str, str]]: r = requests.post( f"{base_url}/api/v1/auth/login", data={"username": username, "password": password}, headers={"Content-Type": "application/x-www-form-urlencoded"}, timeout=timeout, ) if r.status_code != 200: print(f"[FAIL] 登录 {r.status_code}: {r.text[:500]}") return None token = r.json().get("access_token") if not token: print("[FAIL] 登录响应无 access_token") return None return {"Authorization": f"Bearer {token}"} def _resolve_agent_id( base_url: str, headers: Dict[str, str], agent: Dict[str, Any], timeout: float, ) -> Optional[str]: if agent.get("id"): return str(agent["id"]) name = agent.get("name") if not name: print("[FAIL] agent 需包含 id 或 name") return None r = requests.get( f"{base_url}/api/v1/agents", headers=headers, params={"search": name, "limit": 100}, timeout=timeout, ) if r.status_code != 200: print(f"[FAIL] 查找 Agent {r.status_code}: {r.text[:500]}") return None agents: List[Dict[str, Any]] = r.json() or [] exact = [a for a in agents if (a.get("name") or "").strip() == name] pick = exact[0] if exact else (agents[0] if agents else None) if not pick: print(f"[FAIL] 未找到 Agent: {name}") return None print(f"[OK] Agent: {pick.get('name')} ({pick['id']}) status={pick.get('status')}") return str(pick["id"]) def _extract_output_text(output_data: Any) -> str: if output_data is None: return "" if isinstance(output_data, str): return output_data if isinstance(output_data, dict): for key in ("result", "output", "text", "content"): v = output_data.get(key) if v is not None: return v if isinstance(v, str) else str(v) return json.dumps(output_data, ensure_ascii=False) return str(output_data) def _poll_until_terminal( base_url: str, headers: Dict[str, str], execution_id: str, max_wait: float, poll_interval: float, timeout: float, ) -> Tuple[str, Optional[Dict[str, Any]]]: deadline = time.time() + max_wait last_status = "unknown" while time.time() < deadline: sr = requests.get( f"{base_url}/api/v1/executions/{execution_id}/status", headers=headers, timeout=timeout, ) if sr.status_code != 200: print(f"[WARN] status {sr.status_code}: {sr.text[:300]}") time.sleep(poll_interval) continue body = sr.json() last_status = str(body.get("status") or "") if last_status in ("completed", "failed", "cancelled", "awaiting_approval"): break time.sleep(poll_interval) dr = requests.get( f"{base_url}/api/v1/executions/{execution_id}", headers=headers, timeout=timeout, ) if dr.status_code != 200: print(f"[FAIL] 获取执行详情 {dr.status_code}: {dr.text[:500]}") return last_status, None return last_status, dr.json() def _check_expect(text: str, status: str, detail: Optional[Dict[str, Any]], expect: Dict[str, Any]) -> List[str]: errors: List[str] = [] want_status = expect.get("status", "completed") if status != want_status: errors.append(f"状态期望 {want_status!r},实际 {status!r}") if detail and status != "completed": em = detail.get("error_message") if em: errors.append(f"error_message: {em[:500]}") if not expect: return errors ci = bool(expect.get("case_insensitive")) hay = text if not ci else text.lower() for sub in expect.get("output_contains") or []: s = sub if not ci else sub.lower() if s not in hay: errors.append(f"输出应包含 {sub!r}") for sub in expect.get("output_not_contains") or []: s = sub if not ci else sub.lower() if s in hay: errors.append(f"输出不应包含 {sub!r}") return errors def _run_one_case( base_url: str, headers: Dict[str, str], defaults: Dict[str, Any], case: Dict[str, Any], ) -> bool: cid = case.get("id", "(no-id)") title = case.get("name", "") print("\n" + "-" * 60) print(f"CASE {cid}" + (f" — {title}" if title else "")) req_timeout = float(case.get("request_timeout_sec", defaults.get("request_timeout_sec", 120))) max_wait = float(case.get("max_wait_sec", defaults.get("max_wait_sec", 300))) poll_iv = float(case.get("poll_interval_sec", defaults.get("poll_interval_sec", 2))) agent_id = _resolve_agent_id(base_url, headers, case.get("agent") or {}, req_timeout) if not agent_id: return False message = case.get("message") if message is None: print("[FAIL] 缺少 message") return False input_data: Dict[str, Any] = {"query": message, "USER_INPUT": message} extra = case.get("input_extra") if isinstance(extra, dict): input_data = {**extra, **input_data} er = requests.post( f"{base_url}/api/v1/executions", headers=headers, json={"agent_id": agent_id, "input_data": input_data}, timeout=req_timeout, ) if er.status_code != 201: print(f"[FAIL] 创建执行 {er.status_code}: {er.text[:800]}") return False ex = er.json() eid = ex["id"] print(f"[OK] execution_id={eid}") st, detail = _poll_until_terminal(base_url, headers, eid, max_wait, poll_iv, req_timeout) text = _extract_output_text((detail or {}).get("output_data")) expect = case.get("expect") or {} errs = _check_expect(text, st, detail, expect) if errs: for e in errs: print(f"[FAIL] {e}") if text: print("[OUTPUT_PREVIEW]") print(text[:2000] + ("…" if len(text) > 2000 else "")) return False print(f"[OK] 通过 status={st}") if text: print("[OUTPUT_PREVIEW]") print(text[:1200] + ("…" if len(text) > 1200 else "")) return True def main() -> int: ap = argparse.ArgumentParser(description="批量运行 Agent 测试用例(JSON)") ap.add_argument( "--cases", default=os.environ.get("AGENT_TEST_CASES", DEFAULT_CASES_FILE), help=f"用例 JSON 路径(默认 {DEFAULT_CASES_FILE})", ) ap.add_argument("--username", default=None) ap.add_argument("--password", default=None) ap.add_argument("--base-url", default=None, help="覆盖 defaults.base_url / API_BASE_URL") args = ap.parse_args() path = args.cases if not os.path.isfile(path): print(f"[FAIL] 找不到用例文件: {path}") print("请先按 (红头)agent测试用例文档.md 创建 JSON,或复制示例为 agent_test_cases.json") return 2 with open(path, encoding="utf-8") as f: spec = json.load(f) defaults = spec.get("defaults") or {} base_url = ( args.base_url or defaults.get("base_url") or os.environ.get("API_BASE_URL", "http://localhost:8037") ) base_url = base_url.rstrip("/") username = args.username or defaults.get("username", "admin") password = args.password or defaults.get("password", "123456") req_timeout = float(defaults.get("request_timeout_sec", 120)) print(f"API: {base_url}") print(f"用例文件: {path}") headers = _login(base_url, username, password, req_timeout) if not headers: return 3 cases: List[Dict[str, Any]] = spec.get("cases") or [] if not cases: print("[FAIL] cases 为空") return 4 ok, skip, fail = 0, 0, 0 for case in cases: if case.get("enabled") is False: print(f"\n[SKIP] {case.get('id', '?')}") skip += 1 continue if _run_one_case(base_url, headers, defaults, case): ok += 1 else: fail += 1 print("\n" + "=" * 60) print(f"汇总: 通过 {ok} 失败 {fail} 跳过 {skip}") return 0 if fail == 0 else 1 if __name__ == "__main__": sys.exit(main())