117 lines
3.3 KiB
Python
117 lines
3.3 KiB
Python
|
|
"""
|
||
|
|
认证相关API
|
||
|
|
"""
|
||
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||
|
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||
|
|
from sqlalchemy.orm import Session
|
||
|
|
from pydantic import BaseModel, EmailStr
|
||
|
|
import logging
|
||
|
|
from app.core.database import get_db
|
||
|
|
from app.core.security import verify_password, get_password_hash, create_access_token
|
||
|
|
from app.models.user import User
|
||
|
|
from datetime import timedelta
|
||
|
|
from app.core.config import settings
|
||
|
|
from app.core.exceptions import ConflictError, UnauthorizedError, NotFoundError
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
router = APIRouter(
|
||
|
|
prefix="/api/v1/auth",
|
||
|
|
tags=["auth"],
|
||
|
|
responses={
|
||
|
|
401: {"description": "未授权"},
|
||
|
|
400: {"description": "请求参数错误"},
|
||
|
|
500: {"description": "服务器内部错误"}
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||
|
|
|
||
|
|
|
||
|
|
class UserCreate(BaseModel):
|
||
|
|
"""用户创建模型"""
|
||
|
|
username: str
|
||
|
|
email: EmailStr
|
||
|
|
password: str
|
||
|
|
|
||
|
|
|
||
|
|
class UserResponse(BaseModel):
|
||
|
|
"""用户响应模型"""
|
||
|
|
id: str
|
||
|
|
username: str
|
||
|
|
email: str
|
||
|
|
role: str
|
||
|
|
|
||
|
|
class Config:
|
||
|
|
from_attributes = True
|
||
|
|
|
||
|
|
|
||
|
|
class Token(BaseModel):
|
||
|
|
"""令牌响应模型"""
|
||
|
|
access_token: str
|
||
|
|
token_type: str = "bearer"
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||
|
|
async def register(user_data: UserCreate, db: Session = Depends(get_db)):
|
||
|
|
"""用户注册"""
|
||
|
|
# 检查用户名是否已存在
|
||
|
|
if db.query(User).filter(User.username == user_data.username).first():
|
||
|
|
raise ConflictError("用户名已存在")
|
||
|
|
|
||
|
|
# 检查邮箱是否已存在
|
||
|
|
if db.query(User).filter(User.email == user_data.email).first():
|
||
|
|
raise ConflictError("邮箱已存在")
|
||
|
|
|
||
|
|
# 创建新用户
|
||
|
|
hashed_password = get_password_hash(user_data.password)
|
||
|
|
user = User(
|
||
|
|
username=user_data.username,
|
||
|
|
email=user_data.email,
|
||
|
|
password_hash=hashed_password
|
||
|
|
)
|
||
|
|
db.add(user)
|
||
|
|
db.commit()
|
||
|
|
db.refresh(user)
|
||
|
|
|
||
|
|
return user
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/login", response_model=Token)
|
||
|
|
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||
|
|
"""用户登录"""
|
||
|
|
user = db.query(User).filter(User.username == form_data.username).first()
|
||
|
|
|
||
|
|
if not user or not verify_password(form_data.password, user.password_hash):
|
||
|
|
logger.warning(f"登录失败: 用户名 {form_data.username}")
|
||
|
|
raise UnauthorizedError("用户名或密码错误")
|
||
|
|
|
||
|
|
access_token = create_access_token(
|
||
|
|
data={"sub": user.id, "username": user.username}
|
||
|
|
)
|
||
|
|
|
||
|
|
return {"access_token": access_token, "token_type": "bearer"}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/me", response_model=UserResponse)
|
||
|
|
async def get_current_user(
|
||
|
|
token: str = Depends(oauth2_scheme),
|
||
|
|
db: Session = Depends(get_db)
|
||
|
|
):
|
||
|
|
"""获取当前用户信息"""
|
||
|
|
from app.core.security import decode_access_token
|
||
|
|
|
||
|
|
payload = decode_access_token(token)
|
||
|
|
if payload is None:
|
||
|
|
raise UnauthorizedError("无效的访问令牌")
|
||
|
|
|
||
|
|
user_id = payload.get("sub")
|
||
|
|
if user_id is None:
|
||
|
|
raise UnauthorizedError("无效的访问令牌")
|
||
|
|
|
||
|
|
user = db.query(User).filter(User.id == user_id).first()
|
||
|
|
if user is None:
|
||
|
|
raise NotFoundError("用户", user_id)
|
||
|
|
|
||
|
|
return user
|