Pydantic models (#28697)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-11-26 23:44:14 +09:00
committed by GitHub
parent e8ca80a61a
commit 2731b04ff9
8 changed files with 1065 additions and 802 deletions

View File

@@ -1,8 +1,10 @@
from datetime import datetime
from typing import Literal
import pytz
from flask import request
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -42,20 +44,198 @@ from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def _init_parser():
parser = reqparse.RequestParser()
if dify_config.EDITION == "CLOUD":
parser.add_argument("invitation_code", type=str, location="json")
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
"timezone", type=timezone, required=True, location="json"
)
return parser
class AccountInitPayload(BaseModel):
interface_language: str
timezone: str
invitation_code: str | None = None
@field_validator("interface_language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
@field_validator("timezone")
@classmethod
def validate_timezone(cls, value: str) -> str:
return timezone(value)
class AccountNamePayload(BaseModel):
name: str = Field(min_length=3, max_length=30)
class AccountAvatarPayload(BaseModel):
avatar: str
class AccountInterfaceLanguagePayload(BaseModel):
interface_language: str
@field_validator("interface_language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
class AccountInterfaceThemePayload(BaseModel):
interface_theme: Literal["light", "dark"]
class AccountTimezonePayload(BaseModel):
timezone: str
@field_validator("timezone")
@classmethod
def validate_timezone(cls, value: str) -> str:
return timezone(value)
class AccountPasswordPayload(BaseModel):
password: str | None = None
new_password: str
repeat_new_password: str
@model_validator(mode="after")
def check_passwords_match(self) -> "AccountPasswordPayload":
if self.new_password != self.repeat_new_password:
raise RepeatPasswordNotMatchError()
return self
class AccountDeletePayload(BaseModel):
token: str
code: str
class AccountDeletionFeedbackPayload(BaseModel):
email: str
feedback: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class EducationActivatePayload(BaseModel):
token: str
institution: str
role: str
class EducationAutocompleteQuery(BaseModel):
keywords: str
page: int = 0
limit: int = 20
class ChangeEmailSendPayload(BaseModel):
email: str
language: str | None = None
phase: str | None = None
token: str | None = None
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailValidityPayload(BaseModel):
email: str
code: str
token: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailResetPayload(BaseModel):
new_email: str
token: str
@field_validator("new_email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class CheckEmailUniquePayload(BaseModel):
email: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
console_ns.schema_model(
AccountInitPayload.__name__, AccountInitPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountNamePayload.__name__, AccountNamePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountAvatarPayload.__name__, AccountAvatarPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountInterfaceLanguagePayload.__name__,
AccountInterfaceLanguagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountInterfaceThemePayload.__name__,
AccountInterfaceThemePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountTimezonePayload.__name__,
AccountTimezonePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountPasswordPayload.__name__,
AccountPasswordPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountDeletePayload.__name__,
AccountDeletePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountDeletionFeedbackPayload.__name__,
AccountDeletionFeedbackPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
EducationActivatePayload.__name__,
EducationActivatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
EducationAutocompleteQuery.__name__,
EducationAutocompleteQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailSendPayload.__name__,
ChangeEmailSendPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailValidityPayload.__name__,
ChangeEmailValidityPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailResetPayload.__name__,
ChangeEmailResetPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
CheckEmailUniquePayload.__name__,
CheckEmailUniquePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/account/init")
class AccountInitApi(Resource):
@console_ns.expect(_init_parser())
@console_ns.expect(console_ns.models[AccountInitPayload.__name__])
@setup_required
@login_required
def post(self):
@@ -64,17 +244,18 @@ class AccountInitApi(Resource):
if account.status == "active":
raise AccountAlreadyInitedError()
args = _init_parser().parse_args()
payload = console_ns.payload or {}
args = AccountInitPayload.model_validate(payload)
if dify_config.EDITION == "CLOUD":
if not args["invitation_code"]:
if not args.invitation_code:
raise ValueError("invitation_code is required")
# check invitation code
invitation_code = (
db.session.query(InvitationCode)
.where(
InvitationCode.code == args["invitation_code"],
InvitationCode.code == args.invitation_code,
InvitationCode.status == "unused",
)
.first()
@@ -88,8 +269,8 @@ class AccountInitApi(Resource):
invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_language = args.interface_language
account.timezone = args.timezone
account.interface_theme = "light"
account.status = "active"
account.initialized_at = naive_utc_now()
@@ -110,137 +291,104 @@ class AccountProfileApi(Resource):
return current_user
parser_name = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
@console_ns.route("/account/name")
class AccountNameApi(Resource):
@console_ns.expect(parser_name)
@console_ns.expect(console_ns.models[AccountNamePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_name.parse_args()
# Validate account name length
if len(args["name"]) < 3 or len(args["name"]) > 30:
raise ValueError("Account name must be between 3 and 30 characters.")
updated_account = AccountService.update_account(current_user, name=args["name"])
payload = console_ns.payload or {}
args = AccountNamePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, name=args.name)
return updated_account
parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@console_ns.expect(parser_avatar)
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_avatar.parse_args()
payload = console_ns.payload or {}
args = AccountAvatarPayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
return updated_account
parser_interface = reqparse.RequestParser().add_argument(
"interface_language", type=supported_language, required=True, location="json"
)
@console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource):
@console_ns.expect(parser_interface)
@console_ns.expect(console_ns.models[AccountInterfaceLanguagePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_interface.parse_args()
payload = console_ns.payload or {}
args = AccountInterfaceLanguagePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
return updated_account
parser_theme = reqparse.RequestParser().add_argument(
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
)
@console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource):
@console_ns.expect(parser_theme)
@console_ns.expect(console_ns.models[AccountInterfaceThemePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_theme.parse_args()
payload = console_ns.payload or {}
args = AccountInterfaceThemePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
return updated_account
parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
@console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource):
@console_ns.expect(parser_timezone)
@console_ns.expect(console_ns.models[AccountTimezonePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_timezone.parse_args()
payload = console_ns.payload or {}
args = AccountTimezonePayload.model_validate(payload)
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
if args["timezone"] not in pytz.all_timezones:
raise ValueError("Invalid timezone string.")
updated_account = AccountService.update_account(current_user, timezone=args["timezone"])
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
return updated_account
parser_pw = (
reqparse.RequestParser()
.add_argument("password", type=str, required=False, location="json")
.add_argument("new_password", type=str, required=True, location="json")
.add_argument("repeat_new_password", type=str, required=True, location="json")
)
@console_ns.route("/account/password")
class AccountPasswordApi(Resource):
@console_ns.expect(parser_pw)
@console_ns.expect(console_ns.models[AccountPasswordPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_pw.parse_args()
if args["new_password"] != args["repeat_new_password"]:
raise RepeatPasswordNotMatchError()
payload = console_ns.payload or {}
args = AccountPasswordPayload.model_validate(payload)
try:
AccountService.update_account_password(current_user, args["password"], args["new_password"])
AccountService.update_account_password(current_user, args.password, args.new_password)
except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError()
@@ -316,25 +464,19 @@ class AccountDeleteVerifyApi(Resource):
return {"result": "success", "data": token}
parser_delete = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
)
@console_ns.route("/account/delete")
class AccountDeleteApi(Resource):
@console_ns.expect(parser_delete)
@console_ns.expect(console_ns.models[AccountDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
account, _ = current_account_with_tenant()
args = parser_delete.parse_args()
payload = console_ns.payload or {}
args = AccountDeletePayload.model_validate(payload)
if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
if not AccountService.verify_account_deletion_code(args.token, args.code):
raise InvalidAccountDeletionCodeError()
AccountService.delete_account(account)
@@ -342,21 +484,15 @@ class AccountDeleteApi(Resource):
return {"result": "success"}
parser_feedback = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("feedback", type=str, required=True, location="json")
)
@console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource):
@console_ns.expect(parser_feedback)
@console_ns.expect(console_ns.models[AccountDeletionFeedbackPayload.__name__])
@setup_required
def post(self):
args = parser_feedback.parse_args()
payload = console_ns.payload or {}
args = AccountDeletionFeedbackPayload.model_validate(payload)
BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
BillingService.update_account_deletion_feedback(args.email, args.feedback)
return {"result": "success"}
@@ -379,14 +515,6 @@ class EducationVerifyApi(Resource):
return BillingService.EducationIdentity.verify(account.id, account.email)
parser_edu = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("institution", type=str, required=True, location="json")
.add_argument("role", type=str, required=True, location="json")
)
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
@@ -396,7 +524,7 @@ class EducationApi(Resource):
"allow_refresh": fields.Boolean,
}
@console_ns.expect(parser_edu)
@console_ns.expect(console_ns.models[EducationActivatePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -405,9 +533,10 @@ class EducationApi(Resource):
def post(self):
account, _ = current_account_with_tenant()
args = parser_edu.parse_args()
payload = console_ns.payload or {}
args = EducationActivatePayload.model_validate(payload)
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
return BillingService.EducationIdentity.activate(account, args.token, args.institution, args.role)
@setup_required
@login_required
@@ -425,14 +554,6 @@ class EducationApi(Resource):
return res
parser_autocomplete = (
reqparse.RequestParser()
.add_argument("keywords", type=str, required=True, location="args")
.add_argument("page", type=int, required=False, location="args", default=0)
.add_argument("limit", type=int, required=False, location="args", default=20)
)
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
@@ -441,7 +562,7 @@ class EducationAutoCompleteApi(Resource):
"has_next": fields.Boolean,
}
@console_ns.expect(parser_autocomplete)
@console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -449,46 +570,39 @@ class EducationAutoCompleteApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(data_fields)
def get(self):
args = parser_autocomplete.parse_args()
payload = request.args.to_dict(flat=True) # type: ignore
args = EducationAutocompleteQuery.model_validate(payload)
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
parser_change_email = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
.add_argument("phase", type=str, required=False, location="json")
.add_argument("token", type=str, required=False, location="json")
)
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
@console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource):
@console_ns.expect(parser_change_email)
@console_ns.expect(console_ns.models[ChangeEmailSendPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
args = parser_change_email.parse_args()
payload = console_ns.payload or {}
args = ChangeEmailSendPayload.model_validate(payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = None
user_email = args["email"]
if args["phase"] is not None and args["phase"] == "new_email":
if args["token"] is None:
user_email = args.email
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
reset_data = AccountService.get_change_email_data(args["token"])
reset_data = AccountService.get_change_email_data(args.token)
if reset_data is None:
raise InvalidTokenError()
user_email = reset_data.get("email", "")
@@ -497,118 +611,103 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidEmailError()
else:
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
if account is None:
raise AccountNotFound()
token = AccountService.send_change_email_email(
account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"]
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
)
return {"result": "success", "data": token}
parser_validity = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource):
@console_ns.expect(parser_validity)
@console_ns.expect(console_ns.models[ChangeEmailValidityPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
args = parser_validity.parse_args()
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
user_email = args["email"]
user_email = args.email
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"])
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
if is_change_email_error_rate_limit:
raise EmailChangeLimitError()
token_data = AccountService.get_change_email_data(args["token"])
token_data = AccountService.get_change_email_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_change_email_error_rate_limit(args["email"])
if args.code != token_data.get("code"):
AccountService.add_change_email_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_change_email_token(args["token"])
AccountService.revoke_change_email_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_change_email_token(
user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={}
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
AccountService.reset_change_email_error_rate_limit(args["email"])
AccountService.reset_change_email_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
parser_reset = (
reqparse.RequestParser()
.add_argument("new_email", type=email, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource):
@console_ns.expect(parser_reset)
@console_ns.expect(console_ns.models[ChangeEmailResetPayload.__name__])
@enable_change_email
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
args = parser_reset.parse_args()
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
if AccountService.is_account_in_freeze(args["new_email"]):
if AccountService.is_account_in_freeze(args.new_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args["new_email"]):
if not AccountService.check_email_unique(args.new_email):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args["token"])
reset_data = AccountService.get_change_email_data(args.token)
if not reset_data:
raise InvalidTokenError()
AccountService.revoke_change_email_token(args["token"])
AccountService.revoke_change_email_token(args.token)
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email != old_email:
raise AccountNotFound()
updated_account = AccountService.update_account_email(current_user, email=args["new_email"])
updated_account = AccountService.update_account_email(current_user, email=args.new_email)
AccountService.send_change_email_completed_notify_email(
email=args["new_email"],
email=args.new_email,
)
return updated_account
parser_check = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
@console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource):
@console_ns.expect(parser_check)
@console_ns.expect(console_ns.models[CheckEmailUniquePayload.__name__])
@setup_required
def post(self):
args = parser_check.parse_args()
if AccountService.is_account_in_freeze(args["email"]):
payload = console_ns.payload or {}
args = CheckEmailUniquePayload.model_validate(payload)
if AccountService.is_account_in_freeze(args.email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args["email"]):
if not AccountService.check_email_unique(args.email):
raise EmailAlreadyInUseError()
return {"result": "success"}