""" 认证相关API """ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from sqlalchemy.orm import Session from pydantic import BaseModel, field_validator import re import secrets import logging from app.core.database import get_db from app.core.security import verify_password, get_password_hash, create_access_token from app.models.user import User from datetime import datetime, timedelta from app.core.config import settings from app.core.exceptions import ConflictError, UnauthorizedError, NotFoundError from app.core.redis_client import get_redis_client logger = logging.getLogger(__name__) router = APIRouter( prefix="/api/v1/auth", tags=["auth"], responses={ 401: {"description": "未授权"}, 400: {"description": "请求参数错误"}, 500: {"description": "服务器内部错误"} } ) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") oauth2_scheme_optional = OAuth2PasswordBearer( tokenUrl="/api/v1/auth/login", auto_error=False ) class UserCreate(BaseModel): """用户创建模型""" username: str email: str password: str @field_validator("email") @classmethod def email_format(cls, v: str) -> str: if not v or not re.match(r"^[^@]+@[^@]+\.[^@]+$", v): raise ValueError("邮箱格式无效") return v.lower() class UserResponse(BaseModel): """用户响应模型""" id: str username: str email: str role: str class Config: from_attributes = True class MeResponse(BaseModel): """当前用户完整信息(含工作区列表)""" id: str username: str email: str role: str workspaces: list = [] current_workspace_id: str | None = None class Token(BaseModel): """令牌响应模型""" access_token: str token_type: str = "bearer" @router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) async def register(user_data: UserCreate, db: Session = Depends(get_db)): """用户注册""" # 检查用户名是否已存在 if db.query(User).filter(User.username == user_data.username).first(): raise ConflictError("用户名已存在") # 检查邮箱是否已存在 if db.query(User).filter(User.email == user_data.email).first(): raise ConflictError("邮箱已存在") # 创建新用户 hashed_password = get_password_hash(user_data.password) user = User( username=user_data.username, email=user_data.email, password_hash=hashed_password ) db.add(user) db.commit() db.refresh(user) return user def _get_user_default_workspace_id(db: Session, user: User) -> str | None: """获取用户的默认工作区 ID。优先使用默认工作区,其次第一个 membership。""" from app.models.workspace import Workspace, WorkspaceMembership # 优先使用系统默认工作区 default_ws = db.query(Workspace).filter(Workspace.is_default == 1, Workspace.status == "active").first() if default_ws: membership = ( db.query(WorkspaceMembership) .filter( WorkspaceMembership.workspace_id == default_ws.id, WorkspaceMembership.user_id == user.id, ) .first() ) if membership: return default_ws.id # 没有默认工作区,使用第一个 membership first_membership = ( db.query(WorkspaceMembership) .filter(WorkspaceMembership.user_id == user.id) .first() ) if first_membership: return first_membership.workspace_id return None @router.post("/login", response_model=Token) async def login( form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db), client_type: str = "web" ): """用户登录。client_type=android/ios 时签发 7 天 token,web 默认 30 分钟。""" user = db.query(User).filter(User.username == form_data.username).first() if not user or not verify_password(form_data.password, user.password_hash): logger.warning(f"登录失败: 用户名 {form_data.username}") raise UnauthorizedError("用户名或密码错误") from datetime import timedelta if client_type in ("android", "ios"): expires = timedelta(minutes=settings.JWT_MOBILE_TOKEN_EXPIRE_MINUTES) else: expires = timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) ws_id = _get_user_default_workspace_id(db, user) access_token = create_access_token( data={"sub": user.id, "username": user.username, "ws": ws_id or ""}, expires_delta=expires, ) return {"access_token": access_token, "token_type": "bearer"} async def get_current_user( token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) ) -> User: """FastAPI 依赖 — 从 JWT 提取当前用户,返回 User 模型。""" from app.core.security import decode_access_token payload = decode_access_token(token) if payload is None: raise UnauthorizedError("无效的访问令牌") user_id = payload.get("sub") if user_id is None: raise UnauthorizedError("无效的访问令牌") user = db.query(User).filter(User.id == user_id).first() if user is None: raise NotFoundError("用户", user_id) return user @router.get("/me", response_model=MeResponse) async def get_me( token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) ): """获取当前用户信息(含工作区列表)。""" from app.core.security import decode_access_token from app.services.workspace_service import get_user_workspaces payload = decode_access_token(token) if payload is None: raise UnauthorizedError("无效的访问令牌") user_id = payload.get("sub") if user_id is None: raise UnauthorizedError("无效的访问令牌") user = db.query(User).filter(User.id == user_id).first() if user is None: raise NotFoundError("用户", user_id) workspaces = get_user_workspaces(db, user) current_ws_id = payload.get("ws", "") return { "id": user.id, "username": user.username, "email": user.email, "role": user.role, "workspaces": workspaces, "current_workspace_id": current_ws_id if current_ws_id else None, } @router.post("/switch-workspace/{workspace_id}") async def switch_workspace( workspace_id: str, token: str = Depends(oauth2_scheme), db: Session = Depends(get_db), ): """切换当前工作区,重新签发 JWT(包含新的 ws 字段)。""" from app.core.security import decode_access_token from app.services.workspace_service import check_workspace_access payload = decode_access_token(token) if payload is None: raise UnauthorizedError("无效的访问令牌") user_id = payload.get("sub") if user_id is None: raise UnauthorizedError("无效的访问令牌") user = db.query(User).filter(User.id == user_id).first() if user is None: raise NotFoundError("用户", user_id) if not check_workspace_access(db, user, workspace_id): raise HTTPException(status_code=403, detail="无权访问此工作区") from datetime import timedelta expires = timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) new_token = create_access_token( data={"sub": user.id, "username": user.username, "ws": workspace_id}, expires_delta=expires, ) return {"access_token": new_token, "token_type": "bearer", "workspace_id": workspace_id} # ─── 密码重置 ─────────────────────────────────────────────── RESET_CODE_TTL_SEC = 600 # 验证码 10 分钟有效 RESET_RATE_LIMIT_SEC = 60 # 同一邮箱 60 秒内只能发一次 class ForgotPasswordRequest(BaseModel): email: str @field_validator("email") @classmethod def email_format(cls, v: str) -> str: if not v or not re.match(r"^[^@]+@[^@]+\.[^@]+$", v): raise ValueError("邮箱格式无效") return v.lower() class ResetPasswordRequest(BaseModel): email: str code: str new_password: str @field_validator("email") @classmethod def email_format(cls, v: str) -> str: if not v or not re.match(r"^[^@]+@[^@]+\.[^@]+$", v): raise ValueError("邮箱格式无效") return v.lower() @field_validator("new_password") @classmethod def password_length(cls, v: str) -> str: if len(v) < 6: raise ValueError("密码不少于 6 个字符") if len(v) > 32: raise ValueError("密码不超过 32 个字符") return v async def _send_reset_email(email: str, code: str) -> bool: """发送密码重置邮件。SMTP 不可用时记日志。""" try: import aiosmtplib from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart smtp_host = getattr(settings, 'SMTP_HOST', '') or 'smtp.qq.com' smtp_port = int(getattr(settings, 'SMTP_PORT', 0) or 587) smtp_user = getattr(settings, 'SMTP_USER', '') or '' smtp_password = getattr(settings, 'SMTP_PASSWORD', '') or '' if not smtp_user or not smtp_password: logger.warning("SMTP 未配置,无法发送邮件。重置码: %s", code) return False msg = MIMEMultipart() msg['From'] = smtp_user msg['To'] = email msg['Subject'] = '天工智能体 - 密码重置验证码' msg.attach(MIMEText( f'您的密码重置验证码是:{code}

' f'验证码 10 分钟内有效。如非本人操作请忽略此邮件。', 'html', 'utf-8' )) await aiosmtplib.send( msg, hostname=smtp_host, port=smtp_port, username=smtp_user, password=smtp_password, use_tls=smtp_port == 587, ) logger.info("密码重置邮件已发送至 %s", email) return True except Exception as e: logger.warning("邮件发送失败: %s,重置码: %s", e, code) return False @router.post("/forgot-password") async def forgot_password(body: ForgotPasswordRequest, db: Session = Depends(get_db)): """发送密码重置验证码。""" user = db.query(User).filter(User.email == body.email).first() if not user: # 不泄露邮箱是否注册,统一返回成功 return {"message": "如果邮箱已注册,验证码已发送"} redis = get_redis_client() # 频率限制 rate_key = f"pwd_reset_rate:{body.email}" if redis: if redis.exists(rate_key): ttl = redis.ttl(rate_key) raise HTTPException( status_code=429, detail=f"操作过于频繁,请 {ttl} 秒后重试" ) code = secrets.randbelow(900000) + 100000 # 6 位数字 code_str = str(code) # 存储到 Redis code_key = f"pwd_reset_code:{body.email}" if redis: redis.setex(code_key, RESET_CODE_TTL_SEC, code_str) redis.setex(rate_key, RESET_RATE_LIMIT_SEC, "1") else: # 无 Redis 时用内存存储(重启失效) if not hasattr(forgot_password, '_memory_store'): forgot_password._memory_store = {} forgot_password._memory_rate = {} forgot_password._memory_store[body.email] = { "code": code_str, "expires_at": datetime.utcnow() + timedelta(seconds=RESET_CODE_TTL_SEC), } forgot_password._memory_rate[body.email] = \ datetime.utcnow() + timedelta(seconds=RESET_RATE_LIMIT_SEC) # 尝试发送邮件 sent = await _send_reset_email(body.email, code_str) if not sent: # SMTP 未配置时记录验证码并返回(开发/测试环境) logger.info("开发模式:%s 的密码重置验证码为 %s", body.email, code_str) return { "message": "验证码已生成", "dev_code": code_str, } return {"message": "验证码已发送至邮箱"} @router.post("/reset-password") async def reset_password(body: ResetPasswordRequest, db: Session = Depends(get_db)): """使用验证码重置密码。""" user = db.query(User).filter(User.email == body.email).first() if not user: raise HTTPException(status_code=400, detail="邮箱未注册") redis = get_redis_client() code_key = f"pwd_reset_code:{body.email}" stored_code = None if redis: stored_code = redis.get(code_key) elif hasattr(forgot_password, '_memory_store'): entry = forgot_password._memory_store.get(body.email, {}) if entry and entry.get("expires_at", datetime.min) > datetime.utcnow(): stored_code = entry.get("code") if not stored_code: raise HTTPException(status_code=400, detail="验证码已过期或未请求") if stored_code != body.code.strip(): raise HTTPException(status_code=400, detail="验证码错误") # 更新密码 user.password_hash = get_password_hash(body.new_password) db.commit() # 清除验证码 if redis: redis.delete(code_key) elif hasattr(forgot_password, '_memory_store'): forgot_password._memory_store.pop(body.email, None) logger.info("用户 %s 密码重置成功", user.username) return {"message": "密码重置成功,请使用新密码登录"} async def get_optional_user( token: str | None = Depends(oauth2_scheme_optional), db: Session = Depends(get_db) ) -> User | None: """获取当前用户(可选登录)。未提供 token 或 token 无效时返回 None。""" if not token: return None from app.core.security import decode_access_token payload = decode_access_token(token) if payload is None: return None user_id = payload.get("sub") if user_id is None: return None user = db.query(User).filter(User.id == user_id).first() return user