chore(api/services): apply ruff reformatting (#7599)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Bowen Liang
2024-08-26 13:43:57 +08:00
committed by GitHub
parent 979422cdc6
commit 17fd773a30
49 changed files with 2630 additions and 2655 deletions

View File

@@ -1,14 +1,12 @@
from services.auth.firecrawl import FirecrawlAuth
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
if provider == 'firecrawl':
if provider == "firecrawl":
self.auth = FirecrawlAuth(credentials)
else:
raise ValueError('Invalid provider')
raise ValueError("Invalid provider")
def validate_credentials(self):
return self.auth.validate_credentials()

View File

@@ -7,39 +7,43 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str) -> list:
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.disabled.is_(False)
).all()
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
.all()
)
return data_source_api_key_bindings
@staticmethod
def create_provider_auth(tenant_id: str, args: dict):
auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
if auth_result:
# Encrypt the api key
api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
args['credentials']['config']['api_key'] = api_key
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
args["credentials"]["config"]["api_key"] = api_key
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
data_source_api_key_binding.tenant_id = tenant_id
data_source_api_key_binding.category = args['category']
data_source_api_key_binding.provider = args['provider']
data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
data_source_api_key_binding.category = args["category"]
data_source_api_key_binding.provider = args["provider"]
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
db.session.add(data_source_api_key_binding)
db.session.commit()
@staticmethod
def get_auth_credentials(tenant_id: str, category: str, provider: str):
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.category == category,
DataSourceApiKeyAuthBinding.provider == provider,
DataSourceApiKeyAuthBinding.disabled.is_(False)
).first()
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.category == category,
DataSourceApiKeyAuthBinding.provider == provider,
DataSourceApiKeyAuthBinding.disabled.is_(False),
)
.first()
)
if not data_source_api_key_bindings:
return None
credentials = json.loads(data_source_api_key_bindings.credentials)
@@ -47,24 +51,24 @@ class ApiKeyAuthService:
@staticmethod
def delete_provider_auth(tenant_id: str, binding_id: str):
data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.id == binding_id
).first()
data_source_api_key_binding = (
db.session.query(DataSourceApiKeyAuthBinding)
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
.first()
)
if data_source_api_key_binding:
db.session.delete(data_source_api_key_binding)
db.session.commit()
@classmethod
def validate_api_key_auth_args(cls, args):
if 'category' not in args or not args['category']:
raise ValueError('category is required')
if 'provider' not in args or not args['provider']:
raise ValueError('provider is required')
if 'credentials' not in args or not args['credentials']:
raise ValueError('credentials is required')
if not isinstance(args['credentials'], dict):
raise ValueError('credentials must be a dictionary')
if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
raise ValueError('auth_type is required')
if "category" not in args or not args["category"]:
raise ValueError("category is required")
if "provider" not in args or not args["provider"]:
raise ValueError("provider is required")
if "credentials" not in args or not args["credentials"]:
raise ValueError("credentials is required")
if not isinstance(args["credentials"], dict):
raise ValueError("credentials must be a dictionary")
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
raise ValueError("auth_type is required")

View File

@@ -8,49 +8,40 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get('auth_type')
if auth_type != 'bearer':
raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer')
self.api_key = credentials.get('config').get('api_key', None)
self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev')
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")
self.api_key = credentials.get("config").get("api_key", None)
self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev")
if not self.api_key:
raise ValueError('No API key provided')
raise ValueError("No API key provided")
def validate_credentials(self):
headers = self._prepare_headers()
options = {
'url': 'https://example.com',
'crawlerOptions': {
'excludes': [],
'includes': [],
'limit': 1
},
'pageOptions': {
'onlyMainContent': True
}
"url": "https://example.com",
"crawlerOptions": {"excludes": [], "includes": [], "limit": 1},
"pageOptions": {"onlyMainContent": True},
}
response = self._post_request(f'{self.base_url}/v0/crawl', options, headers)
response = self._post_request(f"{self.base_url}/v0/crawl", options, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in [402, 409, 500]:
error_message = response.json().get('error', 'Unknown error occurred')
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
else:
if response.text:
error_message = json.loads(response.text).get('error', 'Unknown error occurred')
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}')
error_message = json.loads(response.text).get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")