Files
aiagent/scripts/tools/run_agent_test_cases.py

280 lines
8.8 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
读取 agent_test_cases.json --cases 指定批量执行 Agent 并做简单断言
规范见(红头)agent测试用例文档.md
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from typing import Any, Dict, List, Optional, Tuple
import requests
DEFAULT_CASES_FILE = "agent_test_cases.json"
def _ensure_utf8_stdio() -> None:
if sys.platform != "win32":
return
for name in ("stdout", "stderr"):
stream = getattr(sys, name, None)
if stream is not None and hasattr(stream, "reconfigure"):
try:
stream.reconfigure(encoding="utf-8", errors="replace")
except Exception:
pass
_ensure_utf8_stdio()
def _login(base_url: str, username: str, password: str, timeout: float) -> Optional[Dict[str, str]]:
r = requests.post(
f"{base_url}/api/v1/auth/login",
data={"username": username, "password": password},
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=timeout,
)
if r.status_code != 200:
print(f"[FAIL] 登录 {r.status_code}: {r.text[:500]}")
return None
token = r.json().get("access_token")
if not token:
print("[FAIL] 登录响应无 access_token")
return None
return {"Authorization": f"Bearer {token}"}
def _resolve_agent_id(
base_url: str,
headers: Dict[str, str],
agent: Dict[str, Any],
timeout: float,
) -> Optional[str]:
if agent.get("id"):
return str(agent["id"])
name = agent.get("name")
if not name:
print("[FAIL] agent 需包含 id 或 name")
return None
r = requests.get(
f"{base_url}/api/v1/agents",
headers=headers,
params={"search": name, "limit": 100},
timeout=timeout,
)
if r.status_code != 200:
print(f"[FAIL] 查找 Agent {r.status_code}: {r.text[:500]}")
return None
agents: List[Dict[str, Any]] = r.json() or []
exact = [a for a in agents if (a.get("name") or "").strip() == name]
pick = exact[0] if exact else (agents[0] if agents else None)
if not pick:
print(f"[FAIL] 未找到 Agent: {name}")
return None
print(f"[OK] Agent: {pick.get('name')} ({pick['id']}) status={pick.get('status')}")
return str(pick["id"])
def _extract_output_text(output_data: Any) -> str:
if output_data is None:
return ""
if isinstance(output_data, str):
return output_data
if isinstance(output_data, dict):
for key in ("result", "output", "text", "content"):
v = output_data.get(key)
if v is not None:
return v if isinstance(v, str) else str(v)
return json.dumps(output_data, ensure_ascii=False)
return str(output_data)
def _poll_until_terminal(
base_url: str,
headers: Dict[str, str],
execution_id: str,
max_wait: float,
poll_interval: float,
timeout: float,
) -> Tuple[str, Optional[Dict[str, Any]]]:
deadline = time.time() + max_wait
last_status = "unknown"
while time.time() < deadline:
sr = requests.get(
f"{base_url}/api/v1/executions/{execution_id}/status",
headers=headers,
timeout=timeout,
)
if sr.status_code != 200:
print(f"[WARN] status {sr.status_code}: {sr.text[:300]}")
time.sleep(poll_interval)
continue
body = sr.json()
last_status = str(body.get("status") or "")
if last_status in ("completed", "failed", "cancelled", "awaiting_approval"):
break
time.sleep(poll_interval)
dr = requests.get(
f"{base_url}/api/v1/executions/{execution_id}",
headers=headers,
timeout=timeout,
)
if dr.status_code != 200:
print(f"[FAIL] 获取执行详情 {dr.status_code}: {dr.text[:500]}")
return last_status, None
return last_status, dr.json()
def _check_expect(text: str, status: str, detail: Optional[Dict[str, Any]], expect: Dict[str, Any]) -> List[str]:
errors: List[str] = []
want_status = expect.get("status", "completed")
if status != want_status:
errors.append(f"状态期望 {want_status!r},实际 {status!r}")
if detail and status != "completed":
em = detail.get("error_message")
if em:
errors.append(f"error_message: {em[:500]}")
if not expect:
return errors
ci = bool(expect.get("case_insensitive"))
hay = text if not ci else text.lower()
for sub in expect.get("output_contains") or []:
s = sub if not ci else sub.lower()
if s not in hay:
errors.append(f"输出应包含 {sub!r}")
for sub in expect.get("output_not_contains") or []:
s = sub if not ci else sub.lower()
if s in hay:
errors.append(f"输出不应包含 {sub!r}")
return errors
def _run_one_case(
base_url: str,
headers: Dict[str, str],
defaults: Dict[str, Any],
case: Dict[str, Any],
) -> bool:
cid = case.get("id", "(no-id)")
title = case.get("name", "")
print("\n" + "-" * 60)
print(f"CASE {cid}" + (f"{title}" if title else ""))
req_timeout = float(case.get("request_timeout_sec", defaults.get("request_timeout_sec", 120)))
max_wait = float(case.get("max_wait_sec", defaults.get("max_wait_sec", 300)))
poll_iv = float(case.get("poll_interval_sec", defaults.get("poll_interval_sec", 2)))
agent_id = _resolve_agent_id(base_url, headers, case.get("agent") or {}, req_timeout)
if not agent_id:
return False
message = case.get("message")
if message is None:
print("[FAIL] 缺少 message")
return False
input_data: Dict[str, Any] = {"query": message, "USER_INPUT": message}
extra = case.get("input_extra")
if isinstance(extra, dict):
input_data = {**extra, **input_data}
er = requests.post(
f"{base_url}/api/v1/executions",
headers=headers,
json={"agent_id": agent_id, "input_data": input_data},
timeout=req_timeout,
)
if er.status_code != 201:
print(f"[FAIL] 创建执行 {er.status_code}: {er.text[:800]}")
return False
ex = er.json()
eid = ex["id"]
print(f"[OK] execution_id={eid}")
st, detail = _poll_until_terminal(base_url, headers, eid, max_wait, poll_iv, req_timeout)
text = _extract_output_text((detail or {}).get("output_data"))
expect = case.get("expect") or {}
errs = _check_expect(text, st, detail, expect)
if errs:
for e in errs:
print(f"[FAIL] {e}")
if text:
print("[OUTPUT_PREVIEW]")
print(text[:2000] + ("" if len(text) > 2000 else ""))
return False
print(f"[OK] 通过 status={st}")
if text:
print("[OUTPUT_PREVIEW]")
print(text[:1200] + ("" if len(text) > 1200 else ""))
return True
def main() -> int:
ap = argparse.ArgumentParser(description="批量运行 Agent 测试用例JSON")
ap.add_argument(
"--cases",
default=os.environ.get("AGENT_TEST_CASES", DEFAULT_CASES_FILE),
help=f"用例 JSON 路径(默认 {DEFAULT_CASES_FILE}",
)
ap.add_argument("--username", default=None)
ap.add_argument("--password", default=None)
ap.add_argument("--base-url", default=None, help="覆盖 defaults.base_url / API_BASE_URL")
args = ap.parse_args()
path = args.cases
if not os.path.isfile(path):
print(f"[FAIL] 找不到用例文件: {path}")
print("请先按 (红头)agent测试用例文档.md 创建 JSON或复制示例为 agent_test_cases.json")
return 2
with open(path, encoding="utf-8") as f:
spec = json.load(f)
defaults = spec.get("defaults") or {}
base_url = (
args.base_url
or defaults.get("base_url")
or os.environ.get("API_BASE_URL", "http://localhost:8037")
)
base_url = base_url.rstrip("/")
username = args.username or defaults.get("username", "admin")
password = args.password or defaults.get("password", "123456")
req_timeout = float(defaults.get("request_timeout_sec", 120))
print(f"API: {base_url}")
print(f"用例文件: {path}")
headers = _login(base_url, username, password, req_timeout)
if not headers:
return 3
cases: List[Dict[str, Any]] = spec.get("cases") or []
if not cases:
print("[FAIL] cases 为空")
return 4
ok, skip, fail = 0, 0, 0
for case in cases:
if case.get("enabled") is False:
print(f"\n[SKIP] {case.get('id', '?')}")
skip += 1
continue
if _run_one_case(base_url, headers, defaults, case):
ok += 1
else:
fail += 1
print("\n" + "=" * 60)
print(f"汇总: 通过 {ok} 失败 {fail} 跳过 {skip}")
return 0 if fail == 0 else 1
if __name__ == "__main__":
sys.exit(main())