154 lines
5.1 KiB
Python
154 lines
5.1 KiB
Python
"""
|
||
知你客服11号 E2E:普通对话 + 要求拉取 URL(触发 http_request)。
|
||
需 API、Celery、LLM、外网可达测试 URL。
|
||
|
||
默认会先重启本机 Celery Worker(与 e2e_zhini7 一致),以加载含 code 节点 re/hashlib 注入的引擎。
|
||
跳过重启: 设置环境变量 E2E_RESTART_CELERY=0
|
||
|
||
用法: cd backend && .\\venv\\Scripts\\python.exe scripts/e2e_zhini11_test.py
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
import subprocess
|
||
import sys
|
||
import time
|
||
import uuid
|
||
from pathlib import Path
|
||
|
||
BACKEND_DIR = Path(__file__).resolve().parents[1]
|
||
VENV_PY = BACKEND_DIR / "venv" / "Scripts" / "python.exe"
|
||
API_BASE = os.environ.get("API_BASE", "http://127.0.0.1:8037")
|
||
AGENT_NAME = os.environ.get("E2E_AGENT_NAME", "知你客服11号")
|
||
# 小 JSON,适合测 GET
|
||
TEST_URL = os.environ.get(
|
||
"E2E_TEST_URL",
|
||
"https://jsonplaceholder.typicode.com/posts/1",
|
||
)
|
||
|
||
|
||
def _restart_celery() -> None:
|
||
ps = (
|
||
"Get-CimInstance Win32_Process | "
|
||
"Where-Object { $_.CommandLine -match 'celery_app' } | "
|
||
"ForEach-Object { Stop-Process -Id $_.ProcessId -Force -ErrorAction SilentlyContinue }"
|
||
)
|
||
subprocess.run(
|
||
["powershell", "-NoProfile", "-Command", ps],
|
||
cwd=str(BACKEND_DIR),
|
||
capture_output=True,
|
||
text=True,
|
||
)
|
||
time.sleep(2)
|
||
if not VENV_PY.is_file():
|
||
print("未找到 venv Python,跳过启动 Celery", file=sys.stderr)
|
||
return
|
||
popen_kw: dict = {
|
||
"cwd": str(BACKEND_DIR),
|
||
"stdout": subprocess.DEVNULL,
|
||
"stderr": subprocess.STDOUT,
|
||
}
|
||
if sys.platform == "win32":
|
||
popen_kw["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore[attr-defined]
|
||
subprocess.Popen(
|
||
[
|
||
str(VENV_PY),
|
||
"-m",
|
||
"celery",
|
||
"-A",
|
||
"app.core.celery_app",
|
||
"worker",
|
||
"--loglevel=info",
|
||
"--pool=threads",
|
||
"--concurrency=8",
|
||
],
|
||
**popen_kw,
|
||
)
|
||
print("已启动新 Celery Worker,等待就绪…")
|
||
time.sleep(4)
|
||
|
||
|
||
def main() -> int:
|
||
os.chdir(BACKEND_DIR)
|
||
sys.path.insert(0, str(BACKEND_DIR))
|
||
|
||
if os.environ.get("E2E_RESTART_CELERY", "1").strip().lower() not in ("0", "false", "no"):
|
||
_restart_celery()
|
||
|
||
import httpx
|
||
from app.core.database import SessionLocal
|
||
from app.core.security import create_access_token
|
||
from app.models.agent import Agent
|
||
from app.models.user import User
|
||
|
||
db = SessionLocal()
|
||
try:
|
||
agent = db.query(Agent).filter(Agent.name == AGENT_NAME).first()
|
||
if not agent:
|
||
print(f"未找到「{AGENT_NAME}」", file=sys.stderr)
|
||
return 1
|
||
owner = db.query(User).filter(User.id == agent.user_id).first()
|
||
user = owner or db.query(User).first()
|
||
if not user:
|
||
print("无用户", file=sys.stderr)
|
||
return 1
|
||
token = create_access_token(data={"sub": user.id, "username": user.username})
|
||
headers = {"Authorization": f"Bearer {token}"}
|
||
uid = f"e2e_z11_{uuid.uuid4().hex[:10]}"
|
||
print(f"agent={agent.id} user_id={uid} test_url={TEST_URL}\n")
|
||
|
||
def poll(client: httpx.Client, eid: str, timeout: float = 420.0) -> dict:
|
||
t0 = time.time()
|
||
while time.time() - t0 < timeout:
|
||
r = client.get(f"/api/v1/executions/{eid}", headers=headers)
|
||
r.raise_for_status()
|
||
d = r.json()
|
||
st = d.get("status")
|
||
if st == "completed":
|
||
return d
|
||
if st == "failed":
|
||
print("failed:", d.get("error_message"), file=sys.stderr)
|
||
raise RuntimeError("执行失败")
|
||
time.sleep(1.5)
|
||
raise TimeoutError("超时")
|
||
|
||
def reply_text(out: dict) -> str:
|
||
od = out.get("output_data") or {}
|
||
if isinstance(od, dict):
|
||
r = od.get("result")
|
||
if isinstance(r, str):
|
||
return r[:800]
|
||
return json.dumps(od, ensure_ascii=False)[:800]
|
||
|
||
rounds = [
|
||
"我的名字叫测试员",
|
||
f"请用工具访问这个网址并简要说明返回里 title 或主要内容是什么(只回答要点):{TEST_URL}",
|
||
"我叫什么名字?",
|
||
]
|
||
|
||
with httpx.Client(base_url=API_BASE, timeout=420.0) as client:
|
||
for i, q in enumerate(rounds, 1):
|
||
r = client.post(
|
||
"/api/v1/executions",
|
||
json={"agent_id": str(agent.id), "input_data": {"query": q, "user_id": uid}},
|
||
headers=headers,
|
||
)
|
||
if r.status_code >= 400:
|
||
print(r.text, file=sys.stderr)
|
||
r.raise_for_status()
|
||
eid = r.json()["id"]
|
||
print(f"--- 第{i}轮 execution={eid} ---")
|
||
out = poll(client, eid)
|
||
print(f"Q: {q[:120]}...")
|
||
print(f"A: {reply_text(out)}\n")
|
||
|
||
print("完成")
|
||
finally:
|
||
db.close()
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|