Files
aiagent/backend/scripts/e2e_zhini12_file_test.py

143 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
知你客服12号一轮对话触发 file_write相对路径 user_data/e2e_12.md
可选重启 CeleryE2E_RESTART_CELERY=1默认 1
用法: cd backend && .\\venv\\Scripts\\python.exe scripts/e2e_zhini12_file_test.py
"""
from __future__ import annotations
import json
import os
import subprocess
import sys
import time
import uuid
from pathlib import Path
BACKEND_DIR = Path(__file__).resolve().parents[1]
VENV_PY = BACKEND_DIR / "venv" / "Scripts" / "python.exe"
API_BASE = os.environ.get("API_BASE", "http://127.0.0.1:8037")
AGENT_NAME = os.environ.get("E2E_AGENT_NAME", "知你客服12号")
REL_PATH = os.environ.get("E2E_REL_FILE", "user_data/e2e_12.md")
FILE_CONTENT = os.environ.get("E2E_FILE_CONTENT", "e2e zhini12 ok\n")
def _restart_celery() -> None:
ps = (
"Get-CimInstance Win32_Process | "
"Where-Object { $_.CommandLine -match 'celery_app' } | "
"ForEach-Object { Stop-Process -Id $_.ProcessId -Force -ErrorAction SilentlyContinue }"
)
subprocess.run(
["powershell", "-NoProfile", "-Command", ps],
cwd=str(BACKEND_DIR),
capture_output=True,
text=True,
)
time.sleep(2)
if not VENV_PY.is_file():
print("未找到 venv Python跳过启动 Celery", file=sys.stderr)
return
kw: dict = {"cwd": str(BACKEND_DIR), "stdout": subprocess.DEVNULL, "stderr": subprocess.STDOUT}
if sys.platform == "win32":
kw["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore[attr-defined]
subprocess.Popen(
[
str(VENV_PY),
"-m",
"celery",
"-A",
"app.core.celery_app",
"worker",
"--loglevel=info",
"--pool=threads",
"--concurrency=8",
],
**kw,
)
print("已启动 Celery等待就绪…")
time.sleep(4)
def main() -> int:
os.chdir(BACKEND_DIR)
sys.path.insert(0, str(BACKEND_DIR))
if os.environ.get("E2E_RESTART_CELERY", "1").strip().lower() not in ("0", "false", "no"):
_restart_celery()
import httpx
from app.core.database import SessionLocal
from app.core.security import create_access_token
from app.models.agent import Agent
from app.models.user import User
db = SessionLocal()
try:
agent = db.query(Agent).filter(Agent.name == AGENT_NAME).first()
if not agent:
print(f"未找到「{AGENT_NAME}", file=sys.stderr)
return 1
owner = db.query(User).filter(User.id == agent.user_id).first()
user = owner or db.query(User).first()
if not user:
print("无用户", file=sys.stderr)
return 1
token = create_access_token(data={"sub": user.id, "username": user.username})
headers = {"Authorization": f"Bearer {token}"}
uid = f"e2e12_{uuid.uuid4().hex[:10]}"
q = (
f"请调用 file_writefile_path 用相对路径 {REL_PATH}content 用 {json.dumps(FILE_CONTENT, ensure_ascii=False)}"
"mode 用 w。完成后在 reply 里写出 file_write 返回的原始 JSON 字符串(不要编造)。"
"最终只输出一行 JSONintent、reply、user_profile。"
)
print(f"agent={agent.id} user_id={uid}\nQ: {q[:200]}...")
def poll(client: httpx.Client, eid: str, timeout: float = 300.0) -> dict:
t0 = time.time()
while time.time() - t0 < timeout:
r = client.get(f"/api/v1/executions/{eid}", headers=headers)
r.raise_for_status()
d = r.json()
st = d.get("status")
if st == "completed":
return d
if st == "failed":
print("failed:", d.get("error_message"), file=sys.stderr)
raise RuntimeError("执行失败")
time.sleep(1.5)
raise TimeoutError("超时")
with httpx.Client(base_url=API_BASE, timeout=300.0) as client:
r = client.post(
"/api/v1/executions",
json={"agent_id": str(agent.id), "input_data": {"query": q, "user_id": uid}},
headers=headers,
)
if r.status_code >= 400:
print(r.text, file=sys.stderr)
r.raise_for_status()
eid = r.json()["id"]
print(f"execution={eid}")
out = poll(client, eid)
od = out.get("output_data") or {}
result = od.get("result", od)
print("--- API result (截断) ---")
print(str(result)[:1200])
root = BACKEND_DIR.parent
abs_file = (root / REL_PATH.replace("/", os.sep)).resolve()
if abs_file.is_file():
body = abs_file.read_text(encoding="utf-8", errors="replace")
print(f"\n磁盘文件存在: {abs_file}\n内容:\n{body!r}")
else:
print(f"\n磁盘未找到: {abs_file}", file=sys.stderr)
return 2
finally:
db.close()
print("\n完成")
return 0
if __name__ == "__main__":
raise SystemExit(main())