0
api/core/rag/__init__.py
Normal file
0
api/core/rag/__init__.py
Normal file
38
api/core/rag/cleaner/clean_processor.py
Normal file
38
api/core/rag/cleaner/clean_processor.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import re
|
||||
|
||||
|
||||
class CleanProcessor:
|
||||
|
||||
@classmethod
|
||||
def clean(cls, text: str, process_rule: dict) -> str:
|
||||
# default clean
|
||||
# remove invalid symbol
|
||||
text = re.sub(r'<\|', '<', text)
|
||||
text = re.sub(r'\|>', '>', text)
|
||||
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
|
||||
# Unicode U+FFFE
|
||||
text = re.sub('\uFFFE', '', text)
|
||||
|
||||
rules = process_rule['rules'] if process_rule else None
|
||||
if 'pre_processing_rules' in rules:
|
||||
pre_processing_rules = rules["pre_processing_rules"]
|
||||
for pre_processing_rule in pre_processing_rules:
|
||||
if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
|
||||
# Remove extra spaces
|
||||
pattern = r'\n{3,}'
|
||||
text = re.sub(pattern, '\n\n', text)
|
||||
pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
|
||||
text = re.sub(pattern, ' ', text)
|
||||
elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
|
||||
# Remove email
|
||||
pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
|
||||
text = re.sub(pattern, '', text)
|
||||
|
||||
# Remove URL
|
||||
pattern = r'https?://[^\s]+'
|
||||
text = re.sub(pattern, '', text)
|
||||
return text
|
||||
|
||||
def filter_string(self, text):
|
||||
|
||||
return text
|
||||
12
api/core/rag/cleaner/cleaner_base.py
Normal file
12
api/core/rag/cleaner/cleaner_base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Abstract interface for document cleaner implementations."""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseCleaner(ABC):
|
||||
"""Interface for clean chunk content.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def clean(self, content: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""clean document content."""
|
||||
from unstructured.cleaners.core import clean_extra_whitespace
|
||||
|
||||
# Returns "ITEM 1A: RISK FACTORS"
|
||||
return clean_extra_whitespace(content)
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""clean document content."""
|
||||
import re
|
||||
|
||||
from unstructured.cleaners.core import group_broken_paragraphs
|
||||
|
||||
para_split_re = re.compile(r"(\s*\n\s*){3}")
|
||||
|
||||
return group_broken_paragraphs(content, paragraph_split=para_split_re)
|
||||
@@ -0,0 +1,12 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""clean document content."""
|
||||
from unstructured.cleaners.core import clean_non_ascii_chars
|
||||
|
||||
# Returns "This text containsnon-ascii characters!"
|
||||
return clean_non_ascii_chars(content)
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""Replaces unicode quote characters, such as the \x91 character in a string."""
|
||||
|
||||
from unstructured.cleaners.core import replace_unicode_quotes
|
||||
return replace_unicode_quotes(content)
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredTranslateTextCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""clean document content."""
|
||||
from unstructured.cleaners.translate import translate_text
|
||||
|
||||
return translate_text(content)
|
||||
0
api/core/rag/data_post_processor/__init__.py
Normal file
0
api/core/rag/data_post_processor/__init__.py
Normal file
49
api/core/rag/data_post_processor/data_post_processor.py
Normal file
49
api/core/rag/data_post_processor/data_post_processor.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.data_post_processor.reorder import ReorderRunner
|
||||
from core.rag.models.document import Document
|
||||
from core.rerank.rerank import RerankRunner
|
||||
|
||||
|
||||
class DataPostProcessor:
|
||||
"""Interface for data post-processing document.
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False):
|
||||
self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id)
|
||||
self.reorder_runner = self._get_reorder_runner(reorder_enabled)
|
||||
|
||||
def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
|
||||
if self.rerank_runner:
|
||||
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
|
||||
|
||||
if self.reorder_runner:
|
||||
documents = self.reorder_runner.run(documents)
|
||||
|
||||
return documents
|
||||
|
||||
def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]:
|
||||
if reranking_model:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=reranking_model['reranking_provider_name'],
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model['reranking_model_name']
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return None
|
||||
return RerankRunner(rerank_model_instance)
|
||||
return None
|
||||
|
||||
def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
|
||||
if reorder_enabled:
|
||||
return ReorderRunner()
|
||||
return None
|
||||
|
||||
|
||||
19
api/core/rag/data_post_processor/reorder.py
Normal file
19
api/core/rag/data_post_processor/reorder.py
Normal file
@@ -0,0 +1,19 @@
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
class ReorderRunner:
|
||||
|
||||
def run(self, documents: list[Document]) -> list[Document]:
|
||||
# Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list
|
||||
odd_elements = documents[::2]
|
||||
|
||||
# Retrieve elements from even indices (1, 3, 5, etc.) of the documents list
|
||||
even_elements = documents[1::2]
|
||||
|
||||
# Reverse the list of elements from even indices
|
||||
even_elements_reversed = even_elements[::-1]
|
||||
|
||||
new_documents = odd_elements + even_elements_reversed
|
||||
|
||||
return new_documents
|
||||
0
api/core/rag/datasource/__init__.py
Normal file
0
api/core/rag/datasource/__init__.py
Normal file
21
api/core/rag/datasource/entity/embedding.py
Normal file
21
api/core/rag/datasource/entity/embedding.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
raise NotImplementedError
|
||||
0
api/core/rag/datasource/keyword/__init__.py
Normal file
0
api/core/rag/datasource/keyword/__init__.py
Normal file
0
api/core/rag/datasource/keyword/jieba/__init__.py
Normal file
0
api/core/rag/datasource/keyword/jieba/__init__.py
Normal file
232
api/core/rag/datasource/keyword/jieba/jieba.py
Normal file
232
api/core/rag/datasource/keyword/jieba/jieba.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
|
||||
|
||||
|
||||
class KeywordTableConfig(BaseModel):
|
||||
max_keywords_per_chunk: int = 10
|
||||
|
||||
|
||||
class Jieba(BaseKeyword):
|
||||
def __init__(self, dataset: Dataset):
|
||||
super().__init__(dataset)
|
||||
self._config = KeywordTableConfig()
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
return self
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keywords_list = kwargs.get('keywords_list', None)
|
||||
for i in range(len(texts)):
|
||||
text = texts[i]
|
||||
if keywords_list:
|
||||
keywords = keywords_list[i]
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
return id in set.union(*keyword_table.values())
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
# get segment ids by document_id
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset.id,
|
||||
DocumentSegment.document_id == document_id
|
||||
).all()
|
||||
|
||||
ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
k = kwargs.get('top_k', 4)
|
||||
|
||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
|
||||
|
||||
documents = []
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset.id,
|
||||
DocumentSegment.index_node_id == chunk_index
|
||||
).first()
|
||||
|
||||
if segment:
|
||||
documents.append(Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": chunk_index,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
))
|
||||
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
db.session.delete(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
def _save_dataset_keyword_table(self, keyword_table):
|
||||
keyword_table_dict = {
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": keyword_table
|
||||
}
|
||||
}
|
||||
self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
|
||||
db.session.commit()
|
||||
|
||||
def _get_dataset_keyword_table(self) -> Optional[dict]:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
if dataset_keyword_table.keyword_table_dict:
|
||||
return dataset_keyword_table.keyword_table_dict['__data__']['table']
|
||||
else:
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self.dataset.id,
|
||||
keyword_table=json.dumps({
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": {}
|
||||
}
|
||||
}, cls=SetEncoder)
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
return {}
|
||||
|
||||
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
|
||||
for keyword in keywords:
|
||||
if keyword not in keyword_table:
|
||||
keyword_table[keyword] = set()
|
||||
keyword_table[keyword].add(id)
|
||||
return keyword_table
|
||||
|
||||
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict:
|
||||
# get set of ids that correspond to node
|
||||
node_idxs_to_delete = set(ids)
|
||||
|
||||
# delete node_idxs from keyword to node idxs mapping
|
||||
keywords_to_delete = set()
|
||||
for keyword, node_idxs in keyword_table.items():
|
||||
if node_idxs_to_delete.intersection(node_idxs):
|
||||
keyword_table[keyword] = node_idxs.difference(
|
||||
node_idxs_to_delete
|
||||
)
|
||||
if not keyword_table[keyword]:
|
||||
keywords_to_delete.add(keyword)
|
||||
|
||||
for keyword in keywords_to_delete:
|
||||
del keyword_table[keyword]
|
||||
|
||||
return keyword_table
|
||||
|
||||
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keywords = keyword_table_handler.extract_keywords(query)
|
||||
|
||||
# go through text chunks in order of most matching keywords
|
||||
chunk_indices_count: dict[str, int] = defaultdict(int)
|
||||
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
|
||||
for keyword in keywords:
|
||||
for node_id in keyword_table[keyword]:
|
||||
chunk_indices_count[node_id] += 1
|
||||
|
||||
sorted_chunk_indices = sorted(
|
||||
list(chunk_indices_count.keys()),
|
||||
key=lambda x: chunk_indices_count[x],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return sorted_chunk_indices[: k]
|
||||
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
|
||||
document_segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.index_node_id == node_id
|
||||
).first()
|
||||
if document_segment:
|
||||
document_segment.keywords = keywords
|
||||
db.session.add(document_segment)
|
||||
db.session.commit()
|
||||
|
||||
def create_segment_keywords(self, node_id: str, keywords: list[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
self._update_segment_keywords(self.dataset.id, node_id, keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def multi_create_segment_keywords(self, pre_segment_data_list: list):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for pre_segment_data in pre_segment_data_list:
|
||||
segment = pre_segment_data['segment']
|
||||
if pre_segment_data['keywords']:
|
||||
segment.keywords = pre_segment_data['keywords']
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id,
|
||||
pre_segment_data['keywords'])
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content,
|
||||
self._config.max_keywords_per_chunk)
|
||||
segment.keywords = list(keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def update_segment_keywords_index(self, node_id: str, keywords: list[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
|
||||
class SetEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, set):
|
||||
return list(obj)
|
||||
return super().default(obj)
|
||||
@@ -0,0 +1,32 @@
|
||||
import re
|
||||
|
||||
import jieba
|
||||
from jieba.analyse import default_tfidf
|
||||
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
|
||||
class JiebaKeywordTableHandler:
|
||||
|
||||
def __init__(self):
|
||||
default_tfidf.stop_words = STOPWORDS
|
||||
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
sentence=text,
|
||||
topK=max_keywords_per_chunk,
|
||||
)
|
||||
|
||||
return set(self._expand_tokens_with_subtokens(keywords))
|
||||
|
||||
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
|
||||
"""Get subtokens from a list of tokens., filtering for stopwords."""
|
||||
results = set()
|
||||
for token in tokens:
|
||||
results.add(token)
|
||||
sub_tokens = re.findall(r"\w+", token)
|
||||
if len(sub_tokens) > 1:
|
||||
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
|
||||
|
||||
return results
|
||||
90
api/core/rag/datasource/keyword/jieba/stopwords.py
Normal file
90
api/core/rag/datasource/keyword/jieba/stopwords.py
Normal file
@@ -0,0 +1,90 @@
|
||||
STOPWORDS = {
|
||||
"during", "when", "but", "then", "further", "isn", "mustn't", "until", "own", "i", "couldn", "y", "only", "you've",
|
||||
"ours", "who", "where", "ourselves", "has", "to", "was", "didn't", "themselves", "if", "against", "through", "her",
|
||||
"an", "your", "can", "those", "didn", "about", "aren't", "shan't", "be", "not", "these", "again", "so", "t",
|
||||
"theirs", "weren", "won't", "won", "itself", "just", "same", "while", "why", "doesn", "aren", "him", "haven",
|
||||
"for", "you'll", "that", "we", "am", "d", "by", "having", "wasn't", "than", "weren't", "out", "from", "now",
|
||||
"their", "too", "hadn", "o", "needn", "most", "it", "under", "needn't", "any", "some", "few", "ll", "hers", "which",
|
||||
"m", "you're", "off", "other", "had", "she", "you'd", "do", "you", "does", "s", "will", "each", "wouldn't", "hasn't",
|
||||
"such", "more", "whom", "she's", "my", "yours", "yourself", "of", "on", "very", "hadn't", "with", "yourselves",
|
||||
"been", "ma", "them", "mightn't", "shan", "mustn", "they", "what", "both", "that'll", "how", "is", "he", "because",
|
||||
"down", "haven't", "are", "no", "it's", "our", "being", "the", "or", "above", "myself", "once", "don't", "doesn't",
|
||||
"as", "nor", "here", "herself", "hasn", "mightn", "have", "its", "all", "were", "ain", "this", "at", "after",
|
||||
"over", "shouldn't", "into", "before", "don", "wouldn", "re", "couldn't", "wasn", "in", "should", "there",
|
||||
"himself", "isn't", "should've", "doing", "ve", "shouldn", "a", "did", "and", "his", "between", "me", "up", "below",
|
||||
"人民", "末##末", "啊", "阿", "哎", "哎呀", "哎哟", "唉", "俺", "俺们", "按", "按照", "吧", "吧哒", "把", "罢了", "被", "本",
|
||||
"本着", "比", "比方", "比如", "鄙人", "彼", "彼此", "边", "别", "别的", "别说", "并", "并且", "不比", "不成", "不单", "不但",
|
||||
"不独", "不管", "不光", "不过", "不仅", "不拘", "不论", "不怕", "不然", "不如", "不特", "不惟", "不问", "不只", "朝", "朝着",
|
||||
"趁", "趁着", "乘", "冲", "除", "除此之外", "除非", "除了", "此", "此间", "此外", "从", "从而", "打", "待", "但", "但是", "当",
|
||||
"当着", "到", "得", "的", "的话", "等", "等等", "地", "第", "叮咚", "对", "对于", "多", "多少", "而", "而况", "而且", "而是",
|
||||
"而外", "而言", "而已", "尔后", "反过来", "反过来说", "反之", "非但", "非徒", "否则", "嘎", "嘎登", "该", "赶", "个", "各",
|
||||
"各个", "各位", "各种", "各自", "给", "根据", "跟", "故", "故此", "固然", "关于", "管", "归", "果然", "果真", "过", "哈",
|
||||
"哈哈", "呵", "和", "何", "何处", "何况", "何时", "嘿", "哼", "哼唷", "呼哧", "乎", "哗", "还是", "还有", "换句话说", "换言之",
|
||||
"或", "或是", "或者", "极了", "及", "及其", "及至", "即", "即便", "即或", "即令", "即若", "即使", "几", "几时", "己", "既",
|
||||
"既然", "既是", "继而", "加之", "假如", "假若", "假使", "鉴于", "将", "较", "较之", "叫", "接着", "结果", "借", "紧接着",
|
||||
"进而", "尽", "尽管", "经", "经过", "就", "就是", "就是说", "据", "具体地说", "具体说来", "开始", "开外", "靠", "咳", "可",
|
||||
"可见", "可是", "可以", "况且", "啦", "来", "来着", "离", "例如", "哩", "连", "连同", "两者", "了", "临", "另", "另外",
|
||||
"另一方面", "论", "嘛", "吗", "慢说", "漫说", "冒", "么", "每", "每当", "们", "莫若", "某", "某个", "某些", "拿", "哪",
|
||||
"哪边", "哪儿", "哪个", "哪里", "哪年", "哪怕", "哪天", "哪些", "哪样", "那", "那边", "那儿", "那个", "那会儿", "那里", "那么",
|
||||
"那么些", "那么样", "那时", "那些", "那样", "乃", "乃至", "呢", "能", "你", "你们", "您", "宁", "宁可", "宁肯", "宁愿", "哦",
|
||||
"呕", "啪达", "旁人", "呸", "凭", "凭借", "其", "其次", "其二", "其他", "其它", "其一", "其余", "其中", "起", "起见", "岂但",
|
||||
"恰恰相反", "前后", "前者", "且", "然而", "然后", "然则", "让", "人家", "任", "任何", "任凭", "如", "如此", "如果", "如何",
|
||||
"如其", "如若", "如上所述", "若", "若非", "若是", "啥", "上下", "尚且", "设若", "设使", "甚而", "甚么", "甚至", "省得", "时候",
|
||||
"什么", "什么样", "使得", "是", "是的", "首先", "谁", "谁知", "顺", "顺着", "似的", "虽", "虽然", "虽说", "虽则", "随", "随着",
|
||||
"所", "所以", "他", "他们", "他人", "它", "它们", "她", "她们", "倘", "倘或", "倘然", "倘若", "倘使", "腾", "替", "通过", "同",
|
||||
"同时", "哇", "万一", "往", "望", "为", "为何", "为了", "为什么", "为着", "喂", "嗡嗡", "我", "我们", "呜", "呜呼", "乌乎",
|
||||
"无论", "无宁", "毋宁", "嘻", "吓", "相对而言", "像", "向", "向着", "嘘", "呀", "焉", "沿", "沿着", "要", "要不", "要不然",
|
||||
"要不是", "要么", "要是", "也", "也罢", "也好", "一", "一般", "一旦", "一方面", "一来", "一切", "一样", "一则", "依", "依照",
|
||||
"矣", "以", "以便", "以及", "以免", "以至", "以至于", "以致", "抑或", "因", "因此", "因而", "因为", "哟", "用", "由",
|
||||
"由此可见", "由于", "有", "有的", "有关", "有些", "又", "于", "于是", "于是乎", "与", "与此同时", "与否", "与其", "越是",
|
||||
"云云", "哉", "再说", "再者", "在", "在下", "咱", "咱们", "则", "怎", "怎么", "怎么办", "怎么样", "怎样", "咋", "照", "照着",
|
||||
"者", "这", "这边", "这儿", "这个", "这会儿", "这就是说", "这里", "这么", "这么点儿", "这么些", "这么样", "这时", "这些", "这样",
|
||||
"正如", "吱", "之", "之类", "之所以", "之一", "只是", "只限", "只要", "只有", "至", "至于", "诸位", "着", "着呢", "自", "自从",
|
||||
"自个儿", "自各儿", "自己", "自家", "自身", "综上所述", "总的来看", "总的来说", "总的说来", "总而言之", "总之", "纵", "纵令",
|
||||
"纵然", "纵使", "遵照", "作为", "兮", "呃", "呗", "咚", "咦", "喏", "啐", "喔唷", "嗬", "嗯", "嗳", "~", "!", ".", ":",
|
||||
"\"", "'", "(", ")", "*", "A", "白", "社会主义", "--", "..", ">>", " [", " ]", "", "<", ">", "/", "\\", "|", "-", "_",
|
||||
"+", "=", "&", "^", "%", "#", "@", "`", ";", "$", "(", ")", "——", "—", "¥", "·", "...", "‘", "’", "〉", "〈", "…",
|
||||
" ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "二",
|
||||
"三", "四", "五", "六", "七", "八", "九", "零", ">", "<", "@", "#", "$", "%", "︿", "&", "*", "+", "~", "|", "[",
|
||||
"]", "{", "}", "啊哈", "啊呀", "啊哟", "挨次", "挨个", "挨家挨户", "挨门挨户", "挨门逐户", "挨着", "按理", "按期", "按时",
|
||||
"按说", "暗地里", "暗中", "暗自", "昂然", "八成", "白白", "半", "梆", "保管", "保险", "饱", "背地里", "背靠背", "倍感", "倍加",
|
||||
"本人", "本身", "甭", "比起", "比如说", "比照", "毕竟", "必", "必定", "必将", "必须", "便", "别人", "并非", "并肩", "并没",
|
||||
"并没有", "并排", "并无", "勃然", "不", "不必", "不常", "不大", "不但...而且", "不得", "不得不", "不得了", "不得已", "不迭",
|
||||
"不定", "不对", "不妨", "不管怎样", "不会", "不仅...而且", "不仅仅", "不仅仅是", "不经意", "不可开交", "不可抗拒", "不力", "不了",
|
||||
"不料", "不满", "不免", "不能不", "不起", "不巧", "不然的话", "不日", "不少", "不胜", "不时", "不是", "不同", "不能", "不要",
|
||||
"不外", "不外乎", "不下", "不限", "不消", "不已", "不亦乐乎", "不由得", "不再", "不择手段", "不怎么", "不曾", "不知不觉", "不止",
|
||||
"不止一次", "不至于", "才", "才能", "策略地", "差不多", "差一点", "常", "常常", "常言道", "常言说", "常言说得好", "长此下去",
|
||||
"长话短说", "长期以来", "长线", "敞开儿", "彻夜", "陈年", "趁便", "趁机", "趁热", "趁势", "趁早", "成年", "成年累月", "成心",
|
||||
"乘机", "乘胜", "乘势", "乘隙", "乘虚", "诚然", "迟早", "充分", "充其极", "充其量", "抽冷子", "臭", "初", "出", "出来", "出去",
|
||||
"除此", "除此而外", "除此以外", "除开", "除去", "除却", "除外", "处处", "川流不息", "传", "传说", "传闻", "串行", "纯", "纯粹",
|
||||
"此后", "此中", "次第", "匆匆", "从不", "从此", "从此以后", "从古到今", "从古至今", "从今以后", "从宽", "从来", "从轻", "从速",
|
||||
"从头", "从未", "从无到有", "从小", "从新", "从严", "从优", "从早到晚", "从中", "从重", "凑巧", "粗", "存心", "达旦", "打从",
|
||||
"打开天窗说亮话", "大", "大不了", "大大", "大抵", "大都", "大多", "大凡", "大概", "大家", "大举", "大略", "大面儿上", "大事",
|
||||
"大体", "大体上", "大约", "大张旗鼓", "大致", "呆呆地", "带", "殆", "待到", "单", "单纯", "单单", "但愿", "弹指之间", "当场",
|
||||
"当儿", "当即", "当口儿", "当然", "当庭", "当头", "当下", "当真", "当中", "倒不如", "倒不如说", "倒是", "到处", "到底", "到了儿",
|
||||
"到目前为止", "到头", "到头来", "得起", "得天独厚", "的确", "等到", "叮当", "顶多", "定", "动不动", "动辄", "陡然", "都", "独",
|
||||
"独自", "断然", "顿时", "多次", "多多", "多多少少", "多多益善", "多亏", "多年来", "多年前", "而后", "而论", "而又", "尔等",
|
||||
"二话不说", "二话没说", "反倒", "反倒是", "反而", "反手", "反之亦然", "反之则", "方", "方才", "方能", "放量", "非常", "非得",
|
||||
"分期", "分期分批", "分头", "奋勇", "愤然", "风雨无阻", "逢", "弗", "甫", "嘎嘎", "该当", "概", "赶快", "赶早不赶晚", "敢",
|
||||
"敢情", "敢于", "刚", "刚才", "刚好", "刚巧", "高低", "格外", "隔日", "隔夜", "个人", "各式", "更", "更加", "更进一步", "更为",
|
||||
"公然", "共", "共总", "够瞧的", "姑且", "古来", "故而", "故意", "固", "怪", "怪不得", "惯常", "光", "光是", "归根到底",
|
||||
"归根结底", "过于", "毫不", "毫无", "毫无保留地", "毫无例外", "好在", "何必", "何尝", "何妨", "何苦", "何乐而不为", "何须",
|
||||
"何止", "很", "很多", "很少", "轰然", "后来", "呼啦", "忽地", "忽然", "互", "互相", "哗啦", "话说", "还", "恍然", "会", "豁然",
|
||||
"活", "伙同", "或多或少", "或许", "基本", "基本上", "基于", "极", "极大", "极度", "极端", "极力", "极其", "极为", "急匆匆",
|
||||
"即将", "即刻", "即是说", "几度", "几番", "几乎", "几经", "既...又", "继之", "加上", "加以", "间或", "简而言之", "简言之",
|
||||
"简直", "见", "将才", "将近", "将要", "交口", "较比", "较为", "接连不断", "接下来", "皆可", "截然", "截至", "藉以", "借此",
|
||||
"借以", "届时", "仅", "仅仅", "谨", "进来", "进去", "近", "近几年来", "近来", "近年来", "尽管如此", "尽可能", "尽快", "尽量",
|
||||
"尽然", "尽如人意", "尽心竭力", "尽心尽力", "尽早", "精光", "经常", "竟", "竟然", "究竟", "就此", "就地", "就算", "居然", "局外",
|
||||
"举凡", "据称", "据此", "据实", "据说", "据我所知", "据悉", "具体来说", "决不", "决非", "绝", "绝不", "绝顶", "绝对", "绝非",
|
||||
"均", "喀", "看", "看来", "看起来", "看上去", "看样子", "可好", "可能", "恐怕", "快", "快要", "来不及", "来得及", "来讲",
|
||||
"来看", "拦腰", "牢牢", "老", "老大", "老老实实", "老是", "累次", "累年", "理当", "理该", "理应", "历", "立", "立地", "立刻",
|
||||
"立马", "立时", "联袂", "连连", "连日", "连日来", "连声", "连袂", "临到", "另方面", "另行", "另一个", "路经", "屡", "屡次",
|
||||
"屡次三番", "屡屡", "缕缕", "率尔", "率然", "略", "略加", "略微", "略为", "论说", "马上", "蛮", "满", "没", "没有", "每逢",
|
||||
"每每", "每时每刻", "猛然", "猛然间", "莫", "莫不", "莫非", "莫如", "默默地", "默然", "呐", "那末", "奈", "难道", "难得", "难怪",
|
||||
"难说", "内", "年复一年", "凝神", "偶而", "偶尔", "怕", "砰", "碰巧", "譬如", "偏偏", "乒", "平素", "颇", "迫于", "扑通",
|
||||
"其后", "其实", "奇", "齐", "起初", "起来", "起首", "起头", "起先", "岂", "岂非", "岂止", "迄", "恰逢", "恰好", "恰恰", "恰巧",
|
||||
"恰如", "恰似", "千", "千万", "千万千万", "切", "切不可", "切莫", "切切", "切勿", "窃", "亲口", "亲身", "亲手", "亲眼", "亲自",
|
||||
"顷", "顷刻", "顷刻间", "顷刻之间", "请勿", "穷年累月", "取道", "去", "权时", "全都", "全力", "全年", "全然", "全身心", "然",
|
||||
"人人", "仍", "仍旧", "仍然", "日复一日", "日见", "日渐", "日益", "日臻", "如常", "如此等等", "如次", "如今", "如期", "如前所述",
|
||||
"如上", "如下", "汝", "三番两次", "三番五次", "三天两头", "瑟瑟", "沙沙", "上", "上来", "上去", "一个", "月", "日", "\n"
|
||||
}
|
||||
54
api/core/rag/datasource/keyword/keyword_base.py
Normal file
54
api/core/rag/datasource/keyword/keyword_base.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class BaseKeyword(ABC):
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
@abstractmethod
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def text_exists(self, id: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_document_id(self, document_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts:
|
||||
doc_id = text.metadata['doc_id']
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
|
||||
return texts
|
||||
|
||||
def _get_uuids(self, texts: list[Document]) -> list[str]:
|
||||
return [text.metadata['doc_id'] for text in texts]
|
||||
60
api/core/rag/datasource/keyword/keyword_factory.py
Normal file
60
api/core/rag/datasource/keyword/keyword_factory.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.datasource.keyword.jieba.jieba import Jieba
|
||||
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class Keyword:
|
||||
def __init__(self, dataset: Dataset):
|
||||
self._dataset = dataset
|
||||
self._keyword_processor = self._init_keyword()
|
||||
|
||||
def _init_keyword(self) -> BaseKeyword:
|
||||
config = cast(dict, current_app.config)
|
||||
keyword_type = config.get('KEYWORD_STORE')
|
||||
|
||||
if not keyword_type:
|
||||
raise ValueError("Keyword store must be specified.")
|
||||
|
||||
if keyword_type == "jieba":
|
||||
return Jieba(
|
||||
dataset=self._dataset
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
def create(self, texts: list[Document], **kwargs):
|
||||
self._keyword_processor.create(texts, **kwargs)
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
self._keyword_processor.add_texts(texts, **kwargs)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self._keyword_processor.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._keyword_processor.delete_by_ids(ids)
|
||||
|
||||
def delete_by_document_id(self, document_id: str) -> None:
|
||||
self._keyword_processor.delete_by_document_id(document_id)
|
||||
|
||||
def delete(self) -> None:
|
||||
self._keyword_processor.delete()
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
return self._keyword_processor.search(query, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self._keyword_processor is not None:
|
||||
method = getattr(self._keyword_processor, name)
|
||||
if callable(method):
|
||||
return method
|
||||
|
||||
raise AttributeError(f"'Keyword' object has no attribute '{name}'")
|
||||
165
api/core/rag/datasource/retrieval_service.py
Normal file
165
api/core/rag/datasource/retrieval_service.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask_login import current_user
|
||||
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
|
||||
class RetrievalService:
|
||||
|
||||
@classmethod
|
||||
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
|
||||
all_documents = []
|
||||
threads = []
|
||||
# retrieval_model source with keyword
|
||||
if retrival_method == 'keyword_search':
|
||||
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'top_k': top_k
|
||||
})
|
||||
threads.append(keyword_thread)
|
||||
keyword_thread.start()
|
||||
# retrieval_model source with semantic
|
||||
if retrival_method == 'semantic_search' or retrival_method == 'hybrid_search':
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'top_k': top_k,
|
||||
'score_threshold': score_threshold,
|
||||
'reranking_model': reranking_model,
|
||||
'all_documents': all_documents,
|
||||
'retrival_method': retrival_method
|
||||
})
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# retrieval source with full text
|
||||
if retrival_method == 'full_text_search' or retrival_method == 'hybrid_search':
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'retrival_method': retrival_method,
|
||||
'score_threshold': score_threshold,
|
||||
'top_k': top_k,
|
||||
'reranking_model': reranking_model,
|
||||
'all_documents': all_documents
|
||||
})
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if retrival_method == 'hybrid_search':
|
||||
data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False)
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k
|
||||
)
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||
top_k: int, all_documents: list):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
keyword = Keyword(
|
||||
dataset=dataset
|
||||
)
|
||||
|
||||
documents = keyword.search(
|
||||
query,
|
||||
k=top_k
|
||||
)
|
||||
all_documents.extend(documents)
|
||||
|
||||
@classmethod
|
||||
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||
all_documents: list, retrival_method: str):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
vector = Vector(
|
||||
dataset=dataset
|
||||
)
|
||||
|
||||
documents = vector.search_by_vector(
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
filter={
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
)
|
||||
|
||||
if documents:
|
||||
if reranking_model and retrival_method == 'semantic_search':
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
|
||||
all_documents.extend(data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents)
|
||||
))
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
|
||||
@classmethod
|
||||
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||
all_documents: list, retrival_method: str):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
vector_processor = Vector(
|
||||
dataset=dataset,
|
||||
)
|
||||
|
||||
documents = vector_processor.search_by_full_text(
|
||||
query,
|
||||
top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
if reranking_model and retrival_method == 'full_text_search':
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
|
||||
all_documents.extend(data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents)
|
||||
))
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
0
api/core/rag/datasource/vdb/__init__.py
Normal file
0
api/core/rag/datasource/vdb/__init__.py
Normal file
10
api/core/rag/datasource/vdb/field.py
Normal file
10
api/core/rag/datasource/vdb/field.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Field(Enum):
|
||||
CONTENT_KEY = "page_content"
|
||||
METADATA_KEY = "metadata"
|
||||
GROUP_KEY = "group_id"
|
||||
VECTOR = "vector"
|
||||
TEXT_KEY = "text"
|
||||
PRIMARY_KEY = " id"
|
||||
0
api/core/rag/datasource/vdb/milvus/__init__.py
Normal file
0
api/core/rag/datasource/vdb/milvus/__init__.py
Normal file
214
api/core/rag/datasource/vdb/milvus/milvus_vector.py
Normal file
214
api/core/rag/datasource/vdb/milvus/milvus_vector.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pymilvus import MilvusClient, MilvusException, connections
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
secure: bool = False
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
raise ValueError("config MILVUS_HOST is required")
|
||||
if not values['port']:
|
||||
raise ValueError("config MILVUS_PORT is required")
|
||||
if not values['user']:
|
||||
raise ValueError("config MILVUS_USER is required")
|
||||
if not values['password']:
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
return {
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'user': self.user,
|
||||
'password': self.password,
|
||||
'secure': self.secure
|
||||
}
|
||||
|
||||
|
||||
class MilvusVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: MilvusConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = self._init_client(config)
|
||||
self._consistency_level = 'Session'
|
||||
self._fields = []
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'milvus'
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
index_params = {
|
||||
'metric_type': 'IP',
|
||||
'index_type': "HNSW",
|
||||
'params': {"M": 8, "efConstruction": 64}
|
||||
}
|
||||
metadatas = [d.metadata for d in texts]
|
||||
|
||||
# Grab the existing collection if it exists
|
||||
from pymilvus import utility
|
||||
alias = uuid4().hex
|
||||
if self._client_config.secure:
|
||||
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
else:
|
||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
|
||||
if not utility.has_collection(self._collection_name, using=alias):
|
||||
self.create_collection(embeddings, metadatas, index_params)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
insert_dict_list = []
|
||||
for i in range(len(documents)):
|
||||
insert_dict = {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata
|
||||
}
|
||||
insert_dict_list.append(insert_dict)
|
||||
# Total insert count
|
||||
total_count = len(insert_dict_list)
|
||||
|
||||
pks: list[str] = []
|
||||
|
||||
for i in range(0, total_count, 1000):
|
||||
batch_insert_list = insert_dict_list[i:i + 1000]
|
||||
# Insert into the collection.
|
||||
try:
|
||||
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
|
||||
pks.extend(ids)
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to insert batch starting at entity: %s/%s", i, total_count
|
||||
)
|
||||
raise e
|
||||
return pks
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
|
||||
ids = self.get_ids_by_metadata_field('document_id', document_id)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["{key}"] == "{value}"',
|
||||
output_fields=["id"])
|
||||
if result:
|
||||
return [item["id"] for item in result]
|
||||
else:
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
|
||||
self._client.delete(collection_name=self._collection_name, pks=doc_ids)
|
||||
|
||||
def delete(self) -> None:
|
||||
|
||||
from pymilvus import utility
|
||||
utility.drop_collection(self._collection_name, None)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] == "{id}"',
|
||||
output_fields=["id"])
|
||||
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
|
||||
# Set search parameters.
|
||||
results = self._client.search(collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
limit=kwargs.get('top_k', 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results[0]:
|
||||
metadata = result['entity'].get(Field.METADATA_KEY.value)
|
||||
metadata['score'] = result['distance']
|
||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||
if result['distance'] > score_threshold:
|
||||
doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
|
||||
metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# milvus/zilliz doesn't support bm25 search
|
||||
return []
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
) -> str:
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema
|
||||
from pymilvus.orm.types import infer_dtype_bydata
|
||||
|
||||
# Determine embedding dim
|
||||
dim = len(embeddings[0])
|
||||
fields = []
|
||||
if metadatas:
|
||||
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
||||
|
||||
# Create the text field
|
||||
fields.append(
|
||||
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
|
||||
)
|
||||
# Create the primary key field
|
||||
fields.append(
|
||||
FieldSchema(
|
||||
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
|
||||
)
|
||||
)
|
||||
# Create the vector field, supports binary or float vectors
|
||||
fields.append(
|
||||
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
|
||||
)
|
||||
|
||||
# Create the schema for the collection
|
||||
schema = CollectionSchema(fields)
|
||||
|
||||
for x in schema.fields:
|
||||
self._fields.append(x.name)
|
||||
# Since primary field is auto-id, no need to track it
|
||||
self._fields.remove(Field.PRIMARY_KEY.value)
|
||||
|
||||
# Create the collection
|
||||
collection_name = self._collection_name
|
||||
self._client.create_collection_with_schema(collection_name=collection_name,
|
||||
schema=schema, index_param=index_params,
|
||||
consistency_level=self._consistency_level)
|
||||
return collection_name
|
||||
|
||||
def _init_client(self, config) -> MilvusClient:
|
||||
if config.secure:
|
||||
uri = "https://" + str(config.host) + ":" + str(config.port)
|
||||
else:
|
||||
uri = "http://" + str(config.host) + ":" + str(config.port)
|
||||
client = MilvusClient(uri=uri, user=config.user, password=config.password)
|
||||
return client
|
||||
0
api/core/rag/datasource/vdb/qdrant/__init__.py
Normal file
0
api/core/rag/datasource/vdb/qdrant/__init__.py
Normal file
360
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
Normal file
360
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
Normal file
@@ -0,0 +1,360 @@
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
import qdrant_client
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client.http import models as rest
|
||||
from qdrant_client.http.models import (
|
||||
FilterSelector,
|
||||
HnswConfigDiff,
|
||||
PayloadSchemaType,
|
||||
TextIndexParams,
|
||||
TextIndexType,
|
||||
TokenizerType,
|
||||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
from qdrant_client.conversions import common_types
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
DictFilter = dict[str, Union[str, int, bool, dict, list]]
|
||||
MetadataFilter = Union[DictFilter, common_types.Filter]
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
timeout: float = 20
|
||||
root_path: Optional[str]
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith('path:'):
|
||||
path = self.endpoint.replace('path:', '')
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {
|
||||
'path': path
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'url': self.endpoint,
|
||||
'api_key': self.api_key,
|
||||
'timeout': self.timeout
|
||||
}
|
||||
|
||||
|
||||
class QdrantVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
|
||||
self._distance_func = distance_func.upper()
|
||||
self._group_id = group_id
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'qdrant'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
if texts:
|
||||
# get embedding vector size
|
||||
vector_size = len(embeddings[0])
|
||||
# get collection name
|
||||
collection_name = self._collection_name
|
||||
collection_name = collection_name or uuid.uuid4().hex
|
||||
all_collection_name = []
|
||||
collections_response = self._client.get_collections()
|
||||
collection_list = collections_response.collections
|
||||
for collection in collection_list:
|
||||
all_collection_name.append(collection.name)
|
||||
if collection_name not in all_collection_name:
|
||||
# create collection
|
||||
self.create_collection(collection_name, vector_size)
|
||||
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(self, collection_name: str, vector_size: int):
|
||||
from qdrant_client.http import models as rest
|
||||
vectors_config = rest.VectorParams(
|
||||
size=vector_size,
|
||||
distance=rest.Distance[self._distance_func],
|
||||
)
|
||||
hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||
max_indexing_threads=0, on_disk=False)
|
||||
self._client.recreate_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=vectors_config,
|
||||
hnsw_config=hnsw_config,
|
||||
timeout=int(self._client_config.timeout),
|
||||
)
|
||||
|
||||
# create payload index
|
||||
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
|
||||
field_schema=PayloadSchemaType.KEYWORD,
|
||||
field_type=PayloadSchemaType.KEYWORD)
|
||||
# creat full text index
|
||||
text_index_params = TextIndexParams(
|
||||
type=TextIndexType.TEXT,
|
||||
tokenizer=TokenizerType.MULTILINGUAL,
|
||||
min_token_len=2,
|
||||
max_token_len=20,
|
||||
lowercase=True
|
||||
)
|
||||
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
|
||||
field_schema=text_index_params)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
added_ids = []
|
||||
for batch_ids, points in self._generate_rest_batches(
|
||||
texts, embeddings, metadatas, uuids, 64, self._group_id
|
||||
):
|
||||
self._client.upsert(
|
||||
collection_name=self._collection_name, points=points
|
||||
)
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
return added_ids
|
||||
|
||||
def _generate_rest_batches(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
batch_size: int = 64,
|
||||
group_id: Optional[str] = None,
|
||||
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
|
||||
from qdrant_client.http import models as rest
|
||||
texts_iterator = iter(texts)
|
||||
embeddings_iterator = iter(embeddings)
|
||||
metadatas_iterator = iter(metadatas or [])
|
||||
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
|
||||
while batch_texts := list(islice(texts_iterator, batch_size)):
|
||||
# Take the corresponding metadata and id for each text in a batch
|
||||
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
||||
batch_ids = list(islice(ids_iterator, batch_size))
|
||||
|
||||
# Generate the embeddings for all the texts in a batch
|
||||
batch_embeddings = list(islice(embeddings_iterator, batch_size))
|
||||
|
||||
points = [
|
||||
rest.PointStruct(
|
||||
id=point_id,
|
||||
vector=vector,
|
||||
payload=payload,
|
||||
)
|
||||
for point_id, vector, payload in zip(
|
||||
batch_ids,
|
||||
batch_embeddings,
|
||||
self._build_payloads(
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.METADATA_KEY.value,
|
||||
group_id,
|
||||
Field.GROUP_KEY.value,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
yield batch_ids, points
|
||||
|
||||
@classmethod
|
||||
def _build_payloads(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]],
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
group_id: str,
|
||||
group_payload_key: str
|
||||
) -> list[dict]:
|
||||
payloads = []
|
||||
for i, text in enumerate(texts):
|
||||
if text is None:
|
||||
raise ValueError(
|
||||
"At least one of the texts is None. Please remove it before "
|
||||
"calling .from_texts or .add_texts on Qdrant instance."
|
||||
)
|
||||
metadata = metadatas[i] if metadatas is not None else None
|
||||
payloads.append(
|
||||
{
|
||||
content_payload_key: text,
|
||||
metadata_payload_key: metadata,
|
||||
group_payload_key: group_id
|
||||
}
|
||||
)
|
||||
|
||||
return payloads
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
|
||||
def delete(self):
|
||||
from qdrant_client.http import models
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self._group_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
from qdrant_client.http import models
|
||||
for node_id in ids:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
response = self._client.retrieve(
|
||||
collection_name=self._collection_name,
|
||||
ids=[id]
|
||||
)
|
||||
|
||||
return len(response) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from qdrant_client.http import models
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self._group_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=filter,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
score_threshold=kwargs.get("score_threshold", .0)
|
||||
)
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
if result.score > score_threshold:
|
||||
metadata['score'] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Return docs most similar by bm25.
|
||||
Returns:
|
||||
List of documents most similar to the query text and distance for each.
|
||||
"""
|
||||
from qdrant_client.http import models
|
||||
scroll_filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self._group_id),
|
||||
),
|
||||
models.FieldCondition(
|
||||
key="page_content",
|
||||
match=models.MatchText(text=query),
|
||||
)
|
||||
]
|
||||
)
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
limit=kwargs.get('top_k', 2),
|
||||
with_payload=True,
|
||||
with_vectors=True
|
||||
|
||||
)
|
||||
results = response[0]
|
||||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
documents.append(self._document_from_scored_point(
|
||||
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
|
||||
))
|
||||
|
||||
return documents
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client, QdrantLocal):
|
||||
self._client = cast(QdrantLocal, self._client)
|
||||
self._client._load()
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
cls,
|
||||
scored_point: Any,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
) -> Document:
|
||||
return Document(
|
||||
page_content=scored_point.payload.get(content_payload_key),
|
||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||
)
|
||||
62
api/core/rag/datasource/vdb/vector_base.py
Normal file
62
api/core/rag/datasource/vdb/vector_base.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class BaseVector(ABC):
|
||||
|
||||
def __init__(self, collection_name: str):
|
||||
self._collection_name = collection_name
|
||||
|
||||
@abstractmethod
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def text_exists(self, id: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search_by_vector(
|
||||
self,
|
||||
query_vector: list[float],
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search_by_full_text(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts:
|
||||
doc_id = text.metadata['doc_id']
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
|
||||
return texts
|
||||
|
||||
def _get_uuids(self, texts: list[Document]) -> list[str]:
|
||||
return [text.metadata['doc_id'] for text in texts]
|
||||
171
api/core/rag/datasource/vdb/vector_factory.py
Normal file
171
api/core/rag/datasource/vdb/vector_factory.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
|
||||
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list = None):
|
||||
if attributes is None:
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
||||
self._dataset = dataset
|
||||
self._embeddings = self._get_embeddings()
|
||||
self._attributes = attributes
|
||||
self._vector_processor = self._init_vector()
|
||||
|
||||
def _init_vector(self) -> BaseVector:
|
||||
config = cast(dict, current_app.config)
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
|
||||
if self._dataset.index_struct_dict:
|
||||
vector_type = self._dataset.index_struct_dict['type']
|
||||
|
||||
if not vector_type:
|
||||
raise ValueError("Vector store must be specified.")
|
||||
|
||||
if vector_type == "weaviate":
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
||||
if self._dataset.index_struct_dict:
|
||||
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = self._dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
return WeaviateVector(
|
||||
collection_name=collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint=config.get('WEAVIATE_ENDPOINT'),
|
||||
api_key=config.get('WEAVIATE_API_KEY'),
|
||||
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
|
||||
),
|
||||
attributes=self._attributes
|
||||
)
|
||||
elif vector_type == "qdrant":
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
|
||||
if self._dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
if self._dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = self._dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
return QdrantVector(
|
||||
collection_name=collection_name,
|
||||
group_id=self._dataset.id,
|
||||
config=QdrantConfig(
|
||||
endpoint=config.get('QDRANT_URL'),
|
||||
api_key=config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path,
|
||||
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
|
||||
)
|
||||
)
|
||||
elif vector_type == "milvus":
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
|
||||
if self._dataset.index_struct_dict:
|
||||
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = self._dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
return MilvusVector(
|
||||
collection_name=collection_name,
|
||||
config=MilvusConfig(
|
||||
host=config.get('MILVUS_HOST'),
|
||||
port=config.get('MILVUS_PORT'),
|
||||
user=config.get('MILVUS_USER'),
|
||||
password=config.get('MILVUS_PASSWORD'),
|
||||
secure=config.get('MILVUS_SECURE'),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
def create(self, texts: list = None, **kwargs):
|
||||
if texts:
|
||||
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
|
||||
self._vector_processor.create(
|
||||
texts=texts,
|
||||
embeddings=embeddings,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def add_texts(self, documents: list[Document], **kwargs):
|
||||
if kwargs.get('duplicate_check', False):
|
||||
documents = self._filter_duplicate_texts(documents)
|
||||
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
|
||||
self._vector_processor.add_texts(
|
||||
documents=documents,
|
||||
embeddings=embeddings,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self._vector_processor.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._vector_processor.delete_by_ids(ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._vector_processor.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
query_vector = self._embeddings.embed_query(query)
|
||||
return self._vector_processor.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
def search_by_full_text(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
return self._vector_processor.search_by_full_text(query, **kwargs)
|
||||
|
||||
def delete(self) -> None:
|
||||
self._vector_processor.delete()
|
||||
|
||||
def _get_embeddings(self) -> Embeddings:
|
||||
model_manager = ModelManager()
|
||||
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=self._dataset.embedding_model
|
||||
|
||||
)
|
||||
return CacheEmbedding(embedding_model)
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts:
|
||||
doc_id = text.metadata['doc_id']
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
|
||||
return texts
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self._vector_processor is not None:
|
||||
method = getattr(self._vector_processor, name)
|
||||
if callable(method):
|
||||
return method
|
||||
|
||||
raise AttributeError(f"'vector_processor' object has no attribute '{name}'")
|
||||
0
api/core/rag/datasource/vdb/weaviate/__init__.py
Normal file
0
api/core/rag/datasource/vdb/weaviate/__init__.py
Normal file
235
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
Normal file
235
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
import weaviate
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||
return values
|
||||
|
||||
|
||||
class WeaviateVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
|
||||
super().__init__(collection_name)
|
||||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
|
||||
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
|
||||
|
||||
weaviate.connect.connection.has_grpc = False
|
||||
|
||||
try:
|
||||
client = weaviate.Client(
|
||||
url=config.endpoint,
|
||||
auth_client_secret=auth_config,
|
||||
timeout_config=(5, 60),
|
||||
startup_period=None
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise ConnectionError("Vector database connection error")
|
||||
|
||||
client.batch.configure(
|
||||
# `batch_size` takes an `int` value to enable auto-batching
|
||||
# (`None` is used for manual batching)
|
||||
batch_size=config.batch_size,
|
||||
# dynamically update the `batch_size` based on import speed
|
||||
dynamic=True,
|
||||
# `timeout_retries` takes an `int` value to retry on time outs
|
||||
timeout_retries=3,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'weaviate'
|
||||
|
||||
def get_collection_name(self, dataset: Dataset) -> str:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
|
||||
schema = self._default_schema(self._collection_name)
|
||||
|
||||
# check whether the index already exists
|
||||
if not self._client.schema.contains(schema):
|
||||
# create collection
|
||||
self._client.schema.create_class(schema)
|
||||
# create vector
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
ids = []
|
||||
|
||||
with self._client.batch as batch:
|
||||
for i, text in enumerate(texts):
|
||||
data_properties = {Field.TEXT_KEY.value: text}
|
||||
if metadatas is not None:
|
||||
for key, val in metadatas[i].items():
|
||||
data_properties[key] = self._json_serializable(val)
|
||||
|
||||
batch.add_data_object(
|
||||
data_object=data_properties,
|
||||
class_name=self._collection_name,
|
||||
uuid=uuids[i],
|
||||
vector=embeddings[i] if embeddings else None,
|
||||
)
|
||||
ids.append(uuids[i])
|
||||
return ids
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
where_filter = {
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
}
|
||||
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._collection_name,
|
||||
where=where_filter,
|
||||
output='minimal'
|
||||
)
|
||||
|
||||
def delete(self):
|
||||
self._client.schema.delete_class(self._collection_name)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
collection_name = self._collection_name
|
||||
result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueText": id,
|
||||
}).with_limit(1).do()
|
||||
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
|
||||
entries = result["data"]["Get"][collection_name]
|
||||
if len(entries) == 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._client.data_object.delete(
|
||||
ids,
|
||||
class_name=self._collection_name
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Look up similar documents by embedding vector in Weaviate."""
|
||||
collection_name = self._collection_name
|
||||
properties = self._attributes
|
||||
properties.append(Field.TEXT_KEY.value)
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
|
||||
vector = {"vector": query_vector}
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
result = (
|
||||
query_obj.with_near_vector(vector)
|
||||
.with_limit(kwargs.get("top_k", 4))
|
||||
.with_additional(["vector", "distance"])
|
||||
.do()
|
||||
)
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
|
||||
docs_and_scores = []
|
||||
for res in result["data"]["Get"][collection_name]:
|
||||
text = res.pop(Field.TEXT_KEY.value)
|
||||
score = 1 - res["_additional"]["distance"]
|
||||
docs_and_scores.append((Document(page_content=text, metadata=res), score))
|
||||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
# check score threshold
|
||||
if score > score_threshold:
|
||||
doc.metadata['score'] = score
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Return docs using BM25F.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
collection_name = self._collection_name
|
||||
content: dict[str, Any] = {"concepts": [query]}
|
||||
properties = self._attributes
|
||||
properties.append(Field.TEXT_KEY.value)
|
||||
if kwargs.get("search_distance"):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("additional"):
|
||||
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||
properties = ['text']
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][collection_name]:
|
||||
text = res.pop(Field.TEXT_KEY.value)
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
return docs
|
||||
|
||||
def _default_schema(self, index_name: str) -> dict:
|
||||
return {
|
||||
"class": index_name,
|
||||
"properties": [
|
||||
{
|
||||
"name": "text",
|
||||
"dataType": ["text"],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _json_serializable(self, value: Any) -> Any:
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value.isoformat()
|
||||
return value
|
||||
166
api/core/rag/extractor/blod/blod.py
Normal file
166
api/core/rag/extractor/blod/blod.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Schema for Blobs and Blob Loaders.
|
||||
|
||||
The goal is to facilitate decoupling of content loading from content parsing code.
|
||||
|
||||
In addition, content loading code should provide a lazy loading interface by default.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import mimetypes
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Iterable, Mapping
|
||||
from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
PathLike = Union[str, PurePath]
|
||||
|
||||
|
||||
class Blob(BaseModel):
|
||||
"""A blob is used to represent raw data by either reference or value.
|
||||
|
||||
Provides an interface to materialize the blob in different representations, and
|
||||
help to decouple the development of data loaders from the downstream parsing of
|
||||
the raw data.
|
||||
|
||||
Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
|
||||
"""
|
||||
|
||||
data: Union[bytes, str, None] # Raw data
|
||||
mimetype: Optional[str] = None # Not to be confused with a file extension
|
||||
encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string
|
||||
# Location where the original content was found
|
||||
# Represent location on the local file system
|
||||
# Useful for situations where downstream code assumes it must work with file paths
|
||||
# rather than in-memory content.
|
||||
path: Optional[PathLike] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
frozen = True
|
||||
|
||||
@property
|
||||
def source(self) -> Optional[str]:
|
||||
"""The source location of the blob as string if known otherwise none."""
|
||||
return str(self.path) if self.path else None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Verify that either data or path is provided."""
|
||||
if "data" not in values and "path" not in values:
|
||||
raise ValueError("Either data or path must be provided")
|
||||
return values
|
||||
|
||||
def as_string(self) -> str:
|
||||
"""Read data as a string."""
|
||||
if self.data is None and self.path:
|
||||
with open(str(self.path), encoding=self.encoding) as f:
|
||||
return f.read()
|
||||
elif isinstance(self.data, bytes):
|
||||
return self.data.decode(self.encoding)
|
||||
elif isinstance(self.data, str):
|
||||
return self.data
|
||||
else:
|
||||
raise ValueError(f"Unable to get string for blob {self}")
|
||||
|
||||
def as_bytes(self) -> bytes:
|
||||
"""Read data as bytes."""
|
||||
if isinstance(self.data, bytes):
|
||||
return self.data
|
||||
elif isinstance(self.data, str):
|
||||
return self.data.encode(self.encoding)
|
||||
elif self.data is None and self.path:
|
||||
with open(str(self.path), "rb") as f:
|
||||
return f.read()
|
||||
else:
|
||||
raise ValueError(f"Unable to get bytes for blob {self}")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]:
|
||||
"""Read data as a byte stream."""
|
||||
if isinstance(self.data, bytes):
|
||||
yield BytesIO(self.data)
|
||||
elif self.data is None and self.path:
|
||||
with open(str(self.path), "rb") as f:
|
||||
yield f
|
||||
else:
|
||||
raise NotImplementedError(f"Unable to convert blob {self}")
|
||||
|
||||
@classmethod
|
||||
def from_path(
|
||||
cls,
|
||||
path: PathLike,
|
||||
*,
|
||||
encoding: str = "utf-8",
|
||||
mime_type: Optional[str] = None,
|
||||
guess_type: bool = True,
|
||||
) -> Blob:
|
||||
"""Load the blob from a path like object.
|
||||
|
||||
Args:
|
||||
path: path like object to file to be read
|
||||
encoding: Encoding to use if decoding the bytes into a string
|
||||
mime_type: if provided, will be set as the mime-type of the data
|
||||
guess_type: If True, the mimetype will be guessed from the file extension,
|
||||
if a mime-type was not provided
|
||||
|
||||
Returns:
|
||||
Blob instance
|
||||
"""
|
||||
if mime_type is None and guess_type:
|
||||
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None
|
||||
else:
|
||||
_mimetype = mime_type
|
||||
# We do not load the data immediately, instead we treat the blob as a
|
||||
# reference to the underlying data.
|
||||
return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path)
|
||||
|
||||
@classmethod
|
||||
def from_data(
|
||||
cls,
|
||||
data: Union[str, bytes],
|
||||
*,
|
||||
encoding: str = "utf-8",
|
||||
mime_type: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
) -> Blob:
|
||||
"""Initialize the blob from in-memory data.
|
||||
|
||||
Args:
|
||||
data: the in-memory data associated with the blob
|
||||
encoding: Encoding to use if decoding the bytes into a string
|
||||
mime_type: if provided, will be set as the mime-type of the data
|
||||
path: if provided, will be set as the source from which the data came
|
||||
|
||||
Returns:
|
||||
Blob instance
|
||||
"""
|
||||
return cls(data=data, mimetype=mime_type, encoding=encoding, path=path)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Define the blob representation."""
|
||||
str_repr = f"Blob {id(self)}"
|
||||
if self.source:
|
||||
str_repr += f" {self.source}"
|
||||
return str_repr
|
||||
|
||||
|
||||
class BlobLoader(ABC):
|
||||
"""Abstract interface for blob loaders implementation.
|
||||
|
||||
Implementer should be able to load raw content from a datasource system according
|
||||
to some criteria and return the raw content lazily as a stream of blobs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def yield_blobs(
|
||||
self,
|
||||
) -> Iterable[Blob]:
|
||||
"""A lazy loader for raw data represented by LangChain's Blob object.
|
||||
|
||||
Returns:
|
||||
A generator over blobs
|
||||
"""
|
||||
71
api/core/rag/extractor/csv_extractor.py
Normal file
71
api/core/rag/extractor/csv_extractor.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
import csv
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class CSVExtractor(BaseExtractor):
|
||||
"""Load CSV files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[dict] = None,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
self.source_column = source_column
|
||||
self.csv_args = csv_args or {}
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load data into document objects."""
|
||||
try:
|
||||
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_filze_encodings(self._file_path)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
|
||||
return docs
|
||||
|
||||
def _read_from_file(self, csvfile) -> list[Document]:
|
||||
docs = []
|
||||
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
for i, row in enumerate(csv_reader):
|
||||
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
|
||||
try:
|
||||
source = (
|
||||
row[self.source_column]
|
||||
if self.source_column is not None
|
||||
else ''
|
||||
)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Source column '{self.source_column}' not found in CSV file."
|
||||
)
|
||||
metadata = {"source": source, "row": i}
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
6
api/core/rag/extractor/entity/datasource_type.py
Normal file
6
api/core/rag/extractor/entity/datasource_type.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DatasourceType(Enum):
|
||||
FILE = "upload_file"
|
||||
NOTION = "notion_import"
|
||||
36
api/core/rag/extractor/entity/extract_setting.py
Normal file
36
api/core/rag/extractor/entity/extract_setting.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.dataset import Document
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
class NotionInfo(BaseModel):
|
||||
"""
|
||||
Notion import info.
|
||||
"""
|
||||
notion_workspace_id: str
|
||||
notion_obj_id: str
|
||||
notion_page_type: str
|
||||
document: Document = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class ExtractSetting(BaseModel):
|
||||
"""
|
||||
Model class for provider response.
|
||||
"""
|
||||
datasource_type: str
|
||||
upload_file: UploadFile = None
|
||||
notion_info: NotionInfo = None
|
||||
document_model: str = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
50
api/core/rag/extractor/excel_extractor.py
Normal file
50
api/core/rag/extractor/excel_extractor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from typing import Optional
|
||||
|
||||
from openpyxl.reader.excel import load_workbook
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class ExcelExtractor(BaseExtractor):
|
||||
"""Load Excel files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load from file path."""
|
||||
data = []
|
||||
keys = []
|
||||
wb = load_workbook(filename=self._file_path, read_only=True)
|
||||
# loop over all sheets
|
||||
for sheet in wb:
|
||||
if 'A1:A1' == sheet.calculate_dimension():
|
||||
sheet.reset_dimensions()
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
if all(v is None for v in row):
|
||||
continue
|
||||
if keys == []:
|
||||
keys = list(map(str, row))
|
||||
else:
|
||||
row_dict = dict(zip(keys, list(map(str, row))))
|
||||
row_dict = {k: v for k, v in row_dict.items() if v}
|
||||
item = ''.join(f'{k}:{v};' for k, v in row_dict.items())
|
||||
document = Document(page_content=item, metadata={'source': self._file_path})
|
||||
data.append(document)
|
||||
|
||||
return data
|
||||
139
api/core/rag/extractor/extract_processor.py
Normal file
139
api/core/rag/extractor/extract_processor.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.extractor.csv_extractor import CSVExtractor
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.excel_extractor import ExcelExtractor
|
||||
from core.rag.extractor.html_extractor import HtmlExtractor
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from core.rag.extractor.pdf_extractor import PdfExtractor
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
|
||||
from core.rag.extractor.word_extractor import WordExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
|
||||
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
|
||||
|
||||
class ExtractProcessor:
|
||||
@classmethod
|
||||
def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \
|
||||
-> Union[list[Document], str]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=upload_file,
|
||||
document_model='text_model'
|
||||
)
|
||||
if return_text:
|
||||
delimiter = '\n'
|
||||
return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)])
|
||||
else:
|
||||
return cls.extract(extract_setting, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
|
||||
response = requests.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(url).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
document_model='text_model'
|
||||
)
|
||||
if return_text:
|
||||
delimiter = '\n'
|
||||
return delimiter.join([document.page_content for document in cls.extract(
|
||||
extract_setting=extract_setting, file_path=file_path)])
|
||||
else:
|
||||
return cls.extract(extract_setting=extract_setting, file_path=file_path)
|
||||
|
||||
@classmethod
|
||||
def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False,
|
||||
file_path: str = None) -> list[Document]:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE.value:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
if not file_path:
|
||||
upload_file: UploadFile = extract_setting.upload_file
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
storage.download(upload_file.key, file_path)
|
||||
input_file = Path(file_path)
|
||||
file_extension = input_file.suffix.lower()
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
|
||||
if etl_type == 'Unstructured':
|
||||
if file_extension == '.xlsx':
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
extractor = PdfExtractor(file_path)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \
|
||||
else MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension in ['.docx']:
|
||||
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.csv':
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension == '.msg':
|
||||
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.eml':
|
||||
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.ppt':
|
||||
extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.pptx':
|
||||
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.xml':
|
||||
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url)
|
||||
else:
|
||||
# txt
|
||||
extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \
|
||||
else TextExtractor(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
if file_extension == '.xlsx':
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
extractor = PdfExtractor(file_path)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension in ['.docx']:
|
||||
extractor = WordExtractor(file_path)
|
||||
elif file_extension == '.csv':
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
extractor = TextExtractor(file_path, autodetect_encoding=True)
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.NOTION.value:
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
|
||||
notion_obj_id=extract_setting.notion_info.notion_obj_id,
|
||||
notion_page_type=extract_setting.notion_info.notion_page_type,
|
||||
document_model=extract_setting.notion_info.document
|
||||
)
|
||||
return extractor.extract()
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")
|
||||
12
api/core/rag/extractor/extractor_base.py
Normal file
12
api/core/rag/extractor/extractor_base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseExtractor(ABC):
|
||||
"""Interface for extract files.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def extract(self):
|
||||
raise NotImplementedError
|
||||
|
||||
46
api/core/rag/extractor/helpers.py
Normal file
46
api/core/rag/extractor/helpers.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Document loader helpers."""
|
||||
|
||||
import concurrent.futures
|
||||
from typing import NamedTuple, Optional, cast
|
||||
|
||||
|
||||
class FileEncoding(NamedTuple):
|
||||
"""A file encoding as the NamedTuple."""
|
||||
|
||||
encoding: Optional[str]
|
||||
"""The encoding of the file."""
|
||||
confidence: float
|
||||
"""The confidence of the encoding."""
|
||||
language: Optional[str]
|
||||
"""The language of the file."""
|
||||
|
||||
|
||||
def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]:
|
||||
"""Try to detect the file encoding.
|
||||
|
||||
Returns a list of `FileEncoding` tuples with the detected encodings ordered
|
||||
by confidence.
|
||||
|
||||
Args:
|
||||
file_path: The path to the file to detect the encoding for.
|
||||
timeout: The timeout in seconds for the encoding detection.
|
||||
"""
|
||||
import chardet
|
||||
|
||||
def read_and_detect(file_path: str) -> list[dict]:
|
||||
with open(file_path, "rb") as f:
|
||||
rawdata = f.read()
|
||||
return cast(list[dict], chardet.detect_all(rawdata))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(read_and_detect, file_path)
|
||||
try:
|
||||
encodings = future.result(timeout=timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Timeout reached while detecting encoding for {file_path}"
|
||||
)
|
||||
|
||||
if all(encoding["encoding"] is None for encoding in encodings):
|
||||
raise RuntimeError(f"Could not detect encoding for {file_path}")
|
||||
return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None]
|
||||
71
api/core/rag/extractor/html_extractor.py
Normal file
71
api/core/rag/extractor/html_extractor.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class HtmlExtractor(BaseExtractor):
|
||||
"""Load html files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[dict] = None,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
self.source_column = source_column
|
||||
self.csv_args = csv_args or {}
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load data into document objects."""
|
||||
try:
|
||||
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self._file_path)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
|
||||
return docs
|
||||
|
||||
def _read_from_file(self, csvfile) -> list[Document]:
|
||||
docs = []
|
||||
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
for i, row in enumerate(csv_reader):
|
||||
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
|
||||
try:
|
||||
source = (
|
||||
row[self.source_column]
|
||||
if self.source_column is not None
|
||||
else ''
|
||||
)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Source column '{self.source_column}' not found in CSV file."
|
||||
)
|
||||
metadata = {"source": source, "row": i}
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
122
api/core/rag/extractor/markdown_extractor.py
Normal file
122
api/core/rag/extractor/markdown_extractor.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
import re
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class MarkdownExtractor(BaseExtractor):
|
||||
"""Load Markdown files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
remove_hyperlinks: bool = True,
|
||||
remove_images: bool = True,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._remove_hyperlinks = remove_hyperlinks
|
||||
self._remove_images = remove_images
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load from file path."""
|
||||
tups = self.parse_tups(self._file_path)
|
||||
documents = []
|
||||
for header, value in tups:
|
||||
value = value.strip()
|
||||
if header is None:
|
||||
documents.append(Document(page_content=value))
|
||||
else:
|
||||
documents.append(Document(page_content=f"\n\n{header}\n{value}"))
|
||||
|
||||
return documents
|
||||
|
||||
def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str]]:
|
||||
"""Convert a markdown file to a dictionary.
|
||||
|
||||
The keys are the headers and the values are the text under each header.
|
||||
|
||||
"""
|
||||
markdown_tups: list[tuple[Optional[str], str]] = []
|
||||
lines = markdown_text.split("\n")
|
||||
|
||||
current_header = None
|
||||
current_text = ""
|
||||
|
||||
for line in lines:
|
||||
header_match = re.match(r"^#+\s", line)
|
||||
if header_match:
|
||||
if current_header is not None:
|
||||
markdown_tups.append((current_header, current_text))
|
||||
|
||||
current_header = line
|
||||
current_text = ""
|
||||
else:
|
||||
current_text += line + "\n"
|
||||
markdown_tups.append((current_header, current_text))
|
||||
|
||||
if current_header is not None:
|
||||
# pass linting, assert keys are defined
|
||||
markdown_tups = [
|
||||
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
|
||||
for key, value in markdown_tups
|
||||
]
|
||||
else:
|
||||
markdown_tups = [
|
||||
(key, re.sub("\n", "", value)) for key, value in markdown_tups
|
||||
]
|
||||
|
||||
return markdown_tups
|
||||
|
||||
def remove_images(self, content: str) -> str:
|
||||
"""Get a dictionary of a markdown file from its path."""
|
||||
pattern = r"!{1}\[\[(.*)\]\]"
|
||||
content = re.sub(pattern, "", content)
|
||||
return content
|
||||
|
||||
def remove_hyperlinks(self, content: str) -> str:
|
||||
"""Get a dictionary of a markdown file from its path."""
|
||||
pattern = r"\[(.*?)\]\((.*?)\)"
|
||||
content = re.sub(pattern, r"\1", content)
|
||||
return content
|
||||
|
||||
def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]:
|
||||
"""Parse file into tuples."""
|
||||
content = ""
|
||||
try:
|
||||
with open(filepath, encoding=self._encoding) as f:
|
||||
content = f.read()
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(filepath)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(filepath, encoding=encoding.encoding) as f:
|
||||
content = f.read()
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {filepath}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error loading {filepath}") from e
|
||||
|
||||
if self._remove_hyperlinks:
|
||||
content = self.remove_hyperlinks(content)
|
||||
|
||||
if self._remove_images:
|
||||
content = self.remove_images(content)
|
||||
|
||||
return self.markdown_to_tups(content)
|
||||
366
api/core/rag/extractor/notion_extractor.py
Normal file
366
api/core/rag/extractor/notion_extractor.py
Normal file
@@ -0,0 +1,366 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document as DocumentModel
|
||||
from models.source import DataSourceBinding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
|
||||
SEARCH_URL = "https://api.notion.com/v1/search"
|
||||
|
||||
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
|
||||
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
|
||||
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
|
||||
|
||||
|
||||
class NotionExtractor(BaseExtractor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
notion_workspace_id: str,
|
||||
notion_obj_id: str,
|
||||
notion_page_type: str,
|
||||
document_model: Optional[DocumentModel] = None,
|
||||
notion_access_token: Optional[str] = None
|
||||
):
|
||||
self._notion_access_token = None
|
||||
self._document_model = document_model
|
||||
self._notion_workspace_id = notion_workspace_id
|
||||
self._notion_obj_id = notion_obj_id
|
||||
self._notion_page_type = notion_page_type
|
||||
if notion_access_token:
|
||||
self._notion_access_token = notion_access_token
|
||||
else:
|
||||
self._notion_access_token = self._get_access_token(current_user.current_tenant_id,
|
||||
self._notion_workspace_id)
|
||||
if not self._notion_access_token:
|
||||
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
|
||||
if integration_token is None:
|
||||
raise ValueError(
|
||||
"Must specify `integration_token` or set environment "
|
||||
"variable `NOTION_INTEGRATION_TOKEN`."
|
||||
)
|
||||
|
||||
self._notion_access_token = integration_token
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
self.update_last_edited_time(
|
||||
self._document_model
|
||||
)
|
||||
|
||||
text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
|
||||
|
||||
return text_docs
|
||||
|
||||
def _load_data_as_documents(
|
||||
self, notion_obj_id: str, notion_page_type: str
|
||||
) -> list[Document]:
|
||||
docs = []
|
||||
if notion_page_type == 'database':
|
||||
# get all the pages in the database
|
||||
page_text_documents = self._get_notion_database_data(notion_obj_id)
|
||||
docs.extend(page_text_documents)
|
||||
elif notion_page_type == 'page':
|
||||
page_text_list = self._get_notion_block_data(notion_obj_id)
|
||||
for page_text in page_text_list:
|
||||
docs.append(Document(page_content=page_text))
|
||||
else:
|
||||
raise ValueError("notion page type not supported")
|
||||
|
||||
return docs
|
||||
|
||||
def _get_notion_database_data(
|
||||
self, database_id: str, query_dict: dict[str, Any] = {}
|
||||
) -> list[Document]:
|
||||
"""Get all the pages from a Notion database."""
|
||||
res = requests.post(
|
||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||
headers={
|
||||
"Authorization": "Bearer " + self._notion_access_token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
json=query_dict,
|
||||
)
|
||||
|
||||
data = res.json()
|
||||
|
||||
database_content_list = []
|
||||
if 'results' not in data or data["results"] is None:
|
||||
return []
|
||||
for result in data["results"]:
|
||||
properties = result['properties']
|
||||
data = {}
|
||||
for property_name, property_value in properties.items():
|
||||
type = property_value['type']
|
||||
if type == 'multi_select':
|
||||
value = []
|
||||
multi_select_list = property_value[type]
|
||||
for multi_select in multi_select_list:
|
||||
value.append(multi_select['name'])
|
||||
elif type == 'rich_text' or type == 'title':
|
||||
if len(property_value[type]) > 0:
|
||||
value = property_value[type][0]['plain_text']
|
||||
else:
|
||||
value = ''
|
||||
elif type == 'select' or type == 'status':
|
||||
if property_value[type]:
|
||||
value = property_value[type]['name']
|
||||
else:
|
||||
value = ''
|
||||
else:
|
||||
value = property_value[type]
|
||||
data[property_name] = value
|
||||
row_dict = {k: v for k, v in data.items() if v}
|
||||
row_content = ''
|
||||
for key, value in row_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value_dict = {k: v for k, v in value.items() if v}
|
||||
value_content = ''.join(f'{k}:{v} ' for k, v in value_dict.items())
|
||||
row_content = row_content + f'{key}:{value_content}\n'
|
||||
else:
|
||||
row_content = row_content + f'{key}:{value}\n'
|
||||
document = Document(page_content=row_content)
|
||||
database_content_list.append(document)
|
||||
|
||||
return database_content_list
|
||||
|
||||
def _get_notion_block_data(self, page_id: str) -> list[str]:
|
||||
result_lines_arr = []
|
||||
cur_block_id = page_id
|
||||
while True:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
block_url,
|
||||
headers={
|
||||
"Authorization": "Bearer " + self._notion_access_token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
json=query_dict
|
||||
)
|
||||
data = res.json()
|
||||
# current block's heading
|
||||
heading = ''
|
||||
for result in data["results"]:
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
cur_result_text_arr = []
|
||||
if result_type == 'table':
|
||||
result_block_id = result["id"]
|
||||
text = self._read_table_rows(result_block_id)
|
||||
text += "\n\n"
|
||||
result_lines_arr.append(text)
|
||||
else:
|
||||
if "rich_text" in result_obj:
|
||||
for rich_text in result_obj["rich_text"]:
|
||||
# skip if doesn't have text object
|
||||
if "text" in rich_text:
|
||||
text = rich_text["text"]["content"]
|
||||
cur_result_text_arr.append(text)
|
||||
if result_type in HEADING_TYPE:
|
||||
heading = text
|
||||
|
||||
result_block_id = result["id"]
|
||||
has_children = result["has_children"]
|
||||
block_type = result["type"]
|
||||
if has_children and block_type != 'child_page':
|
||||
children_text = self._read_block(
|
||||
result_block_id, num_tabs=1
|
||||
)
|
||||
cur_result_text_arr.append(children_text)
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
cur_result_text += "\n\n"
|
||||
if result_type in HEADING_TYPE:
|
||||
result_lines_arr.append(cur_result_text)
|
||||
else:
|
||||
result_lines_arr.append(f'{heading}\n{cur_result_text}')
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
else:
|
||||
cur_block_id = data["next_cursor"]
|
||||
return result_lines_arr
|
||||
|
||||
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
|
||||
"""Read a block."""
|
||||
result_lines_arr = []
|
||||
cur_block_id = block_id
|
||||
while True:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
block_url,
|
||||
headers={
|
||||
"Authorization": "Bearer " + self._notion_access_token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
json=query_dict
|
||||
)
|
||||
data = res.json()
|
||||
if 'results' not in data or data["results"] is None:
|
||||
break
|
||||
heading = ''
|
||||
for result in data["results"]:
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
cur_result_text_arr = []
|
||||
if result_type == 'table':
|
||||
result_block_id = result["id"]
|
||||
text = self._read_table_rows(result_block_id)
|
||||
result_lines_arr.append(text)
|
||||
else:
|
||||
if "rich_text" in result_obj:
|
||||
for rich_text in result_obj["rich_text"]:
|
||||
# skip if doesn't have text object
|
||||
if "text" in rich_text:
|
||||
text = rich_text["text"]["content"]
|
||||
prefix = "\t" * num_tabs
|
||||
cur_result_text_arr.append(prefix + text)
|
||||
if result_type in HEADING_TYPE:
|
||||
heading = text
|
||||
result_block_id = result["id"]
|
||||
has_children = result["has_children"]
|
||||
block_type = result["type"]
|
||||
if has_children and block_type != 'child_page':
|
||||
children_text = self._read_block(
|
||||
result_block_id, num_tabs=num_tabs + 1
|
||||
)
|
||||
cur_result_text_arr.append(children_text)
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
if result_type in HEADING_TYPE:
|
||||
result_lines_arr.append(cur_result_text)
|
||||
else:
|
||||
result_lines_arr.append(f'{heading}\n{cur_result_text}')
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
else:
|
||||
cur_block_id = data["next_cursor"]
|
||||
|
||||
result_lines = "\n".join(result_lines_arr)
|
||||
return result_lines
|
||||
|
||||
def _read_table_rows(self, block_id: str) -> str:
|
||||
"""Read table rows."""
|
||||
done = False
|
||||
result_lines_arr = []
|
||||
cur_block_id = block_id
|
||||
while not done:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
block_url,
|
||||
headers={
|
||||
"Authorization": "Bearer " + self._notion_access_token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
json=query_dict
|
||||
)
|
||||
data = res.json()
|
||||
# get table headers text
|
||||
table_header_cell_texts = []
|
||||
tabel_header_cells = data["results"][0]['table_row']['cells']
|
||||
for tabel_header_cell in tabel_header_cells:
|
||||
if tabel_header_cell:
|
||||
for table_header_cell_text in tabel_header_cell:
|
||||
text = table_header_cell_text["text"]["content"]
|
||||
table_header_cell_texts.append(text)
|
||||
# get table columns text and format
|
||||
results = data["results"]
|
||||
for i in range(len(results) - 1):
|
||||
column_texts = []
|
||||
tabel_column_cells = data["results"][i + 1]['table_row']['cells']
|
||||
for j in range(len(tabel_column_cells)):
|
||||
if tabel_column_cells[j]:
|
||||
for table_column_cell_text in tabel_column_cells[j]:
|
||||
column_text = table_column_cell_text["text"]["content"]
|
||||
column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
|
||||
|
||||
cur_result_text = "\n".join(column_texts)
|
||||
result_lines_arr.append(cur_result_text)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
done = True
|
||||
break
|
||||
else:
|
||||
cur_block_id = data["next_cursor"]
|
||||
|
||||
result_lines = "\n".join(result_lines_arr)
|
||||
return result_lines
|
||||
|
||||
def update_last_edited_time(self, document_model: DocumentModel):
|
||||
if not document_model:
|
||||
return
|
||||
|
||||
last_edited_time = self.get_notion_last_edited_time()
|
||||
data_source_info = document_model.data_source_info_dict
|
||||
data_source_info['last_edited_time'] = last_edited_time
|
||||
update_params = {
|
||||
DocumentModel.data_source_info: json.dumps(data_source_info)
|
||||
}
|
||||
|
||||
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
|
||||
db.session.commit()
|
||||
|
||||
def get_notion_last_edited_time(self) -> str:
|
||||
obj_id = self._notion_obj_id
|
||||
page_type = self._notion_page_type
|
||||
if page_type == 'database':
|
||||
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
|
||||
else:
|
||||
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
|
||||
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
retrieve_page_url,
|
||||
headers={
|
||||
"Authorization": "Bearer " + self._notion_access_token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
json=query_dict
|
||||
)
|
||||
|
||||
data = res.json()
|
||||
return data["last_edited_time"]
|
||||
|
||||
@classmethod
|
||||
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
|
||||
if not data_source_binding:
|
||||
raise Exception(f'No notion data source binding found for tenant {tenant_id} '
|
||||
f'and notion workspace {notion_workspace_id}')
|
||||
|
||||
return data_source_binding.access_token
|
||||
72
api/core/rag/extractor/pdf_extractor.py
Normal file
72
api/core/rag/extractor/pdf_extractor.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.blod.blod import Blob
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class PdfExtractor(BaseExtractor):
|
||||
"""Load pdf files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
file_cache_key: Optional[str] = None
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._file_cache_key = file_cache_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
plaintext_file_key = ''
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
try:
|
||||
text = storage.load(self._file_cache_key).decode('utf-8')
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
documents = list(self.load())
|
||||
text_list = []
|
||||
for document in documents:
|
||||
text_list.append(document.page_content)
|
||||
text = "\n\n".join(text_list)
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode('utf-8'))
|
||||
|
||||
return documents
|
||||
|
||||
def load(
|
||||
self,
|
||||
) -> Iterator[Document]:
|
||||
"""Lazy load given path as pages."""
|
||||
blob = Blob.from_path(self._file_path)
|
||||
yield from self.parse(blob)
|
||||
|
||||
def parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazily parse the blob."""
|
||||
import pypdfium2
|
||||
|
||||
with blob.as_bytes_io() as file_path:
|
||||
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
||||
try:
|
||||
for page_number, page in enumerate(pdf_reader):
|
||||
text_page = page.get_textpage()
|
||||
content = text_page.get_text_range()
|
||||
text_page.close()
|
||||
page.close()
|
||||
metadata = {"source": blob.source, "page": page_number}
|
||||
yield Document(page_content=content, metadata=metadata)
|
||||
finally:
|
||||
pdf_reader.close()
|
||||
50
api/core/rag/extractor/text_extractor.py
Normal file
50
api/core/rag/extractor/text_extractor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class TextExtractor(BaseExtractor):
|
||||
"""Load text files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load from file path."""
|
||||
text = ""
|
||||
try:
|
||||
with open(self._file_path, encoding=self._encoding) as f:
|
||||
text = f.read()
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self._file_path)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(self._file_path, encoding=encoding.encoding) as f:
|
||||
text = f.read()
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
|
||||
metadata = {"source": self._file_path}
|
||||
return [Document(page_content=text, metadata=metadata)]
|
||||
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredWordExtractor(BaseExtractor):
|
||||
"""Loader that uses unstructured to load word documents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.__version__ import __version__ as __unstructured_version__
|
||||
from unstructured.file_utils.filetype import FileType, detect_filetype
|
||||
|
||||
unstructured_version = tuple(
|
||||
[int(x) for x in __unstructured_version__.split(".")]
|
||||
)
|
||||
# check the file extension
|
||||
try:
|
||||
import magic # noqa: F401
|
||||
|
||||
is_doc = detect_filetype(self._file_path) == FileType.DOC
|
||||
except ImportError:
|
||||
_, extension = os.path.splitext(str(self._file_path))
|
||||
is_doc = extension == ".doc"
|
||||
|
||||
if is_doc and unstructured_version < (0, 4, 11):
|
||||
raise ValueError(
|
||||
f"You are on unstructured version {__unstructured_version__}. "
|
||||
"Partitioning .doc files is only supported in unstructured>=0.4.11. "
|
||||
"Please upgrade the unstructured package and try again."
|
||||
)
|
||||
|
||||
if is_doc:
|
||||
from unstructured.partition.doc import partition_doc
|
||||
|
||||
elements = partition_doc(filename=self._file_path)
|
||||
else:
|
||||
from unstructured.partition.docx import partition_docx
|
||||
|
||||
elements = partition_docx(filename=self._file_path)
|
||||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
return documents
|
||||
@@ -0,0 +1,51 @@
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredEmailExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.email import partition_email
|
||||
elements = partition_email(filename=self._file_path, api_url=self._api_url)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
for element in elements:
|
||||
element_text = element.text.strip()
|
||||
|
||||
padding_needed = 4 - len(element_text) % 4
|
||||
element_text += '=' * padding_needed
|
||||
|
||||
element_decode = base64.b64decode(element_text)
|
||||
soup = BeautifulSoup(element_decode.decode('utf-8'), 'html.parser')
|
||||
element.text = soup.get_text()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
return documents
|
||||
@@ -0,0 +1,47 @@
|
||||
import logging
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredMarkdownExtractor(BaseExtractor):
|
||||
"""Load md files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
|
||||
remove_hyperlinks: Whether to remove hyperlinks from the text.
|
||||
|
||||
remove_images: Whether to remove images from the text.
|
||||
|
||||
encoding: File encoding to use. If `None`, the file will be loaded
|
||||
with the default system encoding.
|
||||
|
||||
autodetect_encoding: Whether to try to autodetect the file encoding
|
||||
if the specified encoding fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.md import partition_md
|
||||
|
||||
elements = partition_md(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@@ -0,0 +1,37 @@
|
||||
import logging
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredMsgExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
elements = partition_msg(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
|
||||
text_by_page = {}
|
||||
for element in elements:
|
||||
page = element.metadata.page_number
|
||||
text = element.text
|
||||
if page in text_by_page:
|
||||
text_by_page[page] += "\n" + text
|
||||
else:
|
||||
text_by_page[page] = text
|
||||
|
||||
combined_texts = list(text_by_page.values())
|
||||
documents = []
|
||||
for combined_text in combined_texts:
|
||||
text = combined_text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
return documents
|
||||
@@ -0,0 +1,45 @@
|
||||
import logging
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTXExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
|
||||
text_by_page = {}
|
||||
for element in elements:
|
||||
page = element.metadata.page_number
|
||||
text = element.text
|
||||
if page in text_by_page:
|
||||
text_by_page[page] += "\n" + text
|
||||
else:
|
||||
text_by_page[page] = text
|
||||
|
||||
combined_texts = list(text_by_page.values())
|
||||
documents = []
|
||||
for combined_text in combined_texts:
|
||||
text = combined_text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@@ -0,0 +1,37 @@
|
||||
import logging
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredTextExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.text import partition_text
|
||||
|
||||
elements = partition_text(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@@ -0,0 +1,37 @@
|
||||
import logging
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredXmlExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.xml import partition_xml
|
||||
|
||||
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
62
api/core/rag/extractor/word_extractor.py
Normal file
62
api/core/rag/extractor/word_extractor.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
import os
|
||||
import tempfile
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class WordExtractor(BaseExtractor):
|
||||
"""Load pdf files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
"""Initialize with file path."""
|
||||
self.file_path = file_path
|
||||
if "~" in self.file_path:
|
||||
self.file_path = os.path.expanduser(self.file_path)
|
||||
|
||||
# If the file is a web path, download it to a temporary file, and use that
|
||||
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
|
||||
r = requests.get(self.file_path)
|
||||
|
||||
if r.status_code != 200:
|
||||
raise ValueError(
|
||||
"Check the url of your file; returned status code %s"
|
||||
% r.status_code
|
||||
)
|
||||
|
||||
self.web_path = self.file_path
|
||||
self.temp_file = tempfile.NamedTemporaryFile()
|
||||
self.temp_file.write(r.content)
|
||||
self.file_path = self.temp_file.name
|
||||
elif not os.path.isfile(self.file_path):
|
||||
raise ValueError("File path %s is not a valid file or url" % self.file_path)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, "temp_file"):
|
||||
self.temp_file.close()
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load given path as single page."""
|
||||
import docx2txt
|
||||
|
||||
return [
|
||||
Document(
|
||||
page_content=docx2txt.process(self.file_path),
|
||||
metadata={"source": self.file_path},
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_url(url: str) -> bool:
|
||||
"""Check if the url is valid."""
|
||||
parsed = urlparse(url)
|
||||
return bool(parsed.netloc) and bool(parsed.scheme)
|
||||
0
api/core/rag/index_processor/__init__.py
Normal file
0
api/core/rag/index_processor/__init__.py
Normal file
0
api/core/rag/index_processor/constant/__init__.py
Normal file
0
api/core/rag/index_processor/constant/__init__.py
Normal file
8
api/core/rag/index_processor/constant/index_type.py
Normal file
8
api/core/rag/index_processor/constant/index_type.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class IndexType(Enum):
|
||||
PARAGRAPH_INDEX = "text_model"
|
||||
QA_INDEX = "qa_model"
|
||||
PARENT_CHILD_INDEX = "parent_child_index"
|
||||
SUMMARY_INDEX = "summary_index"
|
||||
70
api/core/rag/index_processor/index_processor_base.py
Normal file
70
api/core/rag/index_processor/index_processor_base.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.models.document import Document
|
||||
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
|
||||
|
||||
class BaseIndexProcessor(ABC):
|
||||
"""Interface for extract files.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
|
||||
raise NotImplementedError
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
|
||||
score_threshold: float, reranking_model: dict) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_splitter(self, processing_rule: dict,
|
||||
embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
|
||||
"""
|
||||
Get the NodeParser object according to the processing rule.
|
||||
"""
|
||||
if processing_rule['mode'] == "custom":
|
||||
# The user-defined segmentation rule
|
||||
rules = processing_rule['rules']
|
||||
segmentation = rules["segmentation"]
|
||||
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
|
||||
raise ValueError("Custom segment length should be between 50 and 1000.")
|
||||
|
||||
separator = segmentation["separator"]
|
||||
if separator:
|
||||
separator = separator.replace('\\n', '\n')
|
||||
|
||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=segmentation["max_tokens"],
|
||||
chunk_overlap=0,
|
||||
fixed_separator=separator,
|
||||
separators=["\n\n", "。", ".", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance
|
||||
)
|
||||
else:
|
||||
# Automatic segmentation
|
||||
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
|
||||
chunk_overlap=0,
|
||||
separators=["\n\n", "。", ".", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance
|
||||
)
|
||||
|
||||
return character_splitter
|
||||
28
api/core/rag/index_processor/index_processor_factory.py
Normal file
28
api/core/rag/index_processor/index_processor_factory.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
|
||||
|
||||
|
||||
class IndexProcessorFactory:
|
||||
"""IndexProcessorInit.
|
||||
"""
|
||||
|
||||
def __init__(self, index_type: str):
|
||||
self._index_type = index_type
|
||||
|
||||
def init_index_processor(self) -> BaseIndexProcessor:
|
||||
"""Init index processor."""
|
||||
|
||||
if not self._index_type:
|
||||
raise ValueError("Index type must be specified.")
|
||||
|
||||
if self._index_type == IndexType.PARAGRAPH_INDEX.value:
|
||||
return ParagraphIndexProcessor()
|
||||
elif self._index_type == IndexType.QA_INDEX.value:
|
||||
|
||||
return QAIndexProcessor()
|
||||
else:
|
||||
raise ValueError(f"Index type {self._index_type} is not supported.")
|
||||
0
api/core/rag/index_processor/processor/__init__.py
Normal file
0
api/core/rag/index_processor/processor/__init__.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Paragraph index processor."""
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
|
||||
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
|
||||
is_automatic=kwargs.get('process_rule_mode') == "automatic")
|
||||
|
||||
return text_docs
|
||||
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
# Split the text documents into nodes.
|
||||
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
|
||||
embedding_model_instance=kwargs.get('embedding_model_instance'))
|
||||
all_documents = []
|
||||
for document in documents:
|
||||
# document clean
|
||||
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
|
||||
document.page_content = document_text
|
||||
# parse document to nodes
|
||||
document_nodes = splitter.split_documents([document])
|
||||
split_documents = []
|
||||
for document_node in document_nodes:
|
||||
|
||||
if document_node.page_content.strip():
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
# delete Spliter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:]
|
||||
else:
|
||||
page_content = page_content
|
||||
document_node.page_content = page_content
|
||||
split_documents.append(document_node)
|
||||
all_documents.extend(split_documents)
|
||||
return all_documents
|
||||
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if with_keywords:
|
||||
keyword = Keyword(dataset)
|
||||
keyword.create(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
else:
|
||||
vector.delete()
|
||||
if with_keywords:
|
||||
keyword = Keyword(dataset)
|
||||
if node_ids:
|
||||
keyword.delete_by_ids(node_ids)
|
||||
else:
|
||||
keyword.delete()
|
||||
|
||||
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
|
||||
score_threshold: float, reranking_model: dict) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata['score'] = result.score
|
||||
if result.score > score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
161
api/core/rag/index_processor/processor/qa_index_processor.py
Normal file
161
api/core/rag/index_processor/processor/qa_index_processor.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Paragraph index processor."""
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
from flask_login import current_user
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class QAIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
|
||||
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
|
||||
is_automatic=kwargs.get('process_rule_mode') == "automatic")
|
||||
return text_docs
|
||||
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
|
||||
embedding_model_instance=None)
|
||||
|
||||
# Split the text documents into nodes.
|
||||
all_documents = []
|
||||
all_qa_documents = []
|
||||
for document in documents:
|
||||
# document clean
|
||||
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
|
||||
document.page_content = document_text
|
||||
|
||||
# parse document to nodes
|
||||
document_nodes = splitter.split_documents([document])
|
||||
split_documents = []
|
||||
for document_node in document_nodes:
|
||||
|
||||
if document_node.page_content.strip():
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
# delete Spliter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:]
|
||||
else:
|
||||
page_content = page_content
|
||||
document_node.page_content = page_content
|
||||
split_documents.append(document_node)
|
||||
all_documents.extend(split_documents)
|
||||
for i in range(0, len(all_documents), 10):
|
||||
threads = []
|
||||
sub_documents = all_documents[i:i + 10]
|
||||
for doc in sub_documents:
|
||||
document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'tenant_id': current_user.current_tenant.id,
|
||||
'document_node': doc,
|
||||
'all_qa_documents': all_qa_documents,
|
||||
'document_language': kwargs.get('document_language', 'English')})
|
||||
threads.append(document_format_thread)
|
||||
document_format_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
return all_qa_documents
|
||||
|
||||
def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
|
||||
|
||||
# check file type
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
try:
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file)
|
||||
text_docs = []
|
||||
for index, row in df.iterrows():
|
||||
data = Document(page_content=row[0], metadata={'answer': row[1]})
|
||||
text_docs.append(data)
|
||||
if len(text_docs) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
return text_docs
|
||||
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
else:
|
||||
vector.delete()
|
||||
|
||||
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
|
||||
score_threshold: float, reranking_model: dict):
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata['score'] = result.score
|
||||
if result.score > score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
||||
format_documents = []
|
||||
if document_node.page_content is None or not document_node.page_content.strip():
|
||||
return
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
|
||||
document_qa_list = self._format_split_text(response)
|
||||
qa_documents = []
|
||||
for result in document_qa_list:
|
||||
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(result['question'])
|
||||
qa_document.metadata['answer'] = result['answer']
|
||||
qa_document.metadata['doc_id'] = doc_id
|
||||
qa_document.metadata['doc_hash'] = hash
|
||||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
all_qa_documents.extend(format_documents)
|
||||
|
||||
def _format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
|
||||
matches = re.findall(regex, text, re.UNICODE)
|
||||
|
||||
return [
|
||||
{
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
}
|
||||
for q, a in matches if q and a
|
||||
]
|
||||
0
api/core/rag/models/__init__.py
Normal file
0
api/core/rag/models/__init__.py
Normal file
16
api/core/rag/models/document.py
Normal file
16
api/core/rag/models/document.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
|
||||
page_content: str
|
||||
|
||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||
documents, etc.).
|
||||
"""
|
||||
metadata: Optional[dict] = Field(default_factory=dict)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user