Files
aiagent/backend/app/services/persistent_memory_service.py

128 lines
3.9 KiB
Python
Raw Normal View History

"""
用户会话记忆持久化 Cache 节点 user_memory_* 键对齐写入 MySQLRedis 仍作热缓存
"""
from __future__ import annotations
import logging
import uuid
from typing import Any, Dict, Optional, Tuple
from sqlalchemy.orm import Session
from app.core.config import settings
from app.models.persistent_user_memory import PersistentUserMemory
logger = logging.getLogger(__name__)
USER_MEMORY_PREFIX = "user_memory_"
def parse_memory_scope(workflow_id: str) -> Tuple[Optional[str], Optional[str]]:
"""
Celery 任务里 workflow_id Agent agent_{uuid}对工作流为 UUID 字符串
返回 (scope_kind, scope_id) (None, None) 表示不做 DB 持久化
"""
if not workflow_id:
return None, None
if workflow_id.startswith("agent_") and len(workflow_id) > 6:
return "agent", workflow_id[6:]
# 标准 UUID
s = workflow_id.strip()
if len(s) == 36 and s.count("-") == 4:
return "workflow", s
return None, None
def is_user_memory_redis_key(key: str) -> bool:
return isinstance(key, str) and key.startswith(USER_MEMORY_PREFIX)
def session_key_from_user_memory_key(key: str) -> Optional[str]:
if not is_user_memory_redis_key(key):
return None
rest = key[len(USER_MEMORY_PREFIX) :]
return rest if rest else None
def merge_memory_payloads(
base: Optional[Dict[str, Any]], overlay: Optional[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""合并 DB 与 Redis对话历史取更长的一条画像/上下文做浅合并。"""
if not base and not overlay:
return None
if not base:
return dict(overlay) if isinstance(overlay, dict) else None
if not overlay:
return dict(base)
out = dict(base)
if not isinstance(overlay, dict):
return out
for k, v in overlay.items():
if k == "conversation_history" and isinstance(v, list) and isinstance(out.get(k), list):
if len(v) > len(out[k]):
out[k] = v
elif k == "user_profile" and isinstance(v, dict):
out[k] = {**(out.get(k) or {}), **v}
elif k == "context" and isinstance(v, dict):
out[k] = {**(out.get(k) or {}), **v}
else:
out[k] = v
return out
def load_persistent_memory(
db: Session, scope_kind: str, scope_id: str, session_key: str
) -> Optional[Dict[str, Any]]:
row = (
db.query(PersistentUserMemory)
.filter(
PersistentUserMemory.scope_kind == scope_kind,
PersistentUserMemory.scope_id == scope_id,
PersistentUserMemory.session_key == session_key,
)
.first()
)
if not row or not isinstance(row.payload, dict):
return None
return dict(row.payload)
def save_persistent_memory(
db: Session, scope_kind: str, scope_id: str, session_key: str, payload: Dict[str, Any]
) -> None:
row = (
db.query(PersistentUserMemory)
.filter(
PersistentUserMemory.scope_kind == scope_kind,
PersistentUserMemory.scope_id == scope_id,
PersistentUserMemory.session_key == session_key,
)
.first()
)
if row:
row.payload = payload
else:
row = PersistentUserMemory(
id=str(uuid.uuid4()),
scope_kind=scope_kind,
scope_id=scope_id,
session_key=session_key,
payload=payload,
)
db.add(row)
db.commit()
def delete_persistent_memory(db: Session, scope_kind: str, scope_id: str, session_key: str) -> None:
q = db.query(PersistentUserMemory).filter(
PersistentUserMemory.scope_kind == scope_kind,
PersistentUserMemory.scope_id == scope_id,
PersistentUserMemory.session_key == session_key,
)
q.delete()
db.commit()
def persist_enabled() -> bool:
return bool(getattr(settings, "MEMORY_PERSIST_DB_ENABLED", True))