Files
aiagent/backend/app/agent_runtime/memory.py

370 lines
14 KiB
Python
Raw Normal View History

"""
Agent 记忆管理包装已有 persistent_memory_service提供会话级和长期记忆
支持 LLM 自动压缩总结对话历史
"""
from __future__ import annotations
import json
import logging
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.services.persistent_memory_service import (
load_persistent_memory,
save_persistent_memory,
persist_enabled,
)
from app.services.embedding_service import embedding_service, VectorEntry
logger = logging.getLogger(__name__)
class AgentMemory:
"""
分层记忆管理器
- 工作记忆当前会话消息列表 AgentRuntime 直接管理
- 长期记忆 MySQL 加载/保存的用户画像和关键事实
- 记忆压缩LLM 自动总结对话历史提取关键信息存入长期记忆
"""
def __init__(
self,
scope_kind: str = "agent",
scope_id: Optional[str] = None,
session_key: Optional[str] = None,
persist: bool = True,
max_history: int = 20,
vector_memory_enabled: bool = True,
vector_memory_top_k: int = 5,
):
self.scope_kind = scope_kind
self.scope_id = scope_id or "default"
self.session_key = session_key or "default_session"
self.persist = persist and persist_enabled()
self.max_history = max_history
self.vector_memory_enabled = vector_memory_enabled
self.vector_memory_top_k = vector_memory_top_k
# 从长期记忆加载的上下文(启动时加载)
self._long_term_context: Dict[str, Any] = {}
# 记录已压缩的消息数,避免重复压缩
self._last_compressed_msg_count = 0
async def initialize(self, query: str = "") -> str:
"""
初始化记忆 DB/Redis 加载长期记忆 + 向量检索相关历史
返回注入 system prompt 的记忆文本块
"""
if not self.persist or not self.scope_id:
return ""
parts: List[str] = []
db: Optional[Session] = None
try:
db = SessionLocal()
payload = load_persistent_memory(
db, self.scope_kind, self.scope_id, self.session_key
)
if payload and isinstance(payload, dict):
self._long_term_context = payload
profile = payload.get("user_profile")
if profile and isinstance(profile, dict):
profile_text = json.dumps(profile, ensure_ascii=False)
parts.append(f"## 用户画像\n{profile_text}")
context = payload.get("context")
if context and isinstance(context, dict):
ctx_text = json.dumps(context, ensure_ascii=False)
parts.append(f"## 上下文\n{ctx_text}")
history = payload.get("conversation_history")
if history and isinstance(history, list) and len(history) > 0:
summary = self._summarize_history(history)
parts.append(f"## 历史对话摘要\n{summary}")
except Exception as e:
logger.warning("加载长期记忆失败: %s", e)
finally:
if db:
db.close()
# 2. 向量检索:查找语义相关的历史对话
if self.vector_memory_enabled and self.scope_kind and self.scope_id:
vector_text = await self._vector_search(query)
if vector_text:
parts.append(vector_text)
return "\n\n".join(parts) if parts else ""
async def _vector_search(self, query: str = "") -> str:
"""
向量检索语义相关的历史记忆返回格式化的文本块
若无 query 则返回最近 Top-5 条记忆
"""
from app.models.agent_vector_memory import AgentVectorMemory
db: Optional[Session] = None
try:
db = SessionLocal()
# 查询当前 scope 的所有向量记忆(按时间倒序)
rows = (
db.query(AgentVectorMemory)
.filter(
AgentVectorMemory.scope_kind == self.scope_kind,
AgentVectorMemory.scope_id == self.scope_id,
)
.order_by(AgentVectorMemory.created_at.desc())
.limit(50) # 最多取最近 50 条做相似度计算
.all()
)
if not rows:
return ""
entries: List[VectorEntry] = []
for row in rows:
emb = embedding_service.deserialize_embedding(row.embedding) if row.embedding else []
entries.append({
"id": row.id,
"scope_kind": row.scope_kind,
"scope_id": row.scope_id,
"content_text": row.content_text,
"embedding": emb,
"metadata": row.metadata_ or {},
})
matched: List[VectorEntry] = []
if query and query.strip():
# 有 query生成 embedding 做语义搜索
query_emb = await embedding_service.generate_embedding(query)
if query_emb:
matched = await embedding_service.similarity_search(
query_emb, entries, top_k=self.vector_memory_top_k
)
else:
# 无 query返回最近几条
matched = entries[: self.vector_memory_top_k]
for m in matched:
m["score"] = 1.0
if not matched:
return ""
# 格式化为文本块
lines = ["## 相关历史记忆"]
for i, m in enumerate(matched, 1):
text = m.get("content_text", "")[:500]
meta = m.get("metadata", {})
entry_type = meta.get("type", "对话")
lines.append(f"{i}. [{entry_type}] {text}")
if m.get("score", 1.0) < 1.0:
lines[-1] += f" (匹配度: {m['score']:.2f})"
return "\n".join(lines)
except Exception as e:
logger.warning("向量检索失败: %s", e)
return ""
finally:
if db:
db.close()
async def save_context(
self, user_message: str, assistant_reply: str,
messages: Optional[List[Dict[str, Any]]] = None,
) -> None:
"""将单轮对话保存到长期记忆。如有消息列表LLM 自动压缩总结。"""
if not self.persist or not self.scope_id:
return
# 更新上下文
ctx = self._long_term_context.get("context", {})
ctx["last_user_message"] = user_message[:500]
ctx["last_assistant_reply"] = assistant_reply[:500]
self._long_term_context["context"] = ctx
# 如果有完整消息列表且新增了足够多的消息,运行 LLM 压缩总结
if messages and len(messages) > self._last_compressed_msg_count + 2:
await self._compress_and_summarize(messages)
self._last_compressed_msg_count = len(messages)
db: Optional[Session] = None
try:
db = SessionLocal()
save_persistent_memory(
db, self.scope_kind, self.scope_id,
self.session_key, self._long_term_context,
)
# 保存向量记忆(异步生成 embedding 并存储)
if self.vector_memory_enabled:
await self._save_vector_memory(
db, user_message, assistant_reply
)
except Exception as e:
logger.warning("保存长期记忆失败: %s", e)
finally:
if db:
db.close()
async def _save_vector_memory(
self, db: Session, user_message: str, assistant_reply: str,
) -> None:
"""生成 embedding 并保存到向量记忆表。"""
from app.models.agent_vector_memory import AgentVectorMemory
content_text = f"用户: {user_message}\n助手: {assistant_reply}"
if len(content_text) > 8000:
content_text = content_text[:8000]
try:
# 生成 embedding
embedding = await embedding_service.generate_embedding(content_text)
embedding_json = embedding_service.serialize_embedding(embedding) if embedding else ""
record = AgentVectorMemory(
scope_kind=self.scope_kind,
scope_id=self.scope_id,
session_key=self.session_key,
content_text=content_text[:2000],
embedding=embedding_json or None,
metadata_={
"type": "conversation_turn",
},
)
db.add(record)
db.commit()
logger.debug("已保存向量记忆 (scope=%s/%s)", self.scope_kind, self.scope_id)
except Exception as e:
logger.warning("保存向量记忆失败: %s", e)
db.rollback()
async def _compress_and_summarize(
self, messages: List[Dict[str, Any]]
) -> None:
"""
使用 LLM 压缩总结对话历史提取用户画像和关键事实
只处理非 system 消息
"""
from openai import AsyncOpenAI
from app.core.config import settings
# 提取对话消息(去掉 system 和 tool 消息)
conversation = []
for m in messages:
role = m.get("role", "")
if role == "system":
continue
if role == "tool":
# 工具结果精简后加入
content = m.get("content", "")
name = m.get("name", "tool")
conversation.append({"role": "user" if role == "tool" else role, "content": f"[工具 {name} 执行结果]\n{content[:200]}"})
else:
conversation.append({"role": role, "content": m.get("content", "")[:500]})
if len(conversation) < 2:
return
# 构建总结 prompt
summary_prompt = (
"你是一个记忆管理助手。请分析以下对话历史,提取关于用户的关键信息。\n\n"
"请返回 JSON 格式(不要 markdown 包裹),包含以下字段:\n"
"1. user_profile: 用户画像对象,包含用户的偏好、角色、关键需求等\n"
"2. key_facts: 从对话中提取的关键事实列表(字符串数组)\n"
"3. summary: 对话的简要总结100字以内\n"
"4. topics: 讨论过的话题列表(字符串数组)\n\n"
"如果没有足够信息,相应字段设为空对象或空数组。"
)
summary_messages = [
{"role": "system", "content": summary_prompt},
*conversation[-10:], # 只取最近 10 条消息
]
try:
api_key = settings.DEEPSEEK_API_KEY or settings.OPENAI_API_KEY or ""
base_url = settings.DEEPSEEK_BASE_URL or settings.OPENAI_BASE_URL or "https://api.deepseek.com"
if api_key == "your-openai-api-key":
api_key = settings.DEEPSEEK_API_KEY or ""
base_url = settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com"
if not api_key:
logger.warning("记忆压缩:未配置 API Key跳过")
return
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
resp = await client.chat.completions.create(
model="deepseek-v4-flash",
messages=summary_messages,
temperature=0.3,
max_tokens=1024,
timeout=30,
)
raw = resp.choices[0].message.content or ""
# 解析 JSON
result = json.loads(raw.strip().removeprefix("```json").removesuffix("```").strip())
# 合并到长期记忆
existing_profile = self._long_term_context.get("user_profile", {})
new_profile = result.get("user_profile", {})
if isinstance(new_profile, dict) and new_profile:
# 合并画像(新信息覆盖旧信息)
existing_profile.update(new_profile)
self._long_term_context["user_profile"] = existing_profile
# 合并关键事实
existing_facts = self._long_term_context.get("key_facts", [])
new_facts = result.get("key_facts", [])
if isinstance(new_facts, list):
all_facts = list(dict.fromkeys(existing_facts + new_facts)) # 去重
self._long_term_context["key_facts"] = all_facts[-20:] # 最多保留 20 条
# 更新摘要
summary = result.get("summary", "")
if summary:
ctx = self._long_term_context.get("context", {})
ctx["compressed_summary"] = summary
self._long_term_context["context"] = ctx
# 记录话题
topics = result.get("topics", [])
if isinstance(topics, list) and topics:
existing_topics = self._long_term_context.get("topics", [])
all_topics = list(dict.fromkeys(existing_topics + topics))
self._long_term_context["topics"] = all_topics[-20:]
logger.info("记忆压缩总结完成: profile=%s facts=%d topics=%d",
"updated" if new_profile else "unchanged",
len(new_facts), len(topics))
except json.JSONDecodeError:
logger.warning("记忆压缩LLM 返回非 JSON 格式,跳过")
except Exception as e:
logger.warning("记忆压缩失败: %s", e)
def trim_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
裁剪消息列表保留最近的 N 但始终保留第一条 system 消息
"""
if len(messages) <= self.max_history:
return messages
system_msgs = [m for m in messages if m.get("role") == "system"]
other_msgs = [m for m in messages if m.get("role") != "system"]
trimmed = other_msgs[-(self.max_history - len(system_msgs)):]
return system_msgs + trimmed
@staticmethod
def _summarize_history(history: List[Dict[str, Any]]) -> str:
"""汇总历史对话。"""
turns = 0
for m in history:
if m.get("role") == "user":
turns += 1
return f"{turns} 轮历史对话(详情已存入长期记忆)"