""" 知识库服务 — 文档管理、分块、Embedding、语义检索、RAG。 流程: 上传 → 解析 → 分块 → Embedding → 存储 检索 → 查询 → Embedding → 余弦相似度 → Top-K RAG → 检索 + 格式化上下文 → 返回给 Agent/用户 """ from __future__ import annotations import json import logging import os import uuid from typing import Any, Dict, List, Optional from sqlalchemy.orm import Session from app.core.config import settings from app.models.knowledge_base import KnowledgeBase, Document, DocumentChunk from app.services.document_parser import parse_document from app.services.embedding_service import embedding_service, VectorEntry from app.services.text_chunker import chunk_text logger = logging.getLogger(__name__) # 上传文件存储根目录 UPLOAD_DIR = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "kb_uploads", ) def _ensure_upload_dir(): """确保上传目录存在。""" os.makedirs(UPLOAD_DIR, exist_ok=True) def _get_kb_dir(kb_id: str) -> str: """返回知识库文件存放目录。""" d = os.path.join(UPLOAD_DIR, kb_id) os.makedirs(d, exist_ok=True) return d # ─── 知识库 CRUD ──────────────────────────────────────────────── def create_knowledge_base( db: Session, name: str, user_id: str, description: str = "", chunk_size: int = 500, chunk_overlap: int = 50, ) -> KnowledgeBase: """创建知识库。""" kb = KnowledgeBase( name=name, description=description, user_id=user_id, chunk_size=max(50, min(2000, chunk_size)), chunk_overlap=max(0, min(chunk_size // 2, chunk_overlap)), ) db.add(kb) db.commit() db.refresh(kb) logger.info("知识库已创建: %s (%s)", kb.name, kb.id) return kb def list_knowledge_bases(db: Session, user_id: Optional[str] = None) -> List[KnowledgeBase]: """列出知识库。""" q = db.query(KnowledgeBase) if user_id: q = q.filter(KnowledgeBase.user_id == user_id) return q.order_by(KnowledgeBase.updated_at.desc()).all() def get_knowledge_base(db: Session, kb_id: str) -> Optional[KnowledgeBase]: """获取知识库详情。""" return db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() def delete_knowledge_base(db: Session, kb_id: str) -> bool: """删除知识库(连带文档和分块)。""" kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() if not kb: return False # 删除磁盘文件 kb_dir = _get_kb_dir(kb_id) if os.path.isdir(kb_dir): import shutil shutil.rmtree(kb_dir, ignore_errors=True) db.delete(kb) db.commit() logger.info("知识库已删除: %s", kb_id) return True # ─── 文档管理 ─────────────────────────────────────────────────── async def upload_document( db: Session, kb_id: str, filename: str, file_content: bytes, ) -> Document: """ 上传文档到知识库: 1. 保存原始文件 2. 解析为文本 3. 分块 4. 生成 Embedding 5. 存储分块 """ kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == kb_id).first() if not kb: raise ValueError(f"知识库不存在: {kb_id}") _ensure_upload_dir() kb_dir = _get_kb_dir(kb_id) # 提取文件类型 ext = os.path.splitext(filename)[1].lower().lstrip(".") if ext in ("txt", "md", "pdf", "docx", "csv"): file_type = ext else: file_type = "txt" # 保存原始文件 file_path = os.path.join(kb_dir, f"{uuid.uuid4()}_{filename}") with open(file_path, "wb") as f: f.write(file_content) file_size = len(file_content) # 创建文档记录 doc = Document( kb_id=kb_id, filename=filename, file_type=file_type, file_size=file_size, status="processing", ) db.add(doc) db.commit() db.refresh(doc) try: # 解析文本 text = parse_document(file_path, file_type) if not text: doc.status = "failed" doc.error_message = "文档解析失败或内容为空" db.commit() return doc # 分块 chunks = chunk_text( text, chunk_size=kb.chunk_size, chunk_overlap=kb.chunk_overlap, ) if not chunks: doc.status = "failed" doc.error_message = "文档分块后为空" db.commit() return doc logger.info("文档分块完成: %s → %d 块", filename, len(chunks)) # 批量生成 Embedding embeddings: List[Optional[List[float]]] = [] try: embeddings = await embedding_service.generate_embeddings(chunks) except Exception as e: logger.warning("批量生成 embedding 失败,逐块回退: %s", e) for c in chunks: try: emb = await embedding_service.generate_embedding(c) embeddings.append(emb) except Exception: embeddings.append(None) # 存储分块 chunk_records = [] for i, (chunk_text_content, emb) in enumerate(zip(chunks, embeddings)): record = DocumentChunk( document_id=doc.id, kb_id=kb_id, chunk_index=i, content=chunk_text_content, embedding=json.dumps(emb) if emb else None, metadata_={ "filename": filename, "file_type": file_type, "chunk_index": i, "has_embedding": emb is not None, }, ) chunk_records.append(record) db.add_all(chunk_records) # 更新文档状态 doc.status = "ready" doc.chunk_count = len(chunks) kb.doc_count = db.query(Document).filter( Document.kb_id == kb_id, Document.status == "ready" ).count() db.commit() logger.info("文档处理完成: %s (%d 块, embedding=%s)", filename, len(chunks), "yes" if any(e for e in embeddings) else "no") except Exception as e: db.rollback() doc = db.query(Document).filter(Document.id == doc.id).first() if doc: doc.status = "failed" doc.error_message = str(e)[:500] db.commit() logger.error("文档处理失败: %s: %s", filename, e, exc_info=True) return doc def list_documents(db: Session, kb_id: str) -> List[Document]: """列出知识库中的文档。""" return ( db.query(Document) .filter(Document.kb_id == kb_id) .order_by(Document.created_at.desc()) .all() ) def delete_document(db: Session, doc_id: str) -> bool: """删除文档(连带分块)。""" doc = db.query(Document).filter(Document.id == doc_id).first() if not doc: return False # 减少知识库文档计数 kb = db.query(KnowledgeBase).filter(KnowledgeBase.id == doc.kb_id).first() db.delete(doc) if kb: kb.doc_count = db.query(Document).filter( Document.kb_id == kb.id, Document.status == "ready" ).count() db.commit() logger.info("文档已删除: %s", doc_id) return True # ─── 语义检索 ─────────────────────────────────────────────────── async def search( db: Session, kb_id: str, query: str, top_k: int = 5, min_score: float = 0.3, ) -> List[Dict[str, Any]]: """ 语义搜索知识库。 流程:查询文本 → Embedding → 余弦相似度匹配所有分块 → Top-K """ # 生成查询 embedding query_emb = await embedding_service.generate_embedding(query) if not query_emb: logger.warning("搜索失败:无法生成查询 embedding") return [] # 加载该知识库所有分块 chunks = ( db.query(DocumentChunk) .filter(DocumentChunk.kb_id == kb_id) .all() ) if not chunks: return [] # 构建 VectorEntry 列表 entries: List[VectorEntry] = [] for c in chunks: if not c.embedding: continue try: emb = json.loads(c.embedding) if isinstance(c.embedding, str) else c.embedding except (json.JSONDecodeError, TypeError): continue entries.append({ "id": c.id, "scope_kind": "kb", "scope_id": kb_id, "content_text": c.content, "embedding": emb, "metadata": c.metadata_ or {}, }) if not entries: return [] # 相似度搜索 matched = await embedding_service.similarity_search( query_emb, entries, top_k=top_k, min_score=min_score ) results = [] for m in matched: results.append({ "chunk_id": m["id"], "content": m["content_text"], "score": m["score"], "metadata": m.get("metadata", {}), }) return results async def rag_query( db: Session, kb_id: str, query: str, top_k: int = 5, min_score: float = 0.3, ) -> Dict[str, Any]: """ RAG 查询:搜索相关片段 + 格式化为上下文。 返回: { "query": "...", "context": "根据以下资料回答:\n\n[片段1]\n[片段2]\n...", "sources": [...], } """ results = await search(db, kb_id, query, top_k=top_k, min_score=min_score) if not results: return { "query": query, "context": "", "sources": [], "found": False, } # 格式化上下文 lines = ["根据以下资料回答用户问题:\n"] for i, r in enumerate(results, 1): source = r["metadata"].get("filename", "未知来源") lines.append(f"[{i}] (来源: {source}):\n{r['content']}\n") context = "\n".join(lines) sources = [ { "content": r["content"], "score": r["score"], "source": r["metadata"].get("filename", "未知"), } for r in results ] return { "query": query, "context": context, "sources": sources, "found": True, }