252 lines
6.7 KiB
Python
252 lines
6.7 KiB
Python
|
|
"""
|
|||
|
|
知识库 RAG API。
|
|||
|
|
|
|||
|
|
提供知识库管理、文档上传、语义搜索和 RAG 查询接口。
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
from typing import Any, Dict, List, Optional
|
|||
|
|
|
|||
|
|
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from app.core.database import get_db
|
|||
|
|
from app.api.auth import get_current_user
|
|||
|
|
from app.models.user import User
|
|||
|
|
from app.services.knowledge_service import (
|
|||
|
|
create_knowledge_base,
|
|||
|
|
delete_document,
|
|||
|
|
delete_knowledge_base,
|
|||
|
|
get_knowledge_base,
|
|||
|
|
list_documents,
|
|||
|
|
list_knowledge_bases,
|
|||
|
|
rag_query,
|
|||
|
|
search,
|
|||
|
|
upload_document,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
router = APIRouter(prefix="/api/v1/knowledge-bases", tags=["knowledge-base"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ─── Schema ──────────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
|
|||
|
|
class KBCreateRequest(BaseModel):
|
|||
|
|
name: str
|
|||
|
|
description: str = ""
|
|||
|
|
chunk_size: int = 500
|
|||
|
|
chunk_overlap: int = 50
|
|||
|
|
|
|||
|
|
|
|||
|
|
class KBResponse(BaseModel):
|
|||
|
|
id: str
|
|||
|
|
name: str
|
|||
|
|
description: Optional[str] = ""
|
|||
|
|
user_id: Optional[str] = ""
|
|||
|
|
chunk_size: int = 500
|
|||
|
|
chunk_overlap: int = 50
|
|||
|
|
doc_count: int = 0
|
|||
|
|
created_at: Optional[str] = ""
|
|||
|
|
updated_at: Optional[str] = ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DocumentResponse(BaseModel):
|
|||
|
|
id: str
|
|||
|
|
kb_id: str
|
|||
|
|
filename: str
|
|||
|
|
file_type: str
|
|||
|
|
file_size: int = 0
|
|||
|
|
status: str = "pending"
|
|||
|
|
error_message: Optional[str] = None
|
|||
|
|
chunk_count: int = 0
|
|||
|
|
created_at: Optional[str] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SearchRequest(BaseModel):
|
|||
|
|
query: str
|
|||
|
|
top_k: int = 5
|
|||
|
|
min_score: float = 0.3
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SearchResult(BaseModel):
|
|||
|
|
chunk_id: str
|
|||
|
|
content: str
|
|||
|
|
score: float
|
|||
|
|
metadata: Dict[str, Any] = {}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SearchResponse(BaseModel):
|
|||
|
|
results: List[SearchResult] = []
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RAGRequest(BaseModel):
|
|||
|
|
query: str
|
|||
|
|
top_k: int = 5
|
|||
|
|
min_score: float = 0.3
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RAGResponse(BaseModel):
|
|||
|
|
query: str
|
|||
|
|
context: str = ""
|
|||
|
|
sources: List[Dict[str, Any]] = []
|
|||
|
|
found: bool = False
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ─── 知识库 CRUD ──────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("", response_model=KBResponse)
|
|||
|
|
async def api_create_kb(
|
|||
|
|
req: KBCreateRequest,
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
"""创建知识库。"""
|
|||
|
|
kb = create_knowledge_base(
|
|||
|
|
db=db,
|
|||
|
|
name=req.name,
|
|||
|
|
user_id=current_user.id,
|
|||
|
|
description=req.description,
|
|||
|
|
chunk_size=req.chunk_size,
|
|||
|
|
chunk_overlap=req.chunk_overlap,
|
|||
|
|
)
|
|||
|
|
return KBResponse(**kb.to_dict())
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("", response_model=List[KBResponse])
|
|||
|
|
async def api_list_kb(
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
"""列出知识库。"""
|
|||
|
|
kbs = list_knowledge_bases(db, user_id=current_user.id)
|
|||
|
|
return [KBResponse(**kb.to_dict()) for kb in kbs]
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/{kb_id}", response_model=KBResponse)
|
|||
|
|
async def api_get_kb(
|
|||
|
|
kb_id: str,
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
"""获取知识库详情。"""
|
|||
|
|
kb = get_knowledge_base(db, kb_id)
|
|||
|
|
if not kb:
|
|||
|
|
raise HTTPException(status_code=404, detail="知识库不存在")
|
|||
|
|
return KBResponse(**kb.to_dict())
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.delete("/{kb_id}")
|
|||
|
|
async def api_delete_kb(
|
|||
|
|
kb_id: str,
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
"""删除知识库。"""
|
|||
|
|
ok = delete_knowledge_base(db, kb_id)
|
|||
|
|
if not ok:
|
|||
|
|
raise HTTPException(status_code=404, detail="知识库不存在")
|
|||
|
|
return {"message": "知识库已删除"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ─── 文档管理 ──────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/{kb_id}/documents", response_model=DocumentResponse)
|
|||
|
|
async def api_upload_document(
|
|||
|
|
kb_id: str,
|
|||
|
|
file: UploadFile = File(...),
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
"""上传文档到知识库(自动解析、分块、生成 Embedding)。"""
|
|||
|
|
if not file.filename:
|
|||
|
|
raise HTTPException(status_code=400, detail="文件名不能为空")
|
|||
|
|
|
|||
|
|
content = await file.read()
|
|||
|
|
if not content:
|
|||
|
|
raise HTTPException(status_code=400, detail="文件内容为空")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
doc = await upload_document(
|
|||
|
|
db=db,
|
|||
|
|
kb_id=kb_id,
|
|||
|
|
filename=file.filename,
|
|||
|
|
file_content=content,
|
|||
|
|
)
|
|||
|
|
except ValueError as e:
|
|||
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|||
|
|
|
|||
|
|
if doc.status == "failed":
|
|||
|
|
# 返回成功但告知处理失败
|
|||
|
|
return DocumentResponse(**doc.to_dict())
|
|||
|
|
|
|||
|
|
return DocumentResponse(**doc.to_dict())
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/{kb_id}/documents", response_model=List[DocumentResponse])
|
|||
|
|
async def api_list_documents(
|
|||
|
|
kb_id: str,
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
"""列出知识库中的文档。"""
|
|||
|
|
docs = list_documents(db, kb_id)
|
|||
|
|
return [DocumentResponse(**d.to_dict()) for d in docs]
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.delete("/{kb_id}/documents/{doc_id}")
|
|||
|
|
async def api_delete_document(
|
|||
|
|
kb_id: str,
|
|||
|
|
doc_id: str,
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
"""删除文档。"""
|
|||
|
|
ok = delete_document(db, doc_id)
|
|||
|
|
if not ok:
|
|||
|
|
raise HTTPException(status_code=404, detail="文档不存在")
|
|||
|
|
return {"message": "文档已删除"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ─── 搜索 & RAG ────────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/{kb_id}/search", response_model=SearchResponse)
|
|||
|
|
async def api_search(
|
|||
|
|
kb_id: str,
|
|||
|
|
req: SearchRequest,
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
"""语义搜索知识库。"""
|
|||
|
|
results = await search(
|
|||
|
|
db=db,
|
|||
|
|
kb_id=kb_id,
|
|||
|
|
query=req.query,
|
|||
|
|
top_k=req.top_k,
|
|||
|
|
min_score=req.min_score,
|
|||
|
|
)
|
|||
|
|
return SearchResponse(results=[SearchResult(**r) for r in results])
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/{kb_id}/rag", response_model=RAGResponse)
|
|||
|
|
async def api_rag(
|
|||
|
|
kb_id: str,
|
|||
|
|
req: RAGRequest,
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
):
|
|||
|
|
"""RAG 查询:搜索相关片段并格式化为上下文。"""
|
|||
|
|
result = await rag_query(
|
|||
|
|
db=db,
|
|||
|
|
kb_id=kb_id,
|
|||
|
|
query=req.query,
|
|||
|
|
top_k=req.top_k,
|
|||
|
|
min_score=req.min_score,
|
|||
|
|
)
|
|||
|
|
return RAGResponse(**result)
|