chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -4,6 +4,7 @@ The goal is to facilitate decoupling of content loading from content parsing cod
|
||||
|
||||
In addition, content loading code should provide a lazy loading interface by default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import csv
|
||||
from typing import Optional
|
||||
|
||||
@@ -18,12 +19,12 @@ class CSVExtractor(BaseExtractor):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[dict] = None,
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[dict] = None,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
@@ -57,7 +58,7 @@ class CSVExtractor(BaseExtractor):
|
||||
docs = []
|
||||
try:
|
||||
# load csv file into pandas dataframe
|
||||
df = pd.read_csv(csvfile, on_bad_lines='skip', **self.csv_args)
|
||||
df = pd.read_csv(csvfile, on_bad_lines="skip", **self.csv_args)
|
||||
|
||||
# check source column exists
|
||||
if self.source_column and self.source_column not in df.columns:
|
||||
@@ -67,7 +68,7 @@ class CSVExtractor(BaseExtractor):
|
||||
|
||||
for i, row in df.iterrows():
|
||||
content = ";".join(f"{col.strip()}: {str(row[col]).strip()}" for col in df.columns)
|
||||
source = row[self.source_column] if self.source_column else ''
|
||||
source = row[self.source_column] if self.source_column else ""
|
||||
metadata = {"source": source, "row": i}
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
@@ -10,6 +10,7 @@ class NotionInfo(BaseModel):
|
||||
"""
|
||||
Notion import info.
|
||||
"""
|
||||
|
||||
notion_workspace_id: str
|
||||
notion_obj_id: str
|
||||
notion_page_type: str
|
||||
@@ -25,6 +26,7 @@ class WebsiteInfo(BaseModel):
|
||||
"""
|
||||
website import info.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
job_id: str
|
||||
url: str
|
||||
@@ -43,6 +45,7 @@ class ExtractSetting(BaseModel):
|
||||
"""
|
||||
Model class for provider response.
|
||||
"""
|
||||
|
||||
datasource_type: str
|
||||
upload_file: Optional[UploadFile] = None
|
||||
notion_info: Optional[NotionInfo] = None
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
@@ -17,23 +18,18 @@ class ExcelExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False
|
||||
):
|
||||
def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
""" Load from Excel file in xls or xlsx format using Pandas and openpyxl."""
|
||||
"""Load from Excel file in xls or xlsx format using Pandas and openpyxl."""
|
||||
documents = []
|
||||
file_extension = os.path.splitext(self._file_path)[-1].lower()
|
||||
|
||||
if file_extension == '.xlsx':
|
||||
if file_extension == ".xlsx":
|
||||
wb = load_workbook(self._file_path, data_only=True)
|
||||
for sheet_name in wb.sheetnames:
|
||||
sheet = wb[sheet_name]
|
||||
@@ -44,35 +40,38 @@ class ExcelExtractor(BaseExtractor):
|
||||
continue
|
||||
df = pd.DataFrame(data, columns=cols)
|
||||
|
||||
df.dropna(how='all', inplace=True)
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
for index, row in df.iterrows():
|
||||
page_content = []
|
||||
for col_index, (k, v) in enumerate(row.items()):
|
||||
if pd.notna(v):
|
||||
cell = sheet.cell(row=index + 2,
|
||||
column=col_index + 1) # +2 to account for header and 1-based index
|
||||
cell = sheet.cell(
|
||||
row=index + 2, column=col_index + 1
|
||||
) # +2 to account for header and 1-based index
|
||||
if cell.hyperlink:
|
||||
value = f"[{v}]({cell.hyperlink.target})"
|
||||
page_content.append(f'"{k}":"{value}"')
|
||||
else:
|
||||
page_content.append(f'"{k}":"{v}"')
|
||||
documents.append(Document(page_content=';'.join(page_content),
|
||||
metadata={'source': self._file_path}))
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
|
||||
elif file_extension == '.xls':
|
||||
excel_file = pd.ExcelFile(self._file_path, engine='xlrd')
|
||||
elif file_extension == ".xls":
|
||||
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
|
||||
for sheet_name in excel_file.sheet_names:
|
||||
df = excel_file.parse(sheet_name=sheet_name)
|
||||
df.dropna(how='all', inplace=True)
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
for _, row in df.iterrows():
|
||||
page_content = []
|
||||
for k, v in row.items():
|
||||
if pd.notna(v):
|
||||
page_content.append(f'"{k}":"{v}"')
|
||||
documents.append(Document(page_content=';'.join(page_content),
|
||||
metadata={'source': self._file_path}))
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file extension: {file_extension}")
|
||||
|
||||
|
||||
@@ -29,61 +29,60 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain', 'application/json']
|
||||
SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"]
|
||||
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
|
||||
|
||||
class ExtractProcessor:
|
||||
@classmethod
|
||||
def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \
|
||||
-> Union[list[Document], str]:
|
||||
def load_from_upload_file(
|
||||
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
|
||||
) -> Union[list[Document], str]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=upload_file,
|
||||
document_model='text_model'
|
||||
datasource_type="upload_file", upload_file=upload_file, document_model="text_model"
|
||||
)
|
||||
if return_text:
|
||||
delimiter = '\n'
|
||||
delimiter = "\n"
|
||||
return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)])
|
||||
else:
|
||||
return cls.extract(extract_setting, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
|
||||
response = ssrf_proxy.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(url).suffix
|
||||
if not suffix and suffix != '.':
|
||||
if not suffix and suffix != ".":
|
||||
# get content-type
|
||||
if response.headers.get('Content-Type'):
|
||||
suffix = '.' + response.headers.get('Content-Type').split('/')[-1]
|
||||
if response.headers.get("Content-Type"):
|
||||
suffix = "." + response.headers.get("Content-Type").split("/")[-1]
|
||||
else:
|
||||
content_disposition = response.headers.get('Content-Disposition')
|
||||
content_disposition = response.headers.get("Content-Disposition")
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
suffix = '.' + re.search(r'\.(\w+)$', filename).group(1)
|
||||
suffix = "." + re.search(r"\.(\w+)$", filename).group(1)
|
||||
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
with open(file_path, 'wb') as file:
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(response.content)
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
document_model='text_model'
|
||||
)
|
||||
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
|
||||
if return_text:
|
||||
delimiter = '\n'
|
||||
return delimiter.join([document.page_content for document in cls.extract(
|
||||
extract_setting=extract_setting, file_path=file_path)])
|
||||
delimiter = "\n"
|
||||
return delimiter.join(
|
||||
[
|
||||
document.page_content
|
||||
for document in cls.extract(extract_setting=extract_setting, file_path=file_path)
|
||||
]
|
||||
)
|
||||
else:
|
||||
return cls.extract(extract_setting=extract_setting, file_path=file_path)
|
||||
|
||||
@classmethod
|
||||
def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False,
|
||||
file_path: str = None) -> list[Document]:
|
||||
def extract(
|
||||
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str = None
|
||||
) -> list[Document]:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE.value:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
if not file_path:
|
||||
@@ -96,50 +95,56 @@ class ExtractProcessor:
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
|
||||
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
|
||||
if etl_type == 'Unstructured':
|
||||
if file_extension == '.xlsx' or file_extension == '.xls':
|
||||
if etl_type == "Unstructured":
|
||||
if file_extension == ".xlsx" or file_extension == ".xls":
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \
|
||||
elif file_extension in [".md", ".markdown"]:
|
||||
extractor = (
|
||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url)
|
||||
if is_automatic
|
||||
else MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
)
|
||||
elif file_extension in [".htm", ".html"]:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension in ['.docx']:
|
||||
elif file_extension in [".docx"]:
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension == '.csv':
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension == '.msg':
|
||||
elif file_extension == ".msg":
|
||||
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.eml':
|
||||
elif file_extension == ".eml":
|
||||
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.ppt':
|
||||
elif file_extension == ".ppt":
|
||||
extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == '.pptx':
|
||||
elif file_extension == ".pptx":
|
||||
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.xml':
|
||||
elif file_extension == ".xml":
|
||||
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == 'epub':
|
||||
elif file_extension == "epub":
|
||||
extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url)
|
||||
else:
|
||||
# txt
|
||||
extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \
|
||||
extractor = (
|
||||
UnstructuredTextExtractor(file_path, unstructured_api_url)
|
||||
if is_automatic
|
||||
else TextExtractor(file_path, autodetect_encoding=True)
|
||||
)
|
||||
else:
|
||||
if file_extension == '.xlsx' or file_extension == '.xls':
|
||||
if file_extension == ".xlsx" or file_extension == ".xls":
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
elif file_extension in [".md", ".markdown"]:
|
||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
elif file_extension in [".htm", ".html"]:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension in ['.docx']:
|
||||
elif file_extension in [".docx"]:
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension == '.csv':
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension == 'epub':
|
||||
elif file_extension == "epub":
|
||||
extractor = UnstructuredEpubExtractor(file_path)
|
||||
else:
|
||||
# txt
|
||||
@@ -155,13 +160,13 @@ class ExtractProcessor:
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
|
||||
if extract_setting.website_info.provider == 'firecrawl':
|
||||
if extract_setting.website_info.provider == "firecrawl":
|
||||
extractor = FirecrawlWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
job_id=extract_setting.website_info.job_id,
|
||||
tenant_id=extract_setting.website_info.tenant_id,
|
||||
mode=extract_setting.website_info.mode,
|
||||
only_main_content=extract_setting.website_info.only_main_content
|
||||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
)
|
||||
return extractor.extract()
|
||||
else:
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseExtractor(ABC):
|
||||
"""Interface for extract files.
|
||||
"""
|
||||
"""Interface for extract files."""
|
||||
|
||||
@abstractmethod
|
||||
def extract(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -9,108 +9,98 @@ from extensions.ext_storage import storage
|
||||
class FirecrawlApp:
|
||||
def __init__(self, api_key=None, base_url=None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or 'https://api.firecrawl.dev'
|
||||
if self.api_key is None and self.base_url == 'https://api.firecrawl.dev':
|
||||
raise ValueError('No API key provided')
|
||||
self.base_url = base_url or "https://api.firecrawl.dev"
|
||||
if self.api_key is None and self.base_url == "https://api.firecrawl.dev":
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def scrape_url(self, url, params=None) -> dict:
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
json_data = {'url': url}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
json_data = {"url": url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = requests.post(
|
||||
f'{self.base_url}/v0/scrape',
|
||||
headers=headers,
|
||||
json=json_data
|
||||
)
|
||||
response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data)
|
||||
if response.status_code == 200:
|
||||
response = response.json()
|
||||
if response['success'] == True:
|
||||
data = response['data']
|
||||
if response["success"] == True:
|
||||
data = response["data"]
|
||||
return {
|
||||
'title': data.get('metadata').get('title'),
|
||||
'description': data.get('metadata').get('description'),
|
||||
'source_url': data.get('metadata').get('sourceURL'),
|
||||
'markdown': data.get('markdown')
|
||||
"title": data.get("metadata").get("title"),
|
||||
"description": data.get("metadata").get("description"),
|
||||
"source_url": data.get("metadata").get("sourceURL"),
|
||||
"markdown": data.get("markdown"),
|
||||
}
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
|
||||
|
||||
elif response.status_code in [402, 409, 500]:
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}')
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}")
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}')
|
||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}")
|
||||
|
||||
def crawl_url(self, url, params=None) -> str:
|
||||
headers = self._prepare_headers()
|
||||
json_data = {'url': url}
|
||||
json_data = {"url": url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers)
|
||||
response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
job_id = response.json().get('jobId')
|
||||
job_id = response.json().get("jobId")
|
||||
return job_id
|
||||
else:
|
||||
self._handle_error(response, 'start crawl job')
|
||||
self._handle_error(response, "start crawl job")
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict:
|
||||
headers = self._prepare_headers()
|
||||
response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers)
|
||||
response = self._get_request(f"{self.base_url}/v0/crawl/status/{job_id}", headers)
|
||||
if response.status_code == 200:
|
||||
crawl_status_response = response.json()
|
||||
if crawl_status_response.get('status') == 'completed':
|
||||
total = crawl_status_response.get('total', 0)
|
||||
if crawl_status_response.get("status") == "completed":
|
||||
total = crawl_status_response.get("total", 0)
|
||||
if total == 0:
|
||||
raise Exception('Failed to check crawl status. Error: No page found')
|
||||
data = crawl_status_response.get('data', [])
|
||||
raise Exception("Failed to check crawl status. Error: No page found")
|
||||
data = crawl_status_response.get("data", [])
|
||||
url_data_list = []
|
||||
for item in data:
|
||||
if isinstance(item, dict) and 'metadata' in item and 'markdown' in item:
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data = {
|
||||
'title': item.get('metadata').get('title'),
|
||||
'description': item.get('metadata').get('description'),
|
||||
'source_url': item.get('metadata').get('sourceURL'),
|
||||
'markdown': item.get('markdown')
|
||||
"title": item.get("metadata").get("title"),
|
||||
"description": item.get("metadata").get("description"),
|
||||
"source_url": item.get("metadata").get("sourceURL"),
|
||||
"markdown": item.get("markdown"),
|
||||
}
|
||||
url_data_list.append(url_data)
|
||||
if url_data_list:
|
||||
file_key = 'website_files/' + job_id + '.txt'
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, json.dumps(url_data_list).encode('utf-8'))
|
||||
storage.save(file_key, json.dumps(url_data_list).encode("utf-8"))
|
||||
return {
|
||||
'status': 'completed',
|
||||
'total': crawl_status_response.get('total'),
|
||||
'current': crawl_status_response.get('current'),
|
||||
'data': url_data_list
|
||||
"status": "completed",
|
||||
"total": crawl_status_response.get("total"),
|
||||
"current": crawl_status_response.get("current"),
|
||||
"data": url_data_list,
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
'status': crawl_status_response.get('status'),
|
||||
'total': crawl_status_response.get('total'),
|
||||
'current': crawl_status_response.get('current'),
|
||||
'data': []
|
||||
"status": crawl_status_response.get("status"),
|
||||
"total": crawl_status_response.get("total"),
|
||||
"current": crawl_status_response.get("current"),
|
||||
"data": [],
|
||||
}
|
||||
|
||||
else:
|
||||
self._handle_error(response, 'check crawl status')
|
||||
self._handle_error(response, "check crawl status")
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5):
|
||||
for attempt in range(retries):
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
if response.status_code == 502:
|
||||
time.sleep(backoff_factor * (2 ** attempt))
|
||||
time.sleep(backoff_factor * (2**attempt))
|
||||
else:
|
||||
return response
|
||||
return response
|
||||
@@ -119,13 +109,11 @@ class FirecrawlApp:
|
||||
for attempt in range(retries):
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 502:
|
||||
time.sleep(backoff_factor * (2 ** attempt))
|
||||
time.sleep(backoff_factor * (2**attempt))
|
||||
else:
|
||||
return response
|
||||
return response
|
||||
|
||||
def _handle_error(self, response, action):
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}')
|
||||
|
||||
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")
|
||||
|
||||
@@ -5,7 +5,7 @@ from services.website_service import WebsiteService
|
||||
|
||||
class FirecrawlWebExtractor(BaseExtractor):
|
||||
"""
|
||||
Crawl and scrape websites and return content in clean llm-ready markdown.
|
||||
Crawl and scrape websites and return content in clean llm-ready markdown.
|
||||
|
||||
|
||||
Args:
|
||||
@@ -15,14 +15,7 @@ class FirecrawlWebExtractor(BaseExtractor):
|
||||
mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
mode: str = 'crawl',
|
||||
only_main_content: bool = False
|
||||
):
|
||||
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False):
|
||||
"""Initialize with url, api_key, base_url and mode."""
|
||||
self._url = url
|
||||
self.job_id = job_id
|
||||
@@ -33,28 +26,31 @@ class FirecrawlWebExtractor(BaseExtractor):
|
||||
def extract(self) -> list[Document]:
|
||||
"""Extract content from the URL."""
|
||||
documents = []
|
||||
if self.mode == 'crawl':
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id)
|
||||
if self.mode == "crawl":
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id)
|
||||
if crawl_data is None:
|
||||
return []
|
||||
document = Document(page_content=crawl_data.get('markdown', ''),
|
||||
metadata={
|
||||
'source_url': crawl_data.get('source_url'),
|
||||
'description': crawl_data.get('description'),
|
||||
'title': crawl_data.get('title')
|
||||
}
|
||||
)
|
||||
document = Document(
|
||||
page_content=crawl_data.get("markdown", ""),
|
||||
metadata={
|
||||
"source_url": crawl_data.get("source_url"),
|
||||
"description": crawl_data.get("description"),
|
||||
"title": crawl_data.get("title"),
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
elif self.mode == 'scrape':
|
||||
scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id,
|
||||
self.only_main_content)
|
||||
elif self.mode == "scrape":
|
||||
scrape_data = WebsiteService.get_scrape_url_data(
|
||||
"firecrawl", self._url, self.tenant_id, self.only_main_content
|
||||
)
|
||||
|
||||
document = Document(page_content=scrape_data.get('markdown', ''),
|
||||
metadata={
|
||||
'source_url': scrape_data.get('source_url'),
|
||||
'description': scrape_data.get('description'),
|
||||
'title': scrape_data.get('title')
|
||||
}
|
||||
)
|
||||
document = Document(
|
||||
page_content=scrape_data.get("markdown", ""),
|
||||
metadata={
|
||||
"source_url": scrape_data.get("source_url"),
|
||||
"description": scrape_data.get("description"),
|
||||
"title": scrape_data.get("title"),
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
return documents
|
||||
|
||||
@@ -37,9 +37,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding
|
||||
try:
|
||||
encodings = future.result(timeout=timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Timeout reached while detecting encoding for {file_path}"
|
||||
)
|
||||
raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}")
|
||||
|
||||
if all(encoding["encoding"] is None for encoding in encodings):
|
||||
raise RuntimeError(f"Could not detect encoding for {file_path}")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
@@ -6,7 +7,6 @@ from core.rag.models.document import Document
|
||||
|
||||
|
||||
class HtmlExtractor(BaseExtractor):
|
||||
|
||||
"""
|
||||
Load html files.
|
||||
|
||||
@@ -15,10 +15,7 @@ class HtmlExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str
|
||||
):
|
||||
def __init__(self, file_path: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
|
||||
@@ -27,8 +24,8 @@ class HtmlExtractor(BaseExtractor):
|
||||
|
||||
def _load_as_text(self) -> str:
|
||||
with open(self._file_path, "rb") as fp:
|
||||
soup = BeautifulSoup(fp, 'html.parser')
|
||||
soup = BeautifulSoup(fp, "html.parser")
|
||||
text = soup.get_text()
|
||||
text = text.strip() if text else ''
|
||||
text = text.strip() if text else ""
|
||||
|
||||
return text
|
||||
return text
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import re
|
||||
from typing import Optional, cast
|
||||
|
||||
@@ -16,12 +17,12 @@ class MarkdownExtractor(BaseExtractor):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
remove_hyperlinks: bool = False,
|
||||
remove_images: bool = False,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
self,
|
||||
file_path: str,
|
||||
remove_hyperlinks: bool = False,
|
||||
remove_images: bool = False,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
@@ -78,13 +79,10 @@ class MarkdownExtractor(BaseExtractor):
|
||||
if current_header is not None:
|
||||
# pass linting, assert keys are defined
|
||||
markdown_tups = [
|
||||
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
|
||||
for key, value in markdown_tups
|
||||
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) for key, value in markdown_tups
|
||||
]
|
||||
else:
|
||||
markdown_tups = [
|
||||
(key, re.sub("\n", "", value)) for key, value in markdown_tups
|
||||
]
|
||||
markdown_tups = [(key, re.sub("\n", "", value)) for key, value in markdown_tups]
|
||||
|
||||
return markdown_tups
|
||||
|
||||
|
||||
@@ -21,22 +21,21 @@ RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
|
||||
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
|
||||
# if user want split by headings, use the corresponding splitter
|
||||
HEADING_SPLITTER = {
|
||||
'heading_1': '# ',
|
||||
'heading_2': '## ',
|
||||
'heading_3': '### ',
|
||||
"heading_1": "# ",
|
||||
"heading_2": "## ",
|
||||
"heading_3": "### ",
|
||||
}
|
||||
|
||||
|
||||
class NotionExtractor(BaseExtractor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
notion_workspace_id: str,
|
||||
notion_obj_id: str,
|
||||
notion_page_type: str,
|
||||
tenant_id: str,
|
||||
document_model: Optional[DocumentModel] = None,
|
||||
notion_access_token: Optional[str] = None,
|
||||
|
||||
self,
|
||||
notion_workspace_id: str,
|
||||
notion_obj_id: str,
|
||||
notion_page_type: str,
|
||||
tenant_id: str,
|
||||
document_model: Optional[DocumentModel] = None,
|
||||
notion_access_token: Optional[str] = None,
|
||||
):
|
||||
self._notion_access_token = None
|
||||
self._document_model = document_model
|
||||
@@ -46,46 +45,38 @@ class NotionExtractor(BaseExtractor):
|
||||
if notion_access_token:
|
||||
self._notion_access_token = notion_access_token
|
||||
else:
|
||||
self._notion_access_token = self._get_access_token(tenant_id,
|
||||
self._notion_workspace_id)
|
||||
self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id)
|
||||
if not self._notion_access_token:
|
||||
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
|
||||
if integration_token is None:
|
||||
raise ValueError(
|
||||
"Must specify `integration_token` or set environment "
|
||||
"variable `NOTION_INTEGRATION_TOKEN`."
|
||||
"Must specify `integration_token` or set environment " "variable `NOTION_INTEGRATION_TOKEN`."
|
||||
)
|
||||
|
||||
self._notion_access_token = integration_token
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
self.update_last_edited_time(
|
||||
self._document_model
|
||||
)
|
||||
self.update_last_edited_time(self._document_model)
|
||||
|
||||
text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
|
||||
|
||||
return text_docs
|
||||
|
||||
def _load_data_as_documents(
|
||||
self, notion_obj_id: str, notion_page_type: str
|
||||
) -> list[Document]:
|
||||
def _load_data_as_documents(self, notion_obj_id: str, notion_page_type: str) -> list[Document]:
|
||||
docs = []
|
||||
if notion_page_type == 'database':
|
||||
if notion_page_type == "database":
|
||||
# get all the pages in the database
|
||||
page_text_documents = self._get_notion_database_data(notion_obj_id)
|
||||
docs.extend(page_text_documents)
|
||||
elif notion_page_type == 'page':
|
||||
elif notion_page_type == "page":
|
||||
page_text_list = self._get_notion_block_data(notion_obj_id)
|
||||
docs.append(Document(page_content='\n'.join(page_text_list)))
|
||||
docs.append(Document(page_content="\n".join(page_text_list)))
|
||||
else:
|
||||
raise ValueError("notion page type not supported")
|
||||
|
||||
return docs
|
||||
|
||||
def _get_notion_database_data(
|
||||
self, database_id: str, query_dict: dict[str, Any] = {}
|
||||
) -> list[Document]:
|
||||
def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]:
|
||||
"""Get all the pages from a Notion database."""
|
||||
res = requests.post(
|
||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||
@@ -100,50 +91,50 @@ class NotionExtractor(BaseExtractor):
|
||||
data = res.json()
|
||||
|
||||
database_content = []
|
||||
if 'results' not in data or data["results"] is None:
|
||||
if "results" not in data or data["results"] is None:
|
||||
return []
|
||||
for result in data["results"]:
|
||||
properties = result['properties']
|
||||
properties = result["properties"]
|
||||
data = {}
|
||||
for property_name, property_value in properties.items():
|
||||
type = property_value['type']
|
||||
if type == 'multi_select':
|
||||
type = property_value["type"]
|
||||
if type == "multi_select":
|
||||
value = []
|
||||
multi_select_list = property_value[type]
|
||||
for multi_select in multi_select_list:
|
||||
value.append(multi_select['name'])
|
||||
elif type == 'rich_text' or type == 'title':
|
||||
value.append(multi_select["name"])
|
||||
elif type == "rich_text" or type == "title":
|
||||
if len(property_value[type]) > 0:
|
||||
value = property_value[type][0]['plain_text']
|
||||
value = property_value[type][0]["plain_text"]
|
||||
else:
|
||||
value = ''
|
||||
elif type == 'select' or type == 'status':
|
||||
value = ""
|
||||
elif type == "select" or type == "status":
|
||||
if property_value[type]:
|
||||
value = property_value[type]['name']
|
||||
value = property_value[type]["name"]
|
||||
else:
|
||||
value = ''
|
||||
value = ""
|
||||
else:
|
||||
value = property_value[type]
|
||||
data[property_name] = value
|
||||
row_dict = {k: v for k, v in data.items() if v}
|
||||
row_content = ''
|
||||
row_content = ""
|
||||
for key, value in row_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value_dict = {k: v for k, v in value.items() if v}
|
||||
value_content = ''.join(f'{k}:{v} ' for k, v in value_dict.items())
|
||||
row_content = row_content + f'{key}:{value_content}\n'
|
||||
value_content = "".join(f"{k}:{v} " for k, v in value_dict.items())
|
||||
row_content = row_content + f"{key}:{value_content}\n"
|
||||
else:
|
||||
row_content = row_content + f'{key}:{value}\n'
|
||||
row_content = row_content + f"{key}:{value}\n"
|
||||
database_content.append(row_content)
|
||||
|
||||
return [Document(page_content='\n'.join(database_content))]
|
||||
return [Document(page_content="\n".join(database_content))]
|
||||
|
||||
def _get_notion_block_data(self, page_id: str) -> list[str]:
|
||||
result_lines_arr = []
|
||||
start_cursor = None
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id)
|
||||
while True:
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor}
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
|
||||
res = requests.request(
|
||||
"GET",
|
||||
block_url,
|
||||
@@ -152,14 +143,14 @@ class NotionExtractor(BaseExtractor):
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
params=query_dict
|
||||
params=query_dict,
|
||||
)
|
||||
data = res.json()
|
||||
for result in data["results"]:
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
cur_result_text_arr = []
|
||||
if result_type == 'table':
|
||||
if result_type == "table":
|
||||
result_block_id = result["id"]
|
||||
text = self._read_table_rows(result_block_id)
|
||||
text += "\n\n"
|
||||
@@ -175,17 +166,15 @@ class NotionExtractor(BaseExtractor):
|
||||
result_block_id = result["id"]
|
||||
has_children = result["has_children"]
|
||||
block_type = result["type"]
|
||||
if has_children and block_type != 'child_page':
|
||||
children_text = self._read_block(
|
||||
result_block_id, num_tabs=1
|
||||
)
|
||||
if has_children and block_type != "child_page":
|
||||
children_text = self._read_block(result_block_id, num_tabs=1)
|
||||
cur_result_text_arr.append(children_text)
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
if result_type in HEADING_SPLITTER:
|
||||
result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}")
|
||||
else:
|
||||
result_lines_arr.append(cur_result_text + '\n\n')
|
||||
result_lines_arr.append(cur_result_text + "\n\n")
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
@@ -199,7 +188,7 @@ class NotionExtractor(BaseExtractor):
|
||||
start_cursor = None
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id)
|
||||
while True:
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor}
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
@@ -209,16 +198,16 @@ class NotionExtractor(BaseExtractor):
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
params=query_dict
|
||||
params=query_dict,
|
||||
)
|
||||
data = res.json()
|
||||
if 'results' not in data or data["results"] is None:
|
||||
if "results" not in data or data["results"] is None:
|
||||
break
|
||||
for result in data["results"]:
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
cur_result_text_arr = []
|
||||
if result_type == 'table':
|
||||
if result_type == "table":
|
||||
result_block_id = result["id"]
|
||||
text = self._read_table_rows(result_block_id)
|
||||
result_lines_arr.append(text)
|
||||
@@ -233,17 +222,15 @@ class NotionExtractor(BaseExtractor):
|
||||
result_block_id = result["id"]
|
||||
has_children = result["has_children"]
|
||||
block_type = result["type"]
|
||||
if has_children and block_type != 'child_page':
|
||||
children_text = self._read_block(
|
||||
result_block_id, num_tabs=num_tabs + 1
|
||||
)
|
||||
if has_children and block_type != "child_page":
|
||||
children_text = self._read_block(result_block_id, num_tabs=num_tabs + 1)
|
||||
cur_result_text_arr.append(children_text)
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
if result_type in HEADING_SPLITTER:
|
||||
result_lines_arr.append(f'{HEADING_SPLITTER[result_type]}{cur_result_text}')
|
||||
result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}")
|
||||
else:
|
||||
result_lines_arr.append(cur_result_text + '\n\n')
|
||||
result_lines_arr.append(cur_result_text + "\n\n")
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
@@ -260,7 +247,7 @@ class NotionExtractor(BaseExtractor):
|
||||
start_cursor = None
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id)
|
||||
while not done:
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor}
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
@@ -270,28 +257,28 @@ class NotionExtractor(BaseExtractor):
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
params=query_dict
|
||||
params=query_dict,
|
||||
)
|
||||
data = res.json()
|
||||
# get table headers text
|
||||
table_header_cell_texts = []
|
||||
table_header_cells = data["results"][0]['table_row']['cells']
|
||||
table_header_cells = data["results"][0]["table_row"]["cells"]
|
||||
for table_header_cell in table_header_cells:
|
||||
if table_header_cell:
|
||||
for table_header_cell_text in table_header_cell:
|
||||
text = table_header_cell_text["text"]["content"]
|
||||
table_header_cell_texts.append(text)
|
||||
else:
|
||||
table_header_cell_texts.append('')
|
||||
table_header_cell_texts.append("")
|
||||
# Initialize Markdown table with headers
|
||||
markdown_table = "| " + " | ".join(table_header_cell_texts) + " |\n"
|
||||
markdown_table += "| " + " | ".join(['---'] * len(table_header_cell_texts)) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["---"] * len(table_header_cell_texts)) + " |\n"
|
||||
|
||||
# Process data to format each row in Markdown table format
|
||||
results = data["results"]
|
||||
for i in range(len(results) - 1):
|
||||
column_texts = []
|
||||
table_column_cells = data["results"][i + 1]['table_row']['cells']
|
||||
table_column_cells = data["results"][i + 1]["table_row"]["cells"]
|
||||
for j in range(len(table_column_cells)):
|
||||
if table_column_cells[j]:
|
||||
for table_column_cell_text in table_column_cells[j]:
|
||||
@@ -315,10 +302,8 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
last_edited_time = self.get_notion_last_edited_time()
|
||||
data_source_info = document_model.data_source_info_dict
|
||||
data_source_info['last_edited_time'] = last_edited_time
|
||||
update_params = {
|
||||
DocumentModel.data_source_info: json.dumps(data_source_info)
|
||||
}
|
||||
data_source_info["last_edited_time"] = last_edited_time
|
||||
update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)}
|
||||
|
||||
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
|
||||
db.session.commit()
|
||||
@@ -326,7 +311,7 @@ class NotionExtractor(BaseExtractor):
|
||||
def get_notion_last_edited_time(self) -> str:
|
||||
obj_id = self._notion_obj_id
|
||||
page_type = self._notion_page_type
|
||||
if page_type == 'database':
|
||||
if page_type == "database":
|
||||
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
|
||||
else:
|
||||
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
|
||||
@@ -341,7 +326,7 @@ class NotionExtractor(BaseExtractor):
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
json=query_dict
|
||||
json=query_dict,
|
||||
)
|
||||
|
||||
data = res.json()
|
||||
@@ -352,14 +337,16 @@ class NotionExtractor(BaseExtractor):
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
|
||||
)
|
||||
).first()
|
||||
|
||||
if not data_source_binding:
|
||||
raise Exception(f'No notion data source binding found for tenant {tenant_id} '
|
||||
f'and notion workspace {notion_workspace_id}')
|
||||
raise Exception(
|
||||
f"No notion data source binding found for tenant {tenant_id} "
|
||||
f"and notion workspace {notion_workspace_id}"
|
||||
)
|
||||
|
||||
return data_source_binding.access_token
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional
|
||||
|
||||
@@ -16,21 +17,17 @@ class PdfExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
file_cache_key: Optional[str] = None
|
||||
):
|
||||
def __init__(self, file_path: str, file_cache_key: Optional[str] = None):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._file_cache_key = file_cache_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
plaintext_file_key = ''
|
||||
plaintext_file_key = ""
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
try:
|
||||
text = storage.load(self._file_cache_key).decode('utf-8')
|
||||
text = storage.load(self._file_cache_key).decode("utf-8")
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
except FileNotFoundError:
|
||||
@@ -43,12 +40,12 @@ class PdfExtractor(BaseExtractor):
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode('utf-8'))
|
||||
storage.save(plaintext_file_key, text.encode("utf-8"))
|
||||
|
||||
return documents
|
||||
|
||||
def load(
|
||||
self,
|
||||
self,
|
||||
) -> Iterator[Document]:
|
||||
"""Lazy load given path as pages."""
|
||||
blob = Blob.from_path(self._file_path)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
@@ -14,12 +15,7 @@ class TextExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False
|
||||
):
|
||||
def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
|
||||
@@ -8,13 +8,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredWordExtractor(BaseExtractor):
|
||||
"""Loader that uses unstructured to load word documents.
|
||||
"""
|
||||
"""Loader that uses unstructured to load word documents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
@@ -24,9 +23,7 @@ class UnstructuredWordExtractor(BaseExtractor):
|
||||
from unstructured.__version__ import __version__ as __unstructured_version__
|
||||
from unstructured.file_utils.filetype import FileType, detect_filetype
|
||||
|
||||
unstructured_version = tuple(
|
||||
int(x) for x in __unstructured_version__.split(".")
|
||||
)
|
||||
unstructured_version = tuple(int(x) for x in __unstructured_version__.split("."))
|
||||
# check the file extension
|
||||
try:
|
||||
import magic # noqa: F401
|
||||
@@ -53,6 +50,7 @@ class UnstructuredWordExtractor(BaseExtractor):
|
||||
elements = partition_docx(filename=self._file_path)
|
||||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
|
||||
@@ -26,6 +26,7 @@ class UnstructuredEmailExtractor(BaseExtractor):
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.email import partition_email
|
||||
|
||||
elements = partition_email(filename=self._file_path)
|
||||
|
||||
# noinspection PyBroadException
|
||||
@@ -34,15 +35,16 @@ class UnstructuredEmailExtractor(BaseExtractor):
|
||||
element_text = element.text.strip()
|
||||
|
||||
padding_needed = 4 - len(element_text) % 4
|
||||
element_text += '=' * padding_needed
|
||||
element_text += "=" * padding_needed
|
||||
|
||||
element_decode = base64.b64decode(element_text)
|
||||
soup = BeautifulSoup(element_decode.decode('utf-8'), 'html.parser')
|
||||
soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser")
|
||||
element.text = soup.get_text()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
|
||||
@@ -28,6 +28,7 @@ class UnstructuredEpubExtractor(BaseExtractor):
|
||||
|
||||
elements = partition_epub(filename=self._file_path, xml_keep_tags=True)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
|
||||
@@ -38,6 +38,7 @@ class UnstructuredMarkdownExtractor(BaseExtractor):
|
||||
|
||||
elements = partition_md(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
|
||||
@@ -14,11 +14,7 @@ class UnstructuredMsgExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
@@ -28,6 +24,7 @@ class UnstructuredMsgExtractor(BaseExtractor):
|
||||
|
||||
elements = partition_msg(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
|
||||
@@ -14,12 +14,7 @@ class UnstructuredPPTExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
api_key: str
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
@@ -14,11 +14,7 @@ class UnstructuredPPTXExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
@@ -14,11 +14,7 @@ class UnstructuredTextExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
@@ -28,6 +24,7 @@ class UnstructuredTextExtractor(BaseExtractor):
|
||||
|
||||
elements = partition_text(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
|
||||
@@ -14,11 +14,7 @@ class UnstructuredXmlExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
@@ -28,6 +24,7 @@ class UnstructuredXmlExtractor(BaseExtractor):
|
||||
|
||||
elements = partition_xml(filename=self._file_path, xml_keep_tags=True)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import mimetypes
|
||||
@@ -21,6 +22,7 @@ from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WordExtractor(BaseExtractor):
|
||||
"""Load docx files.
|
||||
|
||||
@@ -43,9 +45,7 @@ class WordExtractor(BaseExtractor):
|
||||
r = requests.get(self.file_path)
|
||||
|
||||
if r.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Check the url of your file; returned status code {r.status_code}"
|
||||
)
|
||||
raise ValueError(f"Check the url of your file; returned status code {r.status_code}")
|
||||
|
||||
self.web_path = self.file_path
|
||||
self.temp_file = tempfile.NamedTemporaryFile()
|
||||
@@ -60,11 +60,13 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load given path as single page."""
|
||||
content = self.parse_docx(self.file_path, 'storage')
|
||||
return [Document(
|
||||
page_content=content,
|
||||
metadata={"source": self.file_path},
|
||||
)]
|
||||
content = self.parse_docx(self.file_path, "storage")
|
||||
return [
|
||||
Document(
|
||||
page_content=content,
|
||||
metadata={"source": self.file_path},
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_url(url: str) -> bool:
|
||||
@@ -84,18 +86,18 @@ class WordExtractor(BaseExtractor):
|
||||
url = rel.reltype
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
image_ext = mimetypes.guess_extension(response.headers['Content-Type'])
|
||||
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
|
||||
mime_type, _ = mimetypes.guess_type(file_key)
|
||||
storage.save(file_key, response.content)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
image_ext = rel.target_ref.split('.')[-1]
|
||||
image_ext = rel.target_ref.split(".")[-1]
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
|
||||
mime_type, _ = mimetypes.guess_type(file_key)
|
||||
|
||||
storage.save(file_key, rel.target_part.blob)
|
||||
@@ -112,12 +114,14 @@ class WordExtractor(BaseExtractor):
|
||||
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
image_map[rel.target_part] = f""
|
||||
image_map[rel.target_part] = (
|
||||
f""
|
||||
)
|
||||
|
||||
return image_map
|
||||
|
||||
@@ -167,8 +171,8 @@ class WordExtractor(BaseExtractor):
|
||||
def _parse_cell_paragraph(self, paragraph, image_map):
|
||||
paragraph_content = []
|
||||
for run in paragraph.runs:
|
||||
if run.element.xpath('.//a:blip'):
|
||||
for blip in run.element.xpath('.//a:blip'):
|
||||
if run.element.xpath(".//a:blip"):
|
||||
for blip in run.element.xpath(".//a:blip"):
|
||||
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
|
||||
if not image_id:
|
||||
continue
|
||||
@@ -184,16 +188,16 @@ class WordExtractor(BaseExtractor):
|
||||
def _parse_paragraph(self, paragraph, image_map):
|
||||
paragraph_content = []
|
||||
for run in paragraph.runs:
|
||||
if run.element.xpath('.//a:blip'):
|
||||
for blip in run.element.xpath('.//a:blip'):
|
||||
embed_id = blip.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
|
||||
if run.element.xpath(".//a:blip"):
|
||||
for blip in run.element.xpath(".//a:blip"):
|
||||
embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
|
||||
if embed_id:
|
||||
rel_target = run.part.rels[embed_id].target_ref
|
||||
if rel_target in image_map:
|
||||
paragraph_content.append(image_map[rel_target])
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
return ' '.join(paragraph_content) if paragraph_content else ''
|
||||
return " ".join(paragraph_content) if paragraph_content else ""
|
||||
|
||||
def parse_docx(self, docx_path, image_folder):
|
||||
doc = DocxDocument(docx_path)
|
||||
@@ -204,60 +208,59 @@ class WordExtractor(BaseExtractor):
|
||||
image_map = self._extract_images_from_docx(doc, image_folder)
|
||||
|
||||
hyperlinks_url = None
|
||||
url_pattern = re.compile(r'http://[^\s+]+//|https://[^\s+]+')
|
||||
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
|
||||
for para in doc.paragraphs:
|
||||
for run in para.runs:
|
||||
if run.text and hyperlinks_url:
|
||||
result = f' [{run.text}]({hyperlinks_url}) '
|
||||
result = f" [{run.text}]({hyperlinks_url}) "
|
||||
run.text = result
|
||||
hyperlinks_url = None
|
||||
if 'HYPERLINK' in run.element.xml:
|
||||
if "HYPERLINK" in run.element.xml:
|
||||
try:
|
||||
xml = ET.XML(run.element.xml)
|
||||
x_child = [c for c in xml.iter() if c is not None]
|
||||
for x in x_child:
|
||||
if x_child is None:
|
||||
continue
|
||||
if x.tag.endswith('instrText'):
|
||||
if x.tag.endswith("instrText"):
|
||||
for i in url_pattern.findall(x.text):
|
||||
hyperlinks_url = str(i)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
|
||||
|
||||
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
for run in paragraph.runs:
|
||||
if hasattr(run.element, 'tag') and isinstance(element.tag, str) and run.element.tag.endswith('r'):
|
||||
if hasattr(run.element, "tag") and isinstance(element.tag, str) and run.element.tag.endswith("r"):
|
||||
drawing_elements = run.element.findall(
|
||||
'.//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing')
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing"
|
||||
)
|
||||
for drawing in drawing_elements:
|
||||
blip_elements = drawing.findall(
|
||||
'.//{http://schemas.openxmlformats.org/drawingml/2006/main}blip')
|
||||
".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip"
|
||||
)
|
||||
for blip in blip_elements:
|
||||
embed_id = blip.get(
|
||||
'{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
|
||||
)
|
||||
if embed_id:
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
return ''.join(paragraph_content) if paragraph_content else ''
|
||||
return "".join(paragraph_content) if paragraph_content else ""
|
||||
|
||||
paragraphs = doc.paragraphs.copy()
|
||||
tables = doc.tables.copy()
|
||||
for element in doc.element.body:
|
||||
if hasattr(element, 'tag'):
|
||||
if isinstance(element.tag, str) and element.tag.endswith('p'): # paragraph
|
||||
if hasattr(element, "tag"):
|
||||
if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph
|
||||
para = paragraphs.pop(0)
|
||||
parsed_paragraph = parse_paragraph(para)
|
||||
if parsed_paragraph:
|
||||
content.append(parsed_paragraph)
|
||||
elif isinstance(element.tag, str) and element.tag.endswith('tbl'): # table
|
||||
elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table
|
||||
table = tables.pop(0)
|
||||
content.append(self._table_to_markdown(table, image_map))
|
||||
return '\n'.join(content)
|
||||
|
||||
return "\n".join(content)
|
||||
|
||||
Reference in New Issue
Block a user