Files
aiagent/backend/scripts/create_zhini_kefu_14.py

298 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
从「知你客服13号」复制为「知你客服14号」
- **画布**:与 13 号脚本相同(去自环/重复边、分层布局、统一左右锚点)。
- **工具**:在 13 号http_request、file_read、file_write、system_info基础上增加平台已注册的内置工具
text_analyze、datetime、math_calculate、json_process、database_query、adb_log与 `tools_bootstrap` 对齐)。
- **提示词**:在 13 号提示词后追加 14 号扩展工具说明与纪律。
若已存在同名 Agent「知你客服14号」则仅更新其 workflow + 描述(不新建)。
用法:
cd backend && .\\venv\\Scripts\\python.exe scripts/create_zhini_kefu_14.py
环境变量: PLATFORM_BASE_URL, PLATFORM_USERNAME, PLATFORM_PASSWORD,
SOURCE_AGENT_NAME默认 知你客服13号, TARGET_NAME默认 知你客服14号
"""
from __future__ import annotations
import copy
import json
import os
import sys
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
import requests
BASE = os.getenv("PLATFORM_BASE_URL", "http://127.0.0.1:8037").rstrip("/")
USER = os.getenv("PLATFORM_USERNAME", "admin")
PWD = os.getenv("PLATFORM_PASSWORD", "123456")
SOURCE_NAME = os.getenv("SOURCE_AGENT_NAME", "知你客服13号")
TARGET_NAME = os.getenv("TARGET_NAME", "知你客服14号")
# 与 app.core.tools_bootstrap.ensure_builtin_tools_registered 中注册列表一致(全量内置工具)
TOOLS_V14: List[str] = [
"http_request",
"file_read",
"file_write",
"text_analyze",
"datetime",
"math_calculate",
"system_info",
"json_process",
"database_query",
"adb_log",
]
PROMPT_V14_MARKER = "【知你客服 14 号 · 扩展工具】"
PROMPT_V14_EXTRA = f"""
{PROMPT_V14_MARKER}
在 13 号既有能力与纪律之上,可使用下列额外工具(按需调用,避免无关刷屏;仍以 **单行 JSON** 收尾):
【text_analyze】文本分析`text` 为正文,`operation` 为 `count`(字数/行数等统计)、`keywords`(简单词频)、`summary`(取前几句摘要)。
【datetime】日期时间`operation` 常用 `now``format` 为 strftime 格式串(可选)。
【math_calculate】数学计算`expression` 为安全算术表达式(如 `2+2*3`、`sqrt(16)`),勿编造结果,以工具返回为准。
【json_process】JSON 处理:`json_string` + `operation` 为 `parse` | `stringify` | `validate`。
【database_query】只读 SQL**仅允许 SELECT**。未指定数据源时使用平台默认库;若需指定外部数据源可传 `data_source_id`。不得编造查询结果;大表注意 `timeout`(秒)。
【adb_log】Android 日志:依赖运行环境已安装 **adb** 且设备可用;`command` 等参数按工具 schema。仅在用户明确需要拉取/分析设备日志时使用,避免滥用。
【纪律】
- 继承 13 号:同轮避免无故重复 `file_write`;勿在正文中刷屏 DSML。
- `database_query` 禁止非 SELECT`adb_log` 需环境与权限,失败时如实说明工具返回。
"""
def _sanitize_edges(edges: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""去掉自环、按 (source,target) 去重,统一左右锚点。"""
seen: set = set()
out: List[Dict[str, Any]] = []
for e in edges or []:
s, t = e.get("source"), e.get("target")
if not s or not t:
continue
if s == t:
continue
key = (s, t)
if key in seen:
continue
seen.add(key)
ne = dict(e)
ne["sourceHandle"] = "right"
ne["targetHandle"] = "left"
if not ne.get("id"):
ne["id"] = f"edge_{s}_{t}"
out.append(ne)
return out
def _find_start_node_ids(nodes: List[Dict[str, Any]]) -> List[str]:
ids: List[str] = []
for n in nodes or []:
nid = n.get("id") or ""
nt = (n.get("type") or (n.get("data") or {}).get("type") or "").lower()
if nt == "start" or nid in ("start", "start-1") or str(nid).startswith("start-"):
ids.append(nid)
return ids
def _compute_ranks(
nodes: List[Dict[str, Any]], edges: List[Dict[str, Any]]
) -> Dict[str, int]:
node_ids = [n["id"] for n in nodes if n.get("id")]
start_ids = _find_start_node_ids(nodes)
incoming: Dict[str, int] = {nid: 0 for nid in node_ids}
for e in edges:
s, t = e.get("source"), e.get("target")
if not s or not t or s == t:
continue
if t in incoming:
incoming[t] += 1
if not start_ids:
start_ids = [nid for nid in node_ids if incoming.get(nid, 0) == 0] or ([node_ids[0]] if node_ids else [])
rank: Dict[str, int] = {s: 0 for s in start_ids}
nmax = max(len(nodes), 8)
for _ in range(nmax + 5):
updated = False
for e in edges:
s, t = e.get("source"), e.get("target")
if not s or not t or s == t:
continue
if s not in rank:
continue
nv = rank[s] + 1
if t not in rank or rank[t] < nv:
rank[t] = nv
updated = True
if not updated:
break
max_r = max(rank.values(), default=0)
for nid in node_ids:
if nid not in rank:
rank[nid] = max_r + 1
max_r += 1
return rank
def _apply_layered_positions(nodes: List[Dict[str, Any]], ranks: Dict[str, int]) -> None:
layers: Dict[int, List[str]] = defaultdict(list)
for nid, r in ranks.items():
layers[r].append(nid)
for r in layers:
layers[r].sort()
x0, y0 = 80.0, 140.0
x_step = 300.0
y_step = 110.0
for r in sorted(layers.keys()):
ids = layers[r]
nlen = len(ids)
y_base = y0 - (nlen - 1) * y_step / 2.0
for j, nid in enumerate(ids):
for node in nodes:
if node.get("id") != nid:
continue
pos = node.setdefault("position", {})
pos["x"] = x0 + r * x_step
pos["y"] = y_base + j * y_step
break
def improve_workflow_layout_and_edges(wf: Dict[str, Any]) -> Tuple[int, int]:
"""返回 (去掉的自环条数, 去掉的重复边条数)。"""
nodes = wf.get("nodes") or []
raw_edges = wf.get("edges") or []
loops = sum(
1
for e in raw_edges
if e.get("source") and e.get("target") and e.get("source") == e.get("target")
)
clean = _sanitize_edges(raw_edges)
removed_dup = len(raw_edges) - len(clean) - loops
wf["edges"] = clean
ranks = _compute_ranks(nodes, clean)
_apply_layered_positions(nodes, ranks)
return loops, max(0, removed_dup)
def _patch_llm_unified(wf: dict, base_prompt: Optional[str] = None) -> None:
for n in wf.get("nodes") or []:
if n.get("id") != "llm-unified":
continue
d = n.setdefault("data", {})
prompt = base_prompt if base_prompt else d.get("prompt") or ""
if PROMPT_V14_MARKER not in prompt:
prompt = (prompt.rstrip() + "\n" + PROMPT_V14_EXTRA).strip()
d["prompt"] = prompt
d["enable_tools"] = True
d["tools"] = list(TOOLS_V14)
d["selected_tools"] = list(TOOLS_V14)
return
print("警告: 未找到节点 llm-unified", file=sys.stderr)
def _find_agent_id_by_name(h: Dict[str, str], name: str) -> Optional[str]:
r = requests.get(f"{BASE}/api/v1/agents", params={"search": name, "limit": 50}, headers=h, timeout=30)
if r.status_code != 200:
return None
for a in r.json() or []:
if a.get("name") == name:
return a.get("id")
return None
def main() -> int:
r = requests.post(
f"{BASE}/api/v1/auth/login",
data={"username": USER, "password": PWD},
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=15,
)
if r.status_code != 200:
print("登录失败:", r.status_code, r.text[:500], file=sys.stderr)
return 1
token = r.json().get("access_token")
if not token:
print("无 access_token", file=sys.stderr)
return 1
h = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
src_id = _find_agent_id_by_name(h, SOURCE_NAME)
if not src_id:
print(f"未找到源 Agent: {SOURCE_NAME}", file=sys.stderr)
return 1
existing = _find_agent_id_by_name(h, TARGET_NAME)
if existing:
print("已存在", TARGET_NAME, "-> 仅更新工作流", existing)
new_id = existing
g = requests.get(f"{BASE}/api/v1/agents/{new_id}", headers=h, timeout=30)
if g.status_code != 200:
print("读取失败:", g.text, file=sys.stderr)
return 1
agent = g.json()
else:
dup = requests.post(
f"{BASE}/api/v1/agents/{src_id}/duplicate",
headers=h,
json={"name": TARGET_NAME},
timeout=60,
)
if dup.status_code != 201:
print("复制失败:", dup.status_code, dup.text[:800], file=sys.stderr)
return 1
new_id = dup.json()["id"]
agent = dup.json()
print("已创建副本:", new_id, TARGET_NAME)
wf = copy.deepcopy(agent["workflow_config"])
loops, dup_edges = improve_workflow_layout_and_edges(wf)
print(f"连线整理: 去掉自环 {loops} 条, 合并重复边 {dup_edges}")
g2 = requests.get(f"{BASE}/api/v1/agents/{src_id}", headers=h, timeout=30)
base_prompt = None
if g2.status_code == 200:
try:
for n in g2.json().get("workflow_config", {}).get("nodes") or []:
if n.get("id") == "llm-unified":
base_prompt = (n.get("data") or {}).get("prompt")
break
except Exception:
pass
_patch_llm_unified(wf, base_prompt=base_prompt)
desc = (
"在知你客服13号基础上扩展内置工具为全量含 text_analyze、datetime、math_calculate、"
"json_process、database_query、adb_log 等);画布与 13 号一致整理;输出仍为单行 JSON。"
)
up = requests.put(
f"{BASE}/api/v1/agents/{new_id}",
headers=h,
json={"description": desc, "workflow_config": wf},
timeout=120,
)
if up.status_code != 200:
print("更新失败:", up.status_code, up.text[:1200], file=sys.stderr)
return 1
print("已写入工具:", ", ".join(TOOLS_V14))
print("Agent ID:", new_id)
print(json.dumps({"id": new_id, "name": TARGET_NAME}, ensure_ascii=False))
return 0
if __name__ == "__main__":
raise SystemExit(main())