Feat/firecrawl data source (#5232)

Co-authored-by: Nicolas <nicolascamara29@gmail.com>
Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Jyong
2024-06-15 02:46:02 +08:00
committed by GitHub
parent 918ebe1620
commit ba5f8afaa8
36 changed files with 1174 additions and 64 deletions

View File

View File

@@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
class ApiKeyAuthBase(ABC):
def __init__(self, credentials: dict):
self.credentials = credentials
@abstractmethod
def validate_credentials(self):
raise NotImplementedError

View File

@@ -0,0 +1,14 @@
from services.auth.firecrawl import FirecrawlAuth
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
if provider == 'firecrawl':
self.auth = FirecrawlAuth(credentials)
else:
raise ValueError('Invalid provider')
def validate_credentials(self):
return self.auth.validate_credentials()

View File

@@ -0,0 +1,70 @@
import json
from core.helper import encrypter
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str) -> list:
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.disabled.is_(False)
).all()
return data_source_api_key_bindings
@staticmethod
def create_provider_auth(tenant_id: str, args: dict):
auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
if auth_result:
# Encrypt the api key
api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
args['credentials']['config']['api_key'] = api_key
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
data_source_api_key_binding.tenant_id = tenant_id
data_source_api_key_binding.category = args['category']
data_source_api_key_binding.provider = args['provider']
data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
db.session.add(data_source_api_key_binding)
db.session.commit()
@staticmethod
def get_auth_credentials(tenant_id: str, category: str, provider: str):
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.category == category,
DataSourceApiKeyAuthBinding.provider == provider,
DataSourceApiKeyAuthBinding.disabled.is_(False)
).first()
if not data_source_api_key_bindings:
return None
credentials = json.loads(data_source_api_key_bindings.credentials)
return credentials
@staticmethod
def delete_provider_auth(tenant_id: str, binding_id: str):
data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.id == binding_id
).first()
if data_source_api_key_binding:
db.session.delete(data_source_api_key_binding)
db.session.commit()
@classmethod
def validate_api_key_auth_args(cls, args):
if 'category' not in args or not args['category']:
raise ValueError('category is required')
if 'provider' not in args or not args['provider']:
raise ValueError('provider is required')
if 'credentials' not in args or not args['credentials']:
raise ValueError('credentials is required')
if not isinstance(args['credentials'], dict):
raise ValueError('credentials must be a dictionary')
if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
raise ValueError('auth_type is required')

View File

@@ -0,0 +1,56 @@
import json
import requests
from services.auth.api_key_auth_base import ApiKeyAuthBase
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get('auth_type')
if auth_type != 'bearer':
raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer')
self.api_key = credentials.get('config').get('api_key', None)
self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev')
if not self.api_key:
raise ValueError('No API key provided')
def validate_credentials(self):
headers = self._prepare_headers()
options = {
'url': 'https://example.com',
'crawlerOptions': {
'excludes': [],
'includes': [],
'limit': 1
},
'pageOptions': {
'onlyMainContent': True
}
}
response = self._post_request(f'{self.base_url}/v0/crawl', options, headers)
if response.status_code == 200:
return True
else:
self._handle_error(response)
def _prepare_headers(self):
return {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in [402, 409, 500]:
error_message = response.json().get('error', 'Unknown error occurred')
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
else:
if response.text:
error_message = json.loads(response.text).get('error', 'Unknown error occurred')
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}')

View File

@@ -31,7 +31,7 @@ from models.dataset import (
DocumentSegment,
)
from models.model import UploadFile
from models.source import DataSourceBinding
from models.source import DataSourceOauthBinding
from services.errors.account import NoPermissionError
from services.errors.dataset import DatasetInUseError, DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError
@@ -48,6 +48,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.retry_document_indexing_task import retry_document_indexing_task
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
class DatasetService:
@@ -508,18 +509,40 @@ class DocumentService:
@staticmethod
def retry_document(dataset_id: str, documents: list[Document]):
for document in documents:
# add retry flag
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
cache_result = redis_client.get(retry_indexing_cache_key)
if cache_result is not None:
raise ValueError("Document is being retried, please try again later")
# retry document indexing
document.indexing_status = 'waiting'
db.session.add(document)
db.session.commit()
# add retry flag
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
redis_client.setex(retry_indexing_cache_key, 600, 1)
# trigger async task
document_ids = [document.id for document in documents]
retry_document_indexing_task.delay(dataset_id, document_ids)
@staticmethod
def sync_website_document(dataset_id: str, document: Document):
# add sync flag
sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id)
cache_result = redis_client.get(sync_indexing_cache_key)
if cache_result is not None:
raise ValueError("Document is being synced, please try again later")
# sync document indexing
document.indexing_status = 'waiting'
data_source_info = document.data_source_info_dict
data_source_info['mode'] = 'scrape'
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
db.session.add(document)
db.session.commit()
redis_client.setex(sync_indexing_cache_key, 600, 1)
sync_website_document_indexing_task.delay(dataset_id, document.id)
@staticmethod
def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
if document:
@@ -545,6 +568,9 @@ class DocumentService:
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
count = count + len(notion_info['pages'])
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
count = len(website_info['urls'])
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -683,12 +709,12 @@ class DocumentService:
exist_document[data_source_info['notion_page_id']] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
@@ -717,6 +743,28 @@ class DocumentService:
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
urls = website_info['urls']
for url in urls:
data_source_info = {
'url': url,
'provider': website_info['provider'],
'job_id': website_info['job_id'],
'only_main_content': website_info.get('only_main_content', False),
'mode': 'crawl',
}
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, url, batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()
# trigger async task
@@ -818,12 +866,12 @@ class DocumentService:
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
@@ -835,6 +883,17 @@ class DocumentService:
"notion_page_icon": page['page_icon'],
"type": page['type']
}
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
urls = website_info['urls']
for url in urls:
data_source_info = {
'url': url,
'provider': website_info['provider'],
'job_id': website_info['job_id'],
'only_main_content': website_info.get('only_main_content', False),
'mode': 'crawl',
}
document.data_source_type = document_data["data_source"]["type"]
document.data_source_info = json.dumps(data_source_info)
document.name = file_name
@@ -873,6 +932,9 @@ class DocumentService:
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
for notion_info in notion_info_list:
count = count + len(notion_info['pages'])
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]['info_list']['website_info_list']
count = len(website_info['urls'])
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -973,6 +1035,10 @@ class DocumentService:
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
'notion_info_list']:
raise ValueError("Notion source info is required")
if args['data_source']['type'] == 'website_crawl':
if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
'website_info_list']:
raise ValueError("Website source info is required")
@classmethod
def process_rule_args_validate(cls, args: dict):

View File

@@ -0,0 +1,171 @@
import datetime
import json
from flask_login import current_user
from core.helper import encrypter
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from services.auth.api_key_auth_service import ApiKeyAuthService
class WebsiteService:
@classmethod
def document_create_args_validate(cls, args: dict):
if 'url' not in args or not args['url']:
raise ValueError('url is required')
if 'options' not in args or not args['options']:
raise ValueError('options is required')
if 'limit' not in args['options'] or not args['options']['limit']:
raise ValueError('limit is required')
@classmethod
def crawl_url(cls, args: dict) -> dict:
provider = args.get('provider')
url = args.get('url')
options = args.get('options')
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
'website',
provider)
if provider == 'firecrawl':
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id,
token=credentials.get('config').get('api_key')
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
crawl_sub_pages = options.get('crawl_sub_pages', False)
only_main_content = options.get('only_main_content', False)
if not crawl_sub_pages:
params = {
'crawlerOptions': {
"includes": [],
"excludes": [],
"generateImgAltText": True,
"limit": 1,
'returnOnlyUrls': False,
'pageOptions': {
'onlyMainContent': only_main_content,
"includeHtml": False
}
}
}
else:
includes = options.get('includes').split(',') if options.get('includes') else []
excludes = options.get('excludes').split(',') if options.get('excludes') else []
params = {
'crawlerOptions': {
"includes": includes if includes else [],
"excludes": excludes if excludes else [],
"generateImgAltText": True,
"limit": options.get('limit', 1),
'returnOnlyUrls': False,
'pageOptions': {
'onlyMainContent': only_main_content,
"includeHtml": False
}
}
}
if options.get('max_depth'):
params['crawlerOptions']['maxDepth'] = options.get('max_depth')
job_id = firecrawl_app.crawl_url(url, params)
website_crawl_time_cache_key = f'website_crawl_{job_id}'
time = str(datetime.datetime.now().timestamp())
redis_client.setex(website_crawl_time_cache_key, 3600, time)
return {
'status': 'active',
'job_id': job_id
}
else:
raise ValueError('Invalid provider')
@classmethod
def get_crawl_status(cls, job_id: str, provider: str) -> dict:
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
'website',
provider)
if provider == 'firecrawl':
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id,
token=credentials.get('config').get('api_key')
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
result = firecrawl_app.check_crawl_status(job_id)
crawl_status_data = {
'status': result.get('status', 'active'),
'job_id': job_id,
'total': result.get('total', 0),
'current': result.get('current', 0),
'data': result.get('data', [])
}
if crawl_status_data['status'] == 'completed':
website_crawl_time_cache_key = f'website_crawl_{job_id}'
start_time = redis_client.get(website_crawl_time_cache_key)
if start_time:
end_time = datetime.datetime.now().timestamp()
time_consuming = abs(end_time - float(start_time))
crawl_status_data['time_consuming'] = f"{time_consuming:.2f}"
redis_client.delete(website_crawl_time_cache_key)
else:
raise ValueError('Invalid provider')
return crawl_status_data
@classmethod
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
'website',
provider)
if provider == 'firecrawl':
file_key = 'website_files/' + job_id + '.txt'
if storage.exists(file_key):
data = storage.load_once(file_key)
if data:
data = json.loads(data.decode('utf-8'))
else:
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=tenant_id,
token=credentials.get('config').get('api_key')
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
result = firecrawl_app.check_crawl_status(job_id)
if result.get('status') != 'completed':
raise ValueError('Crawl job is not completed')
data = result.get('data')
if data:
for item in data:
if item.get('source_url') == url:
return item
return None
else:
raise ValueError('Invalid provider')
@classmethod
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
'website',
provider)
if provider == 'firecrawl':
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=tenant_id,
token=credentials.get('config').get('api_key')
)
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=credentials.get('config').get('base_url', None))
params = {
'pageOptions': {
'onlyMainContent': only_main_content,
"includeHtml": False
}
}
result = firecrawl_app.scrape_url(url, params)
return result
else:
raise ValueError('Invalid provider')