- 扩展 test_agent_execution(--homework、UTF-8 控制台) - 后端:uploads 预览、file_read、工作流与对话落盘等 - 前端:AgentChatPreview 与设计器相关调整 - 忽略 redis二进制、agent_workspaces、uploads、tessdata 等本机产物 Made-with: Cursor
280 lines
8.8 KiB
Python
280 lines
8.8 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
读取 agent_test_cases.json(或 --cases 指定),批量执行 Agent 并做简单断言。
|
||
规范见:(红头)agent测试用例文档.md
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import sys
|
||
import time
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
import requests
|
||
|
||
DEFAULT_CASES_FILE = "agent_test_cases.json"
|
||
|
||
|
||
def _ensure_utf8_stdio() -> None:
|
||
if sys.platform != "win32":
|
||
return
|
||
for name in ("stdout", "stderr"):
|
||
stream = getattr(sys, name, None)
|
||
if stream is not None and hasattr(stream, "reconfigure"):
|
||
try:
|
||
stream.reconfigure(encoding="utf-8", errors="replace")
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
_ensure_utf8_stdio()
|
||
|
||
|
||
def _login(base_url: str, username: str, password: str, timeout: float) -> Optional[Dict[str, str]]:
|
||
r = requests.post(
|
||
f"{base_url}/api/v1/auth/login",
|
||
data={"username": username, "password": password},
|
||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||
timeout=timeout,
|
||
)
|
||
if r.status_code != 200:
|
||
print(f"[FAIL] 登录 {r.status_code}: {r.text[:500]}")
|
||
return None
|
||
token = r.json().get("access_token")
|
||
if not token:
|
||
print("[FAIL] 登录响应无 access_token")
|
||
return None
|
||
return {"Authorization": f"Bearer {token}"}
|
||
|
||
|
||
def _resolve_agent_id(
|
||
base_url: str,
|
||
headers: Dict[str, str],
|
||
agent: Dict[str, Any],
|
||
timeout: float,
|
||
) -> Optional[str]:
|
||
if agent.get("id"):
|
||
return str(agent["id"])
|
||
name = agent.get("name")
|
||
if not name:
|
||
print("[FAIL] agent 需包含 id 或 name")
|
||
return None
|
||
r = requests.get(
|
||
f"{base_url}/api/v1/agents",
|
||
headers=headers,
|
||
params={"search": name, "limit": 100},
|
||
timeout=timeout,
|
||
)
|
||
if r.status_code != 200:
|
||
print(f"[FAIL] 查找 Agent {r.status_code}: {r.text[:500]}")
|
||
return None
|
||
agents: List[Dict[str, Any]] = r.json() or []
|
||
exact = [a for a in agents if (a.get("name") or "").strip() == name]
|
||
pick = exact[0] if exact else (agents[0] if agents else None)
|
||
if not pick:
|
||
print(f"[FAIL] 未找到 Agent: {name}")
|
||
return None
|
||
print(f"[OK] Agent: {pick.get('name')} ({pick['id']}) status={pick.get('status')}")
|
||
return str(pick["id"])
|
||
|
||
|
||
def _extract_output_text(output_data: Any) -> str:
|
||
if output_data is None:
|
||
return ""
|
||
if isinstance(output_data, str):
|
||
return output_data
|
||
if isinstance(output_data, dict):
|
||
for key in ("result", "output", "text", "content"):
|
||
v = output_data.get(key)
|
||
if v is not None:
|
||
return v if isinstance(v, str) else str(v)
|
||
return json.dumps(output_data, ensure_ascii=False)
|
||
return str(output_data)
|
||
|
||
|
||
def _poll_until_terminal(
|
||
base_url: str,
|
||
headers: Dict[str, str],
|
||
execution_id: str,
|
||
max_wait: float,
|
||
poll_interval: float,
|
||
timeout: float,
|
||
) -> Tuple[str, Optional[Dict[str, Any]]]:
|
||
deadline = time.time() + max_wait
|
||
last_status = "unknown"
|
||
while time.time() < deadline:
|
||
sr = requests.get(
|
||
f"{base_url}/api/v1/executions/{execution_id}/status",
|
||
headers=headers,
|
||
timeout=timeout,
|
||
)
|
||
if sr.status_code != 200:
|
||
print(f"[WARN] status {sr.status_code}: {sr.text[:300]}")
|
||
time.sleep(poll_interval)
|
||
continue
|
||
body = sr.json()
|
||
last_status = str(body.get("status") or "")
|
||
if last_status in ("completed", "failed", "cancelled", "awaiting_approval"):
|
||
break
|
||
time.sleep(poll_interval)
|
||
dr = requests.get(
|
||
f"{base_url}/api/v1/executions/{execution_id}",
|
||
headers=headers,
|
||
timeout=timeout,
|
||
)
|
||
if dr.status_code != 200:
|
||
print(f"[FAIL] 获取执行详情 {dr.status_code}: {dr.text[:500]}")
|
||
return last_status, None
|
||
return last_status, dr.json()
|
||
|
||
|
||
def _check_expect(text: str, status: str, detail: Optional[Dict[str, Any]], expect: Dict[str, Any]) -> List[str]:
|
||
errors: List[str] = []
|
||
want_status = expect.get("status", "completed")
|
||
if status != want_status:
|
||
errors.append(f"状态期望 {want_status!r},实际 {status!r}")
|
||
if detail and status != "completed":
|
||
em = detail.get("error_message")
|
||
if em:
|
||
errors.append(f"error_message: {em[:500]}")
|
||
if not expect:
|
||
return errors
|
||
ci = bool(expect.get("case_insensitive"))
|
||
hay = text if not ci else text.lower()
|
||
for sub in expect.get("output_contains") or []:
|
||
s = sub if not ci else sub.lower()
|
||
if s not in hay:
|
||
errors.append(f"输出应包含 {sub!r}")
|
||
for sub in expect.get("output_not_contains") or []:
|
||
s = sub if not ci else sub.lower()
|
||
if s in hay:
|
||
errors.append(f"输出不应包含 {sub!r}")
|
||
return errors
|
||
|
||
|
||
def _run_one_case(
|
||
base_url: str,
|
||
headers: Dict[str, str],
|
||
defaults: Dict[str, Any],
|
||
case: Dict[str, Any],
|
||
) -> bool:
|
||
cid = case.get("id", "(no-id)")
|
||
title = case.get("name", "")
|
||
print("\n" + "-" * 60)
|
||
print(f"CASE {cid}" + (f" — {title}" if title else ""))
|
||
|
||
req_timeout = float(case.get("request_timeout_sec", defaults.get("request_timeout_sec", 120)))
|
||
max_wait = float(case.get("max_wait_sec", defaults.get("max_wait_sec", 300)))
|
||
poll_iv = float(case.get("poll_interval_sec", defaults.get("poll_interval_sec", 2)))
|
||
|
||
agent_id = _resolve_agent_id(base_url, headers, case.get("agent") or {}, req_timeout)
|
||
if not agent_id:
|
||
return False
|
||
|
||
message = case.get("message")
|
||
if message is None:
|
||
print("[FAIL] 缺少 message")
|
||
return False
|
||
|
||
input_data: Dict[str, Any] = {"query": message, "USER_INPUT": message}
|
||
extra = case.get("input_extra")
|
||
if isinstance(extra, dict):
|
||
input_data = {**extra, **input_data}
|
||
|
||
er = requests.post(
|
||
f"{base_url}/api/v1/executions",
|
||
headers=headers,
|
||
json={"agent_id": agent_id, "input_data": input_data},
|
||
timeout=req_timeout,
|
||
)
|
||
if er.status_code != 201:
|
||
print(f"[FAIL] 创建执行 {er.status_code}: {er.text[:800]}")
|
||
return False
|
||
ex = er.json()
|
||
eid = ex["id"]
|
||
print(f"[OK] execution_id={eid}")
|
||
|
||
st, detail = _poll_until_terminal(base_url, headers, eid, max_wait, poll_iv, req_timeout)
|
||
text = _extract_output_text((detail or {}).get("output_data"))
|
||
expect = case.get("expect") or {}
|
||
errs = _check_expect(text, st, detail, expect)
|
||
if errs:
|
||
for e in errs:
|
||
print(f"[FAIL] {e}")
|
||
if text:
|
||
print("[OUTPUT_PREVIEW]")
|
||
print(text[:2000] + ("…" if len(text) > 2000 else ""))
|
||
return False
|
||
print(f"[OK] 通过 status={st}")
|
||
if text:
|
||
print("[OUTPUT_PREVIEW]")
|
||
print(text[:1200] + ("…" if len(text) > 1200 else ""))
|
||
return True
|
||
|
||
|
||
def main() -> int:
|
||
ap = argparse.ArgumentParser(description="批量运行 Agent 测试用例(JSON)")
|
||
ap.add_argument(
|
||
"--cases",
|
||
default=os.environ.get("AGENT_TEST_CASES", DEFAULT_CASES_FILE),
|
||
help=f"用例 JSON 路径(默认 {DEFAULT_CASES_FILE})",
|
||
)
|
||
ap.add_argument("--username", default=None)
|
||
ap.add_argument("--password", default=None)
|
||
ap.add_argument("--base-url", default=None, help="覆盖 defaults.base_url / API_BASE_URL")
|
||
args = ap.parse_args()
|
||
|
||
path = args.cases
|
||
if not os.path.isfile(path):
|
||
print(f"[FAIL] 找不到用例文件: {path}")
|
||
print("请先按 (红头)agent测试用例文档.md 创建 JSON,或复制示例为 agent_test_cases.json")
|
||
return 2
|
||
|
||
with open(path, encoding="utf-8") as f:
|
||
spec = json.load(f)
|
||
|
||
defaults = spec.get("defaults") or {}
|
||
base_url = (
|
||
args.base_url
|
||
or defaults.get("base_url")
|
||
or os.environ.get("API_BASE_URL", "http://localhost:8037")
|
||
)
|
||
base_url = base_url.rstrip("/")
|
||
|
||
username = args.username or defaults.get("username", "admin")
|
||
password = args.password or defaults.get("password", "123456")
|
||
req_timeout = float(defaults.get("request_timeout_sec", 120))
|
||
|
||
print(f"API: {base_url}")
|
||
print(f"用例文件: {path}")
|
||
|
||
headers = _login(base_url, username, password, req_timeout)
|
||
if not headers:
|
||
return 3
|
||
|
||
cases: List[Dict[str, Any]] = spec.get("cases") or []
|
||
if not cases:
|
||
print("[FAIL] cases 为空")
|
||
return 4
|
||
|
||
ok, skip, fail = 0, 0, 0
|
||
for case in cases:
|
||
if case.get("enabled") is False:
|
||
print(f"\n[SKIP] {case.get('id', '?')}")
|
||
skip += 1
|
||
continue
|
||
if _run_one_case(base_url, headers, defaults, case):
|
||
ok += 1
|
||
else:
|
||
fail += 1
|
||
|
||
print("\n" + "=" * 60)
|
||
print(f"汇总: 通过 {ok} 失败 {fail} 跳过 {skip}")
|
||
return 0 if fail == 0 else 1
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|