128 lines
3.9 KiB
Python
128 lines
3.9 KiB
Python
"""
|
||
用户会话记忆持久化:与 Cache 节点 user_memory_* 键对齐,写入 MySQL,Redis 仍作热缓存。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import uuid
|
||
from typing import Any, Dict, Optional, Tuple
|
||
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.core.config import settings
|
||
from app.models.persistent_user_memory import PersistentUserMemory
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
USER_MEMORY_PREFIX = "user_memory_"
|
||
|
||
|
||
def parse_memory_scope(workflow_id: str) -> Tuple[Optional[str], Optional[str]]:
|
||
"""
|
||
Celery 任务里 workflow_id 对 Agent 为 agent_{uuid},对工作流为 UUID 字符串。
|
||
返回 (scope_kind, scope_id) 或 (None, None) 表示不做 DB 持久化。
|
||
"""
|
||
if not workflow_id:
|
||
return None, None
|
||
if workflow_id.startswith("agent_") and len(workflow_id) > 6:
|
||
return "agent", workflow_id[6:]
|
||
# 标准 UUID
|
||
s = workflow_id.strip()
|
||
if len(s) == 36 and s.count("-") == 4:
|
||
return "workflow", s
|
||
return None, None
|
||
|
||
|
||
def is_user_memory_redis_key(key: str) -> bool:
|
||
return isinstance(key, str) and key.startswith(USER_MEMORY_PREFIX)
|
||
|
||
|
||
def session_key_from_user_memory_key(key: str) -> Optional[str]:
|
||
if not is_user_memory_redis_key(key):
|
||
return None
|
||
rest = key[len(USER_MEMORY_PREFIX) :]
|
||
return rest if rest else None
|
||
|
||
|
||
def merge_memory_payloads(
|
||
base: Optional[Dict[str, Any]], overlay: Optional[Dict[str, Any]]
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""合并 DB 与 Redis:对话历史取更长的一条;画像/上下文做浅合并。"""
|
||
if not base and not overlay:
|
||
return None
|
||
if not base:
|
||
return dict(overlay) if isinstance(overlay, dict) else None
|
||
if not overlay:
|
||
return dict(base)
|
||
out = dict(base)
|
||
if not isinstance(overlay, dict):
|
||
return out
|
||
for k, v in overlay.items():
|
||
if k == "conversation_history" and isinstance(v, list) and isinstance(out.get(k), list):
|
||
if len(v) > len(out[k]):
|
||
out[k] = v
|
||
elif k == "user_profile" and isinstance(v, dict):
|
||
out[k] = {**(out.get(k) or {}), **v}
|
||
elif k == "context" and isinstance(v, dict):
|
||
out[k] = {**(out.get(k) or {}), **v}
|
||
else:
|
||
out[k] = v
|
||
return out
|
||
|
||
|
||
def load_persistent_memory(
|
||
db: Session, scope_kind: str, scope_id: str, session_key: str
|
||
) -> Optional[Dict[str, Any]]:
|
||
row = (
|
||
db.query(PersistentUserMemory)
|
||
.filter(
|
||
PersistentUserMemory.scope_kind == scope_kind,
|
||
PersistentUserMemory.scope_id == scope_id,
|
||
PersistentUserMemory.session_key == session_key,
|
||
)
|
||
.first()
|
||
)
|
||
if not row or not isinstance(row.payload, dict):
|
||
return None
|
||
return dict(row.payload)
|
||
|
||
|
||
def save_persistent_memory(
|
||
db: Session, scope_kind: str, scope_id: str, session_key: str, payload: Dict[str, Any]
|
||
) -> None:
|
||
row = (
|
||
db.query(PersistentUserMemory)
|
||
.filter(
|
||
PersistentUserMemory.scope_kind == scope_kind,
|
||
PersistentUserMemory.scope_id == scope_id,
|
||
PersistentUserMemory.session_key == session_key,
|
||
)
|
||
.first()
|
||
)
|
||
if row:
|
||
row.payload = payload
|
||
else:
|
||
row = PersistentUserMemory(
|
||
id=str(uuid.uuid4()),
|
||
scope_kind=scope_kind,
|
||
scope_id=scope_id,
|
||
session_key=session_key,
|
||
payload=payload,
|
||
)
|
||
db.add(row)
|
||
db.commit()
|
||
|
||
|
||
def delete_persistent_memory(db: Session, scope_kind: str, scope_id: str, session_key: str) -> None:
|
||
q = db.query(PersistentUserMemory).filter(
|
||
PersistentUserMemory.scope_kind == scope_kind,
|
||
PersistentUserMemory.scope_id == scope_id,
|
||
PersistentUserMemory.session_key == session_key,
|
||
)
|
||
q.delete()
|
||
db.commit()
|
||
|
||
|
||
def persist_enabled() -> bool:
|
||
return bool(getattr(settings, "MEMORY_PERSIST_DB_ENABLED", True))
|