128 lines
3.9 KiB
Python
128 lines
3.9 KiB
Python
|
|
"""
|
|||
|
|
用户会话记忆持久化:与 Cache 节点 user_memory_* 键对齐,写入 MySQL,Redis 仍作热缓存。
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
import uuid
|
|||
|
|
from typing import Any, Dict, Optional, Tuple
|
|||
|
|
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from app.core.config import settings
|
|||
|
|
from app.models.persistent_user_memory import PersistentUserMemory
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
USER_MEMORY_PREFIX = "user_memory_"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def parse_memory_scope(workflow_id: str) -> Tuple[Optional[str], Optional[str]]:
|
|||
|
|
"""
|
|||
|
|
Celery 任务里 workflow_id 对 Agent 为 agent_{uuid},对工作流为 UUID 字符串。
|
|||
|
|
返回 (scope_kind, scope_id) 或 (None, None) 表示不做 DB 持久化。
|
|||
|
|
"""
|
|||
|
|
if not workflow_id:
|
|||
|
|
return None, None
|
|||
|
|
if workflow_id.startswith("agent_") and len(workflow_id) > 6:
|
|||
|
|
return "agent", workflow_id[6:]
|
|||
|
|
# 标准 UUID
|
|||
|
|
s = workflow_id.strip()
|
|||
|
|
if len(s) == 36 and s.count("-") == 4:
|
|||
|
|
return "workflow", s
|
|||
|
|
return None, None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def is_user_memory_redis_key(key: str) -> bool:
|
|||
|
|
return isinstance(key, str) and key.startswith(USER_MEMORY_PREFIX)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def session_key_from_user_memory_key(key: str) -> Optional[str]:
|
|||
|
|
if not is_user_memory_redis_key(key):
|
|||
|
|
return None
|
|||
|
|
rest = key[len(USER_MEMORY_PREFIX) :]
|
|||
|
|
return rest if rest else None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def merge_memory_payloads(
|
|||
|
|
base: Optional[Dict[str, Any]], overlay: Optional[Dict[str, Any]]
|
|||
|
|
) -> Optional[Dict[str, Any]]:
|
|||
|
|
"""合并 DB 与 Redis:对话历史取更长的一条;画像/上下文做浅合并。"""
|
|||
|
|
if not base and not overlay:
|
|||
|
|
return None
|
|||
|
|
if not base:
|
|||
|
|
return dict(overlay) if isinstance(overlay, dict) else None
|
|||
|
|
if not overlay:
|
|||
|
|
return dict(base)
|
|||
|
|
out = dict(base)
|
|||
|
|
if not isinstance(overlay, dict):
|
|||
|
|
return out
|
|||
|
|
for k, v in overlay.items():
|
|||
|
|
if k == "conversation_history" and isinstance(v, list) and isinstance(out.get(k), list):
|
|||
|
|
if len(v) > len(out[k]):
|
|||
|
|
out[k] = v
|
|||
|
|
elif k == "user_profile" and isinstance(v, dict):
|
|||
|
|
out[k] = {**(out.get(k) or {}), **v}
|
|||
|
|
elif k == "context" and isinstance(v, dict):
|
|||
|
|
out[k] = {**(out.get(k) or {}), **v}
|
|||
|
|
else:
|
|||
|
|
out[k] = v
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_persistent_memory(
|
|||
|
|
db: Session, scope_kind: str, scope_id: str, session_key: str
|
|||
|
|
) -> Optional[Dict[str, Any]]:
|
|||
|
|
row = (
|
|||
|
|
db.query(PersistentUserMemory)
|
|||
|
|
.filter(
|
|||
|
|
PersistentUserMemory.scope_kind == scope_kind,
|
|||
|
|
PersistentUserMemory.scope_id == scope_id,
|
|||
|
|
PersistentUserMemory.session_key == session_key,
|
|||
|
|
)
|
|||
|
|
.first()
|
|||
|
|
)
|
|||
|
|
if not row or not isinstance(row.payload, dict):
|
|||
|
|
return None
|
|||
|
|
return dict(row.payload)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_persistent_memory(
|
|||
|
|
db: Session, scope_kind: str, scope_id: str, session_key: str, payload: Dict[str, Any]
|
|||
|
|
) -> None:
|
|||
|
|
row = (
|
|||
|
|
db.query(PersistentUserMemory)
|
|||
|
|
.filter(
|
|||
|
|
PersistentUserMemory.scope_kind == scope_kind,
|
|||
|
|
PersistentUserMemory.scope_id == scope_id,
|
|||
|
|
PersistentUserMemory.session_key == session_key,
|
|||
|
|
)
|
|||
|
|
.first()
|
|||
|
|
)
|
|||
|
|
if row:
|
|||
|
|
row.payload = payload
|
|||
|
|
else:
|
|||
|
|
row = PersistentUserMemory(
|
|||
|
|
id=str(uuid.uuid4()),
|
|||
|
|
scope_kind=scope_kind,
|
|||
|
|
scope_id=scope_id,
|
|||
|
|
session_key=session_key,
|
|||
|
|
payload=payload,
|
|||
|
|
)
|
|||
|
|
db.add(row)
|
|||
|
|
db.commit()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def delete_persistent_memory(db: Session, scope_kind: str, scope_id: str, session_key: str) -> None:
|
|||
|
|
q = db.query(PersistentUserMemory).filter(
|
|||
|
|
PersistentUserMemory.scope_kind == scope_kind,
|
|||
|
|
PersistentUserMemory.scope_id == scope_id,
|
|||
|
|
PersistentUserMemory.session_key == session_key,
|
|||
|
|
)
|
|||
|
|
q.delete()
|
|||
|
|
db.commit()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def persist_enabled() -> bool:
|
|||
|
|
return bool(getattr(settings, "MEMORY_PERSIST_DB_ENABLED", True))
|