- Fix delete agent 500: clean up FK records (agent_llm_logs, permissions, schedules, executions, team_members) and unbind goals/tasks before delete - Remove hardcoded personality templates in Android, replace with dynamic system prompt generation from name + description - Set promptSectionsEnabled=false to bypass PromptComposer for personality - Add Tencent Cloud Linux deployment guide (Docker Compose) - Accumulated backend service updates, frontend UI fixes, Android app changes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
443 lines
14 KiB
Python
443 lines
14 KiB
Python
"""
|
||
认证相关API
|
||
"""
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||
from sqlalchemy.orm import Session
|
||
from pydantic import BaseModel, field_validator
|
||
import re
|
||
import secrets
|
||
import logging
|
||
from app.core.database import get_db
|
||
from app.core.security import verify_password, get_password_hash, create_access_token
|
||
from app.models.user import User
|
||
from datetime import datetime, timedelta
|
||
from app.core.config import settings
|
||
from app.core.exceptions import ConflictError, UnauthorizedError, NotFoundError
|
||
from app.core.redis_client import get_redis_client
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(
|
||
prefix="/api/v1/auth",
|
||
tags=["auth"],
|
||
responses={
|
||
401: {"description": "未授权"},
|
||
400: {"description": "请求参数错误"},
|
||
500: {"description": "服务器内部错误"}
|
||
}
|
||
)
|
||
|
||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||
oauth2_scheme_optional = OAuth2PasswordBearer(
|
||
tokenUrl="/api/v1/auth/login", auto_error=False
|
||
)
|
||
|
||
|
||
class UserCreate(BaseModel):
|
||
"""用户创建模型"""
|
||
username: str
|
||
email: str
|
||
password: str
|
||
|
||
@field_validator("email")
|
||
@classmethod
|
||
def email_format(cls, v: str) -> str:
|
||
if not v or not re.match(r"^[^@]+@[^@]+\.[^@]+$", v):
|
||
raise ValueError("邮箱格式无效")
|
||
return v.lower()
|
||
|
||
|
||
class UserResponse(BaseModel):
|
||
"""用户响应模型"""
|
||
id: str
|
||
username: str
|
||
email: str
|
||
role: str
|
||
|
||
class Config:
|
||
from_attributes = True
|
||
|
||
|
||
class MeResponse(BaseModel):
|
||
"""当前用户完整信息(含工作区列表)"""
|
||
id: str
|
||
username: str
|
||
email: str
|
||
role: str
|
||
workspaces: list = []
|
||
current_workspace_id: str | None = None
|
||
|
||
|
||
class Token(BaseModel):
|
||
"""令牌响应模型"""
|
||
access_token: str
|
||
token_type: str = "bearer"
|
||
|
||
|
||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||
async def register(user_data: UserCreate, db: Session = Depends(get_db)):
|
||
"""用户注册"""
|
||
# 检查用户名是否已存在
|
||
if db.query(User).filter(User.username == user_data.username).first():
|
||
raise ConflictError("用户名已存在")
|
||
|
||
# 检查邮箱是否已存在
|
||
if db.query(User).filter(User.email == user_data.email).first():
|
||
raise ConflictError("邮箱已存在")
|
||
|
||
# 创建新用户
|
||
hashed_password = get_password_hash(user_data.password)
|
||
user = User(
|
||
username=user_data.username,
|
||
email=user_data.email,
|
||
password_hash=hashed_password
|
||
)
|
||
db.add(user)
|
||
db.commit()
|
||
db.refresh(user)
|
||
|
||
return user
|
||
|
||
|
||
def _get_user_default_workspace_id(db: Session, user: User) -> str | None:
|
||
"""获取用户的默认工作区 ID。优先使用默认工作区,其次第一个 membership。"""
|
||
from app.models.workspace import Workspace, WorkspaceMembership
|
||
|
||
# 优先使用系统默认工作区
|
||
default_ws = db.query(Workspace).filter(Workspace.is_default == 1, Workspace.status == "active").first()
|
||
if default_ws:
|
||
membership = (
|
||
db.query(WorkspaceMembership)
|
||
.filter(
|
||
WorkspaceMembership.workspace_id == default_ws.id,
|
||
WorkspaceMembership.user_id == user.id,
|
||
)
|
||
.first()
|
||
)
|
||
if membership:
|
||
return default_ws.id
|
||
|
||
# 没有默认工作区,使用第一个 membership
|
||
first_membership = (
|
||
db.query(WorkspaceMembership)
|
||
.filter(WorkspaceMembership.user_id == user.id)
|
||
.first()
|
||
)
|
||
if first_membership:
|
||
return first_membership.workspace_id
|
||
|
||
return None
|
||
|
||
|
||
@router.post("/login", response_model=Token)
|
||
async def login(
|
||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||
db: Session = Depends(get_db),
|
||
client_type: str = "web"
|
||
):
|
||
"""用户登录。client_type=android/ios 时签发 7 天 token,web 默认 30 分钟。"""
|
||
user = db.query(User).filter(User.username == form_data.username).first()
|
||
|
||
if not user or not verify_password(form_data.password, user.password_hash):
|
||
logger.warning(f"登录失败: 用户名 {form_data.username}")
|
||
raise UnauthorizedError("用户名或密码错误")
|
||
|
||
from datetime import timedelta
|
||
|
||
if client_type in ("android", "ios"):
|
||
expires = timedelta(minutes=settings.JWT_MOBILE_TOKEN_EXPIRE_MINUTES)
|
||
else:
|
||
expires = timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
|
||
|
||
ws_id = _get_user_default_workspace_id(db, user)
|
||
|
||
access_token = create_access_token(
|
||
data={"sub": user.id, "username": user.username, "ws": ws_id or ""},
|
||
expires_delta=expires,
|
||
)
|
||
|
||
return {"access_token": access_token, "token_type": "bearer"}
|
||
|
||
|
||
async def get_current_user(
|
||
token: str = Depends(oauth2_scheme),
|
||
db: Session = Depends(get_db)
|
||
) -> User:
|
||
"""FastAPI 依赖 — 从 JWT 提取当前用户,返回 User 模型。"""
|
||
from app.core.security import decode_access_token
|
||
|
||
payload = decode_access_token(token)
|
||
if payload is None:
|
||
raise UnauthorizedError("无效的访问令牌")
|
||
|
||
user_id = payload.get("sub")
|
||
if user_id is None:
|
||
raise UnauthorizedError("无效的访问令牌")
|
||
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
if user is None:
|
||
raise NotFoundError("用户", user_id)
|
||
|
||
return user
|
||
|
||
|
||
@router.get("/me", response_model=MeResponse)
|
||
async def get_me(
|
||
token: str = Depends(oauth2_scheme),
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""获取当前用户信息(含工作区列表)。"""
|
||
from app.core.security import decode_access_token
|
||
from app.services.workspace_service import get_user_workspaces
|
||
|
||
payload = decode_access_token(token)
|
||
if payload is None:
|
||
raise UnauthorizedError("无效的访问令牌")
|
||
|
||
user_id = payload.get("sub")
|
||
if user_id is None:
|
||
raise UnauthorizedError("无效的访问令牌")
|
||
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
if user is None:
|
||
raise NotFoundError("用户", user_id)
|
||
|
||
workspaces = get_user_workspaces(db, user)
|
||
current_ws_id = payload.get("ws", "")
|
||
|
||
return {
|
||
"id": user.id,
|
||
"username": user.username,
|
||
"email": user.email,
|
||
"role": user.role,
|
||
"workspaces": workspaces,
|
||
"current_workspace_id": current_ws_id if current_ws_id else None,
|
||
}
|
||
|
||
|
||
@router.post("/switch-workspace/{workspace_id}")
|
||
async def switch_workspace(
|
||
workspace_id: str,
|
||
token: str = Depends(oauth2_scheme),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""切换当前工作区,重新签发 JWT(包含新的 ws 字段)。"""
|
||
from app.core.security import decode_access_token
|
||
from app.services.workspace_service import check_workspace_access
|
||
|
||
payload = decode_access_token(token)
|
||
if payload is None:
|
||
raise UnauthorizedError("无效的访问令牌")
|
||
|
||
user_id = payload.get("sub")
|
||
if user_id is None:
|
||
raise UnauthorizedError("无效的访问令牌")
|
||
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
if user is None:
|
||
raise NotFoundError("用户", user_id)
|
||
|
||
if not check_workspace_access(db, user, workspace_id):
|
||
raise HTTPException(status_code=403, detail="无权访问此工作区")
|
||
|
||
from datetime import timedelta
|
||
|
||
expires = timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
|
||
new_token = create_access_token(
|
||
data={"sub": user.id, "username": user.username, "ws": workspace_id},
|
||
expires_delta=expires,
|
||
)
|
||
|
||
return {"access_token": new_token, "token_type": "bearer", "workspace_id": workspace_id}
|
||
|
||
|
||
# ─── 密码重置 ───────────────────────────────────────────────
|
||
|
||
RESET_CODE_TTL_SEC = 600 # 验证码 10 分钟有效
|
||
RESET_RATE_LIMIT_SEC = 60 # 同一邮箱 60 秒内只能发一次
|
||
|
||
|
||
class ForgotPasswordRequest(BaseModel):
|
||
email: str
|
||
|
||
@field_validator("email")
|
||
@classmethod
|
||
def email_format(cls, v: str) -> str:
|
||
if not v or not re.match(r"^[^@]+@[^@]+\.[^@]+$", v):
|
||
raise ValueError("邮箱格式无效")
|
||
return v.lower()
|
||
|
||
|
||
class ResetPasswordRequest(BaseModel):
|
||
email: str
|
||
code: str
|
||
new_password: str
|
||
|
||
@field_validator("email")
|
||
@classmethod
|
||
def email_format(cls, v: str) -> str:
|
||
if not v or not re.match(r"^[^@]+@[^@]+\.[^@]+$", v):
|
||
raise ValueError("邮箱格式无效")
|
||
return v.lower()
|
||
|
||
@field_validator("new_password")
|
||
@classmethod
|
||
def password_length(cls, v: str) -> str:
|
||
if len(v) < 6:
|
||
raise ValueError("密码不少于 6 个字符")
|
||
if len(v) > 32:
|
||
raise ValueError("密码不超过 32 个字符")
|
||
return v
|
||
|
||
|
||
async def _send_reset_email(email: str, code: str) -> bool:
|
||
"""发送密码重置邮件。SMTP 不可用时记日志。"""
|
||
try:
|
||
import aiosmtplib
|
||
from email.mime.text import MIMEText
|
||
from email.mime.multipart import MIMEMultipart
|
||
|
||
smtp_host = getattr(settings, 'SMTP_HOST', '') or 'smtp.qq.com'
|
||
smtp_port = int(getattr(settings, 'SMTP_PORT', 0) or 587)
|
||
smtp_user = getattr(settings, 'SMTP_USER', '') or ''
|
||
smtp_password = getattr(settings, 'SMTP_PASSWORD', '') or ''
|
||
|
||
if not smtp_user or not smtp_password:
|
||
logger.warning("SMTP 未配置,无法发送邮件。重置码: %s", code)
|
||
return False
|
||
|
||
msg = MIMEMultipart()
|
||
msg['From'] = smtp_user
|
||
msg['To'] = email
|
||
msg['Subject'] = '天工智能体 - 密码重置验证码'
|
||
msg.attach(MIMEText(
|
||
f'您的密码重置验证码是:<b>{code}</b><br><br>'
|
||
f'验证码 10 分钟内有效。如非本人操作请忽略此邮件。',
|
||
'html', 'utf-8'
|
||
))
|
||
|
||
await aiosmtplib.send(
|
||
msg, hostname=smtp_host, port=smtp_port,
|
||
username=smtp_user, password=smtp_password,
|
||
use_tls=smtp_port == 587,
|
||
)
|
||
logger.info("密码重置邮件已发送至 %s", email)
|
||
return True
|
||
except Exception as e:
|
||
logger.warning("邮件发送失败: %s,重置码: %s", e, code)
|
||
return False
|
||
|
||
|
||
@router.post("/forgot-password")
|
||
async def forgot_password(body: ForgotPasswordRequest, db: Session = Depends(get_db)):
|
||
"""发送密码重置验证码。"""
|
||
user = db.query(User).filter(User.email == body.email).first()
|
||
if not user:
|
||
# 不泄露邮箱是否注册,统一返回成功
|
||
return {"message": "如果邮箱已注册,验证码已发送"}
|
||
|
||
redis = get_redis_client()
|
||
|
||
# 频率限制
|
||
rate_key = f"pwd_reset_rate:{body.email}"
|
||
if redis:
|
||
if redis.exists(rate_key):
|
||
ttl = redis.ttl(rate_key)
|
||
raise HTTPException(
|
||
status_code=429,
|
||
detail=f"操作过于频繁,请 {ttl} 秒后重试"
|
||
)
|
||
|
||
code = secrets.randbelow(900000) + 100000 # 6 位数字
|
||
code_str = str(code)
|
||
|
||
# 存储到 Redis
|
||
code_key = f"pwd_reset_code:{body.email}"
|
||
if redis:
|
||
redis.setex(code_key, RESET_CODE_TTL_SEC, code_str)
|
||
redis.setex(rate_key, RESET_RATE_LIMIT_SEC, "1")
|
||
else:
|
||
# 无 Redis 时用内存存储(重启失效)
|
||
if not hasattr(forgot_password, '_memory_store'):
|
||
forgot_password._memory_store = {}
|
||
forgot_password._memory_rate = {}
|
||
forgot_password._memory_store[body.email] = {
|
||
"code": code_str,
|
||
"expires_at": datetime.utcnow() + timedelta(seconds=RESET_CODE_TTL_SEC),
|
||
}
|
||
forgot_password._memory_rate[body.email] = \
|
||
datetime.utcnow() + timedelta(seconds=RESET_RATE_LIMIT_SEC)
|
||
|
||
# 尝试发送邮件
|
||
sent = await _send_reset_email(body.email, code_str)
|
||
|
||
if not sent:
|
||
# SMTP 未配置时记录验证码并返回(开发/测试环境)
|
||
logger.info("开发模式:%s 的密码重置验证码为 %s", body.email, code_str)
|
||
return {
|
||
"message": "验证码已生成",
|
||
"dev_code": code_str,
|
||
}
|
||
|
||
return {"message": "验证码已发送至邮箱"}
|
||
|
||
|
||
@router.post("/reset-password")
|
||
async def reset_password(body: ResetPasswordRequest, db: Session = Depends(get_db)):
|
||
"""使用验证码重置密码。"""
|
||
user = db.query(User).filter(User.email == body.email).first()
|
||
if not user:
|
||
raise HTTPException(status_code=400, detail="邮箱未注册")
|
||
|
||
redis = get_redis_client()
|
||
code_key = f"pwd_reset_code:{body.email}"
|
||
stored_code = None
|
||
|
||
if redis:
|
||
stored_code = redis.get(code_key)
|
||
elif hasattr(forgot_password, '_memory_store'):
|
||
entry = forgot_password._memory_store.get(body.email, {})
|
||
if entry and entry.get("expires_at", datetime.min) > datetime.utcnow():
|
||
stored_code = entry.get("code")
|
||
|
||
if not stored_code:
|
||
raise HTTPException(status_code=400, detail="验证码已过期或未请求")
|
||
|
||
if stored_code != body.code.strip():
|
||
raise HTTPException(status_code=400, detail="验证码错误")
|
||
|
||
# 更新密码
|
||
user.password_hash = get_password_hash(body.new_password)
|
||
db.commit()
|
||
|
||
# 清除验证码
|
||
if redis:
|
||
redis.delete(code_key)
|
||
elif hasattr(forgot_password, '_memory_store'):
|
||
forgot_password._memory_store.pop(body.email, None)
|
||
|
||
logger.info("用户 %s 密码重置成功", user.username)
|
||
return {"message": "密码重置成功,请使用新密码登录"}
|
||
|
||
|
||
async def get_optional_user(
|
||
token: str | None = Depends(oauth2_scheme_optional),
|
||
db: Session = Depends(get_db)
|
||
) -> User | None:
|
||
"""获取当前用户(可选登录)。未提供 token 或 token 无效时返回 None。"""
|
||
if not token:
|
||
return None
|
||
from app.core.security import decode_access_token
|
||
|
||
payload = decode_access_token(token)
|
||
if payload is None:
|
||
return None
|
||
|
||
user_id = payload.get("sub")
|
||
if user_id is None:
|
||
return None
|
||
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
return user
|