Files
aiagent/backend/app/api/auth.py

443 lines
14 KiB
Python
Raw Normal View History

2026-01-19 00:09:36 +08:00
"""
认证相关API
"""
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
2026-03-06 22:31:41 +08:00
from pydantic import BaseModel, field_validator
import re
import secrets
2026-01-19 00:09:36 +08:00
import logging
from app.core.database import get_db
from app.core.security import verify_password, get_password_hash, create_access_token
from app.models.user import User
from datetime import datetime, timedelta
2026-01-19 00:09:36 +08:00
from app.core.config import settings
from app.core.exceptions import ConflictError, UnauthorizedError, NotFoundError
from app.core.redis_client import get_redis_client
2026-01-19 00:09:36 +08:00
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/api/v1/auth",
tags=["auth"],
responses={
401: {"description": "未授权"},
400: {"description": "请求参数错误"},
500: {"description": "服务器内部错误"}
}
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
oauth2_scheme_optional = OAuth2PasswordBearer(
tokenUrl="/api/v1/auth/login", auto_error=False
)
2026-01-19 00:09:36 +08:00
class UserCreate(BaseModel):
"""用户创建模型"""
username: str
2026-03-06 22:31:41 +08:00
email: str
2026-01-19 00:09:36 +08:00
password: str
2026-03-06 22:31:41 +08:00
@field_validator("email")
@classmethod
def email_format(cls, v: str) -> str:
if not v or not re.match(r"^[^@]+@[^@]+\.[^@]+$", v):
raise ValueError("邮箱格式无效")
return v.lower()
2026-01-19 00:09:36 +08:00
class UserResponse(BaseModel):
"""用户响应模型"""
id: str
username: str
email: str
role: str
class Config:
from_attributes = True
class MeResponse(BaseModel):
"""当前用户完整信息(含工作区列表)"""
id: str
username: str
email: str
role: str
workspaces: list = []
current_workspace_id: str | None = None
2026-01-19 00:09:36 +08:00
class Token(BaseModel):
"""令牌响应模型"""
access_token: str
token_type: str = "bearer"
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(user_data: UserCreate, db: Session = Depends(get_db)):
"""用户注册"""
# 检查用户名是否已存在
if db.query(User).filter(User.username == user_data.username).first():
raise ConflictError("用户名已存在")
# 检查邮箱是否已存在
if db.query(User).filter(User.email == user_data.email).first():
raise ConflictError("邮箱已存在")
# 创建新用户
hashed_password = get_password_hash(user_data.password)
user = User(
username=user_data.username,
email=user_data.email,
password_hash=hashed_password
)
db.add(user)
db.commit()
db.refresh(user)
return user
def _get_user_default_workspace_id(db: Session, user: User) -> str | None:
"""获取用户的默认工作区 ID。优先使用默认工作区其次第一个 membership。"""
from app.models.workspace import Workspace, WorkspaceMembership
# 优先使用系统默认工作区
default_ws = db.query(Workspace).filter(Workspace.is_default == 1, Workspace.status == "active").first()
if default_ws:
membership = (
db.query(WorkspaceMembership)
.filter(
WorkspaceMembership.workspace_id == default_ws.id,
WorkspaceMembership.user_id == user.id,
)
.first()
)
if membership:
return default_ws.id
# 没有默认工作区,使用第一个 membership
first_membership = (
db.query(WorkspaceMembership)
.filter(WorkspaceMembership.user_id == user.id)
.first()
)
if first_membership:
return first_membership.workspace_id
return None
2026-01-19 00:09:36 +08:00
@router.post("/login", response_model=Token)
async def login(
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db),
client_type: str = "web"
):
"""用户登录。client_type=android/ios 时签发 7 天 tokenweb 默认 30 分钟。"""
2026-01-19 00:09:36 +08:00
user = db.query(User).filter(User.username == form_data.username).first()
2026-01-19 00:09:36 +08:00
if not user or not verify_password(form_data.password, user.password_hash):
logger.warning(f"登录失败: 用户名 {form_data.username}")
raise UnauthorizedError("用户名或密码错误")
from datetime import timedelta
if client_type in ("android", "ios"):
expires = timedelta(minutes=settings.JWT_MOBILE_TOKEN_EXPIRE_MINUTES)
else:
expires = timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
ws_id = _get_user_default_workspace_id(db, user)
2026-01-19 00:09:36 +08:00
access_token = create_access_token(
data={"sub": user.id, "username": user.username, "ws": ws_id or ""},
expires_delta=expires,
2026-01-19 00:09:36 +08:00
)
2026-01-19 00:09:36 +08:00
return {"access_token": access_token, "token_type": "bearer"}
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db)
) -> User:
"""FastAPI 依赖 — 从 JWT 提取当前用户,返回 User 模型。"""
from app.core.security import decode_access_token
payload = decode_access_token(token)
if payload is None:
raise UnauthorizedError("无效的访问令牌")
user_id = payload.get("sub")
if user_id is None:
raise UnauthorizedError("无效的访问令牌")
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise NotFoundError("用户", user_id)
return user
@router.get("/me", response_model=MeResponse)
async def get_me(
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db)
2026-01-19 00:09:36 +08:00
):
"""获取当前用户信息(含工作区列表)。"""
2026-01-19 00:09:36 +08:00
from app.core.security import decode_access_token
from app.services.workspace_service import get_user_workspaces
2026-01-19 00:09:36 +08:00
payload = decode_access_token(token)
if payload is None:
raise UnauthorizedError("无效的访问令牌")
2026-01-19 00:09:36 +08:00
user_id = payload.get("sub")
if user_id is None:
raise UnauthorizedError("无效的访问令牌")
2026-01-19 00:09:36 +08:00
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise NotFoundError("用户", user_id)
workspaces = get_user_workspaces(db, user)
current_ws_id = payload.get("ws", "")
return {
"id": user.id,
"username": user.username,
"email": user.email,
"role": user.role,
"workspaces": workspaces,
"current_workspace_id": current_ws_id if current_ws_id else None,
}
@router.post("/switch-workspace/{workspace_id}")
async def switch_workspace(
workspace_id: str,
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db),
):
"""切换当前工作区,重新签发 JWT包含新的 ws 字段)。"""
from app.core.security import decode_access_token
from app.services.workspace_service import check_workspace_access
payload = decode_access_token(token)
if payload is None:
raise UnauthorizedError("无效的访问令牌")
user_id = payload.get("sub")
if user_id is None:
raise UnauthorizedError("无效的访问令牌")
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise NotFoundError("用户", user_id)
if not check_workspace_access(db, user, workspace_id):
raise HTTPException(status_code=403, detail="无权访问此工作区")
from datetime import timedelta
expires = timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
new_token = create_access_token(
data={"sub": user.id, "username": user.username, "ws": workspace_id},
expires_delta=expires,
)
return {"access_token": new_token, "token_type": "bearer", "workspace_id": workspace_id}
# ─── 密码重置 ───────────────────────────────────────────────
RESET_CODE_TTL_SEC = 600 # 验证码 10 分钟有效
RESET_RATE_LIMIT_SEC = 60 # 同一邮箱 60 秒内只能发一次
class ForgotPasswordRequest(BaseModel):
email: str
@field_validator("email")
@classmethod
def email_format(cls, v: str) -> str:
if not v or not re.match(r"^[^@]+@[^@]+\.[^@]+$", v):
raise ValueError("邮箱格式无效")
return v.lower()
class ResetPasswordRequest(BaseModel):
email: str
code: str
new_password: str
@field_validator("email")
@classmethod
def email_format(cls, v: str) -> str:
if not v or not re.match(r"^[^@]+@[^@]+\.[^@]+$", v):
raise ValueError("邮箱格式无效")
return v.lower()
@field_validator("new_password")
@classmethod
def password_length(cls, v: str) -> str:
if len(v) < 6:
raise ValueError("密码不少于 6 个字符")
if len(v) > 32:
raise ValueError("密码不超过 32 个字符")
return v
async def _send_reset_email(email: str, code: str) -> bool:
"""发送密码重置邮件。SMTP 不可用时记日志。"""
try:
import aiosmtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
smtp_host = getattr(settings, 'SMTP_HOST', '') or 'smtp.qq.com'
smtp_port = int(getattr(settings, 'SMTP_PORT', 0) or 587)
smtp_user = getattr(settings, 'SMTP_USER', '') or ''
smtp_password = getattr(settings, 'SMTP_PASSWORD', '') or ''
if not smtp_user or not smtp_password:
logger.warning("SMTP 未配置,无法发送邮件。重置码: %s", code)
return False
msg = MIMEMultipart()
msg['From'] = smtp_user
msg['To'] = email
msg['Subject'] = '天工智能体 - 密码重置验证码'
msg.attach(MIMEText(
f'您的密码重置验证码是:<b>{code}</b><br><br>'
f'验证码 10 分钟内有效。如非本人操作请忽略此邮件。',
'html', 'utf-8'
))
await aiosmtplib.send(
msg, hostname=smtp_host, port=smtp_port,
username=smtp_user, password=smtp_password,
use_tls=smtp_port == 587,
)
logger.info("密码重置邮件已发送至 %s", email)
return True
except Exception as e:
logger.warning("邮件发送失败: %s,重置码: %s", e, code)
return False
@router.post("/forgot-password")
async def forgot_password(body: ForgotPasswordRequest, db: Session = Depends(get_db)):
"""发送密码重置验证码。"""
user = db.query(User).filter(User.email == body.email).first()
if not user:
# 不泄露邮箱是否注册,统一返回成功
return {"message": "如果邮箱已注册,验证码已发送"}
redis = get_redis_client()
# 频率限制
rate_key = f"pwd_reset_rate:{body.email}"
if redis:
if redis.exists(rate_key):
ttl = redis.ttl(rate_key)
raise HTTPException(
status_code=429,
detail=f"操作过于频繁,请 {ttl} 秒后重试"
)
code = secrets.randbelow(900000) + 100000 # 6 位数字
code_str = str(code)
# 存储到 Redis
code_key = f"pwd_reset_code:{body.email}"
if redis:
redis.setex(code_key, RESET_CODE_TTL_SEC, code_str)
redis.setex(rate_key, RESET_RATE_LIMIT_SEC, "1")
else:
# 无 Redis 时用内存存储(重启失效)
if not hasattr(forgot_password, '_memory_store'):
forgot_password._memory_store = {}
forgot_password._memory_rate = {}
forgot_password._memory_store[body.email] = {
"code": code_str,
"expires_at": datetime.utcnow() + timedelta(seconds=RESET_CODE_TTL_SEC),
}
forgot_password._memory_rate[body.email] = \
datetime.utcnow() + timedelta(seconds=RESET_RATE_LIMIT_SEC)
# 尝试发送邮件
sent = await _send_reset_email(body.email, code_str)
if not sent:
# SMTP 未配置时记录验证码并返回(开发/测试环境)
logger.info("开发模式:%s 的密码重置验证码为 %s", body.email, code_str)
return {
"message": "验证码已生成",
"dev_code": code_str,
}
return {"message": "验证码已发送至邮箱"}
@router.post("/reset-password")
async def reset_password(body: ResetPasswordRequest, db: Session = Depends(get_db)):
"""使用验证码重置密码。"""
user = db.query(User).filter(User.email == body.email).first()
if not user:
raise HTTPException(status_code=400, detail="邮箱未注册")
redis = get_redis_client()
code_key = f"pwd_reset_code:{body.email}"
stored_code = None
if redis:
stored_code = redis.get(code_key)
elif hasattr(forgot_password, '_memory_store'):
entry = forgot_password._memory_store.get(body.email, {})
if entry and entry.get("expires_at", datetime.min) > datetime.utcnow():
stored_code = entry.get("code")
if not stored_code:
raise HTTPException(status_code=400, detail="验证码已过期或未请求")
if stored_code != body.code.strip():
raise HTTPException(status_code=400, detail="验证码错误")
# 更新密码
user.password_hash = get_password_hash(body.new_password)
db.commit()
# 清除验证码
if redis:
redis.delete(code_key)
elif hasattr(forgot_password, '_memory_store'):
forgot_password._memory_store.pop(body.email, None)
logger.info("用户 %s 密码重置成功", user.username)
return {"message": "密码重置成功,请使用新密码登录"}
async def get_optional_user(
token: str | None = Depends(oauth2_scheme_optional),
db: Session = Depends(get_db)
) -> User | None:
"""获取当前用户(可选登录)。未提供 token 或 token 无效时返回 None。"""
if not token:
return None
from app.core.security import decode_access_token
payload = decode_access_token(token)
if payload is None:
return None
user_id = payload.get("sub")
if user_id is None:
return None
user = db.query(User).filter(User.id == user_id).first()
2026-01-19 00:09:36 +08:00
return user