136 lines
4.7 KiB
Python
136 lines
4.7 KiB
Python
|
|
"""
|
|||
|
|
Agent 记忆管理:包装已有 persistent_memory_service,提供会话级和长期记忆。
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
from typing import Any, Dict, List, Optional
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from app.core.database import SessionLocal
|
|||
|
|
from app.services.persistent_memory_service import (
|
|||
|
|
load_persistent_memory,
|
|||
|
|
save_persistent_memory,
|
|||
|
|
persist_enabled,
|
|||
|
|
)
|
|||
|
|
from app.core.config import settings
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AgentMemory:
|
|||
|
|
"""
|
|||
|
|
分层记忆管理器:
|
|||
|
|
|
|||
|
|
- 工作记忆:当前会话消息列表(由 AgentRuntime 直接管理)
|
|||
|
|
- 长期记忆:从 MySQL 加载/保存的用户画像和关键事实
|
|||
|
|
- 上下文压缩:对话过长时自动裁剪或总结
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
scope_kind: str = "agent",
|
|||
|
|
scope_id: Optional[str] = None,
|
|||
|
|
session_key: Optional[str] = None,
|
|||
|
|
persist: bool = True,
|
|||
|
|
max_history: int = 20,
|
|||
|
|
):
|
|||
|
|
self.scope_kind = scope_kind
|
|||
|
|
self.scope_id = scope_id or "default"
|
|||
|
|
self.session_key = session_key or "default_session"
|
|||
|
|
self.persist = persist and persist_enabled()
|
|||
|
|
self.max_history = max_history
|
|||
|
|
# 从长期记忆加载的上下文(启动时加载)
|
|||
|
|
self._long_term_context: Dict[str, Any] = {}
|
|||
|
|
|
|||
|
|
async def initialize(self) -> str:
|
|||
|
|
"""
|
|||
|
|
初始化记忆:从 DB/Redis 加载长期记忆,构造初始上下文文本。
|
|||
|
|
返回注入 system prompt 的记忆文本块。
|
|||
|
|
"""
|
|||
|
|
if not self.persist or not self.scope_id:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
db: Optional[Session] = None
|
|||
|
|
try:
|
|||
|
|
db = SessionLocal()
|
|||
|
|
payload = load_persistent_memory(
|
|||
|
|
db, self.scope_kind, self.scope_id, self.session_key
|
|||
|
|
)
|
|||
|
|
if payload and isinstance(payload, dict):
|
|||
|
|
self._long_term_context = payload
|
|||
|
|
# 构建注入 system prompt 的记忆文本
|
|||
|
|
parts = []
|
|||
|
|
profile = payload.get("user_profile")
|
|||
|
|
if profile and isinstance(profile, dict):
|
|||
|
|
profile_text = json.dumps(profile, ensure_ascii=False)
|
|||
|
|
parts.append(f"## 用户画像\n{profile_text}")
|
|||
|
|
|
|||
|
|
context = payload.get("context")
|
|||
|
|
if context and isinstance(context, dict):
|
|||
|
|
ctx_text = json.dumps(context, ensure_ascii=False)
|
|||
|
|
parts.append(f"## 上下文\n{ctx_text}")
|
|||
|
|
|
|||
|
|
history = payload.get("conversation_history")
|
|||
|
|
if history and isinstance(history, list) and len(history) > 0:
|
|||
|
|
summary = self._summarize_history(history)
|
|||
|
|
parts.append(f"## 历史对话摘要\n{summary}")
|
|||
|
|
|
|||
|
|
if parts:
|
|||
|
|
return "\n\n".join(parts)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning("加载长期记忆失败: %s", e)
|
|||
|
|
finally:
|
|||
|
|
if db:
|
|||
|
|
db.close()
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
async def save_context(
|
|||
|
|
self, user_message: str, assistant_reply: str
|
|||
|
|
) -> None:
|
|||
|
|
"""将单轮对话保存到长期记忆。"""
|
|||
|
|
if not self.persist or not self.scope_id:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 更新上下文
|
|||
|
|
ctx = self._long_term_context.get("context", {})
|
|||
|
|
ctx["last_user_message"] = user_message[:500]
|
|||
|
|
ctx["last_assistant_reply"] = assistant_reply[:500]
|
|||
|
|
self._long_term_context["context"] = ctx
|
|||
|
|
|
|||
|
|
db: Optional[Session] = None
|
|||
|
|
try:
|
|||
|
|
db = SessionLocal()
|
|||
|
|
save_persistent_memory(
|
|||
|
|
db, self.scope_kind, self.scope_id,
|
|||
|
|
self.session_key, self._long_term_context,
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning("保存长期记忆失败: %s", e)
|
|||
|
|
finally:
|
|||
|
|
if db:
|
|||
|
|
db.close()
|
|||
|
|
|
|||
|
|
def trim_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
裁剪消息列表:保留最近的 N 条,但始终保留第一条 system 消息。
|
|||
|
|
"""
|
|||
|
|
if len(messages) <= self.max_history:
|
|||
|
|
return messages
|
|||
|
|
|
|||
|
|
system_msgs = [m for m in messages if m.get("role") == "system"]
|
|||
|
|
other_msgs = [m for m in messages if m.get("role") != "system"]
|
|||
|
|
|
|||
|
|
trimmed = other_msgs[-(self.max_history - len(system_msgs)):]
|
|||
|
|
return system_msgs + trimmed
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _summarize_history(history: List[Dict[str, Any]]) -> str:
|
|||
|
|
"""简单汇总历史对话(不做 LLM 压缩,仅计数)。"""
|
|||
|
|
turns = 0
|
|||
|
|
for m in history:
|
|||
|
|
if m.get("role") == "user":
|
|||
|
|
turns += 1
|
|||
|
|
return f"共 {turns} 轮历史对话(详情已存入长期记忆)"
|