100 lines
3.7 KiB
Python
100 lines
3.7 KiB
Python
"""一次调用「知你客服12号」创建仓库根下 123.md(相对路径 123.md)。默认不重启 Celery。"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
import sys
|
||
import time
|
||
import uuid
|
||
from pathlib import Path
|
||
|
||
BACKEND_DIR = Path(__file__).resolve().parents[1]
|
||
REPO_ROOT = BACKEND_DIR.parent
|
||
API_BASE = os.environ.get("API_BASE", "http://127.0.0.1:8037")
|
||
AGENT_NAME = os.environ.get("E2E_AGENT_NAME", "知你客服12号")
|
||
REL_PATH = "123.md"
|
||
FILE_CONTENT = "# 123\ne2e zhini12 123.md marker\n"
|
||
|
||
|
||
def main() -> int:
|
||
os.chdir(BACKEND_DIR)
|
||
sys.path.insert(0, str(BACKEND_DIR))
|
||
|
||
import httpx
|
||
from app.core.database import SessionLocal
|
||
from app.core.security import create_access_token
|
||
from app.models.agent import Agent
|
||
from app.models.user import User
|
||
|
||
db = SessionLocal()
|
||
try:
|
||
agent = db.query(Agent).filter(Agent.name == AGENT_NAME).first()
|
||
if not agent:
|
||
print(f"未找到「{AGENT_NAME}」", file=sys.stderr)
|
||
return 1
|
||
owner = db.query(User).filter(User.id == agent.user_id).first()
|
||
user = owner or db.query(User).first()
|
||
if not user:
|
||
print("无用户", file=sys.stderr)
|
||
return 1
|
||
token = create_access_token(data={"sub": user.id, "username": user.username})
|
||
headers = {"Authorization": f"Bearer {token}"}
|
||
uid = f"e2e123_{uuid.uuid4().hex[:10]}"
|
||
q = (
|
||
f"创建 123.md。请用 file_write:相对路径 {REL_PATH}(工作区根下),"
|
||
f"content 为 {json.dumps(FILE_CONTENT, ensure_ascii=False)},mode 为 w。"
|
||
"reply 中写出 file_write 返回的真实 JSON。最后一行单行 JSON:intent、reply、user_profile。"
|
||
)
|
||
print(f"agent={agent.id} ({AGENT_NAME}) user_id={uid}")
|
||
print(f"目标文件: {(REPO_ROOT / REL_PATH).resolve()}")
|
||
|
||
def poll(client: httpx.Client, eid: str, timeout: float = 300.0) -> dict:
|
||
t0 = time.time()
|
||
while time.time() - t0 < timeout:
|
||
r = client.get(f"/api/v1/executions/{eid}", headers=headers)
|
||
r.raise_for_status()
|
||
d = r.json()
|
||
st = d.get("status")
|
||
if st == "completed":
|
||
return d
|
||
if st == "failed":
|
||
print("failed:", d.get("error_message"), file=sys.stderr)
|
||
raise RuntimeError("执行失败")
|
||
time.sleep(1.5)
|
||
raise TimeoutError("超时")
|
||
|
||
with httpx.Client(base_url=API_BASE, timeout=300.0) as client:
|
||
r = client.post(
|
||
"/api/v1/executions",
|
||
json={"agent_id": str(agent.id), "input_data": {"query": q, "user_id": uid}},
|
||
headers=headers,
|
||
)
|
||
if r.status_code >= 400:
|
||
print(r.text, file=sys.stderr)
|
||
r.raise_for_status()
|
||
eid = r.json()["id"]
|
||
print(f"execution={eid}")
|
||
out = poll(client, eid)
|
||
od = out.get("output_data") or {}
|
||
result = od.get("result", od)
|
||
print("\n--- API result (截断 2000 字符) ---\n")
|
||
print(str(result)[:2000])
|
||
|
||
abs_file = (REPO_ROOT / REL_PATH).resolve()
|
||
if not abs_file.is_file():
|
||
print(f"\n[FAIL] 磁盘未找到: {abs_file}", file=sys.stderr)
|
||
return 2
|
||
body = abs_file.read_text(encoding="utf-8", errors="replace")
|
||
print(f"\n[OK] 文件: {abs_file}")
|
||
print("--- 内容 ---\n", body[:800])
|
||
if "e2e zhini12 123.md marker" not in body:
|
||
print("\n[WARN] 未找到预期标记字符串", file=sys.stderr)
|
||
print("\n完成")
|
||
return 0
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|