Files
aiagent/backend/scripts/e2e_zhini12_123_md.py

100 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""一次调用「知你客服12号」创建仓库根下 123.md相对路径 123.md。默认不重启 Celery。"""
from __future__ import annotations
import json
import os
import sys
import time
import uuid
from pathlib import Path
BACKEND_DIR = Path(__file__).resolve().parents[1]
REPO_ROOT = BACKEND_DIR.parent
API_BASE = os.environ.get("API_BASE", "http://127.0.0.1:8037")
AGENT_NAME = os.environ.get("E2E_AGENT_NAME", "知你客服12号")
REL_PATH = "123.md"
FILE_CONTENT = "# 123\ne2e zhini12 123.md marker\n"
def main() -> int:
os.chdir(BACKEND_DIR)
sys.path.insert(0, str(BACKEND_DIR))
import httpx
from app.core.database import SessionLocal
from app.core.security import create_access_token
from app.models.agent import Agent
from app.models.user import User
db = SessionLocal()
try:
agent = db.query(Agent).filter(Agent.name == AGENT_NAME).first()
if not agent:
print(f"未找到「{AGENT_NAME}", file=sys.stderr)
return 1
owner = db.query(User).filter(User.id == agent.user_id).first()
user = owner or db.query(User).first()
if not user:
print("无用户", file=sys.stderr)
return 1
token = create_access_token(data={"sub": user.id, "username": user.username})
headers = {"Authorization": f"Bearer {token}"}
uid = f"e2e123_{uuid.uuid4().hex[:10]}"
q = (
f"创建 123.md。请用 file_write相对路径 {REL_PATH}(工作区根下),"
f"content 为 {json.dumps(FILE_CONTENT, ensure_ascii=False)}mode 为 w。"
"reply 中写出 file_write 返回的真实 JSON。最后一行单行 JSONintent、reply、user_profile。"
)
print(f"agent={agent.id} ({AGENT_NAME}) user_id={uid}")
print(f"目标文件: {(REPO_ROOT / REL_PATH).resolve()}")
def poll(client: httpx.Client, eid: str, timeout: float = 300.0) -> dict:
t0 = time.time()
while time.time() - t0 < timeout:
r = client.get(f"/api/v1/executions/{eid}", headers=headers)
r.raise_for_status()
d = r.json()
st = d.get("status")
if st == "completed":
return d
if st == "failed":
print("failed:", d.get("error_message"), file=sys.stderr)
raise RuntimeError("执行失败")
time.sleep(1.5)
raise TimeoutError("超时")
with httpx.Client(base_url=API_BASE, timeout=300.0) as client:
r = client.post(
"/api/v1/executions",
json={"agent_id": str(agent.id), "input_data": {"query": q, "user_id": uid}},
headers=headers,
)
if r.status_code >= 400:
print(r.text, file=sys.stderr)
r.raise_for_status()
eid = r.json()["id"]
print(f"execution={eid}")
out = poll(client, eid)
od = out.get("output_data") or {}
result = od.get("result", od)
print("\n--- API result (截断 2000 字符) ---\n")
print(str(result)[:2000])
abs_file = (REPO_ROOT / REL_PATH).resolve()
if not abs_file.is_file():
print(f"\n[FAIL] 磁盘未找到: {abs_file}", file=sys.stderr)
return 2
body = abs_file.read_text(encoding="utf-8", errors="replace")
print(f"\n[OK] 文件: {abs_file}")
print("--- 内容 ---\n", body[:800])
if "e2e zhini12 123.md marker" not in body:
print("\n[WARN] 未找到预期标记字符串", file=sys.stderr)
print("\n完成")
return 0
finally:
db.close()
if __name__ == "__main__":
raise SystemExit(main())