feat: 向量记忆 RAG、工具市场、SSE 流式响应、前端集成与测试覆盖
- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser - 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行) - 新增 agent_vector_memory / knowledge_base 模型及对应数据库表 - 实现 SSE 流式响应与 Agent 预算控制 - AgentChat.vue 集成 MainLayout 导航布局 - 完善测试体系:7 个新测试文件共 110 个测试覆盖 - 修复 conftest.py SQLite 内存数据库连接隔离问题 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from app.services.persistent_memory_service import (
|
||||
save_persistent_memory,
|
||||
persist_enabled,
|
||||
)
|
||||
from app.services.embedding_service import embedding_service, VectorEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,25 +36,31 @@ class AgentMemory:
|
||||
session_key: Optional[str] = None,
|
||||
persist: bool = True,
|
||||
max_history: int = 20,
|
||||
vector_memory_enabled: bool = True,
|
||||
vector_memory_top_k: int = 5,
|
||||
):
|
||||
self.scope_kind = scope_kind
|
||||
self.scope_id = scope_id or "default"
|
||||
self.session_key = session_key or "default_session"
|
||||
self.persist = persist and persist_enabled()
|
||||
self.max_history = max_history
|
||||
self.vector_memory_enabled = vector_memory_enabled
|
||||
self.vector_memory_top_k = vector_memory_top_k
|
||||
# 从长期记忆加载的上下文(启动时加载)
|
||||
self._long_term_context: Dict[str, Any] = {}
|
||||
# 记录已压缩的消息数,避免重复压缩
|
||||
self._last_compressed_msg_count = 0
|
||||
|
||||
async def initialize(self) -> str:
|
||||
async def initialize(self, query: str = "") -> str:
|
||||
"""
|
||||
初始化记忆:从 DB/Redis 加载长期记忆,构造初始上下文文本。
|
||||
初始化记忆:从 DB/Redis 加载长期记忆 + 向量检索相关历史。
|
||||
返回注入 system prompt 的记忆文本块。
|
||||
"""
|
||||
if not self.persist or not self.scope_id:
|
||||
return ""
|
||||
|
||||
parts: List[str] = []
|
||||
|
||||
db: Optional[Session] = None
|
||||
try:
|
||||
db = SessionLocal()
|
||||
@@ -62,8 +69,6 @@ class AgentMemory:
|
||||
)
|
||||
if payload and isinstance(payload, dict):
|
||||
self._long_term_context = payload
|
||||
# 构建注入 system prompt 的记忆文本
|
||||
parts = []
|
||||
profile = payload.get("user_profile")
|
||||
if profile and isinstance(profile, dict):
|
||||
profile_text = json.dumps(profile, ensure_ascii=False)
|
||||
@@ -78,15 +83,93 @@ class AgentMemory:
|
||||
if history and isinstance(history, list) and len(history) > 0:
|
||||
summary = self._summarize_history(history)
|
||||
parts.append(f"## 历史对话摘要\n{summary}")
|
||||
|
||||
if parts:
|
||||
return "\n\n".join(parts)
|
||||
except Exception as e:
|
||||
logger.warning("加载长期记忆失败: %s", e)
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
return ""
|
||||
|
||||
# 2. 向量检索:查找语义相关的历史对话
|
||||
if self.vector_memory_enabled and self.scope_kind and self.scope_id:
|
||||
vector_text = await self._vector_search(query)
|
||||
if vector_text:
|
||||
parts.append(vector_text)
|
||||
|
||||
return "\n\n".join(parts) if parts else ""
|
||||
|
||||
async def _vector_search(self, query: str = "") -> str:
|
||||
"""
|
||||
向量检索语义相关的历史记忆,返回格式化的文本块。
|
||||
若无 query 则返回最近 Top-5 条记忆。
|
||||
"""
|
||||
from app.models.agent_vector_memory import AgentVectorMemory
|
||||
|
||||
db: Optional[Session] = None
|
||||
try:
|
||||
db = SessionLocal()
|
||||
# 查询当前 scope 的所有向量记忆(按时间倒序)
|
||||
rows = (
|
||||
db.query(AgentVectorMemory)
|
||||
.filter(
|
||||
AgentVectorMemory.scope_kind == self.scope_kind,
|
||||
AgentVectorMemory.scope_id == self.scope_id,
|
||||
)
|
||||
.order_by(AgentVectorMemory.created_at.desc())
|
||||
.limit(50) # 最多取最近 50 条做相似度计算
|
||||
.all()
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return ""
|
||||
|
||||
entries: List[VectorEntry] = []
|
||||
for row in rows:
|
||||
emb = embedding_service.deserialize_embedding(row.embedding) if row.embedding else []
|
||||
entries.append({
|
||||
"id": row.id,
|
||||
"scope_kind": row.scope_kind,
|
||||
"scope_id": row.scope_id,
|
||||
"content_text": row.content_text,
|
||||
"embedding": emb,
|
||||
"metadata": row.metadata_ or {},
|
||||
})
|
||||
|
||||
matched: List[VectorEntry] = []
|
||||
|
||||
if query and query.strip():
|
||||
# 有 query:生成 embedding 做语义搜索
|
||||
query_emb = await embedding_service.generate_embedding(query)
|
||||
if query_emb:
|
||||
matched = await embedding_service.similarity_search(
|
||||
query_emb, entries, top_k=self.vector_memory_top_k
|
||||
)
|
||||
else:
|
||||
# 无 query:返回最近几条
|
||||
matched = entries[: self.vector_memory_top_k]
|
||||
for m in matched:
|
||||
m["score"] = 1.0
|
||||
|
||||
if not matched:
|
||||
return ""
|
||||
|
||||
# 格式化为文本块
|
||||
lines = ["## 相关历史记忆"]
|
||||
for i, m in enumerate(matched, 1):
|
||||
text = m.get("content_text", "")[:500]
|
||||
meta = m.get("metadata", {})
|
||||
entry_type = meta.get("type", "对话")
|
||||
lines.append(f"{i}. [{entry_type}] {text}")
|
||||
if m.get("score", 1.0) < 1.0:
|
||||
lines[-1] += f" (匹配度: {m['score']:.2f})"
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("向量检索失败: %s", e)
|
||||
return ""
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
async def save_context(
|
||||
self, user_message: str, assistant_reply: str,
|
||||
@@ -114,12 +197,50 @@ class AgentMemory:
|
||||
db, self.scope_kind, self.scope_id,
|
||||
self.session_key, self._long_term_context,
|
||||
)
|
||||
|
||||
# 保存向量记忆(异步生成 embedding 并存储)
|
||||
if self.vector_memory_enabled:
|
||||
await self._save_vector_memory(
|
||||
db, user_message, assistant_reply
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("保存长期记忆失败: %s", e)
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
async def _save_vector_memory(
|
||||
self, db: Session, user_message: str, assistant_reply: str,
|
||||
) -> None:
|
||||
"""生成 embedding 并保存到向量记忆表。"""
|
||||
from app.models.agent_vector_memory import AgentVectorMemory
|
||||
|
||||
content_text = f"用户: {user_message}\n助手: {assistant_reply}"
|
||||
if len(content_text) > 8000:
|
||||
content_text = content_text[:8000]
|
||||
|
||||
try:
|
||||
# 生成 embedding
|
||||
embedding = await embedding_service.generate_embedding(content_text)
|
||||
embedding_json = embedding_service.serialize_embedding(embedding) if embedding else ""
|
||||
|
||||
record = AgentVectorMemory(
|
||||
scope_kind=self.scope_kind,
|
||||
scope_id=self.scope_id,
|
||||
session_key=self.session_key,
|
||||
content_text=content_text[:2000],
|
||||
embedding=embedding_json or None,
|
||||
metadata_={
|
||||
"type": "conversation_turn",
|
||||
},
|
||||
)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
logger.debug("已保存向量记忆 (scope=%s/%s)", self.scope_kind, self.scope_id)
|
||||
except Exception as e:
|
||||
logger.warning("保存向量记忆失败: %s", e)
|
||||
db.rollback()
|
||||
|
||||
async def _compress_and_summarize(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user