- 前端 WorkflowEditor/ModelConfigs/NodeTemplates:deepseek-v4-flash、v4-pro,弃用提示 - llm_service 默认 deepseek-v4-flash;workflow_engine 等与模型配置注入 - 作业管理脚本支持 AGENT_NAME 与 v4-pro;新增 compare_homework_agents 脚本 - 文档重命名为 (红头)项目核心文档汇总.md 并更新 DeepSeek 说明 Made-with: Cursor
157 lines
4.7 KiB
Python
157 lines
4.7 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
对比「学生作业管理助手」与「学生作业管理助手2号」在相同输入下的节点测试结果。
|
||
调用 POST /api/v1/nodes/test(同步),需后端运行且 DEEPSEEK_API_KEY 有效。
|
||
|
||
用法: cd backend && .\\venv\\Scripts\\python.exe scripts/compare_homework_agents.py
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
import sys
|
||
import time
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import requests
|
||
|
||
BACKEND = os.getenv("PLATFORM_BASE_URL", "http://127.0.0.1:8037").rstrip("/")
|
||
USER = os.getenv("PLATFORM_USERNAME", "admin")
|
||
PWD = os.getenv("PLATFORM_PASSWORD", "123456")
|
||
|
||
NAMES = ["学生作业管理助手", "学生作业管理助手2号"]
|
||
|
||
TEST_QUERY = (
|
||
"帮我记一项作业:语文摘抄名著段落3处并批注,截止周五下午5点前。"
|
||
"请只回复一条简短清单(科目、要点、截止时间),不要超过120字。"
|
||
)
|
||
|
||
|
||
def _login() -> str:
|
||
r = requests.post(
|
||
f"{BACKEND}/api/v1/auth/login",
|
||
data={"username": USER, "password": PWD},
|
||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||
timeout=15,
|
||
)
|
||
r.raise_for_status()
|
||
t = r.json().get("access_token")
|
||
if not t:
|
||
raise RuntimeError("无 access_token")
|
||
return t
|
||
|
||
|
||
def _find_llm_node(wf: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||
for n in wf.get("nodes") or []:
|
||
if n.get("type") == "llm":
|
||
return n
|
||
return None
|
||
|
||
|
||
def _get_agents(h: Dict[str, str]) -> List[Dict[str, Any]]:
|
||
r = requests.get(
|
||
f"{BACKEND}/api/v1/agents",
|
||
params={"search": "学生作业管理", "limit": 50},
|
||
headers=h,
|
||
timeout=30,
|
||
)
|
||
r.raise_for_status()
|
||
return list(r.json() or [])
|
||
|
||
|
||
def _fetch_agent_detail(h: Dict[str, str], aid: str) -> Dict[str, Any]:
|
||
r = requests.get(f"{BACKEND}/api/v1/agents/{aid}", headers=h, timeout=30)
|
||
r.raise_for_status()
|
||
return r.json()
|
||
|
||
|
||
def _test_node(h: Dict[str, str], node: Dict[str, Any]):
|
||
body = {"node": node, "input_data": {"query": TEST_QUERY}}
|
||
t0 = time.perf_counter()
|
||
r = requests.post(
|
||
f"{BACKEND}/api/v1/nodes/test",
|
||
headers=h,
|
||
json=body,
|
||
timeout=240,
|
||
)
|
||
elapsed_ms = int((time.perf_counter() - t0) * 1000)
|
||
if r.status_code != 200:
|
||
return elapsed_ms, "", f"HTTP {r.status_code}: {r.text[:800]}", None
|
||
data = r.json()
|
||
out = data.get("output")
|
||
if isinstance(out, dict):
|
||
text = out.get("output") or out.get("text") or json.dumps(out, ensure_ascii=False)
|
||
else:
|
||
text = str(out)
|
||
err = data.get("error_message")
|
||
status = data.get("status")
|
||
return elapsed_ms, text[:4000], err, status
|
||
|
||
|
||
def main() -> int:
|
||
if hasattr(sys.stdout, "reconfigure"):
|
||
try:
|
||
sys.stdout.reconfigure(encoding="utf-8")
|
||
except Exception:
|
||
pass
|
||
try:
|
||
token = _login()
|
||
except Exception as e:
|
||
print("登录失败:", e, file=sys.stderr)
|
||
return 1
|
||
h = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||
|
||
agents = _get_agents(h)
|
||
by_name = {a.get("name"): a for a in agents}
|
||
|
||
results = []
|
||
for name in NAMES:
|
||
a = by_name.get(name)
|
||
if not a:
|
||
results.append(
|
||
{"name": name, "error": f"未找到名为「{name}」的 Agent"}
|
||
)
|
||
continue
|
||
detail = _fetch_agent_detail(h, a["id"])
|
||
wf = detail.get("workflow_config") or {}
|
||
node = _find_llm_node(wf)
|
||
if not node:
|
||
results.append({"name": name, "id": a["id"], "error": "工作流中无 LLM 节点"})
|
||
continue
|
||
data = node.get("data") or {}
|
||
elapsed_ms, text, err, st = _test_node(h, node)
|
||
results.append(
|
||
{
|
||
"name": name,
|
||
"id": a["id"],
|
||
"provider": data.get("provider"),
|
||
"model": data.get("model"),
|
||
"elapsed_ms": elapsed_ms,
|
||
"status": st,
|
||
"api_error": err,
|
||
"output_excerpt": text[:2000],
|
||
}
|
||
)
|
||
|
||
print("=== 对比测试(同步节点测试 API)===")
|
||
print("输入:", TEST_QUERY)
|
||
print()
|
||
for r in results:
|
||
print(f"【{r['name']}】")
|
||
if r.get("error"):
|
||
print(" ", r["error"])
|
||
print()
|
||
continue
|
||
print(f" id: {r['id']}")
|
||
print(f" provider/model: {r.get('provider')} / {r.get('model')}")
|
||
print(f" 耗时: {r['elapsed_ms']} ms status: {r.get('status')}")
|
||
if r.get("api_error"):
|
||
print(f" error_message: {r['api_error']}")
|
||
print(f" 输出节选:\n{r.get('output_excerpt', '')}\n")
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|