Files
aiagent/backend/scripts/e2e_zhini12_file_test.py

143 lines
5.1 KiB
Python
Raw Normal View History

"""
知你客服12号一轮对话触发 file_write相对路径 user_data/e2e_12.md
可选重启 CeleryE2E_RESTART_CELERY=1默认 1
用法: cd backend && .\\venv\\Scripts\\python.exe scripts/e2e_zhini12_file_test.py
"""
from __future__ import annotations
import json
import os
import subprocess
import sys
import time
import uuid
from pathlib import Path
BACKEND_DIR = Path(__file__).resolve().parents[1]
VENV_PY = BACKEND_DIR / "venv" / "Scripts" / "python.exe"
API_BASE = os.environ.get("API_BASE", "http://127.0.0.1:8037")
AGENT_NAME = os.environ.get("E2E_AGENT_NAME", "知你客服12号")
REL_PATH = os.environ.get("E2E_REL_FILE", "user_data/e2e_12.md")
FILE_CONTENT = os.environ.get("E2E_FILE_CONTENT", "e2e zhini12 ok\n")
def _restart_celery() -> None:
ps = (
"Get-CimInstance Win32_Process | "
"Where-Object { $_.CommandLine -match 'celery_app' } | "
"ForEach-Object { Stop-Process -Id $_.ProcessId -Force -ErrorAction SilentlyContinue }"
)
subprocess.run(
["powershell", "-NoProfile", "-Command", ps],
cwd=str(BACKEND_DIR),
capture_output=True,
text=True,
)
time.sleep(2)
if not VENV_PY.is_file():
print("未找到 venv Python跳过启动 Celery", file=sys.stderr)
return
kw: dict = {"cwd": str(BACKEND_DIR), "stdout": subprocess.DEVNULL, "stderr": subprocess.STDOUT}
if sys.platform == "win32":
kw["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore[attr-defined]
subprocess.Popen(
[
str(VENV_PY),
"-m",
"celery",
"-A",
"app.core.celery_app",
"worker",
"--loglevel=info",
"--pool=threads",
"--concurrency=8",
],
**kw,
)
print("已启动 Celery等待就绪…")
time.sleep(4)
def main() -> int:
os.chdir(BACKEND_DIR)
sys.path.insert(0, str(BACKEND_DIR))
if os.environ.get("E2E_RESTART_CELERY", "1").strip().lower() not in ("0", "false", "no"):
_restart_celery()
import httpx
from app.core.database import SessionLocal
from app.core.security import create_access_token
from app.models.agent import Agent
from app.models.user import User
db = SessionLocal()
try:
agent = db.query(Agent).filter(Agent.name == AGENT_NAME).first()
if not agent:
print(f"未找到「{AGENT_NAME}", file=sys.stderr)
return 1
owner = db.query(User).filter(User.id == agent.user_id).first()
user = owner or db.query(User).first()
if not user:
print("无用户", file=sys.stderr)
return 1
token = create_access_token(data={"sub": user.id, "username": user.username})
headers = {"Authorization": f"Bearer {token}"}
uid = f"e2e12_{uuid.uuid4().hex[:10]}"
q = (
f"请调用 file_writefile_path 用相对路径 {REL_PATH}content 用 {json.dumps(FILE_CONTENT, ensure_ascii=False)}"
"mode 用 w。完成后在 reply 里写出 file_write 返回的原始 JSON 字符串(不要编造)。"
"最终只输出一行 JSONintent、reply、user_profile。"
)
print(f"agent={agent.id} user_id={uid}\nQ: {q[:200]}...")
def poll(client: httpx.Client, eid: str, timeout: float = 300.0) -> dict:
t0 = time.time()
while time.time() - t0 < timeout:
r = client.get(f"/api/v1/executions/{eid}", headers=headers)
r.raise_for_status()
d = r.json()
st = d.get("status")
if st == "completed":
return d
if st == "failed":
print("failed:", d.get("error_message"), file=sys.stderr)
raise RuntimeError("执行失败")
time.sleep(1.5)
raise TimeoutError("超时")
with httpx.Client(base_url=API_BASE, timeout=300.0) as client:
r = client.post(
"/api/v1/executions",
json={"agent_id": str(agent.id), "input_data": {"query": q, "user_id": uid}},
headers=headers,
)
if r.status_code >= 400:
print(r.text, file=sys.stderr)
r.raise_for_status()
eid = r.json()["id"]
print(f"execution={eid}")
out = poll(client, eid)
od = out.get("output_data") or {}
result = od.get("result", od)
print("--- API result (截断) ---")
print(str(result)[:1200])
root = BACKEND_DIR.parent
abs_file = (root / REL_PATH.replace("/", os.sep)).resolve()
if abs_file.is_file():
body = abs_file.read_text(encoding="utf-8", errors="replace")
print(f"\n磁盘文件存在: {abs_file}\n内容:\n{body!r}")
else:
print(f"\n磁盘未找到: {abs_file}", file=sys.stderr)
return 2
finally:
db.close()
print("\n完成")
return 0
if __name__ == "__main__":
raise SystemExit(main())