Feat/firecrawl data source (#5232)
Co-authored-by: Nicolas <nicolascamara29@gmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
0
api/services/auth/__init__.py
Normal file
0
api/services/auth/__init__.py
Normal file
10
api/services/auth/api_key_auth_base.py
Normal file
10
api/services/auth/api_key_auth_base.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ApiKeyAuthBase(ABC):
|
||||
def __init__(self, credentials: dict):
|
||||
self.credentials = credentials
|
||||
|
||||
@abstractmethod
|
||||
def validate_credentials(self):
|
||||
raise NotImplementedError
|
||||
14
api/services/auth/api_key_auth_factory.py
Normal file
14
api/services/auth/api_key_auth_factory.py
Normal file
@@ -0,0 +1,14 @@
|
||||
|
||||
from services.auth.firecrawl import FirecrawlAuth
|
||||
|
||||
|
||||
class ApiKeyAuthFactory:
|
||||
|
||||
def __init__(self, provider: str, credentials: dict):
|
||||
if provider == 'firecrawl':
|
||||
self.auth = FirecrawlAuth(credentials)
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
|
||||
def validate_credentials(self):
|
||||
return self.auth.validate_credentials()
|
||||
70
api/services/auth/api_key_auth_service.py
Normal file
70
api/services/auth/api_key_auth_service.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import json
|
||||
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
|
||||
|
||||
class ApiKeyAuthService:
|
||||
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str) -> list:
|
||||
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
).all()
|
||||
return data_source_api_key_bindings
|
||||
|
||||
@staticmethod
|
||||
def create_provider_auth(tenant_id: str, args: dict):
|
||||
auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
|
||||
if auth_result:
|
||||
# Encrypt the api key
|
||||
api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
|
||||
args['credentials']['config']['api_key'] = api_key
|
||||
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
|
||||
data_source_api_key_binding.tenant_id = tenant_id
|
||||
data_source_api_key_binding.category = args['category']
|
||||
data_source_api_key_binding.provider = args['provider']
|
||||
data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
||||
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.category == category,
|
||||
DataSourceApiKeyAuthBinding.provider == provider,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
).first()
|
||||
if not data_source_api_key_bindings:
|
||||
return None
|
||||
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str):
|
||||
data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.id == binding_id
|
||||
).first()
|
||||
if data_source_api_key_binding:
|
||||
db.session.delete(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def validate_api_key_auth_args(cls, args):
|
||||
if 'category' not in args or not args['category']:
|
||||
raise ValueError('category is required')
|
||||
if 'provider' not in args or not args['provider']:
|
||||
raise ValueError('provider is required')
|
||||
if 'credentials' not in args or not args['credentials']:
|
||||
raise ValueError('credentials is required')
|
||||
if not isinstance(args['credentials'], dict):
|
||||
raise ValueError('credentials must be a dictionary')
|
||||
if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
|
||||
raise ValueError('auth_type is required')
|
||||
|
||||
56
api/services/auth/firecrawl.py
Normal file
56
api/services/auth/firecrawl.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
||||
class FirecrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get('auth_type')
|
||||
if auth_type != 'bearer':
|
||||
raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer')
|
||||
self.api_key = credentials.get('config').get('api_key', None)
|
||||
self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev')
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError('No API key provided')
|
||||
|
||||
def validate_credentials(self):
|
||||
headers = self._prepare_headers()
|
||||
options = {
|
||||
'url': 'https://example.com',
|
||||
'crawlerOptions': {
|
||||
'excludes': [],
|
||||
'includes': [],
|
||||
'limit': 1
|
||||
},
|
||||
'pageOptions': {
|
||||
'onlyMainContent': True
|
||||
}
|
||||
}
|
||||
response = self._post_request(f'{self.base_url}/v0/crawl', options, headers)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
self._handle_error(response)
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return requests.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in [402, 409, 500]:
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
|
||||
else:
|
||||
if response.text:
|
||||
error_message = json.loads(response.text).get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
|
||||
raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}')
|
||||
Reference in New Issue
Block a user