157 lines
4.7 KiB
Python
157 lines
4.7 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
对比「学生作业管理助手」与「学生作业管理助手2号」在相同输入下的节点测试结果。
|
|||
|
|
调用 POST /api/v1/nodes/test(同步),需后端运行且 DEEPSEEK_API_KEY 有效。
|
|||
|
|
|
|||
|
|
用法: cd backend && .\\venv\\Scripts\\python.exe scripts/compare_homework_agents.py
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import time
|
|||
|
|
from typing import Any, Dict, List, Optional
|
|||
|
|
|
|||
|
|
import requests
|
|||
|
|
|
|||
|
|
BACKEND = os.getenv("PLATFORM_BASE_URL", "http://127.0.0.1:8037").rstrip("/")
|
|||
|
|
USER = os.getenv("PLATFORM_USERNAME", "admin")
|
|||
|
|
PWD = os.getenv("PLATFORM_PASSWORD", "123456")
|
|||
|
|
|
|||
|
|
NAMES = ["学生作业管理助手", "学生作业管理助手2号"]
|
|||
|
|
|
|||
|
|
TEST_QUERY = (
|
|||
|
|
"帮我记一项作业:语文摘抄名著段落3处并批注,截止周五下午5点前。"
|
|||
|
|
"请只回复一条简短清单(科目、要点、截止时间),不要超过120字。"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _login() -> str:
|
|||
|
|
r = requests.post(
|
|||
|
|
f"{BACKEND}/api/v1/auth/login",
|
|||
|
|
data={"username": USER, "password": PWD},
|
|||
|
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|||
|
|
timeout=15,
|
|||
|
|
)
|
|||
|
|
r.raise_for_status()
|
|||
|
|
t = r.json().get("access_token")
|
|||
|
|
if not t:
|
|||
|
|
raise RuntimeError("无 access_token")
|
|||
|
|
return t
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _find_llm_node(wf: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|||
|
|
for n in wf.get("nodes") or []:
|
|||
|
|
if n.get("type") == "llm":
|
|||
|
|
return n
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _get_agents(h: Dict[str, str]) -> List[Dict[str, Any]]:
|
|||
|
|
r = requests.get(
|
|||
|
|
f"{BACKEND}/api/v1/agents",
|
|||
|
|
params={"search": "学生作业管理", "limit": 50},
|
|||
|
|
headers=h,
|
|||
|
|
timeout=30,
|
|||
|
|
)
|
|||
|
|
r.raise_for_status()
|
|||
|
|
return list(r.json() or [])
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _fetch_agent_detail(h: Dict[str, str], aid: str) -> Dict[str, Any]:
|
|||
|
|
r = requests.get(f"{BACKEND}/api/v1/agents/{aid}", headers=h, timeout=30)
|
|||
|
|
r.raise_for_status()
|
|||
|
|
return r.json()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _test_node(h: Dict[str, str], node: Dict[str, Any]):
|
|||
|
|
body = {"node": node, "input_data": {"query": TEST_QUERY}}
|
|||
|
|
t0 = time.perf_counter()
|
|||
|
|
r = requests.post(
|
|||
|
|
f"{BACKEND}/api/v1/nodes/test",
|
|||
|
|
headers=h,
|
|||
|
|
json=body,
|
|||
|
|
timeout=240,
|
|||
|
|
)
|
|||
|
|
elapsed_ms = int((time.perf_counter() - t0) * 1000)
|
|||
|
|
if r.status_code != 200:
|
|||
|
|
return elapsed_ms, "", f"HTTP {r.status_code}: {r.text[:800]}", None
|
|||
|
|
data = r.json()
|
|||
|
|
out = data.get("output")
|
|||
|
|
if isinstance(out, dict):
|
|||
|
|
text = out.get("output") or out.get("text") or json.dumps(out, ensure_ascii=False)
|
|||
|
|
else:
|
|||
|
|
text = str(out)
|
|||
|
|
err = data.get("error_message")
|
|||
|
|
status = data.get("status")
|
|||
|
|
return elapsed_ms, text[:4000], err, status
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main() -> int:
|
|||
|
|
if hasattr(sys.stdout, "reconfigure"):
|
|||
|
|
try:
|
|||
|
|
sys.stdout.reconfigure(encoding="utf-8")
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
try:
|
|||
|
|
token = _login()
|
|||
|
|
except Exception as e:
|
|||
|
|
print("登录失败:", e, file=sys.stderr)
|
|||
|
|
return 1
|
|||
|
|
h = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
|||
|
|
|
|||
|
|
agents = _get_agents(h)
|
|||
|
|
by_name = {a.get("name"): a for a in agents}
|
|||
|
|
|
|||
|
|
results = []
|
|||
|
|
for name in NAMES:
|
|||
|
|
a = by_name.get(name)
|
|||
|
|
if not a:
|
|||
|
|
results.append(
|
|||
|
|
{"name": name, "error": f"未找到名为「{name}」的 Agent"}
|
|||
|
|
)
|
|||
|
|
continue
|
|||
|
|
detail = _fetch_agent_detail(h, a["id"])
|
|||
|
|
wf = detail.get("workflow_config") or {}
|
|||
|
|
node = _find_llm_node(wf)
|
|||
|
|
if not node:
|
|||
|
|
results.append({"name": name, "id": a["id"], "error": "工作流中无 LLM 节点"})
|
|||
|
|
continue
|
|||
|
|
data = node.get("data") or {}
|
|||
|
|
elapsed_ms, text, err, st = _test_node(h, node)
|
|||
|
|
results.append(
|
|||
|
|
{
|
|||
|
|
"name": name,
|
|||
|
|
"id": a["id"],
|
|||
|
|
"provider": data.get("provider"),
|
|||
|
|
"model": data.get("model"),
|
|||
|
|
"elapsed_ms": elapsed_ms,
|
|||
|
|
"status": st,
|
|||
|
|
"api_error": err,
|
|||
|
|
"output_excerpt": text[:2000],
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print("=== 对比测试(同步节点测试 API)===")
|
|||
|
|
print("输入:", TEST_QUERY)
|
|||
|
|
print()
|
|||
|
|
for r in results:
|
|||
|
|
print(f"【{r['name']}】")
|
|||
|
|
if r.get("error"):
|
|||
|
|
print(" ", r["error"])
|
|||
|
|
print()
|
|||
|
|
continue
|
|||
|
|
print(f" id: {r['id']}")
|
|||
|
|
print(f" provider/model: {r.get('provider')} / {r.get('model')}")
|
|||
|
|
print(f" 耗时: {r['elapsed_ms']} ms status: {r.get('status')}")
|
|||
|
|
if r.get("api_error"):
|
|||
|
|
print(f" error_message: {r['api_error']}")
|
|||
|
|
print(f" 输出节选:\n{r.get('output_excerpt', '')}\n")
|
|||
|
|
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
raise SystemExit(main())
|