refactor: port reqparse to Pydantic model (#28949)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -1,28 +1,53 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from models import AccountStatus
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
active_check_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
|
||||
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
|
||||
)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ActivateCheckQuery(BaseModel):
|
||||
workspace_id: str | None = Field(default=None)
|
||||
email: EmailStr | None = Field(default=None)
|
||||
token: str
|
||||
|
||||
|
||||
class ActivatePayload(BaseModel):
|
||||
workspace_id: str | None = Field(default=None)
|
||||
email: EmailStr | None = Field(default=None)
|
||||
token: str
|
||||
name: str = Field(..., max_length=30)
|
||||
interface_language: str = Field(...)
|
||||
timezone: str = Field(...)
|
||||
|
||||
@field_validator("interface_language")
|
||||
@classmethod
|
||||
def validate_lang(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_tz(cls, value: str) -> str:
|
||||
return timezone(value)
|
||||
|
||||
|
||||
for model in (ActivateCheckQuery, ActivatePayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/activate/check")
|
||||
class ActivateCheckApi(Resource):
|
||||
@console_ns.doc("check_activation_token")
|
||||
@console_ns.doc(description="Check if activation token is valid")
|
||||
@console_ns.expect(active_check_parser)
|
||||
@console_ns.expect(console_ns.models[ActivateCheckQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
@@ -35,11 +60,11 @@ class ActivateCheckApi(Resource):
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
args = active_check_parser.parse_args()
|
||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workspaceId = args["workspace_id"]
|
||||
reg_email = args["email"]
|
||||
token = args["token"]
|
||||
workspaceId = args.workspace_id
|
||||
reg_email = args.email
|
||||
token = args.token
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||
if invitation:
|
||||
@@ -56,22 +81,11 @@ class ActivateCheckApi(Resource):
|
||||
return {"is_valid": False}
|
||||
|
||||
|
||||
active_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/activate")
|
||||
class ActivateApi(Resource):
|
||||
@console_ns.doc("activate_account")
|
||||
@console_ns.doc(description="Activate account with invitation token")
|
||||
@console_ns.expect(active_parser)
|
||||
@console_ns.expect(console_ns.models[ActivatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Account activated successfully",
|
||||
@@ -85,19 +99,19 @@ class ActivateApi(Resource):
|
||||
)
|
||||
@console_ns.response(400, "Already activated or invalid token")
|
||||
def post(self):
|
||||
args = active_parser.parse_args()
|
||||
args = ActivatePayload.model_validate(console_ns.payload)
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
|
||||
if invitation is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
|
||||
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
|
||||
|
||||
account = invitation["account"]
|
||||
account.name = args["name"]
|
||||
account.name = args.name
|
||||
|
||||
account.interface_language = args["interface_language"]
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_language = args.interface_language
|
||||
account.timezone = args.timezone
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
@@ -1,12 +1,26 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
from .. import console_ns
|
||||
from ..auth.error import ApiKeyAuthFailedError
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ApiKeyAuthBindingPayload(BaseModel):
|
||||
category: str = Field(...)
|
||||
provider: str = Field(...)
|
||||
credentials: dict = Field(...)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ApiKeyAuthBindingPayload.__name__,
|
||||
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source")
|
||||
@@ -40,19 +54,15 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
|
||||
data = payload.model_dump()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(data)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@@ -5,12 +5,11 @@ from flask import current_app, redirect, request
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -14,16 +15,45 @@ from controllers.console.auth.error import (
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
|
||||
from ..error import AccountInFreezeError, EmailSendIpLimitError
|
||||
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class EmailRegisterSendPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Email address")
|
||||
language: str | None = Field(default=None, description="Language code")
|
||||
|
||||
|
||||
class EmailRegisterValidityPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
|
||||
|
||||
class EmailRegisterResetPayload(BaseModel):
|
||||
token: str = Field(...)
|
||||
new_password: str = Field(...)
|
||||
password_confirm: str = Field(...)
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/email-register/send-email")
|
||||
class EmailRegisterSendEmailApi(Resource):
|
||||
@@ -31,27 +61,22 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
language = "en-US"
|
||||
if args["language"] in languages:
|
||||
language = args["language"]
|
||||
if args.language in languages:
|
||||
language = args.language
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
token = None
|
||||
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language)
|
||||
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@@ -61,40 +86,34 @@ class EmailRegisterCheckApi(Resource):
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args["email"]
|
||||
user_email = args.email
|
||||
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"])
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
|
||||
if is_email_register_error_rate_limit:
|
||||
raise EmailRegisterLimitError()
|
||||
|
||||
token_data = AccountService.get_email_register_data(args["token"])
|
||||
token_data = AccountService.get_email_register_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_email_register_error_rate_limit(args["email"])
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_email_register_error_rate_limit(args.email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_email_register_token(args["token"])
|
||||
AccountService.revoke_email_register_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_email_register_token(
|
||||
user_email, code=args["code"], additional_data={"phase": "register"}
|
||||
user_email, code=args.code, additional_data={"phase": "register"}
|
||||
)
|
||||
|
||||
AccountService.reset_email_register_error_rate_limit(args["email"])
|
||||
AccountService.reset_email_register_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@@ -104,20 +123,14 @@ class EmailRegisterResetApi(Resource):
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailRegisterResetPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Validate passwords match
|
||||
if args["new_password"] != args["password_confirm"]:
|
||||
if args.new_password != args.password_confirm:
|
||||
raise PasswordMismatchError()
|
||||
|
||||
# Validate token and get register data
|
||||
register_data = AccountService.get_email_register_data(args["token"])
|
||||
register_data = AccountService.get_email_register_data(args.token)
|
||||
if not register_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
@@ -125,7 +138,7 @@ class EmailRegisterResetApi(Resource):
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_email_register_token(args["token"])
|
||||
AccountService.revoke_email_register_token(args.token)
|
||||
|
||||
email = register_data.get("email", "")
|
||||
|
||||
@@ -135,7 +148,7 @@ class EmailRegisterResetApi(Resource):
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(email, args["password_confirm"])
|
||||
account = self._create_new_account(email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
|
||||
@@ -2,7 +2,8 @@ import base64
|
||||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -18,26 +19,46 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ForgotPasswordSendPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ForgotPasswordCheckPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
|
||||
|
||||
class ForgotPasswordResetPayload(BaseModel):
|
||||
token: str = Field(...)
|
||||
new_password: str = Field(...)
|
||||
password_confirm: str = Field(...)
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password")
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@console_ns.doc("send_forgot_password_email")
|
||||
@console_ns.doc(description="Send password reset email")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordEmailRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Email sent successfully",
|
||||
@@ -54,28 +75,23 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
account=account,
|
||||
email=args["email"],
|
||||
email=args.email,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
)
|
||||
@@ -87,16 +103,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@console_ns.doc("check_forgot_password_code")
|
||||
@console_ns.doc(description="Verify password reset code")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordCheckRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
"code": fields.String(required=True, description="Verification code"),
|
||||
"token": fields.String(required=True, description="Reset token"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Code verified successfully",
|
||||
@@ -113,40 +120,34 @@ class ForgotPasswordCheckApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args["email"]
|
||||
user_email = args.email
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
token_data = AccountService.get_reset_password_data(args["token"])
|
||||
token_data = AccountService.get_reset_password_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args.email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
AccountService.revoke_reset_password_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=args["code"], additional_data={"phase": "reset"}
|
||||
user_email, code=args.code, additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
||||
AccountService.reset_forgot_password_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@@ -154,16 +155,7 @@ class ForgotPasswordCheckApi(Resource):
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@console_ns.doc("reset_password")
|
||||
@console_ns.doc(description="Reset password with verification token")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordResetRequest",
|
||||
{
|
||||
"token": fields.String(required=True, description="Verification token"),
|
||||
"new_password": fields.String(required=True, description="New password"),
|
||||
"password_confirm": fields.String(required=True, description="Password confirmation"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Password reset successfully",
|
||||
@@ -173,20 +165,14 @@ class ForgotPasswordResetApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ForgotPasswordResetPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Validate passwords match
|
||||
if args["new_password"] != args["password_confirm"]:
|
||||
if args.new_password != args.password_confirm:
|
||||
raise PasswordMismatchError()
|
||||
|
||||
# Validate token and get reset data
|
||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
||||
reset_data = AccountService.get_reset_password_data(args.token)
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
@@ -194,11 +180,11 @@ class ForgotPasswordResetApi(Resource):
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
AccountService.revoke_reset_password_token(args.token)
|
||||
|
||||
# Generate secure salt and hash password
|
||||
salt = secrets.token_bytes(16)
|
||||
password_hashed = hash_password(args["new_password"], salt)
|
||||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@@ -23,7 +24,7 @@ from controllers.console.error import (
|
||||
)
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
@@ -40,6 +41,36 @@ from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class LoginPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Email address")
|
||||
password: str = Field(..., description="Password")
|
||||
remember_me: bool = Field(default=False, description="Remember me flag")
|
||||
invite_token: str | None = Field(default=None, description="Invitation token")
|
||||
|
||||
|
||||
class EmailPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
class EmailCodeLoginPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(LoginPayload)
|
||||
reg(EmailPayload)
|
||||
reg(EmailCodeLoginPayload)
|
||||
|
||||
|
||||
@console_ns.route("/login")
|
||||
class LoginApi(Resource):
|
||||
@@ -47,41 +78,36 @@ class LoginApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@console_ns.expect(console_ns.models[LoginPayload.__name__])
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("password", type=str, required=True, location="json")
|
||||
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = LoginPayload.model_validate(console_ns.payload)
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
|
||||
if is_login_error_rate_limit:
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invitation = args["invite_token"]
|
||||
# TODO: why invitation is re-assigned with different type?
|
||||
invitation = args.invite_token # type: ignore
|
||||
if invitation:
|
||||
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
|
||||
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
|
||||
|
||||
try:
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
data = invitation.get("data", {}) # type: ignore
|
||||
invitee_email = data.get("email") if data else None
|
||||
if invitee_email != args["email"]:
|
||||
if invitee_email != args.email:
|
||||
raise InvalidEmailError()
|
||||
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
|
||||
account = AccountService.authenticate(args.email, args.password, args.invite_token)
|
||||
else:
|
||||
account = AccountService.authenticate(args["email"], args["password"])
|
||||
account = AccountService.authenticate(args.email, args.password)
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
AccountService.add_login_error_rate_limit(args["email"])
|
||||
AccountService.add_login_error_rate_limit(args.email)
|
||||
raise AuthenticationFailedError()
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@@ -97,7 +123,7 @@ class LoginApi(Resource):
|
||||
}
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
@@ -134,25 +160,21 @@ class LogoutApi(Resource):
|
||||
class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
email=args["email"],
|
||||
email=args.email,
|
||||
account=account,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
@@ -164,30 +186,26 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
@console_ns.route("/email-code-login")
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
|
||||
token = AccountService.send_email_code_login_email(email=args.email, language=language)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
@@ -199,30 +217,24 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
@console_ns.route("/email-code-login/validity")
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args["email"]
|
||||
language = args["language"]
|
||||
user_email = args.email
|
||||
language = args.language
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args["token"])
|
||||
token_data = AccountService.get_email_code_login_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args["email"]:
|
||||
if token_data["email"] != args.email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args["code"]:
|
||||
if token_data["code"] != args.code:
|
||||
raise EmailCodeError()
|
||||
|
||||
AccountService.revoke_email_code_login_token(args["token"])
|
||||
AccountService.revoke_email_code_login_token(args.token)
|
||||
try:
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
except AccountRegisterError:
|
||||
@@ -255,7 +267,7 @@ class EmailCodeLoginApi(Resource):
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
@@ -3,7 +3,8 @@ from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
@@ -20,15 +21,34 @@ R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class OAuthClientPayload(BaseModel):
|
||||
client_id: str
|
||||
|
||||
|
||||
class OAuthProviderRequest(BaseModel):
|
||||
client_id: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
class OAuthTokenRequest(BaseModel):
|
||||
client_id: str
|
||||
grant_type: str
|
||||
code: str | None = None
|
||||
client_secret: str | None = None
|
||||
redirect_uri: str | None = None
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
client_id = parsed_args.get("client_id")
|
||||
if not client_id:
|
||||
json_data = request.get_json()
|
||||
if json_data is None:
|
||||
raise BadRequest("client_id is required")
|
||||
|
||||
payload = OAuthClientPayload.model_validate(json_data)
|
||||
client_id = payload.client_id
|
||||
|
||||
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
|
||||
if not oauth_provider_app:
|
||||
raise NotFound("client_id is invalid")
|
||||
@@ -89,9 +109,8 @@ class OAuthServerAppApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
redirect_uri = parsed_args.get("redirect_uri")
|
||||
payload = OAuthProviderRequest.model_validate(request.get_json())
|
||||
redirect_uri = payload.redirect_uri
|
||||
|
||||
# check if redirect_uri is valid
|
||||
if redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
@@ -130,33 +149,25 @@ class OAuthServerUserTokenApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("grant_type", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=False, location="json")
|
||||
.add_argument("client_secret", type=str, required=False, location="json")
|
||||
.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
)
|
||||
parsed_args = parser.parse_args()
|
||||
payload = OAuthTokenRequest.model_validate(request.get_json())
|
||||
|
||||
try:
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
grant_type = OAuthGrantType(payload.grant_type)
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not parsed_args["code"]:
|
||||
if not payload.code:
|
||||
raise BadRequest("code is required")
|
||||
|
||||
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
|
||||
if payload.client_secret != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
|
||||
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
|
||||
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
|
||||
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
@@ -167,11 +178,11 @@ class OAuthServerUserTokenApi(Resource):
|
||||
}
|
||||
)
|
||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||
if not parsed_args["refresh_token"]:
|
||||
if not payload.refresh_token:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
|
||||
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user