Files
aiagent/backend/app/services/knowledge_service.py
renjianbo 7b9e0826de feat: 向量记忆 RAG、工具市场、SSE 流式响应、前端集成与测试覆盖
- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser
- 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行)
- 新增 agent_vector_memory / knowledge_base 模型及对应数据库表
- 实现 SSE 流式响应与 Agent 预算控制
- AgentChat.vue 集成 MainLayout 导航布局
- 完善测试体系:7 个新测试文件共 110 个测试覆盖
- 修复 conftest.py SQLite 内存数据库连接隔离问题

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-01 22:30:46 +08:00

376 lines
10 KiB
Python

"""
知识库服务 — 文档管理、分块、Embedding、语义检索、RAG。
流程:
上传 → 解析 → 分块 → Embedding → 存储
检索 → 查询 → Embedding → 余弦相似度 → Top-K
RAG → 检索 + 格式化上下文 → 返回给 Agent/用户
"""
from __future__ import annotations
import json
import logging
import os
import uuid
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from app.core.config import settings
from app.models.knowledge_base import KnowledgeBase, Document, DocumentChunk
from app.services.document_parser import parse_document
from app.services.embedding_service import embedding_service, VectorEntry
from app.services.text_chunker import chunk_text
logger = logging.getLogger(__name__)
# 上传文件存储根目录
UPLOAD_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
"kb_uploads",
)
def _ensure_upload_dir():
"""确保上传目录存在。"""
os.makedirs(UPLOAD_DIR, exist_ok=True)
def _get_kb_dir(kb_id: str) -> str:
"""返回知识库文件存放目录。"""
d = os.path.join(UPLOAD_DIR, kb_id)
os.makedirs(d, exist_ok=True)
return d
# ─── 知识库 CRUD ────────────────────────────────────────────────
def create_knowledge_base(
db: Session,
name: str,
user_id: str,
description: str = "",
chunk_size: int = 500,
chunk_overlap: int = 50,
) -> KnowledgeBase:
"""创建知识库。"""
kb = KnowledgeBase(
name=name,
description=description,
user_id=user_id,
chunk_size=max(50, min(2000, chunk_size)),
chunk_overlap=max(0, min(chunk_size // 2, chunk_overlap)),
)
db.add(kb)
db.commit()
db.refresh(kb)
logger.info("知识库已创建: %s (%s)", kb.name, kb.id)
return kb
def list_knowledge_bases(db: Session, user_id: Optional[str] = None) -> List[KnowledgeBase]:
"""列出知识库。"""
q = db.query(KnowledgeBase)
if user_id:
q = q.filter(KnowledgeBase.user_id == user_id)
return q.order_by(KnowledgeBase.updated_at.desc()).all()
def get_knowledge_base(db: Session, kb_id: str) -> Optional[KnowledgeBase]:
"""获取知识库详情。"""
return db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
def delete_knowledge_base(db: Session, kb_id: str) -> bool:
"""删除知识库(连带文档和分块)。"""
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
if not kb:
return False
# 删除磁盘文件
kb_dir = _get_kb_dir(kb_id)
if os.path.isdir(kb_dir):
import shutil
shutil.rmtree(kb_dir, ignore_errors=True)
db.delete(kb)
db.commit()
logger.info("知识库已删除: %s", kb_id)
return True
# ─── 文档管理 ───────────────────────────────────────────────────
async def upload_document(
db: Session,
kb_id: str,
filename: str,
file_content: bytes,
) -> Document:
"""
上传文档到知识库:
1. 保存原始文件
2. 解析为文本
3. 分块
4. 生成 Embedding
5. 存储分块
"""
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
if not kb:
raise ValueError(f"知识库不存在: {kb_id}")
_ensure_upload_dir()
kb_dir = _get_kb_dir(kb_id)
# 提取文件类型
ext = os.path.splitext(filename)[1].lower().lstrip(".")
if ext in ("txt", "md", "pdf", "docx", "csv"):
file_type = ext
else:
file_type = "txt"
# 保存原始文件
file_path = os.path.join(kb_dir, f"{uuid.uuid4()}_{filename}")
with open(file_path, "wb") as f:
f.write(file_content)
file_size = len(file_content)
# 创建文档记录
doc = Document(
kb_id=kb_id,
filename=filename,
file_type=file_type,
file_size=file_size,
status="processing",
)
db.add(doc)
db.commit()
db.refresh(doc)
try:
# 解析文本
text = parse_document(file_path, file_type)
if not text:
doc.status = "failed"
doc.error_message = "文档解析失败或内容为空"
db.commit()
return doc
# 分块
chunks = chunk_text(
text,
chunk_size=kb.chunk_size,
chunk_overlap=kb.chunk_overlap,
)
if not chunks:
doc.status = "failed"
doc.error_message = "文档分块后为空"
db.commit()
return doc
logger.info("文档分块完成: %s%d", filename, len(chunks))
# 批量生成 Embedding
embeddings: List[Optional[List[float]]] = []
try:
embeddings = await embedding_service.generate_embeddings(chunks)
except Exception as e:
logger.warning("批量生成 embedding 失败,逐块回退: %s", e)
for c in chunks:
try:
emb = await embedding_service.generate_embedding(c)
embeddings.append(emb)
except Exception:
embeddings.append(None)
# 存储分块
chunk_records = []
for i, (chunk_text_content, emb) in enumerate(zip(chunks, embeddings)):
record = DocumentChunk(
document_id=doc.id,
kb_id=kb_id,
chunk_index=i,
content=chunk_text_content,
embedding=json.dumps(emb) if emb else None,
metadata_={
"filename": filename,
"file_type": file_type,
"chunk_index": i,
"has_embedding": emb is not None,
},
)
chunk_records.append(record)
db.add_all(chunk_records)
# 更新文档状态
doc.status = "ready"
doc.chunk_count = len(chunks)
kb.doc_count = db.query(Document).filter(
Document.kb_id == kb_id, Document.status == "ready"
).count()
db.commit()
logger.info("文档处理完成: %s (%d 块, embedding=%s)",
filename, len(chunks),
"yes" if any(e for e in embeddings) else "no")
except Exception as e:
db.rollback()
doc = db.query(Document).filter(Document.id == doc.id).first()
if doc:
doc.status = "failed"
doc.error_message = str(e)[:500]
db.commit()
logger.error("文档处理失败: %s: %s", filename, e, exc_info=True)
return doc
def list_documents(db: Session, kb_id: str) -> List[Document]:
"""列出知识库中的文档。"""
return (
db.query(Document)
.filter(Document.kb_id == kb_id)
.order_by(Document.created_at.desc())
.all()
)
def delete_document(db: Session, doc_id: str) -> bool:
"""删除文档(连带分块)。"""
doc = db.query(Document).filter(Document.id == doc_id).first()
if not doc:
return False
# 减少知识库文档计数
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == doc.kb_id).first()
db.delete(doc)
if kb:
kb.doc_count = db.query(Document).filter(
Document.kb_id == kb.id, Document.status == "ready"
).count()
db.commit()
logger.info("文档已删除: %s", doc_id)
return True
# ─── 语义检索 ───────────────────────────────────────────────────
async def search(
db: Session,
kb_id: str,
query: str,
top_k: int = 5,
min_score: float = 0.3,
) -> List[Dict[str, Any]]:
"""
语义搜索知识库。
流程:查询文本 → Embedding → 余弦相似度匹配所有分块 → Top-K
"""
# 生成查询 embedding
query_emb = await embedding_service.generate_embedding(query)
if not query_emb:
logger.warning("搜索失败:无法生成查询 embedding")
return []
# 加载该知识库所有分块
chunks = (
db.query(DocumentChunk)
.filter(DocumentChunk.kb_id == kb_id)
.all()
)
if not chunks:
return []
# 构建 VectorEntry 列表
entries: List[VectorEntry] = []
for c in chunks:
if not c.embedding:
continue
try:
emb = json.loads(c.embedding) if isinstance(c.embedding, str) else c.embedding
except (json.JSONDecodeError, TypeError):
continue
entries.append({
"id": c.id,
"scope_kind": "kb",
"scope_id": kb_id,
"content_text": c.content,
"embedding": emb,
"metadata": c.metadata_ or {},
})
if not entries:
return []
# 相似度搜索
matched = await embedding_service.similarity_search(
query_emb, entries, top_k=top_k, min_score=min_score
)
results = []
for m in matched:
results.append({
"chunk_id": m["id"],
"content": m["content_text"],
"score": m["score"],
"metadata": m.get("metadata", {}),
})
return results
async def rag_query(
db: Session,
kb_id: str,
query: str,
top_k: int = 5,
min_score: float = 0.3,
) -> Dict[str, Any]:
"""
RAG 查询:搜索相关片段 + 格式化为上下文。
返回:
{
"query": "...",
"context": "根据以下资料回答:\n\n[片段1]\n[片段2]\n...",
"sources": [...],
}
"""
results = await search(db, kb_id, query, top_k=top_k, min_score=min_score)
if not results:
return {
"query": query,
"context": "",
"sources": [],
"found": False,
}
# 格式化上下文
lines = ["根据以下资料回答用户问题:\n"]
for i, r in enumerate(results, 1):
source = r["metadata"].get("filename", "未知来源")
lines.append(f"[{i}] (来源: {source}):\n{r['content']}\n")
context = "\n".join(lines)
sources = [
{
"content": r["content"],
"score": r["score"],
"source": r["metadata"].get("filename", "未知"),
}
for r in results
]
return {
"query": query,
"context": context,
"sources": sources,
"found": True,
}