第一次提交
This commit is contained in:
116
backend/app/api/auth.py
Normal file
116
backend/app/api/auth.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
认证相关API
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, EmailStr
|
||||
import logging
|
||||
from app.core.database import get_db
|
||||
from app.core.security import verify_password, get_password_hash, create_access_token
|
||||
from app.models.user import User
|
||||
from datetime import timedelta
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import ConflictError, UnauthorizedError, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1/auth",
|
||||
tags=["auth"],
|
||||
responses={
|
||||
401: {"description": "未授权"},
|
||||
400: {"description": "请求参数错误"},
|
||||
500: {"description": "服务器内部错误"}
|
||||
}
|
||||
)
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
"""用户创建模型"""
|
||||
username: str
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""用户响应模型"""
|
||||
id: str
|
||||
username: str
|
||||
email: str
|
||||
role: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
"""令牌响应模型"""
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(user_data: UserCreate, db: Session = Depends(get_db)):
|
||||
"""用户注册"""
|
||||
# 检查用户名是否已存在
|
||||
if db.query(User).filter(User.username == user_data.username).first():
|
||||
raise ConflictError("用户名已存在")
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
if db.query(User).filter(User.email == user_data.email).first():
|
||||
raise ConflictError("邮箱已存在")
|
||||
|
||||
# 创建新用户
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
user = User(
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
password_hash=hashed_password
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||
"""用户登录"""
|
||||
user = db.query(User).filter(User.username == form_data.username).first()
|
||||
|
||||
if not user or not verify_password(form_data.password, user.password_hash):
|
||||
logger.warning(f"登录失败: 用户名 {form_data.username}")
|
||||
raise UnauthorizedError("用户名或密码错误")
|
||||
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.id, "username": user.username}
|
||||
)
|
||||
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取当前用户信息"""
|
||||
from app.core.security import decode_access_token
|
||||
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
raise UnauthorizedError("无效的访问令牌")
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise UnauthorizedError("无效的访问令牌")
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None:
|
||||
raise NotFoundError("用户", user_id)
|
||||
|
||||
return user
|
||||
Reference in New Issue
Block a user