Files
aiagent/backend/app/agent_runtime/memory.py
renjianbo 7f4aeb021b fix: Feishu channel agents file_write permission blocked + memory system tests & docs
- Fix 8 Feishu agent handlers to use permission_level="acceptEdits" so file_write
  tool works without Web UI approval popup (lingxi/renshenguo/suyao/tiantian/orange/main/schedule)
- Add P5-P7 memory improvements: offline keyword fallback, team sharing, file-based memory
- Add auto_dream_service for daily memory consolidation
- Add 99 memory system test cases (basic 18 + advanced 43 + pytest 38)
- Add platform capability assessment report and unfinished project checklist

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-06-14 20:35:12 +08:00

902 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Agent 记忆管理:包装已有 persistent_memory_service提供会话级和长期记忆。
支持 LLM 自动压缩总结对话历史。
"""
from __future__ import annotations
import json
import logging
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.services.persistent_memory_service import (
load_persistent_memory,
save_persistent_memory,
persist_enabled,
)
from app.services.embedding_service import embedding_service, VectorEntry
logger = logging.getLogger(__name__)
class AgentMemory:
"""
分层记忆管理器:
- 工作记忆:当前会话消息列表(由 AgentRuntime 直接管理)
- 长期记忆:从 MySQL 加载/保存的用户画像和关键事实
- 记忆压缩LLM 自动总结对话历史,提取关键信息存入长期记忆
"""
def __init__(
self,
scope_kind: str = "agent",
scope_id: Optional[str] = None,
session_key: Optional[str] = None,
persist: bool = True,
max_history: int = 20,
vector_memory_enabled: bool = True,
vector_memory_top_k: int = 5,
vector_memory_rerank: bool = False,
memory_type_filter: Optional[List[str]] = None,
team_id: Optional[str] = None,
team_share_enabled: bool = False,
memory_dir_enabled: bool = False,
memory_dir_path: str = "",
):
self.scope_kind = scope_kind
self.scope_id = scope_id or "default"
self.session_key = session_key or "default_session"
self.persist = persist and persist_enabled()
self.max_history = max_history
self.vector_memory_enabled = vector_memory_enabled
self.vector_memory_top_k = vector_memory_top_k
self.vector_memory_rerank = vector_memory_rerank
self.memory_type_filter = memory_type_filter # None = 全部类型
self.team_id = team_id # 团队共享 ID
self.team_share_enabled = team_share_enabled # 是否自动发布到团队池
# 文件式记忆
self.memory_dir_enabled = memory_dir_enabled
self.memory_dir_path = memory_dir_path
self._file_store = None # 延迟初始化
# 记忆类型分类: user / feedback / project / reference
self.MEMORY_TYPES = ("user", "feedback", "project", "reference")
# 从长期记忆加载的上下文(启动时加载)
self._long_term_context: Dict[str, Any] = {}
# 记录已压缩的消息数,避免重复压缩
self._last_compressed_msg_count = 0
def _get_file_store(self):
"""延迟初始化文件记忆存储。"""
if self._file_store is None and self.memory_dir_enabled:
from app.services.file_memory_service import get_file_memory_store
self._file_store = get_file_memory_store(self.memory_dir_path)
return self._file_store
async def initialize(self, query: str = "") -> str:
"""
初始化记忆:从 DB/Redis 加载长期记忆 + 向量检索相关历史。
返回注入 system prompt 的记忆文本块。
"""
if not self.persist or not self.scope_id:
return ""
parts: List[str] = []
db: Optional[Session] = None
try:
db = SessionLocal()
payload = load_persistent_memory(
db, self.scope_kind, self.scope_id, self.session_key
)
if payload and isinstance(payload, dict):
self._long_term_context = payload
profile = payload.get("user_profile")
if profile and isinstance(profile, dict):
profile_text = json.dumps(profile, ensure_ascii=False)
parts.append(f"## 用户画像\n{profile_text}")
context = payload.get("context")
if context and isinstance(context, dict):
ctx_text = json.dumps(context, ensure_ascii=False)
parts.append(f"## 上下文\n{ctx_text}")
history = payload.get("conversation_history")
if history and isinstance(history, list) and len(history) > 0:
summary = self._summarize_history(history)
parts.append(f"## 历史对话摘要\n{summary}")
except Exception as e:
logger.warning("加载长期记忆失败: %s", e)
finally:
if db:
db.close()
# 2. 向量检索:查找语义相关的历史对话
if self.vector_memory_enabled and self.scope_kind and self.scope_id:
vector_text = await self._vector_search(query)
if vector_text:
parts.append(vector_text)
# 3. P7 文件式记忆:从本地 MEMORY.md 加载
store = self._get_file_store()
if store and store.memory_count > 0 and query:
file_results = store.search(query, top_k=3)
if file_results:
lines = ["## 文件记忆(本地 MEMORY.md"]
for i, r in enumerate(file_results, 1):
mem_type = r.get("type", "reference")
content = r.get("content", "")[:300]
score = r.get("score", 0)
lines.append(f"{i}. [{mem_type}] {content}")
if score < 1.0:
lines[-1] += f" (匹配度: {score:.2f})"
parts.append("\n".join(lines))
# 4. 全局知识检索:从 GlobalKnowledge 表加载相关条目
global_text = await self._global_knowledge_search(query)
if global_text:
parts.append(global_text)
return "\n\n".join(parts) if parts else ""
async def _vector_search(self, query: str = "") -> str:
"""
向量检索语义相关的历史记忆,返回格式化的文本块。
若无 query 则返回最近 Top-5 条记忆。
支持 memory_type_filter 按类型过滤 + LLM Rerank 精选。
"""
from app.models.agent_vector_memory import AgentVectorMemory
db: Optional[Session] = None
try:
db = SessionLocal()
# 查询当前 scope 的所有向量记忆(按时间倒序)
query_builder = (
db.query(AgentVectorMemory)
.filter(
AgentVectorMemory.scope_kind == self.scope_kind,
AgentVectorMemory.scope_id == self.scope_id,
)
)
rows = (
query_builder
.order_by(AgentVectorMemory.created_at.desc())
.limit(50)
.all()
)
# P6 团队共享:同时查询团队记忆池
if self.team_id:
team_rows = (
db.query(AgentVectorMemory)
.filter(
AgentVectorMemory.scope_kind == "team",
AgentVectorMemory.scope_id == self.team_id,
)
.order_by(AgentVectorMemory.created_at.desc())
.limit(30)
.all()
)
rows = list(rows) + list(team_rows)
if not rows:
return ""
entries: List[VectorEntry] = []
for row in rows:
# 类型过滤memory_type_filter 不为空时生效)
meta = row.metadata_ or {}
row_memory_type = meta.get("memory_type", meta.get("type", "conversation_turn"))
if self.memory_type_filter:
if row_memory_type not in self.memory_type_filter:
continue
emb = embedding_service.deserialize_embedding(row.embedding) if row.embedding else []
entries.append({
"id": row.id,
"scope_kind": row.scope_kind,
"scope_id": row.scope_id,
"content_text": row.content_text,
"embedding": emb,
"metadata": meta,
})
if not entries:
return ""
matched: List[VectorEntry] = []
if query and query.strip():
# 有 query生成 embedding 做语义搜索
query_emb = await embedding_service.generate_embedding(query)
if query_emb:
# 向量检索取 top_k * 4 候选(为 rerank 留余量),最少 20 条
candidate_k = max(20, self.vector_memory_top_k * 4)
candidates = await embedding_service.similarity_search(
query_emb, entries, top_k=min(candidate_k, len(entries))
)
# LLM Rerank向量粗筛 → LLM 精选
if self.vector_memory_rerank and len(candidates) > self.vector_memory_top_k:
matched = await self._llm_rerank(query, candidates)
if not matched:
matched = candidates[: self.vector_memory_top_k]
else:
# P5 离线兜底Embedding API 不可用时降级为关键词匹配
logger.info("Embedding 不可用,降级为离线关键词匹配")
matched = embedding_service.keyword_search(
query, entries, top_k=self.vector_memory_top_k, min_score=0.05,
)
else:
# 无 query返回最近几条
matched = entries[: self.vector_memory_top_k]
for m in matched:
m["score"] = 1.0
if not matched:
return ""
# 格式化为文本块
lines = ["## 相关历史记忆"]
for i, m in enumerate(matched, 1):
text = m.get("content_text", "")[:500]
meta = m.get("metadata", {})
mem_type = meta.get("memory_type", meta.get("type", "对话"))
scope_kind = m.get("scope_kind", "")
# 标注团队共享来源
source_tag = ""
if scope_kind == "team":
shared_by = meta.get("shared_by", meta.get("source_scope", "unknown"))
source_tag = f" [团队共享]"
lines.append(f"{i}. [{mem_type}]{source_tag} {text}")
if m.get("score", 1.0) < 1.0:
lines[-1] += f" (匹配度: {m['score']:.2f})"
return "\n".join(lines)
except Exception as e:
logger.warning("向量检索失败: %s", e)
return ""
finally:
if db:
db.close()
async def _llm_rerank(
self, query: str, candidates: List[VectorEntry],
) -> List[VectorEntry]:
"""
LLM Rerank用轻量模型对向量粗筛结果打分排序返回精选 top-K。
流程:取向量检索 top-N 候选 → LLM 按与 query 相关性打分 (1-10)
→ 取 top-K 高分结果。失败时降级返回原始排序。
"""
from openai import AsyncOpenAI
from app.core.config import settings
if not candidates or len(candidates) <= self.vector_memory_top_k:
return candidates[: self.vector_memory_top_k]
try:
# 构建候选列表
items_text = []
for idx, c in enumerate(candidates):
content = c.get("content_text", "")[:300]
mem_type = c.get("metadata", {}).get("memory_type", "unknown")
items_text.append(f"[{idx}] [{mem_type}] {content}")
rerank_prompt = (
"你是一个记忆检索排序助手。请根据用户查询对以下记忆条目按相关性打分1-10分\n"
"只输出 JSON 数组,每个元素包含 index 和 score按 score 降序排列。\n"
"只保留 score >= 4 的结果。最多返回 {} 条。\n\n"
"用户查询: {}\n\n记忆条目:\n{}"
).format(
self.vector_memory_top_k,
query[:500],
"\n".join(items_text),
)
api_key = settings.DEEPSEEK_API_KEY or settings.OPENAI_API_KEY or ""
base_url = settings.DEEPSEEK_BASE_URL or settings.OPENAI_BASE_URL or "https://api.deepseek.com"
if api_key == "your-openai-api-key":
api_key = settings.DEEPSEEK_API_KEY or ""
base_url = settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com"
if not api_key:
return candidates[: self.vector_memory_top_k]
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
resp = await client.chat.completions.create(
model="deepseek-v4-flash",
messages=[{"role": "user", "content": rerank_prompt}],
temperature=0.1,
max_tokens=512,
timeout=15,
)
raw = resp.choices[0].message.content or ""
raw = raw.strip().removeprefix("```json").removesuffix("```").strip()
import json
scored = json.loads(raw)
if not isinstance(scored, list):
return candidates[: self.vector_memory_top_k]
# 按 score 排序取 top-K
scored.sort(key=lambda x: x.get("score", 0), reverse=True)
result: List[VectorEntry] = []
for item in scored[: self.vector_memory_top_k]:
idx = item.get("index", -1)
if 0 <= idx < len(candidates):
candidates[idx]["score"] = float(item.get("score", 5.0)) / 10.0
result.append(candidates[idx])
if result:
logger.info("LLM Rerank: %d 候选 → %d 精选", len(candidates), len(result))
return result
return candidates[: self.vector_memory_top_k]
except Exception as e:
logger.warning("LLM Rerank 失败,使用向量排序: %s", e)
return candidates[: self.vector_memory_top_k]
async def _global_knowledge_search(self, query: str = "") -> str:
"""从 GlobalKnowledge 表检索相关的全局知识条目。"""
from datetime import datetime
from app.models.agent import GlobalKnowledge
db: Optional[Session] = None
try:
db = SessionLocal()
now = datetime.utcnow()
# 查询未过期的知识expires_at IS NULL 或 expires_at > now
rows = (
db.query(GlobalKnowledge)
.filter(
(GlobalKnowledge.expires_at.is_(None))
| (GlobalKnowledge.expires_at > now)
)
.order_by(GlobalKnowledge.created_at.desc())
.limit(50)
.all()
)
if not rows:
return ""
# 如果有 query用向量相似度筛选否则返回最近的条目
if query and query.strip():
entries: List[VectorEntry] = []
for row in rows:
if not row.embedding:
continue
try:
emb = embedding_service.deserialize_embedding(row.embedding)
except Exception:
emb = []
if emb:
entries.append({
"id": row.id,
"scope_kind": "global",
"scope_id": "global",
"content_text": row.content,
"embedding": emb,
"metadata": {
"source_agent_id": row.source_agent_id,
"tags": row.tags or [],
"confidence": row.confidence or "medium",
},
})
if entries:
query_emb = await embedding_service.generate_embedding(query)
if query_emb:
matched = await embedding_service.similarity_search(
query_emb, entries, top_k=min(5, len(entries)),
)
if matched:
lines = ["## 全局知识库"]
for i, m in enumerate(matched, 1):
tags = m.get("metadata", {}).get("tags", [])
conf = m.get("metadata", {}).get("confidence", "medium")
tag_str = f" [{', '.join(tags[:3])}]" if tags else ""
conf_str = f" (置信度:{conf})" if conf != "medium" else ""
lines.append(f"{i}.{tag_str}{conf_str} {m.get('content_text', '')[:500]}")
return "\n".join(lines)
else:
# 无 query返回最近 5 条全局知识(优先高置信度)
recent = sorted(rows, key=lambda r: (
0 if r.confidence == "high" else 1 if r.confidence == "medium" else 2
))[:5]
if recent:
lines = ["## 全局知识库(最近)"]
for i, row in enumerate(recent, 1):
tag_str = f" [{(', '.join(row.tags[:3]))}]" if row.tags else ""
conf_str = f" (置信度:{row.confidence})" if row.confidence and row.confidence != "medium" else ""
lines.append(f"{i}.{tag_str}{conf_str} {row.content[:500]}")
return "\n".join(lines)
return ""
except Exception as e:
logger.warning("全局知识检索失败: %s", e)
return ""
finally:
if db:
db.close()
async def save_global_knowledge(
self, content: str, source_agent_id: str = "",
source_user_id: str = "", tags: Optional[List[str]] = None,
confidence: str = "medium", ttl_hours: int = 0,
) -> None:
"""将知识条目写入全局知识池(带去重、置信度、过期时间)。
去重策略:对 content 取哈希,若已有相同哈希的条目则跳过。
过期策略ttl_hours > 0 时设置 expires_at0 表示永不过期。
"""
from datetime import datetime, timedelta
from app.models.agent import GlobalKnowledge
if not content or len(content) < 20:
return
db: Optional[Session] = None
try:
db = SessionLocal()
# 去重:用 content 的 MD5 哈希检查是否已存在
import hashlib
content_hash = hashlib.md5(content[:500].encode()).hexdigest()
# 查询最近 200 条,检查是否有相同哈希的条目
recent = (
db.query(GlobalKnowledge)
.order_by(GlobalKnowledge.created_at.desc())
.limit(200)
.all()
)
for existing in recent:
existing_hash = hashlib.md5(
(existing.content or "")[:500].encode()
).hexdigest()
if existing_hash == content_hash:
logger.info("全局知识去重:已存在相同条目,跳过写入")
return
# 嵌入向量
embedding_json = ""
try:
emb = await embedding_service.generate_embedding(content)
if emb:
embedding_json = embedding_service.serialize_embedding(emb) or ""
except Exception:
pass
# 过期时间
expires_at = None
if ttl_hours > 0:
expires_at = datetime.utcnow() + timedelta(hours=ttl_hours)
record = GlobalKnowledge(
content=content[:2000],
embedding=embedding_json or None,
source_agent_id=source_agent_id or "",
source_user_id=source_user_id or "",
tags=tags or [],
confidence=confidence or "medium",
expires_at=expires_at,
scope_kind=self.scope_kind,
scope_id=self.scope_id or "global",
)
db.add(record)
db.commit()
logger.info("已写入全局知识: agent=%s tags=%s confidence=%s",
source_agent_id, tags, confidence)
except Exception as e:
logger.warning("保存全局知识失败: %s", e)
if db:
db.rollback()
finally:
if db:
db.close()
async def save_context(
self, user_message: str, assistant_reply: str,
messages: Optional[List[Dict[str, Any]]] = None,
) -> None:
"""将单轮对话保存到长期记忆。
快速路径(同步完成):向量记忆写入 + 基础上下文更新。
慢速路径fire-and-forgetLLM 压缩总结 → persistent_memory 更新。
后台压缩不阻塞对话响应。
"""
if not self.persist or not self.scope_id:
return
# 快速:更新基础上下文
ctx = self._long_term_context.get("context", {})
ctx["last_user_message"] = user_message[:500]
ctx["last_assistant_reply"] = assistant_reply[:500]
self._long_term_context["context"] = ctx
# 后台LLM 压缩总结fire-and-forget不阻塞主对话
if messages and len(messages) > self._last_compressed_msg_count + 2:
self._last_compressed_msg_count = len(messages)
import asyncio as _asyncio
_asyncio.ensure_future(self._background_compress_and_save(messages))
db: Optional[Session] = None
try:
db = SessionLocal()
# 快速:保存基础上下文到 persistent_memory后续后台压缩会覆盖更新
save_persistent_memory(
db, self.scope_kind, self.scope_id,
self.session_key, self._long_term_context,
)
# 快速:保存向量记忆
if self.vector_memory_enabled:
mem_type = self._infer_memory_type(user_message, assistant_reply)
await self._save_vector_memory(
db, user_message, assistant_reply, memory_type=mem_type,
)
# P7 文件式记忆兜底:同步写入本地 MEMORY.md
store = self._get_file_store()
if store:
mem_type = self._infer_memory_type(user_message, assistant_reply)
content = f"用户: {user_message[:300]}\n助手: {assistant_reply[:300]}"
store.save(
name=f"{self.scope_id}_{self.session_key}_{len(ctx)}",
content=content,
mem_type=mem_type,
)
except Exception as e:
logger.warning("保存长期记忆失败: %s", e)
finally:
if db:
db.close()
async def _save_vector_memory(
self, db: Session, user_message: str, assistant_reply: str,
memory_type: str = "conversation_turn",
) -> None:
"""生成 embedding 并保存到向量记忆表。"""
from app.models.agent_vector_memory import AgentVectorMemory
content_text = f"用户: {user_message}\n助手: {assistant_reply}"
if len(content_text) > 8000:
content_text = content_text[:8000]
try:
# 生成 embedding
embedding = await embedding_service.generate_embedding(content_text)
embedding_json = embedding_service.serialize_embedding(embedding) if embedding else ""
record = AgentVectorMemory(
scope_kind=self.scope_kind,
scope_id=self.scope_id,
session_key=self.session_key,
content_text=content_text[:2000],
embedding=embedding_json or None,
metadata_={
"type": memory_type,
"memory_type": memory_type,
},
)
db.add(record)
db.commit()
# P6 团队共享:自动将记忆副本发布到团队池
if self.team_id and self.team_share_enabled:
try:
team_record = AgentVectorMemory(
scope_kind="team",
scope_id=self.team_id,
session_key=self.session_key,
content_text=content_text[:2000],
embedding=embedding_json or None,
metadata_={
"type": memory_type,
"memory_type": memory_type,
"source_scope": f"{self.scope_kind}/{self.scope_id}",
"shared_by": self.scope_id,
},
)
db.add(team_record)
db.commit()
logger.debug("已同步到团队记忆池 (team=%s)", self.team_id)
except Exception:
db.rollback() # 团队同步失败不影响主流程
logger.debug("已保存向量记忆 (scope=%s/%s, type=%s)", self.scope_kind, self.scope_id, memory_type)
except Exception as e:
logger.warning("保存向量记忆失败: %s", e)
db.rollback()
async def _background_compress_and_save(
self, messages: List[Dict[str, Any]],
) -> None:
"""
后台异步LLM 压缩总结 + 写入 persistent_memory。
从 save_context 中 fire-and-forget 调用,不阻塞对话响应。
"""
try:
await self._compress_and_summarize(messages)
# 将压缩更新后的长期上下文写回 DB
db: Optional[Session] = None
try:
db = SessionLocal()
save_persistent_memory(
db, self.scope_kind, self.scope_id,
self.session_key, self._long_term_context,
)
except Exception as e:
logger.warning("后台压缩保存 persistent_memory 失败: %s", e)
finally:
if db:
db.close()
except Exception as e:
logger.warning("后台压缩总结失败: %s", e)
async def _compress_and_summarize(
self, messages: List[Dict[str, Any]]
) -> None:
"""
使用 LLM 压缩总结对话历史,提取用户画像和关键事实。
只处理非 system 消息。
"""
from openai import AsyncOpenAI
from app.core.config import settings
# 提取对话消息(去掉 system 和 tool 消息)
conversation = []
for m in messages:
role = m.get("role", "")
if role == "system":
continue
if role == "tool":
# 工具结果精简后加入
content = m.get("content", "")
name = m.get("name", "tool")
conversation.append({"role": "user" if role == "tool" else role, "content": f"[工具 {name} 执行结果]\n{content[:200]}"})
else:
conversation.append({"role": role, "content": m.get("content", "")[:500]})
if len(conversation) < 2:
return
# 构建总结 prompt
summary_prompt = (
"你是一个记忆管理助手。请分析以下对话历史,提取关于用户的关键信息。\n\n"
"请返回 JSON 格式(不要 markdown 包裹),包含以下字段:\n"
"1. user_profile: 用户画像对象,包含用户的偏好、角色、关键需求等\n"
"2. key_facts: 从对话中提取的关键事实列表(字符串数组)\n"
"3. summary: 对话的简要总结100字以内\n"
"4. topics: 讨论过的话题列表(字符串数组)\n\n"
"如果没有足够信息,相应字段设为空对象或空数组。"
)
summary_messages = [
{"role": "system", "content": summary_prompt},
*conversation[-10:], # 只取最近 10 条消息
]
try:
api_key = settings.DEEPSEEK_API_KEY or settings.OPENAI_API_KEY or ""
base_url = settings.DEEPSEEK_BASE_URL or settings.OPENAI_BASE_URL or "https://api.deepseek.com"
if api_key == "your-openai-api-key":
api_key = settings.DEEPSEEK_API_KEY or ""
base_url = settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com"
if not api_key:
logger.warning("记忆压缩:未配置 API Key跳过")
return
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
resp = await client.chat.completions.create(
model="deepseek-v4-flash",
messages=summary_messages,
temperature=0.3,
max_tokens=1024,
timeout=30,
)
raw = resp.choices[0].message.content or ""
# 解析 JSON
result = json.loads(raw.strip().removeprefix("```json").removesuffix("```").strip())
# 合并到长期记忆
existing_profile = self._long_term_context.get("user_profile", {})
new_profile = result.get("user_profile", {})
if isinstance(new_profile, dict) and new_profile:
# 合并画像(新信息覆盖旧信息)
existing_profile.update(new_profile)
self._long_term_context["user_profile"] = existing_profile
# 合并关键事实
existing_facts = self._long_term_context.get("key_facts", [])
new_facts = result.get("key_facts", [])
if isinstance(new_facts, list):
all_facts = list(dict.fromkeys(existing_facts + new_facts)) # 去重
self._long_term_context["key_facts"] = all_facts[-20:] # 最多保留 20 条
# 更新摘要
summary = result.get("summary", "")
if summary:
ctx = self._long_term_context.get("context", {})
ctx["compressed_summary"] = summary
self._long_term_context["context"] = ctx
# 记录话题
topics = result.get("topics", [])
if isinstance(topics, list) and topics:
existing_topics = self._long_term_context.get("topics", [])
all_topics = list(dict.fromkeys(existing_topics + topics))
self._long_term_context["topics"] = all_topics[-20:]
logger.info("记忆压缩总结完成: profile=%s facts=%d topics=%d",
"updated" if new_profile else "unchanged",
len(new_facts), len(topics))
# P1: 将压缩摘要向量化写入 AgentVectorMemory使其可被语义检索
await self._save_compressed_memories(summary, new_facts, topics)
except json.JSONDecodeError:
logger.warning("记忆压缩LLM 返回非 JSON 格式,跳过")
except Exception as e:
logger.warning("记忆压缩失败: %s", e)
async def _save_compressed_memories(
self, summary: str, facts: List[str], topics: List[str],
) -> None:
"""
将 LLM 压缩总结的结果向量化写入 AgentVectorMemory。
每个 fact/summary/topic 单独写入,标注 memory_type=project来自对话压缩
失败不影响主流程。
"""
from app.models.agent_vector_memory import AgentVectorMemory
memories_to_save: List[tuple] = [] # (content, memory_type)
if summary:
memories_to_save.append((f"[对话摘要] {summary[:1500]}", "project"))
for fact in facts:
if fact and len(fact) > 10:
memories_to_save.append((f"[关键事实] {fact[:1500]}", "reference"))
for topic in topics:
if topic:
memories_to_save.append((f"[话题] {topic[:500]}", "project"))
if not memories_to_save:
return
db: Optional[Session] = None
try:
db = SessionLocal()
for content, mem_type in memories_to_save:
try:
embedding = await embedding_service.generate_embedding(content)
embedding_json = embedding_service.serialize_embedding(embedding) if embedding else ""
record = AgentVectorMemory(
scope_kind=self.scope_kind,
scope_id=self.scope_id,
session_key=self.session_key,
content_text=content[:2000],
embedding=embedding_json or None,
metadata_={
"type": "compressed_summary",
"memory_type": mem_type,
"source": "auto_compress",
},
)
db.add(record)
except Exception:
pass # 单条失败不阻塞其他写入
db.commit()
logger.info("已向量化压缩记忆: %d 条 (scope=%s/%s)",
len(memories_to_save), self.scope_kind, self.scope_id)
except Exception as e:
logger.warning("压缩记忆向量化失败: %s", e)
if db:
db.rollback()
finally:
if db:
db.close()
def trim_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
裁剪消息列表:保留最近的 N 条,但始终保留第一条 system 消息。
同时保证 assistant(tool_calls) 与 tool 消息的配对完整性:
如果裁剪边界落在 assistant(tool_calls) 和其 tool 结果之间,
则向前扩展窗口包含该 assistant 消息,避免孤立的 tool 消息。
"""
if len(messages) <= self.max_history:
return messages
system_msgs = [m for m in messages if m.get("role") == "system"]
other_msgs = [m for m in messages if m.get("role") != "system"]
max_keep = max(1, self.max_history - len(system_msgs))
start_idx = max(0, len(other_msgs) - max_keep)
# 如果裁剪后第一条是 tool 消息,向前找到其父 assistant(tool_calls)
if start_idx > 0 and start_idx < len(other_msgs) and other_msgs[start_idx].get("role") == "tool":
# 收集从 start_idx 开始连续的所有 tool 消息
tool_count = 0
for i in range(start_idx, len(other_msgs)):
if other_msgs[i].get("role") == "tool":
tool_count += 1
else:
break
# 向前查找对应的 assistant(tool_calls),一个 assistant 可包含多个 tool_calls
needed = tool_count
cursor = start_idx - 1
while cursor >= 0 and needed > 0:
role = other_msgs[cursor].get("role")
if role == "assistant" and other_msgs[cursor].get("tool_calls"):
needed -= len(other_msgs[cursor]["tool_calls"])
elif role == "user":
# 遇到 user 说明上一轮已结束,放弃扩展
break
cursor -= 1
if needed <= 0:
# 找到了所有父 assistant 消息,扩展窗口
start_idx = cursor + 1
trimmed = other_msgs[start_idx:]
# 最终安全检查:移除开头仍存在的孤立 tool 消息
while trimmed and trimmed[0].get("role") == "tool":
trimmed.pop(0)
return system_msgs + trimmed
@staticmethod
def _summarize_history(history: List[Dict[str, Any]]) -> str:
"""汇总历史对话。"""
turns = 0
for m in history:
if m.get("role") == "user":
turns += 1
return f"{turns} 轮历史对话(详情已存入长期记忆)"
@staticmethod
def _infer_memory_type(user_message: str, assistant_reply: str) -> str:
"""
根据对话内容推断记忆类型 (user / feedback / project / reference)。
基于关键词快速分类,不做 LLM 调用。
"""
combined = (user_message + " " + assistant_reply).lower()
# feedback: 纠错、反馈、报错
feedback_keywords = [
"不对", "错误", "错了", "报错", "bug", "不正确", "有问题",
"改一下", "修正", "纠正", "不要这样", "不行", "不是这个",
"不对的", "反馈", "建议", "应该", "能不能", "可以不要",
]
if any(kw in combined for kw in feedback_keywords):
return "feedback"
# reference: 链接、配置、系统信息
reference_keywords = [
"http://", "https://", "配置", ".env", "api", "端口",
"数据库", "地址", "密码", "密钥", "token", "url",
"路径", "文件", "目录", "安装", "部署",
]
if any(kw in combined for kw in reference_keywords):
return "reference"
# project: 任务、目标、进度
project_keywords = [
"任务", "目标", "进度", "完成", "计划", "需求", "项目",
"开发", "测试", "上线", "版本", "发布", "迭代",
"bug", "修复", "功能", "实现", "提交",
]
if any(kw in combined for kw in project_keywords):
return "project"
# user: 默认,包含偏好、个人信息等
return "user"