Files
aiagent/backend/scripts/create_zhini_kefu_9.py

244 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
从「知你客服8号」复制为「知你客服9号」强化「摘要 + 检索」可用性:
1. 上下文 code-build-context更长近期轮次、注入 conversation_summary、向量结果 + 关键词从历史中捞相关句。
2. 摘要路径 code-build-memory-value在原有摘要分支上合并进完整 conversation_history追加而非仅 2 条),并写回 conversation_summary。
3. cache-update-summary显式 value 为「memory」表达式避免整包 input_data 写入 Redis。
4. 向量写入:为每条 turn 带 metadata.user_id引擎侧检索已按 user_id 过滤。
需平台可登录;默认源 8 号 ID 为上次创建结果,可用 ZHINI_8_AGENT_ID 覆盖。
部署后请重启 Celery/API 以加载引擎向量过滤逻辑。
"""
from __future__ import annotations
import json
import os
import sys
import requests
BASE = os.getenv("PLATFORM_BASE_URL", "http://127.0.0.1:8037").rstrip("/")
SOURCE_AGENT_ID = os.getenv("ZHINI_8_AGENT_ID", "d7b64bf6-c8e3-4dc7-befc-03a98d5ff741")
USER = os.getenv("PLATFORM_USERNAME", "admin")
PWD = os.getenv("PLATFORM_PASSWORD", "123456")
NEW_NAME = "知你客服9号"
NEW_DESC = (
"在知你客服8号基础上强化摘要与检索"
"远期要点写入 conversation_summary"
"当轮上下文含「近期对话 + 摘要 + 向量片段 + 关键词相关历史」;"
"向量库写入带 user_id 元数据,引擎检索按用户隔离。"
"仍依赖 MEMORY_PERSIST_DB_ENABLED 与固定 user_id。"
)
LLM_PROMPT = """你是客服助手。根据「用户当前输入」「已知用户信息」「远期摘要」「相关历史(检索)」和「最近几轮」完成:
1判断意图
2生成一句自然、有帮助的回复
3【强制】用户说出或暗示姓名、昵称时必须在 user_profile.name 保存;合并已有字段勿丢失;
4用户问「我叫什么」等时必须依据 user_profile.name 与对话/摘要回答;已有 name 时禁止说「还不知道」;
5「远期摘要」概括更早话题「相关历史」可能含向量命中或关键词命中的旧轮次请结合使用。
只输出一行合法 JSON不要 markdown。格式示例
{"intent":"greeting","reply":"你好!","user_profile":{"name":"小明"}}
用户输入:{{user_input}}
已知用户信息:{{memory.user_profile}}
远期摘要:{{memory.conversation_summary}}
相关历史(检索到的):{{memory.relevant_from_retrieval}}
最近几轮:{{memory.recent_turns}}
要求reply 简洁自然200 字以内user_profile 为对象。"""
CODE_BUILD_CONTEXT = r"""left = input_data.get('left') or {}
right = input_data.get('right') or []
if not isinstance(right, list):
right = []
mem = left.get('memory') or {}
hist = mem.get('conversation_history') or []
if not isinstance(hist, list):
hist = []
summary = mem.get('conversation_summary') or ''
recent_n = 16
recent = hist[-recent_n:] if len(hist) > recent_n else hist
recent_str = '\n'.join(f"{x.get('role', '')}: {x.get('content', '')}" for x in recent)
vec_str = '\n'.join((rec.get('text') or rec.get('content') or '') for rec in right)
query = (left.get('user_input') or left.get('query') or '').strip()
older = hist[:-recent_n] if len(hist) > recent_n else []
def _tok(s):
s = str(s)
ch = {c for c in s if '\u4e00' <= c <= '\u9fff'}
wd = set(s.lower().replace('\n', ' ').split())
return ch | wd
qt = _tok(query) if query else set()
scored = []
for m in older:
c = str(m.get('content', ''))
if not c:
continue
sc = len(qt & _tok(c)) if qt else 0
if sc > 0:
scored.append((sc, str(m.get('role', '')), c[:240]))
scored.sort(key=lambda x: -x[0])
kw_lines = [f"{role}: {text}" for _, role, text in scored[:6]]
kw_str = '\n'.join(kw_lines)
relevant_str = vec_str.strip()
if kw_str:
if relevant_str:
relevant_str = relevant_str + '\n---\n关键词相关历史\n' + kw_str
else:
relevant_str = '关键词相关历史:\n' + kw_str
result = {
'user_input': left.get('user_input') or left.get('query') or '',
'memory': {
'user_profile': mem.get('user_profile') or {},
'conversation_summary': summary,
'relevant_from_retrieval': relevant_str,
'recent_turns': recent_str,
},
'query': left.get('query') or '',
'user_id': left.get('user_id'),
}
"""
CODE_BUILD_MEMORY_VALUE = r"""left = input_data.get('left') or {}
right_out = input_data.get('right') or {}
summary = ''
if isinstance(right_out, dict):
summary = right_out.get('output') or right_out.get('result') or ''
if not isinstance(summary, str):
summary = str(summary or '')
summary = summary.strip()
mem = left.get('memory') or {}
user_input = left.get('user_input') or left.get('query') or ''
reply = left.get('right') or ''
if isinstance(reply, dict):
reply = reply.get('right') or reply.get('content') or str(reply)
profile_update = left.get('user_profile_update') or {}
if not isinstance(profile_update, dict):
profile_update = {}
user_profile = dict(mem.get('user_profile') or {}, **profile_update)
ts = datetime.now().isoformat()
old_hist = mem.get('conversation_history') or []
if not isinstance(old_hist, list):
old_hist = []
new_hist = old_hist + [
{'role': 'user', 'content': user_input, 'timestamp': ts},
{'role': 'assistant', 'content': str(reply or ''), 'timestamp': ts},
]
max_len = 40
if len(new_hist) > max_len:
new_hist = new_hist[-max_len:]
prev_sum = (mem.get('conversation_summary') or '').strip()
conversation_summary = summary if summary else prev_sum
memory_value = {
'conversation_summary': conversation_summary,
'conversation_history': new_hist,
'user_profile': user_profile,
'context': mem.get('context') or {},
}
result = {
'memory': memory_value,
'user_id': left.get('user_id'),
'query': left.get('query'),
'user_input': user_input,
'right': reply,
'user_profile_update': profile_update,
}
"""
CODE_BUILD_TURN_FOR_VECTOR = r"""reply = input_data.get('right') or ''
if isinstance(reply, dict):
reply = reply.get('right') or reply.get('content') or str(reply)
query = input_data.get('query') or ''
user_id = str(input_data.get('user_id') or 'default')
raw = (user_id + '\n' + str(query) + '\n' + str(reply)).encode('utf-8', errors='ignore')
doc_id = 'turn_' + hashlib.sha256(raw).hexdigest()[:24]
text = '用户:' + str(query) + '\n助手' + str(reply)
result = {
'text': text,
'user_id': user_id,
'id': doc_id,
'metadata': {'user_id': user_id},
}
"""
def _patch_nodes(wf: dict) -> None:
nodes = wf.get("nodes") or []
for n in nodes:
nid = n.get("id")
if nid == "llm-unified":
n.setdefault("data", {})["prompt"] = LLM_PROMPT
elif nid == "code-build-context":
n.setdefault("data", {})["code"] = CODE_BUILD_CONTEXT
elif nid == "code-build-memory-value":
n.setdefault("data", {})["code"] = CODE_BUILD_MEMORY_VALUE
elif nid == "code-build-turn-for-vector":
n.setdefault("data", {})["code"] = CODE_BUILD_TURN_FOR_VECTOR
elif nid == "cache-update-summary":
d = n.setdefault("data", {})
d["value"] = "memory"
elif nid == "transform-for-vector-upsert":
m = n.setdefault("data", {}).setdefault("mapping", {})
m["metadata"] = "{{left.metadata}}"
def main() -> int:
r = requests.post(
f"{BASE}/api/v1/auth/login",
data={"username": USER, "password": PWD},
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=15,
)
if r.status_code != 200:
print("登录失败:", r.status_code, r.text[:500], file=sys.stderr)
return 1
token = r.json().get("access_token")
if not token:
print("无 access_token", file=sys.stderr)
return 1
h = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
dup = requests.post(
f"{BASE}/api/v1/agents/{SOURCE_AGENT_ID}/duplicate",
headers=h,
json={"name": NEW_NAME},
timeout=30,
)
if dup.status_code != 201:
print("复制失败:", dup.status_code, dup.text[:800], file=sys.stderr)
return 1
new_id = dup.json()["id"]
print("已创建副本:", new_id, NEW_NAME)
g = requests.get(f"{BASE}/api/v1/agents/{new_id}", headers=h, timeout=30)
if g.status_code != 200:
print("读取 Agent 失败:", g.text, file=sys.stderr)
return 1
agent = g.json()
wf = agent["workflow_config"]
_patch_nodes(wf)
up = requests.put(
f"{BASE}/api/v1/agents/{new_id}",
headers=h,
json={"description": NEW_DESC, "workflow_config": wf},
timeout=60,
)
if up.status_code != 200:
print("更新失败:", up.status_code, up.text[:800], file=sys.stderr)
return 1
print("已更新LLM 提示、code-build-context / memory-value / vector-turn、cache-update-summary.value、upsert.metadata")
print("Agent ID:", new_id)
print(json.dumps({"id": new_id, "name": NEW_NAME}, ensure_ascii=False))
return 0
if __name__ == "__main__":
raise SystemExit(main())