""" 知识库 RAG API。 提供知识库管理、文档上传、语义搜索和 RAG 查询接口。 """ from __future__ import annotations import logging from typing import Any, Dict, List, Optional from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile from pydantic import BaseModel from sqlalchemy.orm import Session from app.core.database import get_db from app.api.auth import get_current_user from app.models.user import User from app.services.knowledge_service import ( create_knowledge_base, delete_document, delete_knowledge_base, get_knowledge_base, list_documents, list_knowledge_bases, rag_query, search, upload_document, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/knowledge-bases", tags=["knowledge-base"]) # ─── Schema ────────────────────────────────────────────────────── class KBCreateRequest(BaseModel): name: str description: str = "" chunk_size: int = 500 chunk_overlap: int = 50 class KBResponse(BaseModel): id: str name: str description: Optional[str] = "" user_id: Optional[str] = "" chunk_size: int = 500 chunk_overlap: int = 50 doc_count: int = 0 created_at: Optional[str] = "" updated_at: Optional[str] = "" class DocumentResponse(BaseModel): id: str kb_id: str filename: str file_type: str file_size: int = 0 status: str = "pending" error_message: Optional[str] = None chunk_count: int = 0 created_at: Optional[str] = None class SearchRequest(BaseModel): query: str top_k: int = 5 min_score: float = 0.3 class SearchResult(BaseModel): chunk_id: str content: str score: float metadata: Dict[str, Any] = {} class SearchResponse(BaseModel): results: List[SearchResult] = [] class RAGRequest(BaseModel): query: str top_k: int = 5 min_score: float = 0.3 class RAGResponse(BaseModel): query: str context: str = "" sources: List[Dict[str, Any]] = [] found: bool = False # ─── 知识库 CRUD ────────────────────────────────────────────── @router.post("", response_model=KBResponse) async def api_create_kb( req: KBCreateRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """创建知识库。""" kb = create_knowledge_base( db=db, name=req.name, user_id=current_user.id, description=req.description, chunk_size=req.chunk_size, chunk_overlap=req.chunk_overlap, ) return KBResponse(**kb.to_dict()) @router.get("", response_model=List[KBResponse]) async def api_list_kb( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """列出知识库。""" kbs = list_knowledge_bases(db, user_id=current_user.id) return [KBResponse(**kb.to_dict()) for kb in kbs] @router.get("/{kb_id}", response_model=KBResponse) async def api_get_kb( kb_id: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """获取知识库详情。""" kb = get_knowledge_base(db, kb_id) if not kb: raise HTTPException(status_code=404, detail="知识库不存在") return KBResponse(**kb.to_dict()) @router.delete("/{kb_id}") async def api_delete_kb( kb_id: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """删除知识库。""" ok = delete_knowledge_base(db, kb_id) if not ok: raise HTTPException(status_code=404, detail="知识库不存在") return {"message": "知识库已删除"} # ─── 文档管理 ────────────────────────────────────────────────── @router.post("/{kb_id}/documents", response_model=DocumentResponse) async def api_upload_document( kb_id: str, file: UploadFile = File(...), current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """上传文档到知识库(自动解析、分块、生成 Embedding)。""" if not file.filename: raise HTTPException(status_code=400, detail="文件名不能为空") content = await file.read() if not content: raise HTTPException(status_code=400, detail="文件内容为空") try: doc = await upload_document( db=db, kb_id=kb_id, filename=file.filename, file_content=content, ) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) if doc.status == "failed": # 返回成功但告知处理失败 return DocumentResponse(**doc.to_dict()) return DocumentResponse(**doc.to_dict()) @router.get("/{kb_id}/documents", response_model=List[DocumentResponse]) async def api_list_documents( kb_id: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """列出知识库中的文档。""" docs = list_documents(db, kb_id) return [DocumentResponse(**d.to_dict()) for d in docs] @router.delete("/{kb_id}/documents/{doc_id}") async def api_delete_document( kb_id: str, doc_id: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """删除文档。""" ok = delete_document(db, doc_id) if not ok: raise HTTPException(status_code=404, detail="文档不存在") return {"message": "文档已删除"} # ─── 搜索 & RAG ──────────────────────────────────────────────── @router.post("/{kb_id}/search", response_model=SearchResponse) async def api_search( kb_id: str, req: SearchRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """语义搜索知识库。""" results = await search( db=db, kb_id=kb_id, query=req.query, top_k=req.top_k, min_score=req.min_score, ) return SearchResponse(results=[SearchResult(**r) for r in results]) @router.post("/{kb_id}/rag", response_model=RAGResponse) async def api_rag( kb_id: str, req: RAGRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """RAG 查询:搜索相关片段并格式化为上下文。""" result = await rag_query( db=db, kb_id=kb_id, query=req.query, top_k=req.top_k, min_score=req.min_score, ) return RAGResponse(**result)