chore(api/controllers): Apply Ruff Formatter. (#7645)

This commit is contained in:
-LAN-
2024-08-26 15:29:10 +08:00
committed by GitHub
parent 7ae728a9a3
commit 13be84e4d4
104 changed files with 3849 additions and 3982 deletions

View File

@@ -26,52 +26,53 @@ from services.errors.account import CurrentPasswordIncorrectError as ServiceCurr
class AccountInitApi(Resource):
@setup_required
@login_required
def post(self):
account = current_user
if account.status == 'active':
if account.status == "active":
raise AccountAlreadyInitedError()
parser = reqparse.RequestParser()
if dify_config.EDITION == 'CLOUD':
parser.add_argument('invitation_code', type=str, location='json')
if dify_config.EDITION == "CLOUD":
parser.add_argument("invitation_code", type=str, location="json")
parser.add_argument(
'interface_language', type=supported_language, required=True, location='json')
parser.add_argument('timezone', type=timezone,
required=True, location='json')
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
parser.add_argument("timezone", type=timezone, required=True, location="json")
args = parser.parse_args()
if dify_config.EDITION == 'CLOUD':
if not args['invitation_code']:
raise ValueError('invitation_code is required')
if dify_config.EDITION == "CLOUD":
if not args["invitation_code"]:
raise ValueError("invitation_code is required")
# check invitation code
invitation_code = db.session.query(InvitationCode).filter(
InvitationCode.code == args['invitation_code'],
InvitationCode.status == 'unused',
).first()
invitation_code = (
db.session.query(InvitationCode)
.filter(
InvitationCode.code == args["invitation_code"],
InvitationCode.status == "unused",
)
.first()
)
if not invitation_code:
raise InvalidInvitationCodeError()
invitation_code.status = 'used'
invitation_code.status = "used"
invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id
account.interface_language = args['interface_language']
account.timezone = args['timezone']
account.interface_theme = 'light'
account.status = 'active'
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_theme = "light"
account.status = "active"
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
return {'result': 'success'}
return {"result": "success"}
class AccountProfileApi(Resource):
@@ -90,15 +91,14 @@ class AccountNameApi(Resource):
@marshal_with(account_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
# Validate account name length
if len(args['name']) < 3 or len(args['name']) > 30:
raise ValueError(
"Account name must be between 3 and 30 characters.")
if len(args["name"]) < 3 or len(args["name"]) > 30:
raise ValueError("Account name must be between 3 and 30 characters.")
updated_account = AccountService.update_account(current_user, name=args['name'])
updated_account = AccountService.update_account(current_user, name=args["name"])
return updated_account
@@ -110,10 +110,10 @@ class AccountAvatarApi(Resource):
@marshal_with(account_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('avatar', type=str, required=True, location='json')
parser.add_argument("avatar", type=str, required=True, location="json")
args = parser.parse_args()
updated_account = AccountService.update_account(current_user, avatar=args['avatar'])
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
return updated_account
@@ -125,11 +125,10 @@ class AccountInterfaceLanguageApi(Resource):
@marshal_with(account_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
'interface_language', type=supported_language, required=True, location='json')
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
args = parser.parse_args()
updated_account = AccountService.update_account(current_user, interface_language=args['interface_language'])
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
return updated_account
@@ -141,11 +140,10 @@ class AccountInterfaceThemeApi(Resource):
@marshal_with(account_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('interface_theme', type=str, choices=[
'light', 'dark'], required=True, location='json')
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
args = parser.parse_args()
updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme'])
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
return updated_account
@@ -157,15 +155,14 @@ class AccountTimezoneApi(Resource):
@marshal_with(account_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('timezone', type=str,
required=True, location='json')
parser.add_argument("timezone", type=str, required=True, location="json")
args = parser.parse_args()
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
if args['timezone'] not in pytz.all_timezones:
if args["timezone"] not in pytz.all_timezones:
raise ValueError("Invalid timezone string.")
updated_account = AccountService.update_account(current_user, timezone=args['timezone'])
updated_account = AccountService.update_account(current_user, timezone=args["timezone"])
return updated_account
@@ -177,20 +174,16 @@ class AccountPasswordApi(Resource):
@marshal_with(account_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('password', type=str,
required=False, location='json')
parser.add_argument('new_password', type=str,
required=True, location='json')
parser.add_argument('repeat_new_password', type=str,
required=True, location='json')
parser.add_argument("password", type=str, required=False, location="json")
parser.add_argument("new_password", type=str, required=True, location="json")
parser.add_argument("repeat_new_password", type=str, required=True, location="json")
args = parser.parse_args()
if args['new_password'] != args['repeat_new_password']:
if args["new_password"] != args["repeat_new_password"]:
raise RepeatPasswordNotMatchError()
try:
AccountService.update_account_password(
current_user, args['password'], args['new_password'])
AccountService.update_account_password(current_user, args["password"], args["new_password"])
except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError()
@@ -199,14 +192,14 @@ class AccountPasswordApi(Resource):
class AccountIntegrateApi(Resource):
integrate_fields = {
'provider': fields.String,
'created_at': TimestampField,
'is_bound': fields.Boolean,
'link': fields.String
"provider": fields.String,
"created_at": TimestampField,
"is_bound": fields.Boolean,
"link": fields.String,
}
integrate_list_fields = {
'data': fields.List(fields.Nested(integrate_fields)),
"data": fields.List(fields.Nested(integrate_fields)),
}
@setup_required
@@ -216,10 +209,9 @@ class AccountIntegrateApi(Resource):
def get(self):
account = current_user
account_integrates = db.session.query(AccountIntegrate).filter(
AccountIntegrate.account_id == account.id).all()
account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all()
base_url = request.url_root.rstrip('/')
base_url = request.url_root.rstrip("/")
oauth_base_path = "/console/api/oauth/login"
providers = ["github", "google"]
@@ -227,36 +219,38 @@ class AccountIntegrateApi(Resource):
for provider in providers:
existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None)
if existing_integrate:
integrate_data.append({
'id': existing_integrate.id,
'provider': provider,
'created_at': existing_integrate.created_at,
'is_bound': True,
'link': None
})
integrate_data.append(
{
"id": existing_integrate.id,
"provider": provider,
"created_at": existing_integrate.created_at,
"is_bound": True,
"link": None,
}
)
else:
integrate_data.append({
'id': None,
'provider': provider,
'created_at': None,
'is_bound': False,
'link': f'{base_url}{oauth_base_path}/{provider}'
})
return {'data': integrate_data}
integrate_data.append(
{
"id": None,
"provider": provider,
"created_at": None,
"is_bound": False,
"link": f"{base_url}{oauth_base_path}/{provider}",
}
)
return {"data": integrate_data}
# Register API resources
api.add_resource(AccountInitApi, '/account/init')
api.add_resource(AccountProfileApi, '/account/profile')
api.add_resource(AccountNameApi, '/account/name')
api.add_resource(AccountAvatarApi, '/account/avatar')
api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language')
api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme')
api.add_resource(AccountTimezoneApi, '/account/timezone')
api.add_resource(AccountPasswordApi, '/account/password')
api.add_resource(AccountIntegrateApi, '/account/integrates')
api.add_resource(AccountInitApi, "/account/init")
api.add_resource(AccountProfileApi, "/account/profile")
api.add_resource(AccountNameApi, "/account/name")
api.add_resource(AccountAvatarApi, "/account/avatar")
api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language")
api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
api.add_resource(AccountTimezoneApi, "/account/timezone")
api.add_resource(AccountPasswordApi, "/account/password")
api.add_resource(AccountIntegrateApi, "/account/integrates")
# api.add_resource(AccountEmailApi, '/account/email')
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')

View File

@@ -2,36 +2,36 @@ from libs.exception import BaseHTTPException
class RepeatPasswordNotMatchError(BaseHTTPException):
error_code = 'repeat_password_not_match'
error_code = "repeat_password_not_match"
description = "New password and repeat password does not match."
code = 400
class CurrentPasswordIncorrectError(BaseHTTPException):
error_code = 'current_password_incorrect'
error_code = "current_password_incorrect"
description = "Current password is incorrect."
code = 400
class ProviderRequestFailedError(BaseHTTPException):
error_code = 'provider_request_failed'
error_code = "provider_request_failed"
description = None
code = 400
class InvalidInvitationCodeError(BaseHTTPException):
error_code = 'invalid_invitation_code'
error_code = "invalid_invitation_code"
description = "Invalid invitation code."
code = 400
class AccountAlreadyInitedError(BaseHTTPException):
error_code = 'account_already_inited'
error_code = "account_already_inited"
description = "The account has been initialized. Please refresh the page."
code = 400
class AccountNotInitializedError(BaseHTTPException):
error_code = 'account_not_initialized'
error_code = "account_not_initialized"
description = "The account has not been initialized yet. Please proceed with the initialization process first."
code = 400

View File

@@ -22,10 +22,16 @@ class LoadBalancingCredentialsValidateApi(Resource):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
# validate model load balancing credentials
@@ -38,18 +44,18 @@ class LoadBalancingCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
credentials=args['credentials']
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
response = {"result": "success" if result else "error"}
if not result:
response['error'] = error
response["error"] = error
return response
@@ -65,10 +71,16 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
# validate model load balancing config credentials
@@ -81,26 +93,30 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
credentials=args['credentials'],
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
config_id=config_id,
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
response = {"result": "success" if result else "error"}
if not result:
response['error'] = error
response["error"] = error
return response
# Load Balancing Config
api.add_resource(LoadBalancingCredentialsValidateApi,
'/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate')
api.add_resource(
LoadBalancingCredentialsValidateApi,
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate",
)
api.add_resource(LoadBalancingConfigCredentialsValidateApi,
'/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate')
api.add_resource(
LoadBalancingConfigCredentialsValidateApi,
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
)

View File

@@ -23,7 +23,7 @@ class MemberListApi(Resource):
@marshal_with(account_with_role_list_fields)
def get(self):
members = TenantService.get_tenant_members(current_user.current_tenant)
return {'result': 'success', 'accounts': members}, 200
return {"result": "success", "accounts": members}, 200
class MemberInviteEmailApi(Resource):
@@ -32,48 +32,46 @@ class MemberInviteEmailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('members')
@cloud_edition_billing_resource_check("members")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('emails', type=str, required=True, location='json', action='append')
parser.add_argument('role', type=str, required=True, default='admin', location='json')
parser.add_argument('language', type=str, required=False, location='json')
parser.add_argument("emails", type=str, required=True, location="json", action="append")
parser.add_argument("role", type=str, required=True, default="admin", location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
invitee_emails = args['emails']
invitee_role = args['role']
interface_language = args['language']
invitee_emails = args["emails"]
invitee_role = args["role"]
interface_language = args["language"]
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
return {"code": "invalid-role", "message": "Invalid role"}, 400
inviter = current_user
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL
for invitee_email in invitee_emails:
try:
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter)
invitation_results.append({
'status': 'success',
'email': invitee_email,
'url': f'{console_web_url}/activate?email={invitee_email}&token={token}'
})
token = RegisterService.invite_new_member(
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
)
invitation_results.append(
{
"status": "success",
"email": invitee_email,
"url": f"{console_web_url}/activate?email={invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append({
'status': 'success',
'email': invitee_email,
'url': f'{console_web_url}/signin'
})
invitation_results.append(
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
)
break
except Exception as e:
invitation_results.append({
'status': 'failed',
'email': invitee_email,
'message': str(e)
})
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
return {
'result': 'success',
'invitation_results': invitation_results,
"result": "success",
"invitation_results": invitation_results,
}, 201
@@ -91,15 +89,15 @@ class MemberCancelInviteApi(Resource):
try:
TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user)
except services.errors.account.CannotOperateSelfError as e:
return {'code': 'cannot-operate-self', 'message': str(e)}, 400
return {"code": "cannot-operate-self", "message": str(e)}, 400
except services.errors.account.NoPermissionError as e:
return {'code': 'forbidden', 'message': str(e)}, 403
return {"code": "forbidden", "message": str(e)}, 403
except services.errors.account.MemberNotInTenantError as e:
return {'code': 'member-not-found', 'message': str(e)}, 404
return {"code": "member-not-found", "message": str(e)}, 404
except Exception as e:
raise ValueError(str(e))
return {'result': 'success'}, 204
return {"result": "success"}, 204
class MemberUpdateRoleApi(Resource):
@@ -110,12 +108,12 @@ class MemberUpdateRoleApi(Resource):
@account_initialization_required
def put(self, member_id):
parser = reqparse.RequestParser()
parser.add_argument('role', type=str, required=True, location='json')
parser.add_argument("role", type=str, required=True, location="json")
args = parser.parse_args()
new_role = args['role']
new_role = args["role"]
if not TenantAccountRole.is_valid_role(new_role):
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
return {"code": "invalid-role", "message": "Invalid role"}, 400
member = db.session.get(Account, str(member_id))
if not member:
@@ -128,7 +126,7 @@ class MemberUpdateRoleApi(Resource):
# todo: 403
return {'result': 'success'}
return {"result": "success"}
class DatasetOperatorMemberListApi(Resource):
@@ -140,11 +138,11 @@ class DatasetOperatorMemberListApi(Resource):
@marshal_with(account_with_role_list_fields)
def get(self):
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
return {'result': 'success', 'accounts': members}, 200
return {"result": "success", "accounts": members}, 200
api.add_resource(MemberListApi, '/workspaces/current/members')
api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email')
api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/<uuid:member_id>')
api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members/<uuid:member_id>/update-role')
api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators')
api.add_resource(MemberListApi, "/workspaces/current/members")
api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email")
api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/<uuid:member_id>")
api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members/<uuid:member_id>/update-role")
api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators")

View File

@@ -17,7 +17,6 @@ from services.model_provider_service import ModelProviderService
class ModelProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -25,21 +24,23 @@ class ModelProviderListApi(Resource):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=False, nullable=True,
choices=[mt.value for mt in ModelType], location='args')
parser.add_argument(
"model_type",
type=str,
required=False,
nullable=True,
choices=[mt.value for mt in ModelType],
location="args",
)
args = parser.parse_args()
model_provider_service = ModelProviderService()
provider_list = model_provider_service.get_provider_list(
tenant_id=tenant_id,
model_type=args.get('model_type')
)
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
return jsonable_encoder({"data": provider_list})
class ModelProviderCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -47,25 +48,18 @@ class ModelProviderCredentialApi(Resource):
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_provider_credentials(
tenant_id=tenant_id,
provider=provider
)
credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider)
return {
"credentials": credentials
}
return {"credentials": credentials}
class ModelProviderValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
parser = reqparse.RequestParser()
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
@@ -77,24 +71,21 @@ class ModelProviderValidateApi(Resource):
try:
model_provider_service.provider_credentials_validate(
tenant_id=tenant_id,
provider=provider,
credentials=args['credentials']
tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
response = {"result": "success" if result else "error"}
if not result:
response['error'] = error
response["error"] = error
return response
class ModelProviderApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -103,21 +94,19 @@ class ModelProviderApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()
try:
model_provider_service.save_provider_credentials(
tenant_id=current_user.current_tenant_id,
provider=provider,
credentials=args['credentials']
tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"]
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 201
return {"result": "success"}, 201
@setup_required
@login_required
@@ -127,12 +116,9 @@ class ModelProviderApi(Resource):
raise Forbidden()
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credentials(
tenant_id=current_user.current_tenant_id,
provider=provider
)
model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider)
return {'result': 'success'}, 204
return {"result": "success"}, 204
class ModelProviderIconApi(Resource):
@@ -146,16 +132,13 @@ class ModelProviderIconApi(Resource):
def get(self, provider: str, icon_type: str, lang: str):
model_provider_service = ModelProviderService()
icon, mimetype = model_provider_service.get_model_provider_icon(
provider=provider,
icon_type=icon_type,
lang=lang
provider=provider, icon_type=icon_type, lang=lang
)
return send_file(io.BytesIO(icon), mimetype=mimetype)
class PreferredProviderTypeUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -166,18 +149,22 @@ class PreferredProviderTypeUpdateApi(Resource):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
choices=['system', 'custom'], location='json')
parser.add_argument(
"preferred_provider_type",
type=str,
required=True,
nullable=False,
choices=["system", "custom"],
location="json",
)
args = parser.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.switch_preferred_provider(
tenant_id=tenant_id,
provider=provider,
preferred_provider_type=args['preferred_provider_type']
tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"]
)
return {'result': 'success'}
return {"result": "success"}
class ModelProviderPaymentCheckoutUrlApi(Resource):
@@ -185,13 +172,15 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
@login_required
@account_initialization_required
def get(self, provider: str):
if provider != 'anthropic':
raise ValueError(f'provider name {provider} is invalid')
if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid")
BillingService.is_tenant_owner_or_admin(current_user)
data = BillingService.get_model_provider_payment_link(provider_name=provider,
tenant_id=current_user.current_tenant_id,
account_id=current_user.id,
prefilled_email=current_user.email)
data = BillingService.get_model_provider_payment_link(
provider_name=provider,
tenant_id=current_user.current_tenant_id,
account_id=current_user.id,
prefilled_email=current_user.email,
)
return data
@@ -201,10 +190,7 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
@account_initialization_required
def post(self, provider: str):
model_provider_service = ModelProviderService()
result = model_provider_service.free_quota_submit(
tenant_id=current_user.current_tenant_id,
provider=provider
)
result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider)
return result
@@ -215,32 +201,36 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@account_initialization_required
def get(self, provider: str):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
parser.add_argument("token", type=str, required=False, nullable=True, location="args")
args = parser.parse_args()
model_provider_service = ModelProviderService()
result = model_provider_service.free_quota_qualification_verify(
tenant_id=current_user.current_tenant_id,
provider=provider,
token=args['token']
tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"]
)
return result
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials')
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate')
api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>')
api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/'
'<string:icon_type>/<string:lang>')
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<string:provider>/credentials")
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
api.add_resource(
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/" "<string:icon_type>/<string:lang>"
)
api.add_resource(PreferredProviderTypeUpdateApi,
'/workspaces/current/model-providers/<string:provider>/preferred-provider-type')
api.add_resource(ModelProviderPaymentCheckoutUrlApi,
'/workspaces/current/model-providers/<string:provider>/checkout-url')
api.add_resource(ModelProviderFreeQuotaSubmitApi,
'/workspaces/current/model-providers/<string:provider>/free-quota-submit')
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
'/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify')
api.add_resource(
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<string:provider>/preferred-provider-type"
)
api.add_resource(
ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<string:provider>/checkout-url"
)
api.add_resource(
ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers/<string:provider>/free-quota-submit"
)
api.add_resource(
ModelProviderFreeQuotaQualificationVerifyApi,
"/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify",
)

View File

@@ -16,27 +16,29 @@ from services.model_provider_service import ModelProviderService
class DefaultModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='args')
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id,
model_type=args['model_type']
tenant_id=tenant_id, model_type=args["model_type"]
)
return jsonable_encoder({
"data": default_model_entity
})
return jsonable_encoder({"data": default_model_entity})
@setup_required
@login_required
@@ -44,40 +46,39 @@ class DefaultModelApi(Resource):
def post(self):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
model_settings = args['model_settings']
model_settings = args["model_settings"]
for model_setting in model_settings:
if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]:
raise ValueError('invalid model type')
if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
raise ValueError("invalid model type")
if 'provider' not in model_setting:
if "provider" not in model_setting:
continue
if 'model' not in model_setting:
raise ValueError('invalid model')
if "model" not in model_setting:
raise ValueError("invalid model")
try:
model_provider_service.update_default_model_of_model_type(
tenant_id=tenant_id,
model_type=model_setting['model_type'],
provider=model_setting['provider'],
model=model_setting['model']
model_type=model_setting["model_type"],
provider=model_setting["provider"],
model=model_setting["model"],
)
except Exception:
logging.warning(f"{model_setting['model_type']} save error")
return {'result': 'success'}
return {"result": "success"}
class ModelProviderModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -85,14 +86,9 @@ class ModelProviderModelApi(Resource):
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(
tenant_id=tenant_id,
provider=provider
)
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
return jsonable_encoder({
"data": models
})
return jsonable_encoder({"data": models})
@setup_required
@login_required
@@ -104,62 +100,66 @@ class ModelProviderModelApi(Resource):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json')
parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json')
parser.add_argument('config_from', type=str, required=False, nullable=True, location='json')
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
model_load_balancing_service = ModelLoadBalancingService()
if ('load_balancing' in args and args['load_balancing'] and
'enabled' in args['load_balancing'] and args['load_balancing']['enabled']):
if 'configs' not in args['load_balancing']:
raise ValueError('invalid load balancing configs')
if (
"load_balancing" in args
and args["load_balancing"]
and "enabled" in args["load_balancing"]
and args["load_balancing"]["enabled"]
):
if "configs" not in args["load_balancing"]:
raise ValueError("invalid load balancing configs")
# save load balancing configs
model_load_balancing_service.update_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
configs=args['load_balancing']['configs']
model=args["model"],
model_type=args["model_type"],
configs=args["load_balancing"]["configs"],
)
# enable load balancing
model_load_balancing_service.enable_model_load_balancing(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
)
else:
# disable load balancing
model_load_balancing_service.disable_model_load_balancing(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
)
if args.get('config_from', '') != 'predefined-model':
if args.get("config_from", "") != "predefined-model":
model_provider_service = ModelProviderService()
try:
model_provider_service.save_model_credentials(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
credentials=args['credentials']
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
logging.exception(f"save model credentials error: {ex}")
raise ValueError(str(ex))
return {'result': 'success'}, 200
return {"result": "success"}, 200
@setup_required
@login_required
@@ -171,24 +171,26 @@ class ModelProviderModelApi(Resource):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
args = parser.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credentials(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
)
return {'result': 'success'}, 204
return {"result": "success"}, 204
class ModelProviderModelCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -196,38 +198,34 @@ class ModelProviderModelCredentialApi(Resource):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='args')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='args')
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
args = parser.parse_args()
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_model_credentials(
tenant_id=tenant_id,
provider=provider,
model_type=args['model_type'],
model=args['model']
tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"]
)
model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
)
return {
"credentials": credentials,
"load_balancing": {
"enabled": is_load_balancing_enabled,
"configs": load_balancing_configs
}
"load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs},
}
class ModelProviderModelEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -235,24 +233,26 @@ class ModelProviderModelEnableApi(Resource):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
args = parser.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.enable_model(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
)
return {'result': 'success'}
return {"result": "success"}
class ModelProviderModelDisableApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -260,24 +260,26 @@ class ModelProviderModelDisableApi(Resource):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
args = parser.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.disable_model(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
)
return {'result': 'success'}
return {"result": "success"}
class ModelProviderModelValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -285,10 +287,16 @@ class ModelProviderModelValidateApi(Resource):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()
@@ -300,48 +308,42 @@ class ModelProviderModelValidateApi(Resource):
model_provider_service.model_credentials_validate(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
credentials=args['credentials']
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
response = {"result": "success" if result else "error"}
if not result:
response['error'] = error
response["error"] = error
return response
class ModelProviderModelParameterRuleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='args')
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules(
tenant_id=tenant_id,
provider=provider,
model=args['model']
tenant_id=tenant_id, provider=provider, model=args["model"]
)
return jsonable_encoder({
"data": parameter_rules
})
return jsonable_encoder({"data": parameter_rules})
class ModelProviderAvailableModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -349,27 +351,31 @@ class ModelProviderAvailableModelApi(Resource):
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(
tenant_id=tenant_id,
model_type=model_type
)
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
return jsonable_encoder({
"data": models
})
return jsonable_encoder({"data": models})
api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers/<string:provider>/models/enable',
endpoint='model-provider-model-enable')
api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers/<string:provider>/models/disable',
endpoint='model-provider-model-disable')
api.add_resource(ModelProviderModelCredentialApi,
'/workspaces/current/model-providers/<string:provider>/models/credentials')
api.add_resource(ModelProviderModelValidateApi,
'/workspaces/current/model-providers/<string:provider>/models/credentials/validate')
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models")
api.add_resource(
ModelProviderModelEnableApi,
"/workspaces/current/model-providers/<string:provider>/models/enable",
endpoint="model-provider-model-enable",
)
api.add_resource(
ModelProviderModelDisableApi,
"/workspaces/current/model-providers/<string:provider>/models/disable",
endpoint="model-provider-model-disable",
)
api.add_resource(
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials"
)
api.add_resource(
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate"
)
api.add_resource(ModelProviderModelParameterRuleApi,
'/workspaces/current/model-providers/<string:provider>/models/parameter-rules')
api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/<string:model_type>')
api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
api.add_resource(
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules"
)
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")

View File

@@ -28,10 +28,18 @@ class ToolProviderListApi(Resource):
tenant_id = current_user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args')
req.add_argument(
"type",
type=str,
choices=["builtin", "model", "api", "workflow"],
required=False,
nullable=True,
location="args",
)
args = req.parse_args()
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None))
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
class ToolBuiltinProviderListToolsApi(Resource):
@setup_required
@@ -41,11 +49,14 @@ class ToolBuiltinProviderListToolsApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools(
user_id,
tenant_id,
provider,
))
return jsonable_encoder(
BuiltinToolManageService.list_builtin_tool_provider_tools(
user_id,
tenant_id,
provider,
)
)
class ToolBuiltinProviderDeleteApi(Resource):
@setup_required
@@ -54,7 +65,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
def post(self, provider):
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
@@ -63,7 +74,8 @@ class ToolBuiltinProviderDeleteApi(Resource):
tenant_id,
provider,
)
class ToolBuiltinProviderUpdateApi(Resource):
@setup_required
@login_required
@@ -71,12 +83,12 @@ class ToolBuiltinProviderUpdateApi(Resource):
def post(self, provider):
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
@@ -84,9 +96,10 @@ class ToolBuiltinProviderUpdateApi(Resource):
user_id,
tenant_id,
provider,
args['credentials'],
args["credentials"],
)
class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required
@login_required
@@ -101,6 +114,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
provider,
)
class ToolBuiltinProviderIconApi(Resource):
@setup_required
def get(self, provider):
@@ -108,6 +122,7 @@ class ToolBuiltinProviderIconApi(Resource):
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
class ToolApiProviderAddApi(Resource):
@setup_required
@login_required
@@ -115,35 +130,36 @@ class ToolApiProviderAddApi(Resource):
def post(self):
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json')
parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[])
parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json')
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
return ApiToolManageService.create_api_tool_provider(
user_id,
tenant_id,
args['provider'],
args['icon'],
args['credentials'],
args['schema_type'],
args['schema'],
args.get('privacy_policy', ''),
args.get('custom_disclaimer', ''),
args.get('labels', []),
args["provider"],
args["icon"],
args["credentials"],
args["schema_type"],
args["schema"],
args.get("privacy_policy", ""),
args.get("custom_disclaimer", ""),
args.get("labels", []),
)
class ToolApiProviderGetRemoteSchemaApi(Resource):
@setup_required
@login_required
@@ -151,16 +167,17 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('url', type=str, required=True, nullable=False, location='args')
parser.add_argument("url", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
return ApiToolManageService.get_api_tool_provider_remote_schema(
current_user.id,
current_user.current_tenant_id,
args['url'],
args["url"],
)
class ToolApiProviderListToolsApi(Resource):
@setup_required
@login_required
@@ -171,15 +188,18 @@ class ToolApiProviderListToolsApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools(
user_id,
tenant_id,
args['provider'],
))
return jsonable_encoder(
ApiToolManageService.list_api_tool_provider_tools(
user_id,
tenant_id,
args["provider"],
)
)
class ToolApiProviderUpdateApi(Resource):
@setup_required
@@ -188,37 +208,38 @@ class ToolApiProviderUpdateApi(Resource):
def post(self):
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json')
parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json')
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
args = parser.parse_args()
return ApiToolManageService.update_api_tool_provider(
user_id,
tenant_id,
args['provider'],
args['original_provider'],
args['icon'],
args['credentials'],
args['schema_type'],
args['schema'],
args['privacy_policy'],
args['custom_disclaimer'],
args.get('labels', []),
args["provider"],
args["original_provider"],
args["icon"],
args["credentials"],
args["schema_type"],
args["schema"],
args["privacy_policy"],
args["custom_disclaimer"],
args.get("labels", []),
)
class ToolApiProviderDeleteApi(Resource):
@setup_required
@login_required
@@ -226,22 +247,23 @@ class ToolApiProviderDeleteApi(Resource):
def post(self):
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return ApiToolManageService.delete_api_tool_provider(
user_id,
tenant_id,
args['provider'],
args["provider"],
)
class ToolApiProviderGetApi(Resource):
@setup_required
@login_required
@@ -252,16 +274,17 @@ class ToolApiProviderGetApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
return ApiToolManageService.get_api_tool_provider(
user_id,
tenant_id,
args['provider'],
args["provider"],
)
class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required
@login_required
@@ -269,6 +292,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
def get(self, provider):
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
class ToolApiProviderSchemaApi(Resource):
@setup_required
@login_required
@@ -276,14 +300,15 @@ class ToolApiProviderSchemaApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return ApiToolManageService.parser_api_schema(
schema=args['schema'],
schema=args["schema"],
)
class ToolApiProviderPreviousTestApi(Resource):
@setup_required
@login_required
@@ -291,25 +316,26 @@ class ToolApiProviderPreviousTestApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json')
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
args['provider_name'] if args['provider_name'] else '',
args['tool_name'],
args['credentials'],
args['parameters'],
args['schema_type'],
args['schema'],
args["provider_name"] if args["provider_name"] else "",
args["tool_name"],
args["credentials"],
args["parameters"],
args["schema_type"],
args["schema"],
)
class ToolWorkflowProviderCreateApi(Resource):
@setup_required
@login_required
@@ -317,35 +343,36 @@ class ToolWorkflowProviderCreateApi(Resource):
def post(self):
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json')
reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json')
reqparser.add_argument('label', type=str, required=True, nullable=False, location='json')
reqparser.add_argument('description', type=str, required=True, nullable=False, location='json')
reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json')
reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='')
reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
reqparser.add_argument("label", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("description", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
args = reqparser.parse_args()
return WorkflowToolManageService.create_workflow_tool(
user_id,
tenant_id,
args['workflow_app_id'],
args['name'],
args['label'],
args['icon'],
args['description'],
args['parameters'],
args['privacy_policy'],
args.get('labels', []),
args["workflow_app_id"],
args["name"],
args["label"],
args["icon"],
args["description"],
args["parameters"],
args["privacy_policy"],
args.get("labels", []),
)
class ToolWorkflowProviderUpdateApi(Resource):
@setup_required
@login_required
@@ -353,38 +380,39 @@ class ToolWorkflowProviderUpdateApi(Resource):
def post(self):
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json')
reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json')
reqparser.add_argument('label', type=str, required=True, nullable=False, location='json')
reqparser.add_argument('description', type=str, required=True, nullable=False, location='json')
reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json')
reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='')
reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
reqparser.add_argument("label", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("description", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
args = reqparser.parse_args()
if not args['workflow_tool_id']:
raise ValueError('incorrect workflow_tool_id')
if not args["workflow_tool_id"]:
raise ValueError("incorrect workflow_tool_id")
return WorkflowToolManageService.update_workflow_tool(
user_id,
tenant_id,
args['workflow_tool_id'],
args['name'],
args['label'],
args['icon'],
args['description'],
args['parameters'],
args['privacy_policy'],
args.get('labels', []),
args["workflow_tool_id"],
args["name"],
args["label"],
args["icon"],
args["description"],
args["parameters"],
args["privacy_policy"],
args.get("labels", []),
)
class ToolWorkflowProviderDeleteApi(Resource):
@setup_required
@login_required
@@ -392,21 +420,22 @@ class ToolWorkflowProviderDeleteApi(Resource):
def post(self):
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json')
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
args = reqparser.parse_args()
return WorkflowToolManageService.delete_workflow_tool(
user_id,
tenant_id,
args['workflow_tool_id'],
args["workflow_tool_id"],
)
class ToolWorkflowProviderGetApi(Resource):
@setup_required
@login_required
@@ -416,28 +445,29 @@ class ToolWorkflowProviderGetApi(Resource):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args')
parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args')
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
args = parser.parse_args()
if args.get('workflow_tool_id'):
if args.get("workflow_tool_id"):
tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
user_id,
tenant_id,
args['workflow_tool_id'],
args["workflow_tool_id"],
)
elif args.get('workflow_app_id'):
elif args.get("workflow_app_id"):
tool = WorkflowToolManageService.get_workflow_tool_by_app_id(
user_id,
tenant_id,
args['workflow_app_id'],
args["workflow_app_id"],
)
else:
raise ValueError('incorrect workflow_tool_id or workflow_app_id')
raise ValueError("incorrect workflow_tool_id or workflow_app_id")
return jsonable_encoder(tool)
class ToolWorkflowProviderListToolApi(Resource):
@setup_required
@login_required
@@ -447,15 +477,18 @@ class ToolWorkflowProviderListToolApi(Resource):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args')
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
args = parser.parse_args()
return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools(
user_id,
tenant_id,
args['workflow_tool_id'],
))
return jsonable_encoder(
WorkflowToolManageService.list_single_workflow_tools(
user_id,
tenant_id,
args["workflow_tool_id"],
)
)
class ToolBuiltinListApi(Resource):
@setup_required
@@ -465,11 +498,17 @@ class ToolBuiltinListApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)])
return jsonable_encoder(
[
provider.to_dict()
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
]
)
class ToolApiListApi(Resource):
@setup_required
@login_required
@@ -478,11 +517,17 @@ class ToolApiListApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools(
user_id,
tenant_id,
)])
return jsonable_encoder(
[
provider.to_dict()
for provider in ApiToolManageService.list_api_tools(
user_id,
tenant_id,
)
]
)
class ToolWorkflowListApi(Resource):
@setup_required
@login_required
@@ -491,11 +536,17 @@ class ToolWorkflowListApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools(
user_id,
tenant_id,
)])
return jsonable_encoder(
[
provider.to_dict()
for provider in WorkflowToolManageService.list_tenant_workflow_tools(
user_id,
tenant_id,
)
]
)
class ToolLabelsApi(Resource):
@setup_required
@login_required
@@ -503,36 +554,41 @@ class ToolLabelsApi(Resource):
def get(self):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
# tool provider
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin/<provider>/tools')
api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin/<provider>/delete')
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials')
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update")
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials"
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema"
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon")
# api tool provider
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update')
api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete')
api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get')
api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema')
api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre')
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote")
api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools")
api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update")
api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete")
api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get")
api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema")
api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre")
# workflow tool provider
api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create')
api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update')
api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete')
api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get')
api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools')
api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create")
api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update")
api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete")
api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin')
api.add_resource(ToolApiListApi, '/workspaces/current/tools/api')
api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow')
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels')
api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")

View File

@@ -26,39 +26,34 @@ from services.file_service import FileService
from services.workspace_service import WorkspaceService
provider_fields = {
'provider_name': fields.String,
'provider_type': fields.String,
'is_valid': fields.Boolean,
'token_is_set': fields.Boolean,
"provider_name": fields.String,
"provider_type": fields.String,
"is_valid": fields.Boolean,
"token_is_set": fields.Boolean,
}
tenant_fields = {
'id': fields.String,
'name': fields.String,
'plan': fields.String,
'status': fields.String,
'created_at': TimestampField,
'role': fields.String,
'in_trial': fields.Boolean,
'trial_end_reason': fields.String,
'custom_config': fields.Raw(attribute='custom_config'),
"id": fields.String,
"name": fields.String,
"plan": fields.String,
"status": fields.String,
"created_at": TimestampField,
"role": fields.String,
"in_trial": fields.Boolean,
"trial_end_reason": fields.String,
"custom_config": fields.Raw(attribute="custom_config"),
}
tenants_fields = {
'id': fields.String,
'name': fields.String,
'plan': fields.String,
'status': fields.String,
'created_at': TimestampField,
'current': fields.Boolean
"id": fields.String,
"name": fields.String,
"plan": fields.String,
"status": fields.String,
"created_at": TimestampField,
"current": fields.Boolean,
}
workspace_fields = {
'id': fields.String,
'name': fields.String,
'status': fields.String,
'created_at': TimestampField
}
workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField}
class TenantListApi(Resource):
@@ -71,7 +66,7 @@ class TenantListApi(Resource):
for tenant in tenants:
if tenant.id == current_user.current_tenant_id:
tenant.current = True # Set current=True for current tenant
return {'workspaces': marshal(tenants, tenants_fields)}, 200
return {"workspaces": marshal(tenants, tenants_fields)}, 200
class WorkspaceListApi(Resource):
@@ -79,31 +74,37 @@ class WorkspaceListApi(Resource):
@admin_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\
.paginate(page=args['page'], per_page=args['limit'])
tenants = (
db.session.query(Tenant)
.order_by(Tenant.created_at.desc())
.paginate(page=args["page"], per_page=args["limit"])
)
has_more = False
if len(tenants.items) == args['limit']:
if len(tenants.items) == args["limit"]:
current_page_first_tenant = tenants[-1]
rest_count = db.session.query(Tenant).filter(
Tenant.created_at < current_page_first_tenant.created_at,
Tenant.id != current_page_first_tenant.id
).count()
rest_count = (
db.session.query(Tenant)
.filter(
Tenant.created_at < current_page_first_tenant.created_at, Tenant.id != current_page_first_tenant.id
)
.count()
)
if rest_count > 0:
has_more = True
total = db.session.query(Tenant).count()
return {
'data': marshal(tenants.items, workspace_fields),
'has_more': has_more,
'limit': args['limit'],
'page': args['page'],
'total': total
}, 200
"data": marshal(tenants.items, workspace_fields),
"has_more": has_more,
"limit": args["limit"],
"page": args["page"],
"total": total,
}, 200
class TenantApi(Resource):
@@ -112,8 +113,8 @@ class TenantApi(Resource):
@account_initialization_required
@marshal_with(tenant_fields)
def get(self):
if request.path == '/info':
logging.warning('Deprecated URL /info was used.')
if request.path == "/info":
logging.warning("Deprecated URL /info was used.")
tenant = current_user.current_tenant
@@ -125,7 +126,7 @@ class TenantApi(Resource):
tenant = tenants[0]
# else, raise Unauthorized
else:
raise Unauthorized('workspace is archived')
raise Unauthorized("workspace is archived")
return WorkspaceService.get_tenant_info(tenant), 200
@@ -136,62 +137,64 @@ class SwitchWorkspaceApi(Resource):
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('tenant_id', type=str, required=True, location='json')
parser.add_argument("tenant_id", type=str, required=True, location="json")
args = parser.parse_args()
# check if tenant_id is valid, 403 if not
try:
TenantService.switch_tenant(current_user, args['tenant_id'])
TenantService.switch_tenant(current_user, args["tenant_id"])
except Exception:
raise AccountNotLinkTenantError("Account not link tenant")
new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant
new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant
return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
class CustomConfigWorkspaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('workspace_custom')
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('remove_webapp_brand', type=bool, location='json')
parser.add_argument('replace_webapp_logo', type=str, location='json')
parser.add_argument("remove_webapp_brand", type=bool, location="json")
parser.add_argument("replace_webapp_logo", type=str, location="json")
args = parser.parse_args()
tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404()
custom_config_dict = {
'remove_webapp_brand': args['remove_webapp_brand'],
'replace_webapp_logo': args['replace_webapp_logo'] if args['replace_webapp_logo'] is not None else tenant.custom_config_dict.get('replace_webapp_logo') ,
"remove_webapp_brand": args["remove_webapp_brand"],
"replace_webapp_logo": args["replace_webapp_logo"]
if args["replace_webapp_logo"] is not None
else tenant.custom_config_dict.get("replace_webapp_logo"),
}
tenant.custom_config_dict = custom_config_dict
db.session.commit()
return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
class WebappLogoWorkspaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('workspace_custom')
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
# get file from request
file = request.files['file']
file = request.files["file"]
# check file
if 'file' not in request.files:
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
extension = file.filename.split('.')[-1]
if extension.lower() not in ['svg', 'png']:
extension = file.filename.split(".")[-1]
if extension.lower() not in ["svg", "png"]:
raise UnsupportedFileTypeError()
try:
@@ -201,14 +204,14 @@ class WebappLogoWorkspaceApi(Resource):
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return { 'id': upload_file.id }, 201
return {"id": upload_file.id}, 201
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant
api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config')
api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload')
api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants
api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants
api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info
api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated
api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant
api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config")
api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload")