Feat/firecrawl data source (#5232)
Co-authored-by: Nicolas <nicolascamara29@gmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
0
api/services/auth/__init__.py
Normal file
0
api/services/auth/__init__.py
Normal file
10
api/services/auth/api_key_auth_base.py
Normal file
10
api/services/auth/api_key_auth_base.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ApiKeyAuthBase(ABC):
|
||||
def __init__(self, credentials: dict):
|
||||
self.credentials = credentials
|
||||
|
||||
@abstractmethod
|
||||
def validate_credentials(self):
|
||||
raise NotImplementedError
|
||||
14
api/services/auth/api_key_auth_factory.py
Normal file
14
api/services/auth/api_key_auth_factory.py
Normal file
@@ -0,0 +1,14 @@
|
||||
|
||||
from services.auth.firecrawl import FirecrawlAuth
|
||||
|
||||
|
||||
class ApiKeyAuthFactory:
|
||||
|
||||
def __init__(self, provider: str, credentials: dict):
|
||||
if provider == 'firecrawl':
|
||||
self.auth = FirecrawlAuth(credentials)
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
|
||||
def validate_credentials(self):
|
||||
return self.auth.validate_credentials()
|
||||
70
api/services/auth/api_key_auth_service.py
Normal file
70
api/services/auth/api_key_auth_service.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import json
|
||||
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
|
||||
|
||||
class ApiKeyAuthService:
|
||||
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str) -> list:
|
||||
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
).all()
|
||||
return data_source_api_key_bindings
|
||||
|
||||
@staticmethod
|
||||
def create_provider_auth(tenant_id: str, args: dict):
|
||||
auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
|
||||
if auth_result:
|
||||
# Encrypt the api key
|
||||
api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
|
||||
args['credentials']['config']['api_key'] = api_key
|
||||
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
|
||||
data_source_api_key_binding.tenant_id = tenant_id
|
||||
data_source_api_key_binding.category = args['category']
|
||||
data_source_api_key_binding.provider = args['provider']
|
||||
data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
||||
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.category == category,
|
||||
DataSourceApiKeyAuthBinding.provider == provider,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
).first()
|
||||
if not data_source_api_key_bindings:
|
||||
return None
|
||||
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str):
|
||||
data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.id == binding_id
|
||||
).first()
|
||||
if data_source_api_key_binding:
|
||||
db.session.delete(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def validate_api_key_auth_args(cls, args):
|
||||
if 'category' not in args or not args['category']:
|
||||
raise ValueError('category is required')
|
||||
if 'provider' not in args or not args['provider']:
|
||||
raise ValueError('provider is required')
|
||||
if 'credentials' not in args or not args['credentials']:
|
||||
raise ValueError('credentials is required')
|
||||
if not isinstance(args['credentials'], dict):
|
||||
raise ValueError('credentials must be a dictionary')
|
||||
if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
|
||||
raise ValueError('auth_type is required')
|
||||
|
||||
56
api/services/auth/firecrawl.py
Normal file
56
api/services/auth/firecrawl.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
||||
class FirecrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get('auth_type')
|
||||
if auth_type != 'bearer':
|
||||
raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer')
|
||||
self.api_key = credentials.get('config').get('api_key', None)
|
||||
self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev')
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError('No API key provided')
|
||||
|
||||
def validate_credentials(self):
|
||||
headers = self._prepare_headers()
|
||||
options = {
|
||||
'url': 'https://example.com',
|
||||
'crawlerOptions': {
|
||||
'excludes': [],
|
||||
'includes': [],
|
||||
'limit': 1
|
||||
},
|
||||
'pageOptions': {
|
||||
'onlyMainContent': True
|
||||
}
|
||||
}
|
||||
response = self._post_request(f'{self.base_url}/v0/crawl', options, headers)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
self._handle_error(response)
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return requests.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in [402, 409, 500]:
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
|
||||
else:
|
||||
if response.text:
|
||||
error_message = json.loads(response.text).get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
|
||||
raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}')
|
||||
@@ -31,7 +31,7 @@ from models.dataset import (
|
||||
DocumentSegment,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.source import DataSourceBinding
|
||||
from models.source import DataSourceOauthBinding
|
||||
from services.errors.account import NoPermissionError
|
||||
from services.errors.dataset import DatasetInUseError, DatasetNameDuplicateError
|
||||
from services.errors.document import DocumentIndexingError
|
||||
@@ -48,6 +48,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
||||
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
|
||||
|
||||
|
||||
class DatasetService:
|
||||
@@ -508,18 +509,40 @@ class DocumentService:
|
||||
@staticmethod
|
||||
def retry_document(dataset_id: str, documents: list[Document]):
|
||||
for document in documents:
|
||||
# add retry flag
|
||||
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
|
||||
cache_result = redis_client.get(retry_indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise ValueError("Document is being retried, please try again later")
|
||||
# retry document indexing
|
||||
document.indexing_status = 'waiting'
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
# add retry flag
|
||||
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
|
||||
|
||||
redis_client.setex(retry_indexing_cache_key, 600, 1)
|
||||
# trigger async task
|
||||
document_ids = [document.id for document in documents]
|
||||
retry_document_indexing_task.delay(dataset_id, document_ids)
|
||||
|
||||
@staticmethod
|
||||
def sync_website_document(dataset_id: str, document: Document):
|
||||
# add sync flag
|
||||
sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id)
|
||||
cache_result = redis_client.get(sync_indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise ValueError("Document is being synced, please try again later")
|
||||
# sync document indexing
|
||||
document.indexing_status = 'waiting'
|
||||
data_source_info = document.data_source_info_dict
|
||||
data_source_info['mode'] = 'scrape'
|
||||
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
redis_client.setex(sync_indexing_cache_key, 600, 1)
|
||||
|
||||
sync_website_document_indexing_task.delay(dataset_id, document.id)
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||
if document:
|
||||
@@ -545,6 +568,9 @@ class DocumentService:
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info['pages'])
|
||||
elif document_data["data_source"]["type"] == "website_crawl":
|
||||
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||
count = len(website_info['urls'])
|
||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
@@ -683,12 +709,12 @@ class DocumentService:
|
||||
exist_document[data_source_info['notion_page_id']] = document.id
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
@@ -717,6 +743,28 @@ class DocumentService:
|
||||
# delete not selected documents
|
||||
if len(exist_document) > 0:
|
||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
elif document_data["data_source"]["type"] == "website_crawl":
|
||||
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||
urls = website_info['urls']
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
'url': url,
|
||||
'provider': website_info['provider'],
|
||||
'job_id': website_info['job_id'],
|
||||
'only_main_content': website_info.get('only_main_content', False),
|
||||
'mode': 'crawl',
|
||||
}
|
||||
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, url, batch)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
db.session.commit()
|
||||
|
||||
# trigger async task
|
||||
@@ -818,12 +866,12 @@ class DocumentService:
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
@@ -835,6 +883,17 @@ class DocumentService:
|
||||
"notion_page_icon": page['page_icon'],
|
||||
"type": page['type']
|
||||
}
|
||||
elif document_data["data_source"]["type"] == "website_crawl":
|
||||
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||
urls = website_info['urls']
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
'url': url,
|
||||
'provider': website_info['provider'],
|
||||
'job_id': website_info['job_id'],
|
||||
'only_main_content': website_info.get('only_main_content', False),
|
||||
'mode': 'crawl',
|
||||
}
|
||||
document.data_source_type = document_data["data_source"]["type"]
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.name = file_name
|
||||
@@ -873,6 +932,9 @@ class DocumentService:
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info['pages'])
|
||||
elif document_data["data_source"]["type"] == "website_crawl":
|
||||
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||
count = len(website_info['urls'])
|
||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
@@ -973,6 +1035,10 @@ class DocumentService:
|
||||
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
|
||||
'notion_info_list']:
|
||||
raise ValueError("Notion source info is required")
|
||||
if args['data_source']['type'] == 'website_crawl':
|
||||
if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
|
||||
'website_info_list']:
|
||||
raise ValueError("Website source info is required")
|
||||
|
||||
@classmethod
|
||||
def process_rule_args_validate(cls, args: dict):
|
||||
|
||||
171
api/services/website_service.py
Normal file
171
api/services/website_service.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
|
||||
class WebsiteService:
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
if 'url' not in args or not args['url']:
|
||||
raise ValueError('url is required')
|
||||
if 'options' not in args or not args['options']:
|
||||
raise ValueError('options is required')
|
||||
if 'limit' not in args['options'] or not args['options']['limit']:
|
||||
raise ValueError('limit is required')
|
||||
|
||||
@classmethod
|
||||
def crawl_url(cls, args: dict) -> dict:
|
||||
provider = args.get('provider')
|
||||
url = args.get('url')
|
||||
options = args.get('options')
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
crawl_sub_pages = options.get('crawl_sub_pages', False)
|
||||
only_main_content = options.get('only_main_content', False)
|
||||
if not crawl_sub_pages:
|
||||
params = {
|
||||
'crawlerOptions': {
|
||||
"includes": [],
|
||||
"excludes": [],
|
||||
"generateImgAltText": True,
|
||||
"limit": 1,
|
||||
'returnOnlyUrls': False,
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
includes = options.get('includes').split(',') if options.get('includes') else []
|
||||
excludes = options.get('excludes').split(',') if options.get('excludes') else []
|
||||
params = {
|
||||
'crawlerOptions': {
|
||||
"includes": includes if includes else [],
|
||||
"excludes": excludes if excludes else [],
|
||||
"generateImgAltText": True,
|
||||
"limit": options.get('limit', 1),
|
||||
'returnOnlyUrls': False,
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
}
|
||||
}
|
||||
if options.get('max_depth'):
|
||||
params['crawlerOptions']['maxDepth'] = options.get('max_depth')
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
website_crawl_time_cache_key = f'website_crawl_{job_id}'
|
||||
time = str(datetime.datetime.now().timestamp())
|
||||
redis_client.setex(website_crawl_time_cache_key, 3600, time)
|
||||
return {
|
||||
'status': 'active',
|
||||
'job_id': job_id
|
||||
}
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status(cls, job_id: str, provider: str) -> dict:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data = {
|
||||
'status': result.get('status', 'active'),
|
||||
'job_id': job_id,
|
||||
'total': result.get('total', 0),
|
||||
'current': result.get('current', 0),
|
||||
'data': result.get('data', [])
|
||||
}
|
||||
if crawl_status_data['status'] == 'completed':
|
||||
website_crawl_time_cache_key = f'website_crawl_{job_id}'
|
||||
start_time = redis_client.get(website_crawl_time_cache_key)
|
||||
if start_time:
|
||||
end_time = datetime.datetime.now().timestamp()
|
||||
time_consuming = abs(end_time - float(start_time))
|
||||
crawl_status_data['time_consuming'] = f"{time_consuming:.2f}"
|
||||
redis_client.delete(website_crawl_time_cache_key)
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
file_key = 'website_files/' + job_id + '.txt'
|
||||
if storage.exists(file_key):
|
||||
data = storage.load_once(file_key)
|
||||
if data:
|
||||
data = json.loads(data.decode('utf-8'))
|
||||
else:
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
if result.get('status') != 'completed':
|
||||
raise ValueError('Crawl job is not completed')
|
||||
data = result.get('data')
|
||||
if data:
|
||||
for item in data:
|
||||
if item.get('source_url') == url:
|
||||
return item
|
||||
return None
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
|
||||
@classmethod
|
||||
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
params = {
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
}
|
||||
result = firecrawl_app.scrape_url(url, params)
|
||||
return result
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
Reference in New Issue
Block a user