298 lines
10 KiB
Python
298 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
从「知你客服13号」复制为「知你客服14号」:
|
||
|
||
- **画布**:与 13 号脚本相同(去自环/重复边、分层布局、统一左右锚点)。
|
||
- **工具**:在 13 号(http_request、file_read、file_write、system_info)基础上,增加平台已注册的内置工具:
|
||
text_analyze、datetime、math_calculate、json_process、database_query、adb_log(与 `tools_bootstrap` 对齐)。
|
||
- **提示词**:在 13 号提示词后追加 14 号扩展工具说明与纪律。
|
||
|
||
若已存在同名 Agent「知你客服14号」,则仅更新其 workflow + 描述(不新建)。
|
||
|
||
用法:
|
||
cd backend && .\\venv\\Scripts\\python.exe scripts/create_zhini_kefu_14.py
|
||
|
||
环境变量: PLATFORM_BASE_URL, PLATFORM_USERNAME, PLATFORM_PASSWORD,
|
||
SOURCE_AGENT_NAME(默认 知你客服13号), TARGET_NAME(默认 知你客服14号)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import copy
|
||
import json
|
||
import os
|
||
import sys
|
||
from collections import defaultdict
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
import requests
|
||
|
||
BASE = os.getenv("PLATFORM_BASE_URL", "http://127.0.0.1:8037").rstrip("/")
|
||
USER = os.getenv("PLATFORM_USERNAME", "admin")
|
||
PWD = os.getenv("PLATFORM_PASSWORD", "123456")
|
||
SOURCE_NAME = os.getenv("SOURCE_AGENT_NAME", "知你客服13号")
|
||
TARGET_NAME = os.getenv("TARGET_NAME", "知你客服14号")
|
||
|
||
# 与 app.core.tools_bootstrap.ensure_builtin_tools_registered 中注册列表一致(全量内置工具)
|
||
TOOLS_V14: List[str] = [
|
||
"http_request",
|
||
"file_read",
|
||
"file_write",
|
||
"text_analyze",
|
||
"datetime",
|
||
"math_calculate",
|
||
"system_info",
|
||
"json_process",
|
||
"database_query",
|
||
"adb_log",
|
||
]
|
||
|
||
PROMPT_V14_MARKER = "【知你客服 14 号 · 扩展工具】"
|
||
|
||
PROMPT_V14_EXTRA = f"""
|
||
|
||
{PROMPT_V14_MARKER}
|
||
在 13 号既有能力与纪律之上,可使用下列额外工具(按需调用,避免无关刷屏;仍以 **单行 JSON** 收尾):
|
||
|
||
【text_analyze】文本分析:`text` 为正文,`operation` 为 `count`(字数/行数等统计)、`keywords`(简单词频)、`summary`(取前几句摘要)。
|
||
|
||
【datetime】日期时间:`operation` 常用 `now`;`format` 为 strftime 格式串(可选)。
|
||
|
||
【math_calculate】数学计算:`expression` 为安全算术表达式(如 `2+2*3`、`sqrt(16)`),勿编造结果,以工具返回为准。
|
||
|
||
【json_process】JSON 处理:`json_string` + `operation` 为 `parse` | `stringify` | `validate`。
|
||
|
||
【database_query】只读 SQL:**仅允许 SELECT**。未指定数据源时使用平台默认库;若需指定外部数据源可传 `data_source_id`。不得编造查询结果;大表注意 `timeout`(秒)。
|
||
|
||
【adb_log】Android 日志:依赖运行环境已安装 **adb** 且设备可用;`command` 等参数按工具 schema。仅在用户明确需要拉取/分析设备日志时使用,避免滥用。
|
||
|
||
【纪律】
|
||
- 继承 13 号:同轮避免无故重复 `file_write`;勿在正文中刷屏 DSML。
|
||
- `database_query` 禁止非 SELECT;`adb_log` 需环境与权限,失败时如实说明工具返回。
|
||
"""
|
||
|
||
|
||
def _sanitize_edges(edges: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||
"""去掉自环、按 (source,target) 去重,统一左右锚点。"""
|
||
seen: set = set()
|
||
out: List[Dict[str, Any]] = []
|
||
for e in edges or []:
|
||
s, t = e.get("source"), e.get("target")
|
||
if not s or not t:
|
||
continue
|
||
if s == t:
|
||
continue
|
||
key = (s, t)
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
ne = dict(e)
|
||
ne["sourceHandle"] = "right"
|
||
ne["targetHandle"] = "left"
|
||
if not ne.get("id"):
|
||
ne["id"] = f"edge_{s}_{t}"
|
||
out.append(ne)
|
||
return out
|
||
|
||
|
||
def _find_start_node_ids(nodes: List[Dict[str, Any]]) -> List[str]:
|
||
ids: List[str] = []
|
||
for n in nodes or []:
|
||
nid = n.get("id") or ""
|
||
nt = (n.get("type") or (n.get("data") or {}).get("type") or "").lower()
|
||
if nt == "start" or nid in ("start", "start-1") or str(nid).startswith("start-"):
|
||
ids.append(nid)
|
||
return ids
|
||
|
||
|
||
def _compute_ranks(
|
||
nodes: List[Dict[str, Any]], edges: List[Dict[str, Any]]
|
||
) -> Dict[str, int]:
|
||
node_ids = [n["id"] for n in nodes if n.get("id")]
|
||
start_ids = _find_start_node_ids(nodes)
|
||
incoming: Dict[str, int] = {nid: 0 for nid in node_ids}
|
||
for e in edges:
|
||
s, t = e.get("source"), e.get("target")
|
||
if not s or not t or s == t:
|
||
continue
|
||
if t in incoming:
|
||
incoming[t] += 1
|
||
if not start_ids:
|
||
start_ids = [nid for nid in node_ids if incoming.get(nid, 0) == 0] or ([node_ids[0]] if node_ids else [])
|
||
|
||
rank: Dict[str, int] = {s: 0 for s in start_ids}
|
||
nmax = max(len(nodes), 8)
|
||
for _ in range(nmax + 5):
|
||
updated = False
|
||
for e in edges:
|
||
s, t = e.get("source"), e.get("target")
|
||
if not s or not t or s == t:
|
||
continue
|
||
if s not in rank:
|
||
continue
|
||
nv = rank[s] + 1
|
||
if t not in rank or rank[t] < nv:
|
||
rank[t] = nv
|
||
updated = True
|
||
if not updated:
|
||
break
|
||
max_r = max(rank.values(), default=0)
|
||
for nid in node_ids:
|
||
if nid not in rank:
|
||
rank[nid] = max_r + 1
|
||
max_r += 1
|
||
return rank
|
||
|
||
|
||
def _apply_layered_positions(nodes: List[Dict[str, Any]], ranks: Dict[str, int]) -> None:
|
||
layers: Dict[int, List[str]] = defaultdict(list)
|
||
for nid, r in ranks.items():
|
||
layers[r].append(nid)
|
||
for r in layers:
|
||
layers[r].sort()
|
||
|
||
x0, y0 = 80.0, 140.0
|
||
x_step = 300.0
|
||
y_step = 110.0
|
||
|
||
for r in sorted(layers.keys()):
|
||
ids = layers[r]
|
||
nlen = len(ids)
|
||
y_base = y0 - (nlen - 1) * y_step / 2.0
|
||
for j, nid in enumerate(ids):
|
||
for node in nodes:
|
||
if node.get("id") != nid:
|
||
continue
|
||
pos = node.setdefault("position", {})
|
||
pos["x"] = x0 + r * x_step
|
||
pos["y"] = y_base + j * y_step
|
||
break
|
||
|
||
|
||
def improve_workflow_layout_and_edges(wf: Dict[str, Any]) -> Tuple[int, int]:
|
||
"""返回 (去掉的自环条数, 去掉的重复边条数)。"""
|
||
nodes = wf.get("nodes") or []
|
||
raw_edges = wf.get("edges") or []
|
||
loops = sum(
|
||
1
|
||
for e in raw_edges
|
||
if e.get("source") and e.get("target") and e.get("source") == e.get("target")
|
||
)
|
||
clean = _sanitize_edges(raw_edges)
|
||
removed_dup = len(raw_edges) - len(clean) - loops
|
||
|
||
wf["edges"] = clean
|
||
|
||
ranks = _compute_ranks(nodes, clean)
|
||
_apply_layered_positions(nodes, ranks)
|
||
return loops, max(0, removed_dup)
|
||
|
||
|
||
def _patch_llm_unified(wf: dict, base_prompt: Optional[str] = None) -> None:
|
||
for n in wf.get("nodes") or []:
|
||
if n.get("id") != "llm-unified":
|
||
continue
|
||
d = n.setdefault("data", {})
|
||
prompt = base_prompt if base_prompt else d.get("prompt") or ""
|
||
if PROMPT_V14_MARKER not in prompt:
|
||
prompt = (prompt.rstrip() + "\n" + PROMPT_V14_EXTRA).strip()
|
||
d["prompt"] = prompt
|
||
d["enable_tools"] = True
|
||
d["tools"] = list(TOOLS_V14)
|
||
d["selected_tools"] = list(TOOLS_V14)
|
||
return
|
||
print("警告: 未找到节点 llm-unified", file=sys.stderr)
|
||
|
||
|
||
def _find_agent_id_by_name(h: Dict[str, str], name: str) -> Optional[str]:
|
||
r = requests.get(f"{BASE}/api/v1/agents", params={"search": name, "limit": 50}, headers=h, timeout=30)
|
||
if r.status_code != 200:
|
||
return None
|
||
for a in r.json() or []:
|
||
if a.get("name") == name:
|
||
return a.get("id")
|
||
return None
|
||
|
||
|
||
def main() -> int:
|
||
r = requests.post(
|
||
f"{BASE}/api/v1/auth/login",
|
||
data={"username": USER, "password": PWD},
|
||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||
timeout=15,
|
||
)
|
||
if r.status_code != 200:
|
||
print("登录失败:", r.status_code, r.text[:500], file=sys.stderr)
|
||
return 1
|
||
token = r.json().get("access_token")
|
||
if not token:
|
||
print("无 access_token", file=sys.stderr)
|
||
return 1
|
||
h = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||
|
||
src_id = _find_agent_id_by_name(h, SOURCE_NAME)
|
||
if not src_id:
|
||
print(f"未找到源 Agent: {SOURCE_NAME}", file=sys.stderr)
|
||
return 1
|
||
|
||
existing = _find_agent_id_by_name(h, TARGET_NAME)
|
||
if existing:
|
||
print("已存在", TARGET_NAME, "-> 仅更新工作流", existing)
|
||
new_id = existing
|
||
g = requests.get(f"{BASE}/api/v1/agents/{new_id}", headers=h, timeout=30)
|
||
if g.status_code != 200:
|
||
print("读取失败:", g.text, file=sys.stderr)
|
||
return 1
|
||
agent = g.json()
|
||
else:
|
||
dup = requests.post(
|
||
f"{BASE}/api/v1/agents/{src_id}/duplicate",
|
||
headers=h,
|
||
json={"name": TARGET_NAME},
|
||
timeout=60,
|
||
)
|
||
if dup.status_code != 201:
|
||
print("复制失败:", dup.status_code, dup.text[:800], file=sys.stderr)
|
||
return 1
|
||
new_id = dup.json()["id"]
|
||
agent = dup.json()
|
||
print("已创建副本:", new_id, TARGET_NAME)
|
||
|
||
wf = copy.deepcopy(agent["workflow_config"])
|
||
loops, dup_edges = improve_workflow_layout_and_edges(wf)
|
||
print(f"连线整理: 去掉自环 {loops} 条, 合并重复边 {dup_edges} 条")
|
||
|
||
g2 = requests.get(f"{BASE}/api/v1/agents/{src_id}", headers=h, timeout=30)
|
||
base_prompt = None
|
||
if g2.status_code == 200:
|
||
try:
|
||
for n in g2.json().get("workflow_config", {}).get("nodes") or []:
|
||
if n.get("id") == "llm-unified":
|
||
base_prompt = (n.get("data") or {}).get("prompt")
|
||
break
|
||
except Exception:
|
||
pass
|
||
_patch_llm_unified(wf, base_prompt=base_prompt)
|
||
|
||
desc = (
|
||
"在知你客服13号基础上:扩展内置工具为全量(含 text_analyze、datetime、math_calculate、"
|
||
"json_process、database_query、adb_log 等);画布与 13 号一致整理;输出仍为单行 JSON。"
|
||
)
|
||
|
||
up = requests.put(
|
||
f"{BASE}/api/v1/agents/{new_id}",
|
||
headers=h,
|
||
json={"description": desc, "workflow_config": wf},
|
||
timeout=120,
|
||
)
|
||
if up.status_code != 200:
|
||
print("更新失败:", up.status_code, up.text[:1200], file=sys.stderr)
|
||
return 1
|
||
print("已写入工具:", ", ".join(TOOLS_V14))
|
||
print("Agent ID:", new_id)
|
||
print(json.dumps({"id": new_id, "name": TARGET_NAME}, ensure_ascii=False))
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|