376 lines
10 KiB
Python
376 lines
10 KiB
Python
|
|
"""
|
|||
|
|
知识库服务 — 文档管理、分块、Embedding、语义检索、RAG。
|
|||
|
|
|
|||
|
|
流程:
|
|||
|
|
上传 → 解析 → 分块 → Embedding → 存储
|
|||
|
|
检索 → 查询 → Embedding → 余弦相似度 → Top-K
|
|||
|
|
RAG → 检索 + 格式化上下文 → 返回给 Agent/用户
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
import os
|
|||
|
|
import uuid
|
|||
|
|
from typing import Any, Dict, List, Optional
|
|||
|
|
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from app.core.config import settings
|
|||
|
|
from app.models.knowledge_base import KnowledgeBase, Document, DocumentChunk
|
|||
|
|
from app.services.document_parser import parse_document
|
|||
|
|
from app.services.embedding_service import embedding_service, VectorEntry
|
|||
|
|
from app.services.text_chunker import chunk_text
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
# 上传文件存储根目录
|
|||
|
|
UPLOAD_DIR = os.path.join(
|
|||
|
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
|
|||
|
|
"kb_uploads",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _ensure_upload_dir():
|
|||
|
|
"""确保上传目录存在。"""
|
|||
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _get_kb_dir(kb_id: str) -> str:
|
|||
|
|
"""返回知识库文件存放目录。"""
|
|||
|
|
d = os.path.join(UPLOAD_DIR, kb_id)
|
|||
|
|
os.makedirs(d, exist_ok=True)
|
|||
|
|
return d
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ─── 知识库 CRUD ────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_knowledge_base(
|
|||
|
|
db: Session,
|
|||
|
|
name: str,
|
|||
|
|
user_id: str,
|
|||
|
|
description: str = "",
|
|||
|
|
chunk_size: int = 500,
|
|||
|
|
chunk_overlap: int = 50,
|
|||
|
|
) -> KnowledgeBase:
|
|||
|
|
"""创建知识库。"""
|
|||
|
|
kb = KnowledgeBase(
|
|||
|
|
name=name,
|
|||
|
|
description=description,
|
|||
|
|
user_id=user_id,
|
|||
|
|
chunk_size=max(50, min(2000, chunk_size)),
|
|||
|
|
chunk_overlap=max(0, min(chunk_size // 2, chunk_overlap)),
|
|||
|
|
)
|
|||
|
|
db.add(kb)
|
|||
|
|
db.commit()
|
|||
|
|
db.refresh(kb)
|
|||
|
|
logger.info("知识库已创建: %s (%s)", kb.name, kb.id)
|
|||
|
|
return kb
|
|||
|
|
|
|||
|
|
|
|||
|
|
def list_knowledge_bases(db: Session, user_id: Optional[str] = None) -> List[KnowledgeBase]:
|
|||
|
|
"""列出知识库。"""
|
|||
|
|
q = db.query(KnowledgeBase)
|
|||
|
|
if user_id:
|
|||
|
|
q = q.filter(KnowledgeBase.user_id == user_id)
|
|||
|
|
return q.order_by(KnowledgeBase.updated_at.desc()).all()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_knowledge_base(db: Session, kb_id: str) -> Optional[KnowledgeBase]:
|
|||
|
|
"""获取知识库详情。"""
|
|||
|
|
return db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def delete_knowledge_base(db: Session, kb_id: str) -> bool:
|
|||
|
|
"""删除知识库(连带文档和分块)。"""
|
|||
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|||
|
|
if not kb:
|
|||
|
|
return False
|
|||
|
|
# 删除磁盘文件
|
|||
|
|
kb_dir = _get_kb_dir(kb_id)
|
|||
|
|
if os.path.isdir(kb_dir):
|
|||
|
|
import shutil
|
|||
|
|
shutil.rmtree(kb_dir, ignore_errors=True)
|
|||
|
|
db.delete(kb)
|
|||
|
|
db.commit()
|
|||
|
|
logger.info("知识库已删除: %s", kb_id)
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ─── 文档管理 ───────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def upload_document(
|
|||
|
|
db: Session,
|
|||
|
|
kb_id: str,
|
|||
|
|
filename: str,
|
|||
|
|
file_content: bytes,
|
|||
|
|
) -> Document:
|
|||
|
|
"""
|
|||
|
|
上传文档到知识库:
|
|||
|
|
1. 保存原始文件
|
|||
|
|
2. 解析为文本
|
|||
|
|
3. 分块
|
|||
|
|
4. 生成 Embedding
|
|||
|
|
5. 存储分块
|
|||
|
|
"""
|
|||
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first()
|
|||
|
|
if not kb:
|
|||
|
|
raise ValueError(f"知识库不存在: {kb_id}")
|
|||
|
|
|
|||
|
|
_ensure_upload_dir()
|
|||
|
|
kb_dir = _get_kb_dir(kb_id)
|
|||
|
|
|
|||
|
|
# 提取文件类型
|
|||
|
|
ext = os.path.splitext(filename)[1].lower().lstrip(".")
|
|||
|
|
if ext in ("txt", "md", "pdf", "docx", "csv"):
|
|||
|
|
file_type = ext
|
|||
|
|
else:
|
|||
|
|
file_type = "txt"
|
|||
|
|
|
|||
|
|
# 保存原始文件
|
|||
|
|
file_path = os.path.join(kb_dir, f"{uuid.uuid4()}_{filename}")
|
|||
|
|
with open(file_path, "wb") as f:
|
|||
|
|
f.write(file_content)
|
|||
|
|
file_size = len(file_content)
|
|||
|
|
|
|||
|
|
# 创建文档记录
|
|||
|
|
doc = Document(
|
|||
|
|
kb_id=kb_id,
|
|||
|
|
filename=filename,
|
|||
|
|
file_type=file_type,
|
|||
|
|
file_size=file_size,
|
|||
|
|
status="processing",
|
|||
|
|
)
|
|||
|
|
db.add(doc)
|
|||
|
|
db.commit()
|
|||
|
|
db.refresh(doc)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 解析文本
|
|||
|
|
text = parse_document(file_path, file_type)
|
|||
|
|
if not text:
|
|||
|
|
doc.status = "failed"
|
|||
|
|
doc.error_message = "文档解析失败或内容为空"
|
|||
|
|
db.commit()
|
|||
|
|
return doc
|
|||
|
|
|
|||
|
|
# 分块
|
|||
|
|
chunks = chunk_text(
|
|||
|
|
text,
|
|||
|
|
chunk_size=kb.chunk_size,
|
|||
|
|
chunk_overlap=kb.chunk_overlap,
|
|||
|
|
)
|
|||
|
|
if not chunks:
|
|||
|
|
doc.status = "failed"
|
|||
|
|
doc.error_message = "文档分块后为空"
|
|||
|
|
db.commit()
|
|||
|
|
return doc
|
|||
|
|
|
|||
|
|
logger.info("文档分块完成: %s → %d 块", filename, len(chunks))
|
|||
|
|
|
|||
|
|
# 批量生成 Embedding
|
|||
|
|
embeddings: List[Optional[List[float]]] = []
|
|||
|
|
try:
|
|||
|
|
embeddings = await embedding_service.generate_embeddings(chunks)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning("批量生成 embedding 失败,逐块回退: %s", e)
|
|||
|
|
for c in chunks:
|
|||
|
|
try:
|
|||
|
|
emb = await embedding_service.generate_embedding(c)
|
|||
|
|
embeddings.append(emb)
|
|||
|
|
except Exception:
|
|||
|
|
embeddings.append(None)
|
|||
|
|
|
|||
|
|
# 存储分块
|
|||
|
|
chunk_records = []
|
|||
|
|
for i, (chunk_text_content, emb) in enumerate(zip(chunks, embeddings)):
|
|||
|
|
record = DocumentChunk(
|
|||
|
|
document_id=doc.id,
|
|||
|
|
kb_id=kb_id,
|
|||
|
|
chunk_index=i,
|
|||
|
|
content=chunk_text_content,
|
|||
|
|
embedding=json.dumps(emb) if emb else None,
|
|||
|
|
metadata_={
|
|||
|
|
"filename": filename,
|
|||
|
|
"file_type": file_type,
|
|||
|
|
"chunk_index": i,
|
|||
|
|
"has_embedding": emb is not None,
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
chunk_records.append(record)
|
|||
|
|
|
|||
|
|
db.add_all(chunk_records)
|
|||
|
|
|
|||
|
|
# 更新文档状态
|
|||
|
|
doc.status = "ready"
|
|||
|
|
doc.chunk_count = len(chunks)
|
|||
|
|
kb.doc_count = db.query(Document).filter(
|
|||
|
|
Document.kb_id == kb_id, Document.status == "ready"
|
|||
|
|
).count()
|
|||
|
|
|
|||
|
|
db.commit()
|
|||
|
|
logger.info("文档处理完成: %s (%d 块, embedding=%s)",
|
|||
|
|
filename, len(chunks),
|
|||
|
|
"yes" if any(e for e in embeddings) else "no")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
db.rollback()
|
|||
|
|
doc = db.query(Document).filter(Document.id == doc.id).first()
|
|||
|
|
if doc:
|
|||
|
|
doc.status = "failed"
|
|||
|
|
doc.error_message = str(e)[:500]
|
|||
|
|
db.commit()
|
|||
|
|
logger.error("文档处理失败: %s: %s", filename, e, exc_info=True)
|
|||
|
|
|
|||
|
|
return doc
|
|||
|
|
|
|||
|
|
|
|||
|
|
def list_documents(db: Session, kb_id: str) -> List[Document]:
|
|||
|
|
"""列出知识库中的文档。"""
|
|||
|
|
return (
|
|||
|
|
db.query(Document)
|
|||
|
|
.filter(Document.kb_id == kb_id)
|
|||
|
|
.order_by(Document.created_at.desc())
|
|||
|
|
.all()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def delete_document(db: Session, doc_id: str) -> bool:
|
|||
|
|
"""删除文档(连带分块)。"""
|
|||
|
|
doc = db.query(Document).filter(Document.id == doc_id).first()
|
|||
|
|
if not doc:
|
|||
|
|
return False
|
|||
|
|
# 减少知识库文档计数
|
|||
|
|
kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == doc.kb_id).first()
|
|||
|
|
db.delete(doc)
|
|||
|
|
if kb:
|
|||
|
|
kb.doc_count = db.query(Document).filter(
|
|||
|
|
Document.kb_id == kb.id, Document.status == "ready"
|
|||
|
|
).count()
|
|||
|
|
db.commit()
|
|||
|
|
logger.info("文档已删除: %s", doc_id)
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ─── 语义检索 ───────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def search(
|
|||
|
|
db: Session,
|
|||
|
|
kb_id: str,
|
|||
|
|
query: str,
|
|||
|
|
top_k: int = 5,
|
|||
|
|
min_score: float = 0.3,
|
|||
|
|
) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
语义搜索知识库。
|
|||
|
|
|
|||
|
|
流程:查询文本 → Embedding → 余弦相似度匹配所有分块 → Top-K
|
|||
|
|
"""
|
|||
|
|
# 生成查询 embedding
|
|||
|
|
query_emb = await embedding_service.generate_embedding(query)
|
|||
|
|
if not query_emb:
|
|||
|
|
logger.warning("搜索失败:无法生成查询 embedding")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
# 加载该知识库所有分块
|
|||
|
|
chunks = (
|
|||
|
|
db.query(DocumentChunk)
|
|||
|
|
.filter(DocumentChunk.kb_id == kb_id)
|
|||
|
|
.all()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not chunks:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
# 构建 VectorEntry 列表
|
|||
|
|
entries: List[VectorEntry] = []
|
|||
|
|
for c in chunks:
|
|||
|
|
if not c.embedding:
|
|||
|
|
continue
|
|||
|
|
try:
|
|||
|
|
emb = json.loads(c.embedding) if isinstance(c.embedding, str) else c.embedding
|
|||
|
|
except (json.JSONDecodeError, TypeError):
|
|||
|
|
continue
|
|||
|
|
entries.append({
|
|||
|
|
"id": c.id,
|
|||
|
|
"scope_kind": "kb",
|
|||
|
|
"scope_id": kb_id,
|
|||
|
|
"content_text": c.content,
|
|||
|
|
"embedding": emb,
|
|||
|
|
"metadata": c.metadata_ or {},
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
if not entries:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
# 相似度搜索
|
|||
|
|
matched = await embedding_service.similarity_search(
|
|||
|
|
query_emb, entries, top_k=top_k, min_score=min_score
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
results = []
|
|||
|
|
for m in matched:
|
|||
|
|
results.append({
|
|||
|
|
"chunk_id": m["id"],
|
|||
|
|
"content": m["content_text"],
|
|||
|
|
"score": m["score"],
|
|||
|
|
"metadata": m.get("metadata", {}),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return results
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def rag_query(
|
|||
|
|
db: Session,
|
|||
|
|
kb_id: str,
|
|||
|
|
query: str,
|
|||
|
|
top_k: int = 5,
|
|||
|
|
min_score: float = 0.3,
|
|||
|
|
) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
RAG 查询:搜索相关片段 + 格式化为上下文。
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
{
|
|||
|
|
"query": "...",
|
|||
|
|
"context": "根据以下资料回答:\n\n[片段1]\n[片段2]\n...",
|
|||
|
|
"sources": [...],
|
|||
|
|
}
|
|||
|
|
"""
|
|||
|
|
results = await search(db, kb_id, query, top_k=top_k, min_score=min_score)
|
|||
|
|
|
|||
|
|
if not results:
|
|||
|
|
return {
|
|||
|
|
"query": query,
|
|||
|
|
"context": "",
|
|||
|
|
"sources": [],
|
|||
|
|
"found": False,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 格式化上下文
|
|||
|
|
lines = ["根据以下资料回答用户问题:\n"]
|
|||
|
|
for i, r in enumerate(results, 1):
|
|||
|
|
source = r["metadata"].get("filename", "未知来源")
|
|||
|
|
lines.append(f"[{i}] (来源: {source}):\n{r['content']}\n")
|
|||
|
|
|
|||
|
|
context = "\n".join(lines)
|
|||
|
|
|
|||
|
|
sources = [
|
|||
|
|
{
|
|||
|
|
"content": r["content"],
|
|||
|
|
"score": r["score"],
|
|||
|
|
"source": r["metadata"].get("filename", "未知"),
|
|||
|
|
}
|
|||
|
|
for r in results
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"query": query,
|
|||
|
|
"context": context,
|
|||
|
|
"sources": sources,
|
|||
|
|
"found": True,
|
|||
|
|
}
|