Files
aiagent/backend/scripts/e2e_zhini9_test.py

117 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
对「知你客服9号」做多轮 API 测试(不默认重启 Celery避免打断本机 Worker
轮次:①自我介绍姓名 ②陈述偏好(供摘要/关键词召回)③闲聊 ④问姓名 + 问偏好
用法:
cd backend && .\\venv\\Scripts\\python.exe scripts/e2e_zhini9_test.py
环境变量: API_BASE, E2E_AGENT_NAME默认 知你客服9号
"""
from __future__ import annotations
import json
import os
import sys
import time
import uuid
BACKEND_DIR = __file__.rsplit("scripts", 1)[0]
API_BASE = os.environ.get("API_BASE", "http://127.0.0.1:8037")
AGENT_NAME = os.environ.get("E2E_AGENT_NAME", "知你客服9号")
def main() -> int:
os.chdir(BACKEND_DIR)
sys.path.insert(0, BACKEND_DIR)
import httpx
from app.core.database import SessionLocal
from app.core.security import create_access_token
from app.models.agent import Agent
from app.models.user import User
db = SessionLocal()
try:
agent = db.query(Agent).filter(Agent.name == AGENT_NAME).first()
if not agent:
print(f"数据库中未找到「{AGENT_NAME}", file=sys.stderr)
return 1
owner = db.query(User).filter(User.id == agent.user_id).first()
user = owner or db.query(User).first()
if not user:
print("无可用用户", file=sys.stderr)
return 1
token = create_access_token(data={"sub": user.id, "username": user.username})
headers = {"Authorization": f"Bearer {token}"}
uid = f"e2e_z9_{uuid.uuid4().hex[:10]}"
print(f"agent_id={agent.id} name={agent.name} user_id={uid}\n")
def poll(client: httpx.Client, execution_id: str, timeout: float = 300.0) -> dict:
t0 = time.time()
while time.time() - t0 < timeout:
r = client.get(f"/api/v1/executions/{execution_id}", headers=headers)
r.raise_for_status()
data = r.json()
st = data.get("status")
if st == "completed":
return data
if st == "failed":
print("error:", data.get("error_message"), file=sys.stderr)
raise RuntimeError("执行失败")
time.sleep(1)
raise TimeoutError("超时")
def extract_reply(out: dict) -> str:
od = out.get("output_data") or {}
if isinstance(od, dict):
r = od.get("result")
if isinstance(r, str):
return r[:500]
return json.dumps(od, ensure_ascii=False)[:500]
rounds = [
"我的名字叫阿九",
"记住:我最爱吃火锅,不喜欢甜食。",
"今天天气不错吧?",
"我叫什么名字?你还记得我喜欢吃什么吗?",
]
with httpx.Client(base_url=API_BASE, timeout=300.0) as client:
for i, q in enumerate(rounds, 1):
r = client.post(
"/api/v1/executions",
json={"agent_id": str(agent.id), "input_data": {"query": q, "user_id": uid}},
headers=headers,
)
if r.status_code >= 400:
print(r.text, file=sys.stderr)
r.raise_for_status()
eid = r.json()["id"]
out = poll(client, eid)
print(f"--- 第{i}轮 ---\nQ: {q}\nA: {extract_reply(out)}\n")
try:
from app.core.config import settings
import redis as redis_lib
url = getattr(settings, "REDIS_URL", None) or "redis://localhost:6379/0"
rc = redis_lib.from_url(url, decode_responses=True)
key = f"user_memory_{uid}"
raw = rc.get(key)
print(f"Redis {key}:", "" if raw else "")
if raw:
mem = json.loads(raw)
print("conversation_summary 前120字:", str(mem.get("conversation_summary", ""))[:120])
print("user_profile:", mem.get("user_profile"))
print("history 条数:", len(mem.get("conversation_history") or []))
except Exception as ex:
print("Redis 检查:", ex)
finally:
db.close()
print("完成")
return 0
if __name__ == "__main__":
raise SystemExit(main())