""" Agent 记忆管理:包装已有 persistent_memory_service,提供会话级和长期记忆。 """ from __future__ import annotations import json import logging from typing import Any, Dict, List, Optional from sqlalchemy.orm import Session from app.core.database import SessionLocal from app.services.persistent_memory_service import ( load_persistent_memory, save_persistent_memory, persist_enabled, ) from app.core.config import settings logger = logging.getLogger(__name__) class AgentMemory: """ 分层记忆管理器: - 工作记忆:当前会话消息列表(由 AgentRuntime 直接管理) - 长期记忆:从 MySQL 加载/保存的用户画像和关键事实 - 上下文压缩:对话过长时自动裁剪或总结 """ def __init__( self, scope_kind: str = "agent", scope_id: Optional[str] = None, session_key: Optional[str] = None, persist: bool = True, max_history: int = 20, ): self.scope_kind = scope_kind self.scope_id = scope_id or "default" self.session_key = session_key or "default_session" self.persist = persist and persist_enabled() self.max_history = max_history # 从长期记忆加载的上下文(启动时加载) self._long_term_context: Dict[str, Any] = {} async def initialize(self) -> str: """ 初始化记忆:从 DB/Redis 加载长期记忆,构造初始上下文文本。 返回注入 system prompt 的记忆文本块。 """ if not self.persist or not self.scope_id: return "" db: Optional[Session] = None try: db = SessionLocal() payload = load_persistent_memory( db, self.scope_kind, self.scope_id, self.session_key ) if payload and isinstance(payload, dict): self._long_term_context = payload # 构建注入 system prompt 的记忆文本 parts = [] profile = payload.get("user_profile") if profile and isinstance(profile, dict): profile_text = json.dumps(profile, ensure_ascii=False) parts.append(f"## 用户画像\n{profile_text}") context = payload.get("context") if context and isinstance(context, dict): ctx_text = json.dumps(context, ensure_ascii=False) parts.append(f"## 上下文\n{ctx_text}") history = payload.get("conversation_history") if history and isinstance(history, list) and len(history) > 0: summary = self._summarize_history(history) parts.append(f"## 历史对话摘要\n{summary}") if parts: return "\n\n".join(parts) except Exception as e: logger.warning("加载长期记忆失败: %s", e) finally: if db: db.close() return "" async def save_context( self, user_message: str, assistant_reply: str ) -> None: """将单轮对话保存到长期记忆。""" if not self.persist or not self.scope_id: return # 更新上下文 ctx = self._long_term_context.get("context", {}) ctx["last_user_message"] = user_message[:500] ctx["last_assistant_reply"] = assistant_reply[:500] self._long_term_context["context"] = ctx db: Optional[Session] = None try: db = SessionLocal() save_persistent_memory( db, self.scope_kind, self.scope_id, self.session_key, self._long_term_context, ) except Exception as e: logger.warning("保存长期记忆失败: %s", e) finally: if db: db.close() def trim_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ 裁剪消息列表:保留最近的 N 条,但始终保留第一条 system 消息。 """ if len(messages) <= self.max_history: return messages system_msgs = [m for m in messages if m.get("role") == "system"] other_msgs = [m for m in messages if m.get("role") != "system"] trimmed = other_msgs[-(self.max_history - len(system_msgs)):] return system_msgs + trimmed @staticmethod def _summarize_history(history: List[Dict[str, Any]]) -> str: """简单汇总历史对话(不做 LLM 压缩,仅计数)。""" turns = 0 for m in history: if m.get("role") == "user": turns += 1 return f"共 {turns} 轮历史对话(详情已存入长期记忆)"