Files
aiagent/backend/app/services/persistent_memory_service.py

128 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
用户会话记忆持久化:与 Cache 节点 user_memory_* 键对齐,写入 MySQLRedis 仍作热缓存。
"""
from __future__ import annotations
import logging
import uuid
from typing import Any, Dict, Optional, Tuple
from sqlalchemy.orm import Session
from app.core.config import settings
from app.models.persistent_user_memory import PersistentUserMemory
logger = logging.getLogger(__name__)
USER_MEMORY_PREFIX = "user_memory_"
def parse_memory_scope(workflow_id: str) -> Tuple[Optional[str], Optional[str]]:
"""
Celery 任务里 workflow_id 对 Agent 为 agent_{uuid},对工作流为 UUID 字符串。
返回 (scope_kind, scope_id) 或 (None, None) 表示不做 DB 持久化。
"""
if not workflow_id:
return None, None
if workflow_id.startswith("agent_") and len(workflow_id) > 6:
return "agent", workflow_id[6:]
# 标准 UUID
s = workflow_id.strip()
if len(s) == 36 and s.count("-") == 4:
return "workflow", s
return None, None
def is_user_memory_redis_key(key: str) -> bool:
return isinstance(key, str) and key.startswith(USER_MEMORY_PREFIX)
def session_key_from_user_memory_key(key: str) -> Optional[str]:
if not is_user_memory_redis_key(key):
return None
rest = key[len(USER_MEMORY_PREFIX) :]
return rest if rest else None
def merge_memory_payloads(
base: Optional[Dict[str, Any]], overlay: Optional[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""合并 DB 与 Redis对话历史取更长的一条画像/上下文做浅合并。"""
if not base and not overlay:
return None
if not base:
return dict(overlay) if isinstance(overlay, dict) else None
if not overlay:
return dict(base)
out = dict(base)
if not isinstance(overlay, dict):
return out
for k, v in overlay.items():
if k == "conversation_history" and isinstance(v, list) and isinstance(out.get(k), list):
if len(v) > len(out[k]):
out[k] = v
elif k == "user_profile" and isinstance(v, dict):
out[k] = {**(out.get(k) or {}), **v}
elif k == "context" and isinstance(v, dict):
out[k] = {**(out.get(k) or {}), **v}
else:
out[k] = v
return out
def load_persistent_memory(
db: Session, scope_kind: str, scope_id: str, session_key: str
) -> Optional[Dict[str, Any]]:
row = (
db.query(PersistentUserMemory)
.filter(
PersistentUserMemory.scope_kind == scope_kind,
PersistentUserMemory.scope_id == scope_id,
PersistentUserMemory.session_key == session_key,
)
.first()
)
if not row or not isinstance(row.payload, dict):
return None
return dict(row.payload)
def save_persistent_memory(
db: Session, scope_kind: str, scope_id: str, session_key: str, payload: Dict[str, Any]
) -> None:
row = (
db.query(PersistentUserMemory)
.filter(
PersistentUserMemory.scope_kind == scope_kind,
PersistentUserMemory.scope_id == scope_id,
PersistentUserMemory.session_key == session_key,
)
.first()
)
if row:
row.payload = payload
else:
row = PersistentUserMemory(
id=str(uuid.uuid4()),
scope_kind=scope_kind,
scope_id=scope_id,
session_key=session_key,
payload=payload,
)
db.add(row)
db.commit()
def delete_persistent_memory(db: Session, scope_kind: str, scope_id: str, session_key: str) -> None:
q = db.query(PersistentUserMemory).filter(
PersistentUserMemory.scope_kind == scope_kind,
PersistentUserMemory.scope_id == scope_id,
PersistentUserMemory.session_key == session_key,
)
q.delete()
db.commit()
def persist_enabled() -> bool:
return bool(getattr(settings, "MEMORY_PERSIST_DB_ENABLED", True))