280 lines
8.8 KiB
Python
280 lines
8.8 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
读取 agent_test_cases.json(或 --cases 指定),批量执行 Agent 并做简单断言。
|
|||
|
|
规范见:(红头)agent测试用例文档.md
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import time
|
|||
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|||
|
|
|
|||
|
|
import requests
|
|||
|
|
|
|||
|
|
DEFAULT_CASES_FILE = "agent_test_cases.json"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _ensure_utf8_stdio() -> None:
|
|||
|
|
if sys.platform != "win32":
|
|||
|
|
return
|
|||
|
|
for name in ("stdout", "stderr"):
|
|||
|
|
stream = getattr(sys, name, None)
|
|||
|
|
if stream is not None and hasattr(stream, "reconfigure"):
|
|||
|
|
try:
|
|||
|
|
stream.reconfigure(encoding="utf-8", errors="replace")
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
_ensure_utf8_stdio()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _login(base_url: str, username: str, password: str, timeout: float) -> Optional[Dict[str, str]]:
|
|||
|
|
r = requests.post(
|
|||
|
|
f"{base_url}/api/v1/auth/login",
|
|||
|
|
data={"username": username, "password": password},
|
|||
|
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|||
|
|
timeout=timeout,
|
|||
|
|
)
|
|||
|
|
if r.status_code != 200:
|
|||
|
|
print(f"[FAIL] 登录 {r.status_code}: {r.text[:500]}")
|
|||
|
|
return None
|
|||
|
|
token = r.json().get("access_token")
|
|||
|
|
if not token:
|
|||
|
|
print("[FAIL] 登录响应无 access_token")
|
|||
|
|
return None
|
|||
|
|
return {"Authorization": f"Bearer {token}"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _resolve_agent_id(
|
|||
|
|
base_url: str,
|
|||
|
|
headers: Dict[str, str],
|
|||
|
|
agent: Dict[str, Any],
|
|||
|
|
timeout: float,
|
|||
|
|
) -> Optional[str]:
|
|||
|
|
if agent.get("id"):
|
|||
|
|
return str(agent["id"])
|
|||
|
|
name = agent.get("name")
|
|||
|
|
if not name:
|
|||
|
|
print("[FAIL] agent 需包含 id 或 name")
|
|||
|
|
return None
|
|||
|
|
r = requests.get(
|
|||
|
|
f"{base_url}/api/v1/agents",
|
|||
|
|
headers=headers,
|
|||
|
|
params={"search": name, "limit": 100},
|
|||
|
|
timeout=timeout,
|
|||
|
|
)
|
|||
|
|
if r.status_code != 200:
|
|||
|
|
print(f"[FAIL] 查找 Agent {r.status_code}: {r.text[:500]}")
|
|||
|
|
return None
|
|||
|
|
agents: List[Dict[str, Any]] = r.json() or []
|
|||
|
|
exact = [a for a in agents if (a.get("name") or "").strip() == name]
|
|||
|
|
pick = exact[0] if exact else (agents[0] if agents else None)
|
|||
|
|
if not pick:
|
|||
|
|
print(f"[FAIL] 未找到 Agent: {name}")
|
|||
|
|
return None
|
|||
|
|
print(f"[OK] Agent: {pick.get('name')} ({pick['id']}) status={pick.get('status')}")
|
|||
|
|
return str(pick["id"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_output_text(output_data: Any) -> str:
|
|||
|
|
if output_data is None:
|
|||
|
|
return ""
|
|||
|
|
if isinstance(output_data, str):
|
|||
|
|
return output_data
|
|||
|
|
if isinstance(output_data, dict):
|
|||
|
|
for key in ("result", "output", "text", "content"):
|
|||
|
|
v = output_data.get(key)
|
|||
|
|
if v is not None:
|
|||
|
|
return v if isinstance(v, str) else str(v)
|
|||
|
|
return json.dumps(output_data, ensure_ascii=False)
|
|||
|
|
return str(output_data)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _poll_until_terminal(
|
|||
|
|
base_url: str,
|
|||
|
|
headers: Dict[str, str],
|
|||
|
|
execution_id: str,
|
|||
|
|
max_wait: float,
|
|||
|
|
poll_interval: float,
|
|||
|
|
timeout: float,
|
|||
|
|
) -> Tuple[str, Optional[Dict[str, Any]]]:
|
|||
|
|
deadline = time.time() + max_wait
|
|||
|
|
last_status = "unknown"
|
|||
|
|
while time.time() < deadline:
|
|||
|
|
sr = requests.get(
|
|||
|
|
f"{base_url}/api/v1/executions/{execution_id}/status",
|
|||
|
|
headers=headers,
|
|||
|
|
timeout=timeout,
|
|||
|
|
)
|
|||
|
|
if sr.status_code != 200:
|
|||
|
|
print(f"[WARN] status {sr.status_code}: {sr.text[:300]}")
|
|||
|
|
time.sleep(poll_interval)
|
|||
|
|
continue
|
|||
|
|
body = sr.json()
|
|||
|
|
last_status = str(body.get("status") or "")
|
|||
|
|
if last_status in ("completed", "failed", "cancelled", "awaiting_approval"):
|
|||
|
|
break
|
|||
|
|
time.sleep(poll_interval)
|
|||
|
|
dr = requests.get(
|
|||
|
|
f"{base_url}/api/v1/executions/{execution_id}",
|
|||
|
|
headers=headers,
|
|||
|
|
timeout=timeout,
|
|||
|
|
)
|
|||
|
|
if dr.status_code != 200:
|
|||
|
|
print(f"[FAIL] 获取执行详情 {dr.status_code}: {dr.text[:500]}")
|
|||
|
|
return last_status, None
|
|||
|
|
return last_status, dr.json()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _check_expect(text: str, status: str, detail: Optional[Dict[str, Any]], expect: Dict[str, Any]) -> List[str]:
|
|||
|
|
errors: List[str] = []
|
|||
|
|
want_status = expect.get("status", "completed")
|
|||
|
|
if status != want_status:
|
|||
|
|
errors.append(f"状态期望 {want_status!r},实际 {status!r}")
|
|||
|
|
if detail and status != "completed":
|
|||
|
|
em = detail.get("error_message")
|
|||
|
|
if em:
|
|||
|
|
errors.append(f"error_message: {em[:500]}")
|
|||
|
|
if not expect:
|
|||
|
|
return errors
|
|||
|
|
ci = bool(expect.get("case_insensitive"))
|
|||
|
|
hay = text if not ci else text.lower()
|
|||
|
|
for sub in expect.get("output_contains") or []:
|
|||
|
|
s = sub if not ci else sub.lower()
|
|||
|
|
if s not in hay:
|
|||
|
|
errors.append(f"输出应包含 {sub!r}")
|
|||
|
|
for sub in expect.get("output_not_contains") or []:
|
|||
|
|
s = sub if not ci else sub.lower()
|
|||
|
|
if s in hay:
|
|||
|
|
errors.append(f"输出不应包含 {sub!r}")
|
|||
|
|
return errors
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _run_one_case(
|
|||
|
|
base_url: str,
|
|||
|
|
headers: Dict[str, str],
|
|||
|
|
defaults: Dict[str, Any],
|
|||
|
|
case: Dict[str, Any],
|
|||
|
|
) -> bool:
|
|||
|
|
cid = case.get("id", "(no-id)")
|
|||
|
|
title = case.get("name", "")
|
|||
|
|
print("\n" + "-" * 60)
|
|||
|
|
print(f"CASE {cid}" + (f" — {title}" if title else ""))
|
|||
|
|
|
|||
|
|
req_timeout = float(case.get("request_timeout_sec", defaults.get("request_timeout_sec", 120)))
|
|||
|
|
max_wait = float(case.get("max_wait_sec", defaults.get("max_wait_sec", 300)))
|
|||
|
|
poll_iv = float(case.get("poll_interval_sec", defaults.get("poll_interval_sec", 2)))
|
|||
|
|
|
|||
|
|
agent_id = _resolve_agent_id(base_url, headers, case.get("agent") or {}, req_timeout)
|
|||
|
|
if not agent_id:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
message = case.get("message")
|
|||
|
|
if message is None:
|
|||
|
|
print("[FAIL] 缺少 message")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
input_data: Dict[str, Any] = {"query": message, "USER_INPUT": message}
|
|||
|
|
extra = case.get("input_extra")
|
|||
|
|
if isinstance(extra, dict):
|
|||
|
|
input_data = {**extra, **input_data}
|
|||
|
|
|
|||
|
|
er = requests.post(
|
|||
|
|
f"{base_url}/api/v1/executions",
|
|||
|
|
headers=headers,
|
|||
|
|
json={"agent_id": agent_id, "input_data": input_data},
|
|||
|
|
timeout=req_timeout,
|
|||
|
|
)
|
|||
|
|
if er.status_code != 201:
|
|||
|
|
print(f"[FAIL] 创建执行 {er.status_code}: {er.text[:800]}")
|
|||
|
|
return False
|
|||
|
|
ex = er.json()
|
|||
|
|
eid = ex["id"]
|
|||
|
|
print(f"[OK] execution_id={eid}")
|
|||
|
|
|
|||
|
|
st, detail = _poll_until_terminal(base_url, headers, eid, max_wait, poll_iv, req_timeout)
|
|||
|
|
text = _extract_output_text((detail or {}).get("output_data"))
|
|||
|
|
expect = case.get("expect") or {}
|
|||
|
|
errs = _check_expect(text, st, detail, expect)
|
|||
|
|
if errs:
|
|||
|
|
for e in errs:
|
|||
|
|
print(f"[FAIL] {e}")
|
|||
|
|
if text:
|
|||
|
|
print("[OUTPUT_PREVIEW]")
|
|||
|
|
print(text[:2000] + ("…" if len(text) > 2000 else ""))
|
|||
|
|
return False
|
|||
|
|
print(f"[OK] 通过 status={st}")
|
|||
|
|
if text:
|
|||
|
|
print("[OUTPUT_PREVIEW]")
|
|||
|
|
print(text[:1200] + ("…" if len(text) > 1200 else ""))
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main() -> int:
|
|||
|
|
ap = argparse.ArgumentParser(description="批量运行 Agent 测试用例(JSON)")
|
|||
|
|
ap.add_argument(
|
|||
|
|
"--cases",
|
|||
|
|
default=os.environ.get("AGENT_TEST_CASES", DEFAULT_CASES_FILE),
|
|||
|
|
help=f"用例 JSON 路径(默认 {DEFAULT_CASES_FILE})",
|
|||
|
|
)
|
|||
|
|
ap.add_argument("--username", default=None)
|
|||
|
|
ap.add_argument("--password", default=None)
|
|||
|
|
ap.add_argument("--base-url", default=None, help="覆盖 defaults.base_url / API_BASE_URL")
|
|||
|
|
args = ap.parse_args()
|
|||
|
|
|
|||
|
|
path = args.cases
|
|||
|
|
if not os.path.isfile(path):
|
|||
|
|
print(f"[FAIL] 找不到用例文件: {path}")
|
|||
|
|
print("请先按 (红头)agent测试用例文档.md 创建 JSON,或复制示例为 agent_test_cases.json")
|
|||
|
|
return 2
|
|||
|
|
|
|||
|
|
with open(path, encoding="utf-8") as f:
|
|||
|
|
spec = json.load(f)
|
|||
|
|
|
|||
|
|
defaults = spec.get("defaults") or {}
|
|||
|
|
base_url = (
|
|||
|
|
args.base_url
|
|||
|
|
or defaults.get("base_url")
|
|||
|
|
or os.environ.get("API_BASE_URL", "http://localhost:8037")
|
|||
|
|
)
|
|||
|
|
base_url = base_url.rstrip("/")
|
|||
|
|
|
|||
|
|
username = args.username or defaults.get("username", "admin")
|
|||
|
|
password = args.password or defaults.get("password", "123456")
|
|||
|
|
req_timeout = float(defaults.get("request_timeout_sec", 120))
|
|||
|
|
|
|||
|
|
print(f"API: {base_url}")
|
|||
|
|
print(f"用例文件: {path}")
|
|||
|
|
|
|||
|
|
headers = _login(base_url, username, password, req_timeout)
|
|||
|
|
if not headers:
|
|||
|
|
return 3
|
|||
|
|
|
|||
|
|
cases: List[Dict[str, Any]] = spec.get("cases") or []
|
|||
|
|
if not cases:
|
|||
|
|
print("[FAIL] cases 为空")
|
|||
|
|
return 4
|
|||
|
|
|
|||
|
|
ok, skip, fail = 0, 0, 0
|
|||
|
|
for case in cases:
|
|||
|
|
if case.get("enabled") is False:
|
|||
|
|
print(f"\n[SKIP] {case.get('id', '?')}")
|
|||
|
|
skip += 1
|
|||
|
|
continue
|
|||
|
|
if _run_one_case(base_url, headers, defaults, case):
|
|||
|
|
ok += 1
|
|||
|
|
else:
|
|||
|
|
fail += 1
|
|||
|
|
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print(f"汇总: 通过 {ok} 失败 {fail} 跳过 {skip}")
|
|||
|
|
return 0 if fail == 0 else 1
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
sys.exit(main())
|