""" Agent 记忆管理:包装已有 persistent_memory_service,提供会话级和长期记忆。 支持 LLM 自动压缩总结对话历史。 """ from __future__ import annotations import json import logging from typing import Any, Dict, List, Optional from sqlalchemy.orm import Session from app.core.database import SessionLocal from app.services.persistent_memory_service import ( load_persistent_memory, save_persistent_memory, persist_enabled, ) from app.services.embedding_service import embedding_service, VectorEntry logger = logging.getLogger(__name__) class AgentMemory: """ 分层记忆管理器: - 工作记忆:当前会话消息列表(由 AgentRuntime 直接管理) - 长期记忆:从 MySQL 加载/保存的用户画像和关键事实 - 记忆压缩:LLM 自动总结对话历史,提取关键信息存入长期记忆 """ def __init__( self, scope_kind: str = "agent", scope_id: Optional[str] = None, session_key: Optional[str] = None, persist: bool = True, max_history: int = 20, vector_memory_enabled: bool = True, vector_memory_top_k: int = 5, ): self.scope_kind = scope_kind self.scope_id = scope_id or "default" self.session_key = session_key or "default_session" self.persist = persist and persist_enabled() self.max_history = max_history self.vector_memory_enabled = vector_memory_enabled self.vector_memory_top_k = vector_memory_top_k # 从长期记忆加载的上下文(启动时加载) self._long_term_context: Dict[str, Any] = {} # 记录已压缩的消息数,避免重复压缩 self._last_compressed_msg_count = 0 async def initialize(self, query: str = "") -> str: """ 初始化记忆:从 DB/Redis 加载长期记忆 + 向量检索相关历史。 返回注入 system prompt 的记忆文本块。 """ if not self.persist or not self.scope_id: return "" parts: List[str] = [] db: Optional[Session] = None try: db = SessionLocal() payload = load_persistent_memory( db, self.scope_kind, self.scope_id, self.session_key ) if payload and isinstance(payload, dict): self._long_term_context = payload profile = payload.get("user_profile") if profile and isinstance(profile, dict): profile_text = json.dumps(profile, ensure_ascii=False) parts.append(f"## 用户画像\n{profile_text}") context = payload.get("context") if context and isinstance(context, dict): ctx_text = json.dumps(context, ensure_ascii=False) parts.append(f"## 上下文\n{ctx_text}") history = payload.get("conversation_history") if history and isinstance(history, list) and len(history) > 0: summary = self._summarize_history(history) parts.append(f"## 历史对话摘要\n{summary}") except Exception as e: logger.warning("加载长期记忆失败: %s", e) finally: if db: db.close() # 2. 向量检索:查找语义相关的历史对话 if self.vector_memory_enabled and self.scope_kind and self.scope_id: vector_text = await self._vector_search(query) if vector_text: parts.append(vector_text) # 3. 全局知识检索:从 GlobalKnowledge 表加载相关条目 global_text = await self._global_knowledge_search(query) if global_text: parts.append(global_text) return "\n\n".join(parts) if parts else "" async def _vector_search(self, query: str = "") -> str: """ 向量检索语义相关的历史记忆,返回格式化的文本块。 若无 query 则返回最近 Top-5 条记忆。 """ from app.models.agent_vector_memory import AgentVectorMemory db: Optional[Session] = None try: db = SessionLocal() # 查询当前 scope 的所有向量记忆(按时间倒序) rows = ( db.query(AgentVectorMemory) .filter( AgentVectorMemory.scope_kind == self.scope_kind, AgentVectorMemory.scope_id == self.scope_id, ) .order_by(AgentVectorMemory.created_at.desc()) .limit(50) # 最多取最近 50 条做相似度计算 .all() ) if not rows: return "" entries: List[VectorEntry] = [] for row in rows: emb = embedding_service.deserialize_embedding(row.embedding) if row.embedding else [] entries.append({ "id": row.id, "scope_kind": row.scope_kind, "scope_id": row.scope_id, "content_text": row.content_text, "embedding": emb, "metadata": row.metadata_ or {}, }) matched: List[VectorEntry] = [] if query and query.strip(): # 有 query:生成 embedding 做语义搜索 query_emb = await embedding_service.generate_embedding(query) if query_emb: matched = await embedding_service.similarity_search( query_emb, entries, top_k=self.vector_memory_top_k ) else: # 无 query:返回最近几条 matched = entries[: self.vector_memory_top_k] for m in matched: m["score"] = 1.0 if not matched: return "" # 格式化为文本块 lines = ["## 相关历史记忆"] for i, m in enumerate(matched, 1): text = m.get("content_text", "")[:500] meta = m.get("metadata", {}) entry_type = meta.get("type", "对话") lines.append(f"{i}. [{entry_type}] {text}") if m.get("score", 1.0) < 1.0: lines[-1] += f" (匹配度: {m['score']:.2f})" return "\n".join(lines) except Exception as e: logger.warning("向量检索失败: %s", e) return "" finally: if db: db.close() async def _global_knowledge_search(self, query: str = "") -> str: """从 GlobalKnowledge 表检索相关的全局知识条目。""" from app.models.agent import GlobalKnowledge db: Optional[Session] = None try: db = SessionLocal() rows = ( db.query(GlobalKnowledge) .order_by(GlobalKnowledge.created_at.desc()) .limit(50) .all() ) if not rows: return "" # 如果有 query,用向量相似度筛选;否则返回最近的条目 if query and query.strip(): entries: List[VectorEntry] = [] for row in rows: if not row.embedding: continue try: emb = embedding_service.deserialize_embedding(row.embedding) except Exception: emb = [] if emb: entries.append({ "id": row.id, "scope_kind": "global", "scope_id": "global", "content_text": row.content, "embedding": emb, "metadata": { "source_agent_id": row.source_agent_id, "tags": row.tags or [], }, }) if entries: query_emb = await embedding_service.generate_embedding(query) if query_emb: matched = await embedding_service.similarity_search( query_emb, entries, top_k=min(5, len(entries)), ) if matched: lines = ["## 全局知识库"] for i, m in enumerate(matched, 1): tags = m.get("metadata", {}).get("tags", []) tag_str = f" [{', '.join(tags[:3])}]" if tags else "" lines.append(f"{i}.{tag_str} {m.get('content_text', '')[:500]}") return "\n".join(lines) else: # 无 query,返回最近 5 条全局知识 recent = rows[:5] if recent: lines = ["## 全局知识库(最近)"] for i, row in enumerate(recent, 1): tag_str = f" [{(', '.join(row.tags[:3]))}]" if row.tags else "" lines.append(f"{i}.{tag_str} {row.content[:500]}") return "\n".join(lines) return "" except Exception as e: logger.warning("全局知识检索失败: %s", e) return "" finally: if db: db.close() async def save_global_knowledge( self, content: str, source_agent_id: str = "", source_user_id: str = "", tags: Optional[List[str]] = None, ) -> None: """将知识条目写入全局知识池。""" from app.models.agent import GlobalKnowledge if not content or len(content) < 20: return db: Optional[Session] = None try: db = SessionLocal() # 生成 embedding embedding_json = "" try: emb = await embedding_service.generate_embedding(content) if emb: embedding_json = embedding_service.serialize_embedding(emb) or "" except Exception: pass record = GlobalKnowledge( content=content[:2000], embedding=embedding_json or None, source_agent_id=source_agent_id or "", source_user_id=source_user_id or "", tags=tags or [], scope_kind=self.scope_kind, scope_id=self.scope_id or "global", ) db.add(record) db.commit() logger.info("已写入全局知识: agent=%s tags=%s", source_agent_id, tags) except Exception as e: logger.warning("保存全局知识失败: %s", e) if db: db.rollback() finally: if db: db.close() async def save_context( self, user_message: str, assistant_reply: str, messages: Optional[List[Dict[str, Any]]] = None, ) -> None: """将单轮对话保存到长期记忆。如有消息列表,LLM 自动压缩总结。""" if not self.persist or not self.scope_id: return # 更新上下文 ctx = self._long_term_context.get("context", {}) ctx["last_user_message"] = user_message[:500] ctx["last_assistant_reply"] = assistant_reply[:500] self._long_term_context["context"] = ctx # 如果有完整消息列表且新增了足够多的消息,运行 LLM 压缩总结 if messages and len(messages) > self._last_compressed_msg_count + 2: await self._compress_and_summarize(messages) self._last_compressed_msg_count = len(messages) db: Optional[Session] = None try: db = SessionLocal() save_persistent_memory( db, self.scope_kind, self.scope_id, self.session_key, self._long_term_context, ) # 保存向量记忆(异步生成 embedding 并存储) if self.vector_memory_enabled: await self._save_vector_memory( db, user_message, assistant_reply ) except Exception as e: logger.warning("保存长期记忆失败: %s", e) finally: if db: db.close() async def _save_vector_memory( self, db: Session, user_message: str, assistant_reply: str, ) -> None: """生成 embedding 并保存到向量记忆表。""" from app.models.agent_vector_memory import AgentVectorMemory content_text = f"用户: {user_message}\n助手: {assistant_reply}" if len(content_text) > 8000: content_text = content_text[:8000] try: # 生成 embedding embedding = await embedding_service.generate_embedding(content_text) embedding_json = embedding_service.serialize_embedding(embedding) if embedding else "" record = AgentVectorMemory( scope_kind=self.scope_kind, scope_id=self.scope_id, session_key=self.session_key, content_text=content_text[:2000], embedding=embedding_json or None, metadata_={ "type": "conversation_turn", }, ) db.add(record) db.commit() logger.debug("已保存向量记忆 (scope=%s/%s)", self.scope_kind, self.scope_id) except Exception as e: logger.warning("保存向量记忆失败: %s", e) db.rollback() async def _compress_and_summarize( self, messages: List[Dict[str, Any]] ) -> None: """ 使用 LLM 压缩总结对话历史,提取用户画像和关键事实。 只处理非 system 消息。 """ from openai import AsyncOpenAI from app.core.config import settings # 提取对话消息(去掉 system 和 tool 消息) conversation = [] for m in messages: role = m.get("role", "") if role == "system": continue if role == "tool": # 工具结果精简后加入 content = m.get("content", "") name = m.get("name", "tool") conversation.append({"role": "user" if role == "tool" else role, "content": f"[工具 {name} 执行结果]\n{content[:200]}"}) else: conversation.append({"role": role, "content": m.get("content", "")[:500]}) if len(conversation) < 2: return # 构建总结 prompt summary_prompt = ( "你是一个记忆管理助手。请分析以下对话历史,提取关于用户的关键信息。\n\n" "请返回 JSON 格式(不要 markdown 包裹),包含以下字段:\n" "1. user_profile: 用户画像对象,包含用户的偏好、角色、关键需求等\n" "2. key_facts: 从对话中提取的关键事实列表(字符串数组)\n" "3. summary: 对话的简要总结(100字以内)\n" "4. topics: 讨论过的话题列表(字符串数组)\n\n" "如果没有足够信息,相应字段设为空对象或空数组。" ) summary_messages = [ {"role": "system", "content": summary_prompt}, *conversation[-10:], # 只取最近 10 条消息 ] try: api_key = settings.DEEPSEEK_API_KEY or settings.OPENAI_API_KEY or "" base_url = settings.DEEPSEEK_BASE_URL or settings.OPENAI_BASE_URL or "https://api.deepseek.com" if api_key == "your-openai-api-key": api_key = settings.DEEPSEEK_API_KEY or "" base_url = settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com" if not api_key: logger.warning("记忆压缩:未配置 API Key,跳过") return client = AsyncOpenAI(api_key=api_key, base_url=base_url) resp = await client.chat.completions.create( model="deepseek-v4-flash", messages=summary_messages, temperature=0.3, max_tokens=1024, timeout=30, ) raw = resp.choices[0].message.content or "" # 解析 JSON result = json.loads(raw.strip().removeprefix("```json").removesuffix("```").strip()) # 合并到长期记忆 existing_profile = self._long_term_context.get("user_profile", {}) new_profile = result.get("user_profile", {}) if isinstance(new_profile, dict) and new_profile: # 合并画像(新信息覆盖旧信息) existing_profile.update(new_profile) self._long_term_context["user_profile"] = existing_profile # 合并关键事实 existing_facts = self._long_term_context.get("key_facts", []) new_facts = result.get("key_facts", []) if isinstance(new_facts, list): all_facts = list(dict.fromkeys(existing_facts + new_facts)) # 去重 self._long_term_context["key_facts"] = all_facts[-20:] # 最多保留 20 条 # 更新摘要 summary = result.get("summary", "") if summary: ctx = self._long_term_context.get("context", {}) ctx["compressed_summary"] = summary self._long_term_context["context"] = ctx # 记录话题 topics = result.get("topics", []) if isinstance(topics, list) and topics: existing_topics = self._long_term_context.get("topics", []) all_topics = list(dict.fromkeys(existing_topics + topics)) self._long_term_context["topics"] = all_topics[-20:] logger.info("记忆压缩总结完成: profile=%s facts=%d topics=%d", "updated" if new_profile else "unchanged", len(new_facts), len(topics)) except json.JSONDecodeError: logger.warning("记忆压缩:LLM 返回非 JSON 格式,跳过") except Exception as e: logger.warning("记忆压缩失败: %s", e) def trim_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ 裁剪消息列表:保留最近的 N 条,但始终保留第一条 system 消息。 """ if len(messages) <= self.max_history: return messages system_msgs = [m for m in messages if m.get("role") == "system"] other_msgs = [m for m in messages if m.get("role") != "system"] trimmed = other_msgs[-(self.max_history - len(system_msgs)):] return system_msgs + trimmed @staticmethod def _summarize_history(history: List[Dict[str, Any]]) -> str: """汇总历史对话。""" turns = 0 for m in history: if m.get("role") == "user": turns += 1 return f"共 {turns} 轮历史对话(详情已存入长期记忆)"