Files
aiagent/backend/app/agent_runtime/memory.py

136 lines
4.7 KiB
Python
Raw Normal View History

"""
Agent 记忆管理包装已有 persistent_memory_service提供会话级和长期记忆
"""
from __future__ import annotations
import json
import logging
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.services.persistent_memory_service import (
load_persistent_memory,
save_persistent_memory,
persist_enabled,
)
from app.core.config import settings
logger = logging.getLogger(__name__)
class AgentMemory:
"""
分层记忆管理器
- 工作记忆当前会话消息列表 AgentRuntime 直接管理
- 长期记忆 MySQL 加载/保存的用户画像和关键事实
- 上下文压缩对话过长时自动裁剪或总结
"""
def __init__(
self,
scope_kind: str = "agent",
scope_id: Optional[str] = None,
session_key: Optional[str] = None,
persist: bool = True,
max_history: int = 20,
):
self.scope_kind = scope_kind
self.scope_id = scope_id or "default"
self.session_key = session_key or "default_session"
self.persist = persist and persist_enabled()
self.max_history = max_history
# 从长期记忆加载的上下文(启动时加载)
self._long_term_context: Dict[str, Any] = {}
async def initialize(self) -> str:
"""
初始化记忆 DB/Redis 加载长期记忆构造初始上下文文本
返回注入 system prompt 的记忆文本块
"""
if not self.persist or not self.scope_id:
return ""
db: Optional[Session] = None
try:
db = SessionLocal()
payload = load_persistent_memory(
db, self.scope_kind, self.scope_id, self.session_key
)
if payload and isinstance(payload, dict):
self._long_term_context = payload
# 构建注入 system prompt 的记忆文本
parts = []
profile = payload.get("user_profile")
if profile and isinstance(profile, dict):
profile_text = json.dumps(profile, ensure_ascii=False)
parts.append(f"## 用户画像\n{profile_text}")
context = payload.get("context")
if context and isinstance(context, dict):
ctx_text = json.dumps(context, ensure_ascii=False)
parts.append(f"## 上下文\n{ctx_text}")
history = payload.get("conversation_history")
if history and isinstance(history, list) and len(history) > 0:
summary = self._summarize_history(history)
parts.append(f"## 历史对话摘要\n{summary}")
if parts:
return "\n\n".join(parts)
except Exception as e:
logger.warning("加载长期记忆失败: %s", e)
finally:
if db:
db.close()
return ""
async def save_context(
self, user_message: str, assistant_reply: str
) -> None:
"""将单轮对话保存到长期记忆。"""
if not self.persist or not self.scope_id:
return
# 更新上下文
ctx = self._long_term_context.get("context", {})
ctx["last_user_message"] = user_message[:500]
ctx["last_assistant_reply"] = assistant_reply[:500]
self._long_term_context["context"] = ctx
db: Optional[Session] = None
try:
db = SessionLocal()
save_persistent_memory(
db, self.scope_kind, self.scope_id,
self.session_key, self._long_term_context,
)
except Exception as e:
logger.warning("保存长期记忆失败: %s", e)
finally:
if db:
db.close()
def trim_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
裁剪消息列表保留最近的 N 但始终保留第一条 system 消息
"""
if len(messages) <= self.max_history:
return messages
system_msgs = [m for m in messages if m.get("role") == "system"]
other_msgs = [m for m in messages if m.get("role") != "system"]
trimmed = other_msgs[-(self.max_history - len(system_msgs)):]
return system_msgs + trimmed
@staticmethod
def _summarize_history(history: List[Dict[str, Any]]) -> str:
"""简单汇总历史对话(不做 LLM 压缩,仅计数)。"""
turns = 0
for m in history:
if m.get("role") == "user":
turns += 1
return f"{turns} 轮历史对话(详情已存入长期记忆)"