Files
aiagent/backend/app/api/knowledge_base.py

252 lines
6.7 KiB
Python
Raw Normal View History

"""
知识库 RAG API
提供知识库管理文档上传语义搜索和 RAG 查询接口
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.auth import get_current_user
from app.models.user import User
from app.services.knowledge_service import (
create_knowledge_base,
delete_document,
delete_knowledge_base,
get_knowledge_base,
list_documents,
list_knowledge_bases,
rag_query,
search,
upload_document,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/knowledge-bases", tags=["knowledge-base"])
# ─── Schema ──────────────────────────────────────────────────────
class KBCreateRequest(BaseModel):
name: str
description: str = ""
chunk_size: int = 500
chunk_overlap: int = 50
class KBResponse(BaseModel):
id: str
name: str
description: Optional[str] = ""
user_id: Optional[str] = ""
chunk_size: int = 500
chunk_overlap: int = 50
doc_count: int = 0
created_at: Optional[str] = ""
updated_at: Optional[str] = ""
class DocumentResponse(BaseModel):
id: str
kb_id: str
filename: str
file_type: str
file_size: int = 0
status: str = "pending"
error_message: Optional[str] = None
chunk_count: int = 0
created_at: Optional[str] = None
class SearchRequest(BaseModel):
query: str
top_k: int = 5
min_score: float = 0.3
class SearchResult(BaseModel):
chunk_id: str
content: str
score: float
metadata: Dict[str, Any] = {}
class SearchResponse(BaseModel):
results: List[SearchResult] = []
class RAGRequest(BaseModel):
query: str
top_k: int = 5
min_score: float = 0.3
class RAGResponse(BaseModel):
query: str
context: str = ""
sources: List[Dict[str, Any]] = []
found: bool = False
# ─── 知识库 CRUD ──────────────────────────────────────────────
@router.post("", response_model=KBResponse)
async def api_create_kb(
req: KBCreateRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""创建知识库。"""
kb = create_knowledge_base(
db=db,
name=req.name,
user_id=current_user.id,
description=req.description,
chunk_size=req.chunk_size,
chunk_overlap=req.chunk_overlap,
)
return KBResponse(**kb.to_dict())
@router.get("", response_model=List[KBResponse])
async def api_list_kb(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""列出知识库。"""
kbs = list_knowledge_bases(db, user_id=current_user.id)
return [KBResponse(**kb.to_dict()) for kb in kbs]
@router.get("/{kb_id}", response_model=KBResponse)
async def api_get_kb(
kb_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取知识库详情。"""
kb = get_knowledge_base(db, kb_id)
if not kb:
raise HTTPException(status_code=404, detail="知识库不存在")
return KBResponse(**kb.to_dict())
@router.delete("/{kb_id}")
async def api_delete_kb(
kb_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除知识库。"""
ok = delete_knowledge_base(db, kb_id)
if not ok:
raise HTTPException(status_code=404, detail="知识库不存在")
return {"message": "知识库已删除"}
# ─── 文档管理 ──────────────────────────────────────────────────
@router.post("/{kb_id}/documents", response_model=DocumentResponse)
async def api_upload_document(
kb_id: str,
file: UploadFile = File(...),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""上传文档到知识库(自动解析、分块、生成 Embedding"""
if not file.filename:
raise HTTPException(status_code=400, detail="文件名不能为空")
content = await file.read()
if not content:
raise HTTPException(status_code=400, detail="文件内容为空")
try:
doc = await upload_document(
db=db,
kb_id=kb_id,
filename=file.filename,
file_content=content,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
if doc.status == "failed":
# 返回成功但告知处理失败
return DocumentResponse(**doc.to_dict())
return DocumentResponse(**doc.to_dict())
@router.get("/{kb_id}/documents", response_model=List[DocumentResponse])
async def api_list_documents(
kb_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""列出知识库中的文档。"""
docs = list_documents(db, kb_id)
return [DocumentResponse(**d.to_dict()) for d in docs]
@router.delete("/{kb_id}/documents/{doc_id}")
async def api_delete_document(
kb_id: str,
doc_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除文档。"""
ok = delete_document(db, doc_id)
if not ok:
raise HTTPException(status_code=404, detail="文档不存在")
return {"message": "文档已删除"}
# ─── 搜索 & RAG ────────────────────────────────────────────────
@router.post("/{kb_id}/search", response_model=SearchResponse)
async def api_search(
kb_id: str,
req: SearchRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""语义搜索知识库。"""
results = await search(
db=db,
kb_id=kb_id,
query=req.query,
top_k=req.top_k,
min_score=req.min_score,
)
return SearchResponse(results=[SearchResult(**r) for r in results])
@router.post("/{kb_id}/rag", response_model=RAGResponse)
async def api_rag(
kb_id: str,
req: RAGRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""RAG 查询:搜索相关片段并格式化为上下文。"""
result = await rag_query(
db=db,
kb_id=kb_id,
query=req.query,
top_k=req.top_k,
min_score=req.min_score,
)
return RAGResponse(**result)