Files
aiagent/run_agent_test_cases.py
renjianbo df4fab1e6e feat: Agent 批量测试、作业助手与上传预览;Windows 启动脚本与文档- 新增 run_agent_test_cases 与示例 JSON、(红头)agent测试用例文档
- 扩展 test_agent_execution(--homework、UTF-8 控制台)
- 后端:uploads 预览、file_read、工作流与对话落盘等
- 前端:AgentChatPreview 与设计器相关调整
- 忽略 redis二进制、agent_workspaces、uploads、tessdata 等本机产物

Made-with: Cursor
2026-04-13 20:17:18 +08:00

280 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
读取 agent_test_cases.json或 --cases 指定),批量执行 Agent 并做简单断言。
规范见:(红头)agent测试用例文档.md
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from typing import Any, Dict, List, Optional, Tuple
import requests
DEFAULT_CASES_FILE = "agent_test_cases.json"
def _ensure_utf8_stdio() -> None:
if sys.platform != "win32":
return
for name in ("stdout", "stderr"):
stream = getattr(sys, name, None)
if stream is not None and hasattr(stream, "reconfigure"):
try:
stream.reconfigure(encoding="utf-8", errors="replace")
except Exception:
pass
_ensure_utf8_stdio()
def _login(base_url: str, username: str, password: str, timeout: float) -> Optional[Dict[str, str]]:
r = requests.post(
f"{base_url}/api/v1/auth/login",
data={"username": username, "password": password},
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=timeout,
)
if r.status_code != 200:
print(f"[FAIL] 登录 {r.status_code}: {r.text[:500]}")
return None
token = r.json().get("access_token")
if not token:
print("[FAIL] 登录响应无 access_token")
return None
return {"Authorization": f"Bearer {token}"}
def _resolve_agent_id(
base_url: str,
headers: Dict[str, str],
agent: Dict[str, Any],
timeout: float,
) -> Optional[str]:
if agent.get("id"):
return str(agent["id"])
name = agent.get("name")
if not name:
print("[FAIL] agent 需包含 id 或 name")
return None
r = requests.get(
f"{base_url}/api/v1/agents",
headers=headers,
params={"search": name, "limit": 100},
timeout=timeout,
)
if r.status_code != 200:
print(f"[FAIL] 查找 Agent {r.status_code}: {r.text[:500]}")
return None
agents: List[Dict[str, Any]] = r.json() or []
exact = [a for a in agents if (a.get("name") or "").strip() == name]
pick = exact[0] if exact else (agents[0] if agents else None)
if not pick:
print(f"[FAIL] 未找到 Agent: {name}")
return None
print(f"[OK] Agent: {pick.get('name')} ({pick['id']}) status={pick.get('status')}")
return str(pick["id"])
def _extract_output_text(output_data: Any) -> str:
if output_data is None:
return ""
if isinstance(output_data, str):
return output_data
if isinstance(output_data, dict):
for key in ("result", "output", "text", "content"):
v = output_data.get(key)
if v is not None:
return v if isinstance(v, str) else str(v)
return json.dumps(output_data, ensure_ascii=False)
return str(output_data)
def _poll_until_terminal(
base_url: str,
headers: Dict[str, str],
execution_id: str,
max_wait: float,
poll_interval: float,
timeout: float,
) -> Tuple[str, Optional[Dict[str, Any]]]:
deadline = time.time() + max_wait
last_status = "unknown"
while time.time() < deadline:
sr = requests.get(
f"{base_url}/api/v1/executions/{execution_id}/status",
headers=headers,
timeout=timeout,
)
if sr.status_code != 200:
print(f"[WARN] status {sr.status_code}: {sr.text[:300]}")
time.sleep(poll_interval)
continue
body = sr.json()
last_status = str(body.get("status") or "")
if last_status in ("completed", "failed", "cancelled", "awaiting_approval"):
break
time.sleep(poll_interval)
dr = requests.get(
f"{base_url}/api/v1/executions/{execution_id}",
headers=headers,
timeout=timeout,
)
if dr.status_code != 200:
print(f"[FAIL] 获取执行详情 {dr.status_code}: {dr.text[:500]}")
return last_status, None
return last_status, dr.json()
def _check_expect(text: str, status: str, detail: Optional[Dict[str, Any]], expect: Dict[str, Any]) -> List[str]:
errors: List[str] = []
want_status = expect.get("status", "completed")
if status != want_status:
errors.append(f"状态期望 {want_status!r},实际 {status!r}")
if detail and status != "completed":
em = detail.get("error_message")
if em:
errors.append(f"error_message: {em[:500]}")
if not expect:
return errors
ci = bool(expect.get("case_insensitive"))
hay = text if not ci else text.lower()
for sub in expect.get("output_contains") or []:
s = sub if not ci else sub.lower()
if s not in hay:
errors.append(f"输出应包含 {sub!r}")
for sub in expect.get("output_not_contains") or []:
s = sub if not ci else sub.lower()
if s in hay:
errors.append(f"输出不应包含 {sub!r}")
return errors
def _run_one_case(
base_url: str,
headers: Dict[str, str],
defaults: Dict[str, Any],
case: Dict[str, Any],
) -> bool:
cid = case.get("id", "(no-id)")
title = case.get("name", "")
print("\n" + "-" * 60)
print(f"CASE {cid}" + (f"{title}" if title else ""))
req_timeout = float(case.get("request_timeout_sec", defaults.get("request_timeout_sec", 120)))
max_wait = float(case.get("max_wait_sec", defaults.get("max_wait_sec", 300)))
poll_iv = float(case.get("poll_interval_sec", defaults.get("poll_interval_sec", 2)))
agent_id = _resolve_agent_id(base_url, headers, case.get("agent") or {}, req_timeout)
if not agent_id:
return False
message = case.get("message")
if message is None:
print("[FAIL] 缺少 message")
return False
input_data: Dict[str, Any] = {"query": message, "USER_INPUT": message}
extra = case.get("input_extra")
if isinstance(extra, dict):
input_data = {**extra, **input_data}
er = requests.post(
f"{base_url}/api/v1/executions",
headers=headers,
json={"agent_id": agent_id, "input_data": input_data},
timeout=req_timeout,
)
if er.status_code != 201:
print(f"[FAIL] 创建执行 {er.status_code}: {er.text[:800]}")
return False
ex = er.json()
eid = ex["id"]
print(f"[OK] execution_id={eid}")
st, detail = _poll_until_terminal(base_url, headers, eid, max_wait, poll_iv, req_timeout)
text = _extract_output_text((detail or {}).get("output_data"))
expect = case.get("expect") or {}
errs = _check_expect(text, st, detail, expect)
if errs:
for e in errs:
print(f"[FAIL] {e}")
if text:
print("[OUTPUT_PREVIEW]")
print(text[:2000] + ("" if len(text) > 2000 else ""))
return False
print(f"[OK] 通过 status={st}")
if text:
print("[OUTPUT_PREVIEW]")
print(text[:1200] + ("" if len(text) > 1200 else ""))
return True
def main() -> int:
ap = argparse.ArgumentParser(description="批量运行 Agent 测试用例JSON")
ap.add_argument(
"--cases",
default=os.environ.get("AGENT_TEST_CASES", DEFAULT_CASES_FILE),
help=f"用例 JSON 路径(默认 {DEFAULT_CASES_FILE}",
)
ap.add_argument("--username", default=None)
ap.add_argument("--password", default=None)
ap.add_argument("--base-url", default=None, help="覆盖 defaults.base_url / API_BASE_URL")
args = ap.parse_args()
path = args.cases
if not os.path.isfile(path):
print(f"[FAIL] 找不到用例文件: {path}")
print("请先按 (红头)agent测试用例文档.md 创建 JSON或复制示例为 agent_test_cases.json")
return 2
with open(path, encoding="utf-8") as f:
spec = json.load(f)
defaults = spec.get("defaults") or {}
base_url = (
args.base_url
or defaults.get("base_url")
or os.environ.get("API_BASE_URL", "http://localhost:8037")
)
base_url = base_url.rstrip("/")
username = args.username or defaults.get("username", "admin")
password = args.password or defaults.get("password", "123456")
req_timeout = float(defaults.get("request_timeout_sec", 120))
print(f"API: {base_url}")
print(f"用例文件: {path}")
headers = _login(base_url, username, password, req_timeout)
if not headers:
return 3
cases: List[Dict[str, Any]] = spec.get("cases") or []
if not cases:
print("[FAIL] cases 为空")
return 4
ok, skip, fail = 0, 0, 0
for case in cases:
if case.get("enabled") is False:
print(f"\n[SKIP] {case.get('id', '?')}")
skip += 1
continue
if _run_one_case(base_url, headers, defaults, case):
ok += 1
else:
fail += 1
print("\n" + "=" * 60)
print(f"汇总: 通过 {ok} 失败 {fail} 跳过 {skip}")
return 0 if fail == 0 else 1
if __name__ == "__main__":
sys.exit(main())