Files
aiagent/backend/app/agent_runtime/memory.py
renjianbo 7b9e0826de feat: 向量记忆 RAG、工具市场、SSE 流式响应、前端集成与测试覆盖
- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser
- 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行)
- 新增 agent_vector_memory / knowledge_base 模型及对应数据库表
- 实现 SSE 流式响应与 Agent 预算控制
- AgentChat.vue 集成 MainLayout 导航布局
- 完善测试体系:7 个新测试文件共 110 个测试覆盖
- 修复 conftest.py SQLite 内存数据库连接隔离问题

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-01 22:30:46 +08:00

370 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Agent 记忆管理:包装已有 persistent_memory_service提供会话级和长期记忆。
支持 LLM 自动压缩总结对话历史。
"""
from __future__ import annotations
import json
import logging
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.services.persistent_memory_service import (
load_persistent_memory,
save_persistent_memory,
persist_enabled,
)
from app.services.embedding_service import embedding_service, VectorEntry
logger = logging.getLogger(__name__)
class AgentMemory:
"""
分层记忆管理器:
- 工作记忆:当前会话消息列表(由 AgentRuntime 直接管理)
- 长期记忆:从 MySQL 加载/保存的用户画像和关键事实
- 记忆压缩LLM 自动总结对话历史,提取关键信息存入长期记忆
"""
def __init__(
self,
scope_kind: str = "agent",
scope_id: Optional[str] = None,
session_key: Optional[str] = None,
persist: bool = True,
max_history: int = 20,
vector_memory_enabled: bool = True,
vector_memory_top_k: int = 5,
):
self.scope_kind = scope_kind
self.scope_id = scope_id or "default"
self.session_key = session_key or "default_session"
self.persist = persist and persist_enabled()
self.max_history = max_history
self.vector_memory_enabled = vector_memory_enabled
self.vector_memory_top_k = vector_memory_top_k
# 从长期记忆加载的上下文(启动时加载)
self._long_term_context: Dict[str, Any] = {}
# 记录已压缩的消息数,避免重复压缩
self._last_compressed_msg_count = 0
async def initialize(self, query: str = "") -> str:
"""
初始化记忆:从 DB/Redis 加载长期记忆 + 向量检索相关历史。
返回注入 system prompt 的记忆文本块。
"""
if not self.persist or not self.scope_id:
return ""
parts: List[str] = []
db: Optional[Session] = None
try:
db = SessionLocal()
payload = load_persistent_memory(
db, self.scope_kind, self.scope_id, self.session_key
)
if payload and isinstance(payload, dict):
self._long_term_context = payload
profile = payload.get("user_profile")
if profile and isinstance(profile, dict):
profile_text = json.dumps(profile, ensure_ascii=False)
parts.append(f"## 用户画像\n{profile_text}")
context = payload.get("context")
if context and isinstance(context, dict):
ctx_text = json.dumps(context, ensure_ascii=False)
parts.append(f"## 上下文\n{ctx_text}")
history = payload.get("conversation_history")
if history and isinstance(history, list) and len(history) > 0:
summary = self._summarize_history(history)
parts.append(f"## 历史对话摘要\n{summary}")
except Exception as e:
logger.warning("加载长期记忆失败: %s", e)
finally:
if db:
db.close()
# 2. 向量检索:查找语义相关的历史对话
if self.vector_memory_enabled and self.scope_kind and self.scope_id:
vector_text = await self._vector_search(query)
if vector_text:
parts.append(vector_text)
return "\n\n".join(parts) if parts else ""
async def _vector_search(self, query: str = "") -> str:
"""
向量检索语义相关的历史记忆,返回格式化的文本块。
若无 query 则返回最近 Top-5 条记忆。
"""
from app.models.agent_vector_memory import AgentVectorMemory
db: Optional[Session] = None
try:
db = SessionLocal()
# 查询当前 scope 的所有向量记忆(按时间倒序)
rows = (
db.query(AgentVectorMemory)
.filter(
AgentVectorMemory.scope_kind == self.scope_kind,
AgentVectorMemory.scope_id == self.scope_id,
)
.order_by(AgentVectorMemory.created_at.desc())
.limit(50) # 最多取最近 50 条做相似度计算
.all()
)
if not rows:
return ""
entries: List[VectorEntry] = []
for row in rows:
emb = embedding_service.deserialize_embedding(row.embedding) if row.embedding else []
entries.append({
"id": row.id,
"scope_kind": row.scope_kind,
"scope_id": row.scope_id,
"content_text": row.content_text,
"embedding": emb,
"metadata": row.metadata_ or {},
})
matched: List[VectorEntry] = []
if query and query.strip():
# 有 query生成 embedding 做语义搜索
query_emb = await embedding_service.generate_embedding(query)
if query_emb:
matched = await embedding_service.similarity_search(
query_emb, entries, top_k=self.vector_memory_top_k
)
else:
# 无 query返回最近几条
matched = entries[: self.vector_memory_top_k]
for m in matched:
m["score"] = 1.0
if not matched:
return ""
# 格式化为文本块
lines = ["## 相关历史记忆"]
for i, m in enumerate(matched, 1):
text = m.get("content_text", "")[:500]
meta = m.get("metadata", {})
entry_type = meta.get("type", "对话")
lines.append(f"{i}. [{entry_type}] {text}")
if m.get("score", 1.0) < 1.0:
lines[-1] += f" (匹配度: {m['score']:.2f})"
return "\n".join(lines)
except Exception as e:
logger.warning("向量检索失败: %s", e)
return ""
finally:
if db:
db.close()
async def save_context(
self, user_message: str, assistant_reply: str,
messages: Optional[List[Dict[str, Any]]] = None,
) -> None:
"""将单轮对话保存到长期记忆。如有消息列表LLM 自动压缩总结。"""
if not self.persist or not self.scope_id:
return
# 更新上下文
ctx = self._long_term_context.get("context", {})
ctx["last_user_message"] = user_message[:500]
ctx["last_assistant_reply"] = assistant_reply[:500]
self._long_term_context["context"] = ctx
# 如果有完整消息列表且新增了足够多的消息,运行 LLM 压缩总结
if messages and len(messages) > self._last_compressed_msg_count + 2:
await self._compress_and_summarize(messages)
self._last_compressed_msg_count = len(messages)
db: Optional[Session] = None
try:
db = SessionLocal()
save_persistent_memory(
db, self.scope_kind, self.scope_id,
self.session_key, self._long_term_context,
)
# 保存向量记忆(异步生成 embedding 并存储)
if self.vector_memory_enabled:
await self._save_vector_memory(
db, user_message, assistant_reply
)
except Exception as e:
logger.warning("保存长期记忆失败: %s", e)
finally:
if db:
db.close()
async def _save_vector_memory(
self, db: Session, user_message: str, assistant_reply: str,
) -> None:
"""生成 embedding 并保存到向量记忆表。"""
from app.models.agent_vector_memory import AgentVectorMemory
content_text = f"用户: {user_message}\n助手: {assistant_reply}"
if len(content_text) > 8000:
content_text = content_text[:8000]
try:
# 生成 embedding
embedding = await embedding_service.generate_embedding(content_text)
embedding_json = embedding_service.serialize_embedding(embedding) if embedding else ""
record = AgentVectorMemory(
scope_kind=self.scope_kind,
scope_id=self.scope_id,
session_key=self.session_key,
content_text=content_text[:2000],
embedding=embedding_json or None,
metadata_={
"type": "conversation_turn",
},
)
db.add(record)
db.commit()
logger.debug("已保存向量记忆 (scope=%s/%s)", self.scope_kind, self.scope_id)
except Exception as e:
logger.warning("保存向量记忆失败: %s", e)
db.rollback()
async def _compress_and_summarize(
self, messages: List[Dict[str, Any]]
) -> None:
"""
使用 LLM 压缩总结对话历史,提取用户画像和关键事实。
只处理非 system 消息。
"""
from openai import AsyncOpenAI
from app.core.config import settings
# 提取对话消息(去掉 system 和 tool 消息)
conversation = []
for m in messages:
role = m.get("role", "")
if role == "system":
continue
if role == "tool":
# 工具结果精简后加入
content = m.get("content", "")
name = m.get("name", "tool")
conversation.append({"role": "user" if role == "tool" else role, "content": f"[工具 {name} 执行结果]\n{content[:200]}"})
else:
conversation.append({"role": role, "content": m.get("content", "")[:500]})
if len(conversation) < 2:
return
# 构建总结 prompt
summary_prompt = (
"你是一个记忆管理助手。请分析以下对话历史,提取关于用户的关键信息。\n\n"
"请返回 JSON 格式(不要 markdown 包裹),包含以下字段:\n"
"1. user_profile: 用户画像对象,包含用户的偏好、角色、关键需求等\n"
"2. key_facts: 从对话中提取的关键事实列表(字符串数组)\n"
"3. summary: 对话的简要总结100字以内\n"
"4. topics: 讨论过的话题列表(字符串数组)\n\n"
"如果没有足够信息,相应字段设为空对象或空数组。"
)
summary_messages = [
{"role": "system", "content": summary_prompt},
*conversation[-10:], # 只取最近 10 条消息
]
try:
api_key = settings.DEEPSEEK_API_KEY or settings.OPENAI_API_KEY or ""
base_url = settings.DEEPSEEK_BASE_URL or settings.OPENAI_BASE_URL or "https://api.deepseek.com"
if api_key == "your-openai-api-key":
api_key = settings.DEEPSEEK_API_KEY or ""
base_url = settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com"
if not api_key:
logger.warning("记忆压缩:未配置 API Key跳过")
return
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
resp = await client.chat.completions.create(
model="deepseek-v4-flash",
messages=summary_messages,
temperature=0.3,
max_tokens=1024,
timeout=30,
)
raw = resp.choices[0].message.content or ""
# 解析 JSON
result = json.loads(raw.strip().removeprefix("```json").removesuffix("```").strip())
# 合并到长期记忆
existing_profile = self._long_term_context.get("user_profile", {})
new_profile = result.get("user_profile", {})
if isinstance(new_profile, dict) and new_profile:
# 合并画像(新信息覆盖旧信息)
existing_profile.update(new_profile)
self._long_term_context["user_profile"] = existing_profile
# 合并关键事实
existing_facts = self._long_term_context.get("key_facts", [])
new_facts = result.get("key_facts", [])
if isinstance(new_facts, list):
all_facts = list(dict.fromkeys(existing_facts + new_facts)) # 去重
self._long_term_context["key_facts"] = all_facts[-20:] # 最多保留 20 条
# 更新摘要
summary = result.get("summary", "")
if summary:
ctx = self._long_term_context.get("context", {})
ctx["compressed_summary"] = summary
self._long_term_context["context"] = ctx
# 记录话题
topics = result.get("topics", [])
if isinstance(topics, list) and topics:
existing_topics = self._long_term_context.get("topics", [])
all_topics = list(dict.fromkeys(existing_topics + topics))
self._long_term_context["topics"] = all_topics[-20:]
logger.info("记忆压缩总结完成: profile=%s facts=%d topics=%d",
"updated" if new_profile else "unchanged",
len(new_facts), len(topics))
except json.JSONDecodeError:
logger.warning("记忆压缩LLM 返回非 JSON 格式,跳过")
except Exception as e:
logger.warning("记忆压缩失败: %s", e)
def trim_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
裁剪消息列表:保留最近的 N 条,但始终保留第一条 system 消息。
"""
if len(messages) <= self.max_history:
return messages
system_msgs = [m for m in messages if m.get("role") == "system"]
other_msgs = [m for m in messages if m.get("role") != "system"]
trimmed = other_msgs[-(self.max_history - len(system_msgs)):]
return system_msgs + trimmed
@staticmethod
def _summarize_history(history: List[Dict[str, Any]]) -> str:
"""汇总历史对话。"""
turns = 0
for m in history:
if m.get("role") == "user":
turns += 1
return f"{turns} 轮历史对话(详情已存入长期记忆)"