Feat/enterprise sso (#3602)
This commit is contained in:
0
api/services/enterprise/__init__.py
Normal file
0
api/services/enterprise/__init__.py
Normal file
20
api/services/enterprise/base.py
Normal file
20
api/services/enterprise/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class EnterpriseRequest:
|
||||
base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL')
|
||||
secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY')
|
||||
|
||||
@classmethod
|
||||
def send_request(cls, method, endpoint, json=None, params=None):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Enterprise-Api-Secret-Key": cls.secret_key
|
||||
}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers)
|
||||
|
||||
return response.json()
|
||||
28
api/services/enterprise/enterprise_feature_service.py
Normal file
28
api/services/enterprise/enterprise_feature_service.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
|
||||
class EnterpriseFeatureModel(BaseModel):
|
||||
sso_enforced_for_signin: bool = False
|
||||
sso_enforced_for_signin_protocol: str = ''
|
||||
|
||||
|
||||
class EnterpriseFeatureService:
|
||||
|
||||
@classmethod
|
||||
def get_enterprise_features(cls) -> EnterpriseFeatureModel:
|
||||
features = EnterpriseFeatureModel()
|
||||
|
||||
if current_app.config['ENTERPRISE_ENABLED']:
|
||||
cls._fulfill_params_from_enterprise(features)
|
||||
|
||||
return features
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_enterprise(cls, features):
|
||||
enterprise_info = EnterpriseService.get_info()
|
||||
|
||||
features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin']
|
||||
features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol']
|
||||
8
api/services/enterprise/enterprise_service.py
Normal file
8
api/services/enterprise/enterprise_service.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
|
||||
class EnterpriseService:
|
||||
|
||||
@classmethod
|
||||
def get_info(cls):
|
||||
return EnterpriseRequest.send_request('GET', '/info')
|
||||
60
api/services/enterprise/enterprise_sso_service.py
Normal file
60
api/services/enterprise/enterprise_sso_service.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import logging
|
||||
|
||||
from models.account import Account, AccountStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnterpriseSSOService:
|
||||
|
||||
@classmethod
|
||||
def get_sso_saml_login(cls) -> str:
|
||||
return EnterpriseRequest.send_request('GET', '/sso/saml/login')
|
||||
|
||||
@classmethod
|
||||
def post_sso_saml_acs(cls, saml_response: str) -> str:
|
||||
response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response})
|
||||
if 'email' not in response or response['email'] is None:
|
||||
logger.exception(response)
|
||||
raise Exception('Saml response is invalid')
|
||||
|
||||
return cls.login_with_email(response.get('email'))
|
||||
|
||||
@classmethod
|
||||
def get_sso_oidc_login(cls):
|
||||
return EnterpriseRequest.send_request('GET', '/sso/oidc/login')
|
||||
|
||||
@classmethod
|
||||
def get_sso_oidc_callback(cls, args: dict):
|
||||
state_from_query = args['state']
|
||||
code_from_query = args['code']
|
||||
state_from_cookies = args['oidc-state']
|
||||
|
||||
if state_from_cookies != state_from_query:
|
||||
raise Exception('invalid state or code')
|
||||
|
||||
response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query})
|
||||
if 'email' not in response or response['email'] is None:
|
||||
logger.exception(response)
|
||||
raise Exception('OIDC response is invalid')
|
||||
|
||||
return cls.login_with_email(response.get('email'))
|
||||
|
||||
@classmethod
|
||||
def login_with_email(cls, email: str) -> str:
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
if account is None:
|
||||
raise Exception('account not found, please contact system admin to invite you to join in a workspace')
|
||||
|
||||
if account.status == AccountStatus.BANNED:
|
||||
raise Exception('account is banned, please contact system admin')
|
||||
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if len(tenants) == 0:
|
||||
raise Exception("workspace not found, please contact system admin to invite you to join in a workspace")
|
||||
|
||||
token = AccountService.get_account_jwt_token(account)
|
||||
|
||||
return token
|
||||
Reference in New Issue
Block a user