Feat/dify rag (#2528)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2024-02-22 23:31:57 +08:00
committed by GitHub
parent 97fe817186
commit 6c4e6bf1d6
119 changed files with 3181 additions and 5892 deletions

0
api/core/rag/__init__.py Normal file
View File

View File

@@ -0,0 +1,38 @@
import re
class CleanProcessor:
@classmethod
def clean(cls, text: str, process_rule: dict) -> str:
# default clean
# remove invalid symbol
text = re.sub(r'<\|', '<', text)
text = re.sub(r'\|>', '>', text)
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
# Unicode U+FFFE
text = re.sub('\uFFFE', '', text)
rules = process_rule['rules'] if process_rule else None
if 'pre_processing_rules' in rules:
pre_processing_rules = rules["pre_processing_rules"]
for pre_processing_rule in pre_processing_rules:
if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
# Remove extra spaces
pattern = r'\n{3,}'
text = re.sub(pattern, '\n\n', text)
pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
text = re.sub(pattern, ' ', text)
elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
# Remove email
pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
text = re.sub(pattern, '', text)
# Remove URL
pattern = r'https?://[^\s]+'
text = re.sub(pattern, '', text)
return text
def filter_string(self, text):
return text

View File

@@ -0,0 +1,12 @@
"""Abstract interface for document cleaner implementations."""
from abc import ABC, abstractmethod
class BaseCleaner(ABC):
"""Interface for clean chunk content.
"""
@abstractmethod
def clean(self, content: str):
raise NotImplementedError

View File

@@ -0,0 +1,12 @@
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.core import clean_extra_whitespace
# Returns "ITEM 1A: RISK FACTORS"
return clean_extra_whitespace(content)

View File

@@ -0,0 +1,15 @@
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
import re
from unstructured.cleaners.core import group_broken_paragraphs
para_split_re = re.compile(r"(\s*\n\s*){3}")
return group_broken_paragraphs(content, paragraph_split=para_split_re)

View File

@@ -0,0 +1,12 @@
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.core import clean_non_ascii_chars
# Returns "This text containsnon-ascii characters!"
return clean_non_ascii_chars(content)

View File

@@ -0,0 +1,11 @@
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""Replaces unicode quote characters, such as the \x91 character in a string."""
from unstructured.cleaners.core import replace_unicode_quotes
return replace_unicode_quotes(content)

View File

@@ -0,0 +1,11 @@
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredTranslateTextCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.translate import translate_text
return translate_text(content)

View File

@@ -0,0 +1,49 @@
from typing import Optional
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.models.document import Document
from core.rerank.rerank import RerankRunner
class DataPostProcessor:
"""Interface for data post-processing document.
"""
def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False):
self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id)
self.reorder_runner = self._get_reorder_runner(reorder_enabled)
def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
if self.rerank_runner:
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
if self.reorder_runner:
documents = self.reorder_runner.run(documents)
return documents
def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]:
if reranking_model:
try:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return None
return RerankRunner(rerank_model_instance)
return None
def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
if reorder_enabled:
return ReorderRunner()
return None

View File

@@ -0,0 +1,19 @@
from langchain.schema import Document
class ReorderRunner:
def run(self, documents: list[Document]) -> list[Document]:
# Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list
odd_elements = documents[::2]
# Retrieve elements from even indices (1, 3, 5, etc.) of the documents list
even_elements = documents[1::2]
# Reverse the list of elements from even indices
even_elements_reversed = even_elements[::-1]
new_documents = odd_elements + even_elements_reversed
return new_documents

View File

View File

@@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
class Embeddings(ABC):
"""Interface for embedding models."""
@abstractmethod
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
@abstractmethod
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronous Embed search docs."""
raise NotImplementedError
async def aembed_query(self, text: str) -> list[float]:
"""Asynchronous Embed query text."""
raise NotImplementedError

View File

@@ -0,0 +1,232 @@
import json
from collections import defaultdict
from typing import Any, Optional
from pydantic import BaseModel
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
class KeywordTableConfig(BaseModel):
max_keywords_per_chunk: int = 10
class Jieba(BaseKeyword):
def __init__(self, dataset: Dataset):
super().__init__(dataset)
self._config = KeywordTableConfig()
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
self._save_dataset_keyword_table(keyword_table)
return self
def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get('keywords_list', None)
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
self._save_dataset_keyword_table(keyword_table)
def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]) -> None:
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
def delete_by_document_id(self, document_id: str):
# get segment ids by document_id
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.document_id == document_id
).all()
ids = [segment.index_node_id for segment in segments]
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
keyword_table = self._get_dataset_keyword_table()
k = kwargs.get('top_k', 4)
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
documents = []
for chunk_index in sorted_chunk_indices:
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.index_node_id == chunk_index
).first()
if segment:
documents.append(Document(
page_content=segment.content,
metadata={
"doc_id": chunk_index,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
))
return documents
def delete(self) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": keyword_table
}
}
self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
db.session.commit()
def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
if dataset_keyword_table.keyword_table_dict:
return dataset_keyword_table.keyword_table_dict['__data__']['table']
else:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
return {}
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
for keyword in keywords:
if keyword not in keyword_table:
keyword_table[keyword] = set()
keyword_table[keyword].add(id)
return keyword_table
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict:
# get set of ids that correspond to node
node_idxs_to_delete = set(ids)
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete = set()
for keyword, node_idxs in keyword_table.items():
if node_idxs_to_delete.intersection(node_idxs):
keyword_table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
if not keyword_table[keyword]:
keywords_to_delete.add(keyword)
for keyword in keywords_to_delete:
del keyword_table[keyword]
return keyword_table
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
keyword_table_handler = JiebaKeywordTableHandler()
keywords = keyword_table_handler.extract_keywords(query)
# go through text chunks in order of most matching keywords
chunk_indices_count: dict[str, int] = defaultdict(int)
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
for keyword in keywords:
for node_id in keyword_table[keyword]:
chunk_indices_count[node_id] += 1
sorted_chunk_indices = sorted(
list(chunk_indices_count.keys()),
key=lambda x: chunk_indices_count[x],
reverse=True,
)
return sorted_chunk_indices[: k]
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id == node_id
).first()
if document_segment:
document_segment.keywords = keywords
db.session.add(document_segment)
db.session.commit()
def create_segment_keywords(self, node_id: str, keywords: list[str]):
keyword_table = self._get_dataset_keyword_table()
self._update_segment_keywords(self.dataset.id, node_id, keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
def multi_create_segment_keywords(self, pre_segment_data_list: list):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for pre_segment_data in pre_segment_data_list:
segment = pre_segment_data['segment']
if pre_segment_data['keywords']:
segment.keywords = pre_segment_data['keywords']
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id,
pre_segment_data['keywords'])
else:
keywords = keyword_table_handler.extract_keywords(segment.content,
self._config.max_keywords_per_chunk)
segment.keywords = list(keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
self._save_dataset_keyword_table(keyword_table)
def update_segment_keywords_index(self, node_id: str, keywords: list[str]):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
class SetEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
return super().default(obj)

View File

@@ -0,0 +1,32 @@
import re
import jieba
from jieba.analyse import default_tfidf
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
)
return set(self._expand_tokens_with_subtokens(keywords))
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results = set()
for token in tokens:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results

View File

@@ -0,0 +1,90 @@
STOPWORDS = {
"during", "when", "but", "then", "further", "isn", "mustn't", "until", "own", "i", "couldn", "y", "only", "you've",
"ours", "who", "where", "ourselves", "has", "to", "was", "didn't", "themselves", "if", "against", "through", "her",
"an", "your", "can", "those", "didn", "about", "aren't", "shan't", "be", "not", "these", "again", "so", "t",
"theirs", "weren", "won't", "won", "itself", "just", "same", "while", "why", "doesn", "aren", "him", "haven",
"for", "you'll", "that", "we", "am", "d", "by", "having", "wasn't", "than", "weren't", "out", "from", "now",
"their", "too", "hadn", "o", "needn", "most", "it", "under", "needn't", "any", "some", "few", "ll", "hers", "which",
"m", "you're", "off", "other", "had", "she", "you'd", "do", "you", "does", "s", "will", "each", "wouldn't", "hasn't",
"such", "more", "whom", "she's", "my", "yours", "yourself", "of", "on", "very", "hadn't", "with", "yourselves",
"been", "ma", "them", "mightn't", "shan", "mustn", "they", "what", "both", "that'll", "how", "is", "he", "because",
"down", "haven't", "are", "no", "it's", "our", "being", "the", "or", "above", "myself", "once", "don't", "doesn't",
"as", "nor", "here", "herself", "hasn", "mightn", "have", "its", "all", "were", "ain", "this", "at", "after",
"over", "shouldn't", "into", "before", "don", "wouldn", "re", "couldn't", "wasn", "in", "should", "there",
"himself", "isn't", "should've", "doing", "ve", "shouldn", "a", "did", "and", "his", "between", "me", "up", "below",
"人民", "末##末", "", "", "", "哎呀", "哎哟", "", "", "俺们", "", "按照", "", "吧哒", "", "罢了", "", "",
"本着", "", "比方", "比如", "鄙人", "", "彼此", "", "", "别的", "别说", "", "并且", "不比", "不成", "不单", "不但",
"不独", "不管", "不光", "不过", "不仅", "不拘", "不论", "不怕", "不然", "不如", "不特", "不惟", "不问", "不只", "", "朝着",
"", "趁着", "", "", "", "除此之外", "除非", "除了", "", "此间", "此外", "", "从而", "", "", "", "但是", "",
"当着", "", "", "", "的话", "", "等等", "", "", "叮咚", "", "对于", "", "多少", "", "而况", "而且", "而是",
"而外", "而言", "而已", "尔后", "反过来", "反过来说", "反之", "非但", "非徒", "否则", "", "嘎登", "", "", "", "",
"各个", "各位", "各种", "各自", "", "根据", "", "", "故此", "固然", "关于", "", "", "果然", "果真", "", "",
"哈哈", "", "", "", "何处", "何况", "何时", "", "", "哼唷", "呼哧", "", "", "还是", "还有", "换句话说", "换言之",
"", "或是", "或者", "极了", "", "及其", "及至", "", "即便", "即或", "即令", "即若", "即使", "", "几时", "", "",
"既然", "既是", "继而", "加之", "假如", "假若", "假使", "鉴于", "", "", "较之", "", "接着", "结果", "", "紧接着",
"进而", "", "尽管", "", "经过", "", "就是", "就是说", "", "具体地说", "具体说来", "开始", "开外", "", "", "",
"可见", "可是", "可以", "况且", "", "", "来着", "", "例如", "", "", "连同", "两者", "", "", "", "另外",
"另一方面", "", "", "", "慢说", "漫说", "", "", "", "每当", "", "莫若", "", "某个", "某些", "", "",
"哪边", "哪儿", "哪个", "哪里", "哪年", "哪怕", "哪天", "哪些", "哪样", "", "那边", "那儿", "那个", "那会儿", "那里", "那么",
"那么些", "那么样", "那时", "那些", "那样", "", "乃至", "", "", "", "你们", "", "", "宁可", "宁肯", "宁愿", "",
"", "啪达", "旁人", "", "", "凭借", "", "其次", "其二", "其他", "其它", "其一", "其余", "其中", "", "起见", "岂但",
"恰恰相反", "前后", "前者", "", "然而", "然后", "然则", "", "人家", "", "任何", "任凭", "", "如此", "如果", "如何",
"如其", "如若", "如上所述", "", "若非", "若是", "", "上下", "尚且", "设若", "设使", "甚而", "甚么", "甚至", "省得", "时候",
"什么", "什么样", "使得", "", "是的", "首先", "", "谁知", "", "顺着", "似的", "", "虽然", "虽说", "虽则", "", "随着",
"", "所以", "", "他们", "他人", "", "它们", "", "她们", "", "倘或", "倘然", "倘若", "倘使", "", "", "通过", "",
"同时", "", "万一", "", "", "", "为何", "为了", "为什么", "为着", "", "嗡嗡", "", "我们", "", "呜呼", "乌乎",
"无论", "无宁", "毋宁", "", "", "相对而言", "", "", "向着", "", "", "", "沿", "沿着", "", "要不", "要不然",
"要不是", "要么", "要是", "", "也罢", "也好", "", "一般", "一旦", "一方面", "一来", "一切", "一样", "一则", "", "依照",
"", "", "以便", "以及", "以免", "以至", "以至于", "以致", "抑或", "", "因此", "因而", "因为", "", "", "",
"由此可见", "由于", "", "有的", "有关", "有些", "", "", "于是", "于是乎", "", "与此同时", "与否", "与其", "越是",
"云云", "", "再说", "再者", "", "在下", "", "咱们", "", "", "怎么", "怎么办", "怎么样", "怎样", "", "", "照着",
"", "", "这边", "这儿", "这个", "这会儿", "这就是说", "这里", "这么", "这么点儿", "这么些", "这么样", "这时", "这些", "这样",
"正如", "", "", "之类", "之所以", "之一", "只是", "只限", "只要", "只有", "", "至于", "诸位", "", "着呢", "", "自从",
"自个儿", "自各儿", "自己", "自家", "自身", "综上所述", "总的来看", "总的来说", "总的说来", "总而言之", "总之", "", "纵令",
"纵然", "纵使", "遵照", "作为", "", "", "", "", "", "", "", "喔唷", "", "", "", "~", "!", ".", ":",
"\"", "'", "(", ")", "*", "A", "", "社会主义", "--", "..", ">>", " [", " ]", "", "<", ">", "/", "\\", "|", "-", "_",
"+", "=", "&", "^", "%", "#", "@", "`", ";", "$", "", "", "——", "", "", "·", "...", "", "", "", "", "",
" ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "", "", "", "", "", "", "", "︿", "", "", "", "", "", "",
"", "", "", "啊哈", "啊呀", "啊哟", "挨次", "挨个", "挨家挨户", "挨门挨户", "挨门逐户", "挨着", "按理", "按期", "按时",
"按说", "暗地里", "暗中", "暗自", "昂然", "八成", "白白", "", "", "保管", "保险", "", "背地里", "背靠背", "倍感", "倍加",
"本人", "本身", "", "比起", "比如说", "比照", "毕竟", "", "必定", "必将", "必须", "便", "别人", "并非", "并肩", "并没",
"并没有", "并排", "并无", "勃然", "", "不必", "不常", "不大", "不但...而且", "不得", "不得不", "不得了", "不得已", "不迭",
"不定", "不对", "不妨", "不管怎样", "不会", "不仅...而且", "不仅仅", "不仅仅是", "不经意", "不可开交", "不可抗拒", "不力", "不了",
"不料", "不满", "不免", "不能不", "不起", "不巧", "不然的话", "不日", "不少", "不胜", "不时", "不是", "不同", "不能", "不要",
"不外", "不外乎", "不下", "不限", "不消", "不已", "不亦乐乎", "不由得", "不再", "不择手段", "不怎么", "不曾", "不知不觉", "不止",
"不止一次", "不至于", "", "才能", "策略地", "差不多", "差一点", "", "常常", "常言道", "常言说", "常言说得好", "长此下去",
"长话短说", "长期以来", "长线", "敞开儿", "彻夜", "陈年", "趁便", "趁机", "趁热", "趁势", "趁早", "成年", "成年累月", "成心",
"乘机", "乘胜", "乘势", "乘隙", "乘虚", "诚然", "迟早", "充分", "充其极", "充其量", "抽冷子", "", "", "", "出来", "出去",
"除此", "除此而外", "除此以外", "除开", "除去", "除却", "除外", "处处", "川流不息", "", "传说", "传闻", "串行", "", "纯粹",
"此后", "此中", "次第", "匆匆", "从不", "从此", "从此以后", "从古到今", "从古至今", "从今以后", "从宽", "从来", "从轻", "从速",
"从头", "从未", "从无到有", "从小", "从新", "从严", "从优", "从早到晚", "从中", "从重", "凑巧", "", "存心", "达旦", "打从",
"打开天窗说亮话", "", "大不了", "大大", "大抵", "大都", "大多", "大凡", "大概", "大家", "大举", "大略", "大面儿上", "大事",
"大体", "大体上", "大约", "大张旗鼓", "大致", "呆呆地", "", "", "待到", "", "单纯", "单单", "但愿", "弹指之间", "当场",
"当儿", "当即", "当口儿", "当然", "当庭", "当头", "当下", "当真", "当中", "倒不如", "倒不如说", "倒是", "到处", "到底", "到了儿",
"到目前为止", "到头", "到头来", "得起", "得天独厚", "的确", "等到", "叮当", "顶多", "", "动不动", "动辄", "陡然", "", "",
"独自", "断然", "顿时", "多次", "多多", "多多少少", "多多益善", "多亏", "多年来", "多年前", "而后", "而论", "而又", "尔等",
"二话不说", "二话没说", "反倒", "反倒是", "反而", "反手", "反之亦然", "反之则", "", "方才", "方能", "放量", "非常", "非得",
"分期", "分期分批", "分头", "奋勇", "愤然", "风雨无阻", "", "", "", "嘎嘎", "该当", "", "赶快", "赶早不赶晚", "",
"敢情", "敢于", "", "刚才", "刚好", "刚巧", "高低", "格外", "隔日", "隔夜", "个人", "各式", "", "更加", "更进一步", "更为",
"公然", "", "共总", "够瞧的", "姑且", "古来", "故而", "故意", "", "", "怪不得", "惯常", "", "光是", "归根到底",
"归根结底", "过于", "毫不", "毫无", "毫无保留地", "毫无例外", "好在", "何必", "何尝", "何妨", "何苦", "何乐而不为", "何须",
"何止", "", "很多", "很少", "轰然", "后来", "呼啦", "忽地", "忽然", "", "互相", "哗啦", "话说", "", "恍然", "", "豁然",
"", "伙同", "或多或少", "或许", "基本", "基本上", "基于", "", "极大", "极度", "极端", "极力", "极其", "极为", "急匆匆",
"即将", "即刻", "即是说", "几度", "几番", "几乎", "几经", "既...又", "继之", "加上", "加以", "间或", "简而言之", "简言之",
"简直", "", "将才", "将近", "将要", "交口", "较比", "较为", "接连不断", "接下来", "皆可", "截然", "截至", "藉以", "借此",
"借以", "届时", "", "仅仅", "", "进来", "进去", "", "近几年来", "近来", "近年来", "尽管如此", "尽可能", "尽快", "尽量",
"尽然", "尽如人意", "尽心竭力", "尽心尽力", "尽早", "精光", "经常", "", "竟然", "究竟", "就此", "就地", "就算", "居然", "局外",
"举凡", "据称", "据此", "据实", "据说", "据我所知", "据悉", "具体来说", "决不", "决非", "", "绝不", "绝顶", "绝对", "绝非",
"", "", "", "看来", "看起来", "看上去", "看样子", "可好", "可能", "恐怕", "", "快要", "来不及", "来得及", "来讲",
"来看", "拦腰", "牢牢", "", "老大", "老老实实", "老是", "累次", "累年", "理当", "理该", "理应", "", "", "立地", "立刻",
"立马", "立时", "联袂", "连连", "连日", "连日来", "连声", "连袂", "临到", "另方面", "另行", "另一个", "路经", "", "屡次",
"屡次三番", "屡屡", "缕缕", "率尔", "率然", "", "略加", "略微", "略为", "论说", "马上", "", "", "", "没有", "每逢",
"每每", "每时每刻", "猛然", "猛然间", "", "莫不", "莫非", "莫如", "默默地", "默然", "", "那末", "", "难道", "难得", "难怪",
"难说", "", "年复一年", "凝神", "偶而", "偶尔", "", "", "碰巧", "譬如", "偏偏", "", "平素", "", "迫于", "扑通",
"其后", "其实", "", "", "起初", "起来", "起首", "起头", "起先", "", "岂非", "岂止", "", "恰逢", "恰好", "恰恰", "恰巧",
"恰如", "恰似", "", "千万", "千万千万", "", "切不可", "切莫", "切切", "切勿", "", "亲口", "亲身", "亲手", "亲眼", "亲自",
"", "顷刻", "顷刻间", "顷刻之间", "请勿", "穷年累月", "取道", "", "权时", "全都", "全力", "全年", "全然", "全身心", "",
"人人", "", "仍旧", "仍然", "日复一日", "日见", "日渐", "日益", "日臻", "如常", "如此等等", "如次", "如今", "如期", "如前所述",
"如上", "如下", "", "三番两次", "三番五次", "三天两头", "瑟瑟", "沙沙", "", "上来", "上去", "一个", "", "", "\n"
}

View File

@@ -0,0 +1,54 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
from core.rag.models.document import Document
from models.dataset import Dataset
class BaseKeyword(ABC):
def __init__(self, dataset: Dataset):
self.dataset = dataset
@abstractmethod
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
raise NotImplementedError
@abstractmethod
def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError
@abstractmethod
def text_exists(self, id: str) -> bool:
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_document_id(self, document_id: str) -> None:
raise NotImplementedError
def delete(self) -> None:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]

View File

@@ -0,0 +1,60 @@
from typing import Any, cast
from flask import current_app
from core.rag.datasource.keyword.jieba.jieba import Jieba
from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document
from models.dataset import Dataset
class Keyword:
def __init__(self, dataset: Dataset):
self._dataset = dataset
self._keyword_processor = self._init_keyword()
def _init_keyword(self) -> BaseKeyword:
config = cast(dict, current_app.config)
keyword_type = config.get('KEYWORD_STORE')
if not keyword_type:
raise ValueError("Keyword store must be specified.")
if keyword_type == "jieba":
return Jieba(
dataset=self._dataset
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def create(self, texts: list[Document], **kwargs):
self._keyword_processor.create(texts, **kwargs)
def add_texts(self, texts: list[Document], **kwargs):
self._keyword_processor.add_texts(texts, **kwargs)
def text_exists(self, id: str) -> bool:
return self._keyword_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
self._keyword_processor.delete_by_ids(ids)
def delete_by_document_id(self, document_id: str) -> None:
self._keyword_processor.delete_by_document_id(document_id)
def delete(self) -> None:
self._keyword_processor.delete()
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
return self._keyword_processor.search(query, **kwargs)
def __getattr__(self, name):
if self._keyword_processor is not None:
method = getattr(self._keyword_processor, name)
if callable(method):
return method
raise AttributeError(f"'Keyword' object has no attribute '{name}'")

View File

@@ -0,0 +1,165 @@
import threading
from typing import Optional
from flask import Flask, current_app
from flask_login import current_user
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
class RetrievalService:
@classmethod
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
all_documents = []
threads = []
# retrieval_model source with keyword
if retrival_method == 'keyword_search':
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k
})
threads.append(keyword_thread)
keyword_thread.start()
# retrieval_model source with semantic
if retrival_method == 'semantic_search' or retrival_method == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'score_threshold': score_threshold,
'reranking_model': reranking_model,
'all_documents': all_documents,
'retrival_method': retrival_method
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if retrival_method == 'full_text_search' or retrival_method == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'retrival_method': retrival_method,
'score_threshold': score_threshold,
'top_k': top_k,
'reranking_model': reranking_model,
'all_documents': all_documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
if retrival_method == 'hybrid_search':
data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k
)
return all_documents
@classmethod
def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, all_documents: list):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
keyword = Keyword(
dataset=dataset
)
documents = keyword.search(
query,
k=top_k
)
all_documents.extend(documents)
@classmethod
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, retrival_method: str):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
vector = Vector(
dataset=dataset
)
documents = vector.search_by_vector(
query,
search_type='similarity_score_threshold',
k=top_k,
score_threshold=score_threshold,
filter={
'group_id': [dataset.id]
}
)
if documents:
if reranking_model and retrival_method == 'semantic_search':
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
all_documents.extend(data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents)
))
else:
all_documents.extend(documents)
@classmethod
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, retrival_method: str):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
vector_processor = Vector(
dataset=dataset,
)
documents = vector_processor.search_by_full_text(
query,
top_k=top_k
)
if documents:
if reranking_model and retrival_method == 'full_text_search':
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
all_documents.extend(data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents)
))
else:
all_documents.extend(documents)

View File

View File

@@ -0,0 +1,10 @@
from enum import Enum
class Field(Enum):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = "vector"
TEXT_KEY = "text"
PRIMARY_KEY = " id"

View File

@@ -0,0 +1,214 @@
import logging
from typing import Any, Optional
from uuid import uuid4
from pydantic import BaseModel, root_validator
from pymilvus import MilvusClient, MilvusException, connections
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class MilvusConfig(BaseModel):
host: str
port: int
user: str
password: str
secure: bool = False
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config MILVUS_HOST is required")
if not values['port']:
raise ValueError("config MILVUS_PORT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'secure': self.secure
}
class MilvusVector(BaseVector):
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = 'Session'
self._fields = []
def get_type(self) -> str:
return 'milvus'
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
metadatas = [d.metadata for d in texts]
# Grab the existing collection if it exists
from pymilvus import utility
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
if not utility.has_collection(self._collection_name, using=alias):
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
insert_dict_list = []
for i in range(len(documents)):
insert_dict = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata
}
insert_dict_list.append(insert_dict)
# Total insert count
total_count = len(insert_dict_list)
pks: list[str] = []
for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i:i + 1000]
# Insert into the collection.
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
except MilvusException as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
raise e
return pks
def delete_by_document_id(self, document_id: str):
ids = self.get_ids_by_metadata_field('document_id', document_id)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def get_ids_by_metadata_field(self, key: str, value: str):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["{key}"] == "{value}"',
output_fields=["id"])
if result:
return [item["id"] for item in result]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, doc_ids: list[str]) -> None:
self._client.delete(collection_name=self._collection_name, pks=doc_ids)
def delete(self) -> None:
from pymilvus import utility
utility.drop_collection(self._collection_name, None)
def text_exists(self, id: str) -> bool:
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] == "{id}"',
output_fields=["id"])
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
results = self._client.search(collection_name=self._collection_name,
data=[query_vector],
limit=kwargs.get('top_k', 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
docs = []
for result in results[0]:
metadata = result['entity'].get(Field.METADATA_KEY.value)
metadata['score'] = result['distance']
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if result['distance'] > score_threshold:
doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
metadata=metadata)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
) -> str:
from pymilvus import CollectionSchema, DataType, FieldSchema
from pymilvus.orm.types import infer_dtype_bydata
# Determine embedding dim
dim = len(embeddings[0])
fields = []
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
)
# Create the primary key field
fields.append(
FieldSchema(
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
)
)
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
)
# Create the schema for the collection
schema = CollectionSchema(fields)
for x in schema.fields:
self._fields.append(x.name)
# Since primary field is auto-id, no need to track it
self._fields.remove(Field.PRIMARY_KEY.value)
# Create the collection
collection_name = self._collection_name
self._client.create_collection_with_schema(collection_name=collection_name,
schema=schema, index_param=index_params,
consistency_level=self._consistency_level)
return collection_name
def _init_client(self, config) -> MilvusClient:
if config.secure:
uri = "https://" + str(config.host) + ":" + str(config.port)
else:
uri = "http://" + str(config.host) + ":" + str(config.port)
client = MilvusClient(uri=uri, user=config.user, password=config.password)
return client

View File

@@ -0,0 +1,360 @@
import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
import qdrant_client
from pydantic import BaseModel
from qdrant_client.http import models as rest
from qdrant_client.http.models import (
FilterSelector,
HnswConfigDiff,
PayloadSchemaType,
TextIndexParams,
TextIndexType,
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
timeout: float = 20
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
'timeout': self.timeout
}
class QdrantVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'):
super().__init__(collection_name)
self._client_config = config
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
self._distance_func = distance_func.upper()
self._group_id = group_id
def get_type(self) -> str:
return 'qdrant'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:
# get embedding vector size
vector_size = len(embeddings[0])
# get collection name
collection_name = self._collection_name
collection_name = collection_name or uuid.uuid4().hex
all_collection_name = []
collections_response = self._client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
# create collection
self.create_collection(collection_name, vector_size)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
from qdrant_client.http import models as rest
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[self._distance_func],
)
hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False)
self._client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
)
# create payload index
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
# creat full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
field_schema=text_index_params)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
added_ids = []
for batch_ids, points in self._generate_rest_batches(
texts, embeddings, metadatas, uuids, 64, self._group_id
):
self._client.upsert(
collection_name=self._collection_name, points=points
)
added_ids.extend(batch_ids)
return added_ids
def _generate_rest_batches(
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
texts_iterator = iter(texts)
embeddings_iterator = iter(embeddings)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
# Generate the embeddings for all the texts in a batch
batch_embeddings = list(islice(embeddings_iterator, batch_size))
points = [
rest.PointStruct(
id=point_id,
vector=vector,
payload=payload,
)
for point_id, vector, payload in zip(
batch_ids,
batch_embeddings,
self._build_payloads(
batch_texts,
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
group_id,
Field.GROUP_KEY.value,
),
)
]
yield batch_ids, points
@classmethod
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str
) -> list[dict]:
payloads = []
for i, text in enumerate(texts):
if text is None:
raise ValueError(
"At least one of the texts is None. Please remove it before "
"calling .from_texts or .add_texts on Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append(
{
content_payload_key: text,
metadata_payload_key: metadata,
group_payload_key: group_id
}
)
return payloads
def delete_by_metadata_field(self, key: str, value: str):
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
)
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
)
def delete(self):
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
)
def delete_by_ids(self, ids: list[str]) -> None:
from qdrant_client.http import models
for node_id in ids:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
)
def text_exists(self, id: str) -> bool:
response = self._client.retrieve(
collection_name=self._collection_name,
ids=[id]
)
return len(response) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
],
)
results = self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
query_filter=filter,
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=kwargs.get("score_threshold", .0)
)
docs = []
for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
if result.score > score_threshold:
metadata['score'] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value),
metadata=metadata,
)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs most similar by bm25.
Returns:
List of documents most similar to the query text and distance for each.
"""
from qdrant_client.http import models
scroll_filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
]
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,
limit=kwargs.get('top_k', 2),
with_payload=True,
with_vectors=True
)
results = response[0]
documents = []
for result in results:
if result:
documents.append(self._document_from_scored_point(
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
))
return documents
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load()
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata=scored_point.payload.get(metadata_payload_key) or {},
)

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
from core.rag.models.document import Document
class BaseVector(ABC):
def __init__(self, collection_name: str):
self._collection_name = collection_name
@abstractmethod
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError
@abstractmethod
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError
@abstractmethod
def text_exists(self, id: str) -> bool:
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None:
raise NotImplementedError
@abstractmethod
def search_by_vector(
self,
query_vector: list[float],
**kwargs: Any
) -> list[Document]:
raise NotImplementedError
@abstractmethod
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
raise NotImplementedError
def delete(self) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]

View File

@@ -0,0 +1,171 @@
from typing import Any, cast
from flask import current_app
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class Vector:
def __init__(self, dataset: Dataset, attributes: list = None):
if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self._dataset = dataset
self._embeddings = self._get_embeddings()
self._attributes = attributes
self._vector_processor = self._init_vector()
def _init_vector(self) -> BaseVector:
config = cast(dict, current_app.config)
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
if not vector_type:
raise ValueError("Vector store must be specified.")
if vector_type == "weaviate":
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
attributes=self._attributes
)
elif vector_type == "qdrant":
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
if self._dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self._dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
return QdrantVector(
collection_name=collection_name,
group_id=self._dataset.id,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
)
)
elif vector_type == "milvus":
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
)
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def create(self, texts: list = None, **kwargs):
if texts:
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
self._vector_processor.create(
texts=texts,
embeddings=embeddings,
**kwargs
)
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get('duplicate_check', False):
documents = self._filter_duplicate_texts(documents)
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
self._vector_processor.add_texts(
documents=documents,
embeddings=embeddings,
**kwargs
)
def text_exists(self, id: str) -> bool:
return self._vector_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
self._vector_processor.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._vector_processor.delete_by_metadata_field(key, value)
def search_by_vector(
self, query: str,
**kwargs: Any
) -> list[Document]:
query_vector = self._embeddings.embed_query(query)
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)
def delete(self) -> None:
self._vector_processor.delete()
def _get_embeddings(self) -> Embeddings:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model
)
return CacheEmbedding(embedding_model)
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def __getattr__(self, name):
if self._vector_processor is not None:
method = getattr(self._vector_processor, name)
if callable(method):
return method
raise AttributeError(f"'vector_processor' object has no attribute '{name}'")

View File

@@ -0,0 +1,235 @@
import datetime
from typing import Any, Optional
import requests
import weaviate
from pydantic import BaseModel, root_validator
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str]
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVector(BaseVector):
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
super().__init__(collection_name)
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
weaviate.connect.connection.has_grpc = False
try:
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_type(self) -> str:
return 'weaviate'
def get_collection_name(self, dataset: Dataset) -> str:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
schema = self._default_schema(self._collection_name)
# check whether the index already exists
if not self._client.schema.contains(schema):
# create collection
self._client.schema.create_class(schema)
# create vector
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
ids = []
with self._client.batch as batch:
for i, text in enumerate(texts):
data_properties = {Field.TEXT_KEY.value: text}
if metadatas is not None:
for key, val in metadatas[i].items():
data_properties[key] = self._json_serializable(val)
batch.add_data_object(
data_object=data_properties,
class_name=self._collection_name,
uuid=uuids[i],
vector=embeddings[i] if embeddings else None,
)
ids.append(uuids[i])
return ids
def delete_by_metadata_field(self, key: str, value: str):
where_filter = {
"operator": "Equal",
"path": [key],
"valueText": value
}
self._client.batch.delete_objects(
class_name=self._collection_name,
where=where_filter,
output='minimal'
)
def delete(self):
self._client.schema.delete_class(self._collection_name)
def text_exists(self, id: str) -> bool:
collection_name = self._collection_name
result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": id,
}).with_limit(1).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
entries = result["data"]["Get"][collection_name]
if len(entries) == 0:
return False
return True
def delete_by_ids(self, ids: list[str]) -> None:
self._client.data_object.delete(
ids,
class_name=self._collection_name
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate."""
collection_name = self._collection_name
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
query_obj = self._client.query.get(collection_name, properties)
vector = {"vector": query_vector}
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
result = (
query_obj.with_near_vector(vector)
.with_limit(kwargs.get("top_k", 4))
.with_additional(["vector", "distance"])
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs_and_scores = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
score = 1 - res["_additional"]["distance"]
docs_and_scores.append((Document(page_content=text, metadata=res), score))
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
# check score threshold
if score > score_threshold:
doc.metadata['score'] = score
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs using BM25F.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
collection_name = self._collection_name
content: dict[str, Any] = {"concepts": [query]}
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(collection_name, properties)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
properties = ['text']
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
docs.append(Document(page_content=text, metadata=res))
return docs
def _default_schema(self, index_name: str) -> dict:
return {
"class": index_name,
"properties": [
{
"name": "text",
"dataType": ["text"],
}
],
}
def _json_serializable(self, value: Any) -> Any:
if isinstance(value, datetime.datetime):
return value.isoformat()
return value

View File

@@ -0,0 +1,166 @@
"""Schema for Blobs and Blob Loaders.
The goal is to facilitate decoupling of content loading from content parsing code.
In addition, content loading code should provide a lazy loading interface by default.
"""
from __future__ import annotations
import contextlib
import mimetypes
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable, Mapping
from io import BufferedReader, BytesIO
from pathlib import PurePath
from typing import Any, Optional, Union
from pydantic import BaseModel, root_validator
PathLike = Union[str, PurePath]
class Blob(BaseModel):
"""A blob is used to represent raw data by either reference or value.
Provides an interface to materialize the blob in different representations, and
help to decouple the development of data loaders from the downstream parsing of
the raw data.
Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
"""
data: Union[bytes, str, None] # Raw data
mimetype: Optional[str] = None # Not to be confused with a file extension
encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string
# Location where the original content was found
# Represent location on the local file system
# Useful for situations where downstream code assumes it must work with file paths
# rather than in-memory content.
path: Optional[PathLike] = None
class Config:
arbitrary_types_allowed = True
frozen = True
@property
def source(self) -> Optional[str]:
"""The source location of the blob as string if known otherwise none."""
return str(self.path) if self.path else None
@root_validator(pre=True)
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
"""Verify that either data or path is provided."""
if "data" not in values and "path" not in values:
raise ValueError("Either data or path must be provided")
return values
def as_string(self) -> str:
"""Read data as a string."""
if self.data is None and self.path:
with open(str(self.path), encoding=self.encoding) as f:
return f.read()
elif isinstance(self.data, bytes):
return self.data.decode(self.encoding)
elif isinstance(self.data, str):
return self.data
else:
raise ValueError(f"Unable to get string for blob {self}")
def as_bytes(self) -> bytes:
"""Read data as bytes."""
if isinstance(self.data, bytes):
return self.data
elif isinstance(self.data, str):
return self.data.encode(self.encoding)
elif self.data is None and self.path:
with open(str(self.path), "rb") as f:
return f.read()
else:
raise ValueError(f"Unable to get bytes for blob {self}")
@contextlib.contextmanager
def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]:
"""Read data as a byte stream."""
if isinstance(self.data, bytes):
yield BytesIO(self.data)
elif self.data is None and self.path:
with open(str(self.path), "rb") as f:
yield f
else:
raise NotImplementedError(f"Unable to convert blob {self}")
@classmethod
def from_path(
cls,
path: PathLike,
*,
encoding: str = "utf-8",
mime_type: Optional[str] = None,
guess_type: bool = True,
) -> Blob:
"""Load the blob from a path like object.
Args:
path: path like object to file to be read
encoding: Encoding to use if decoding the bytes into a string
mime_type: if provided, will be set as the mime-type of the data
guess_type: If True, the mimetype will be guessed from the file extension,
if a mime-type was not provided
Returns:
Blob instance
"""
if mime_type is None and guess_type:
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None
else:
_mimetype = mime_type
# We do not load the data immediately, instead we treat the blob as a
# reference to the underlying data.
return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path)
@classmethod
def from_data(
cls,
data: Union[str, bytes],
*,
encoding: str = "utf-8",
mime_type: Optional[str] = None,
path: Optional[str] = None,
) -> Blob:
"""Initialize the blob from in-memory data.
Args:
data: the in-memory data associated with the blob
encoding: Encoding to use if decoding the bytes into a string
mime_type: if provided, will be set as the mime-type of the data
path: if provided, will be set as the source from which the data came
Returns:
Blob instance
"""
return cls(data=data, mimetype=mime_type, encoding=encoding, path=path)
def __repr__(self) -> str:
"""Define the blob representation."""
str_repr = f"Blob {id(self)}"
if self.source:
str_repr += f" {self.source}"
return str_repr
class BlobLoader(ABC):
"""Abstract interface for blob loaders implementation.
Implementer should be able to load raw content from a datasource system according
to some criteria and return the raw content lazily as a stream of blobs.
"""
@abstractmethod
def yield_blobs(
self,
) -> Iterable[Blob]:
"""A lazy loader for raw data represented by LangChain's Blob object.
Returns:
A generator over blobs
"""

View File

@@ -0,0 +1,71 @@
"""Abstract interface for document loader implementations."""
import csv
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
class CSVExtractor(BaseExtractor):
"""Load CSV files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
source_column: Optional[str] = None,
csv_args: Optional[dict] = None,
):
"""Initialize with file path."""
self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
self.source_column = source_column
self.csv_args = csv_args or {}
def extract(self) -> list[Document]:
"""Load data into document objects."""
try:
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
docs = self._read_from_file(csvfile)
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_filze_encodings(self._file_path)
for encoding in detected_encodings:
try:
with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
docs = self._read_from_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self._file_path}") from e
return docs
def _read_from_file(self, csvfile) -> list[Document]:
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
try:
source = (
row[self.source_column]
if self.source_column is not None
else ''
)
except KeyError:
raise ValueError(
f"Source column '{self.source_column}' not found in CSV file."
)
metadata = {"source": source, "row": i}
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
return docs

View File

@@ -0,0 +1,6 @@
from enum import Enum
class DatasourceType(Enum):
FILE = "upload_file"
NOTION = "notion_import"

View File

@@ -0,0 +1,36 @@
from pydantic import BaseModel
from models.dataset import Document
from models.model import UploadFile
class NotionInfo(BaseModel):
"""
Notion import info.
"""
notion_workspace_id: str
notion_obj_id: str
notion_page_type: str
document: Document = None
class Config:
arbitrary_types_allowed = True
def __init__(self, **data) -> None:
super().__init__(**data)
class ExtractSetting(BaseModel):
"""
Model class for provider response.
"""
datasource_type: str
upload_file: UploadFile = None
notion_info: NotionInfo = None
document_model: str = None
class Config:
arbitrary_types_allowed = True
def __init__(self, **data) -> None:
super().__init__(**data)

View File

@@ -0,0 +1,50 @@
"""Abstract interface for document loader implementations."""
from typing import Optional
from openpyxl.reader.excel import load_workbook
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
class ExcelExtractor(BaseExtractor):
"""Load Excel files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False
):
"""Initialize with file path."""
self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def extract(self) -> list[Document]:
"""Load from file path."""
data = []
keys = []
wb = load_workbook(filename=self._file_path, read_only=True)
# loop over all sheets
for sheet in wb:
if 'A1:A1' == sheet.calculate_dimension():
sheet.reset_dimensions()
for row in sheet.iter_rows(values_only=True):
if all(v is None for v in row):
continue
if keys == []:
keys = list(map(str, row))
else:
row_dict = dict(zip(keys, list(map(str, row))))
row_dict = {k: v for k, v in row_dict.items() if v}
item = ''.join(f'{k}:{v};' for k, v in row_dict.items())
document = Document(page_content=item, metadata={'source': self._file_path})
data.append(document)
return data

View File

@@ -0,0 +1,139 @@
import tempfile
from pathlib import Path
from typing import Union
import requests
from flask import current_app
from core.rag.extractor.csv_extractor import CSVExtractor
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.excel_extractor import ExcelExtractor
from core.rag.extractor.html_extractor import HtmlExtractor
from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.extractor.pdf_extractor import PdfExtractor
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
from core.rag.extractor.word_extractor import WordExtractor
from core.rag.models.document import Document
from extensions.ext_storage import storage
from models.model import UploadFile
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
class ExtractProcessor:
@classmethod
def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \
-> Union[list[Document], str]:
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=upload_file,
document_model='text_model'
)
if return_text:
delimiter = '\n'
return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)])
else:
return cls.extract(extract_setting, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
response = requests.get(url, headers={
"User-Agent": USER_AGENT
})
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(url).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, 'wb') as file:
file.write(response.content)
extract_setting = ExtractSetting(
datasource_type="upload_file",
document_model='text_model'
)
if return_text:
delimiter = '\n'
return delimiter.join([document.page_content for document in cls.extract(
extract_setting=extract_setting, file_path=file_path)])
else:
return cls.extract(extract_setting=extract_setting, file_path=file_path)
@classmethod
def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False,
file_path: str = None) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE.value:
with tempfile.TemporaryDirectory() as temp_dir:
if not file_path:
upload_file: UploadFile = extract_setting.upload_file
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
etl_type = current_app.config['ETL_TYPE']
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
if etl_type == 'Unstructured':
if file_extension == '.xlsx':
extractor = ExcelExtractor(file_path)
elif file_extension == '.pdf':
extractor = PdfExtractor(file_path)
elif file_extension in ['.md', '.markdown']:
extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \
else MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
extractor = HtmlExtractor(file_path)
elif file_extension in ['.docx']:
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url)
elif file_extension == '.csv':
extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == '.msg':
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url)
elif file_extension == '.eml':
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
elif file_extension == '.ppt':
extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url)
elif file_extension == '.pptx':
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
elif file_extension == '.xml':
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url)
else:
# txt
extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \
else TextExtractor(file_path, autodetect_encoding=True)
else:
if file_extension == '.xlsx':
extractor = ExcelExtractor(file_path)
elif file_extension == '.pdf':
extractor = PdfExtractor(file_path)
elif file_extension in ['.md', '.markdown']:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
extractor = HtmlExtractor(file_path)
elif file_extension in ['.docx']:
extractor = WordExtractor(file_path)
elif file_extension == '.csv':
extractor = CSVExtractor(file_path, autodetect_encoding=True)
else:
# txt
extractor = TextExtractor(file_path, autodetect_encoding=True)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.NOTION.value:
extractor = NotionExtractor(
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
notion_obj_id=extract_setting.notion_info.notion_obj_id,
notion_page_type=extract_setting.notion_info.notion_page_type,
document_model=extract_setting.notion_info.document
)
return extractor.extract()
else:
raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")

View File

@@ -0,0 +1,12 @@
"""Abstract interface for document loader implementations."""
from abc import ABC, abstractmethod
class BaseExtractor(ABC):
"""Interface for extract files.
"""
@abstractmethod
def extract(self):
raise NotImplementedError

View File

@@ -0,0 +1,46 @@
"""Document loader helpers."""
import concurrent.futures
from typing import NamedTuple, Optional, cast
class FileEncoding(NamedTuple):
"""A file encoding as the NamedTuple."""
encoding: Optional[str]
"""The encoding of the file."""
confidence: float
"""The confidence of the encoding."""
language: Optional[str]
"""The language of the file."""
def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]:
"""Try to detect the file encoding.
Returns a list of `FileEncoding` tuples with the detected encodings ordered
by confidence.
Args:
file_path: The path to the file to detect the encoding for.
timeout: The timeout in seconds for the encoding detection.
"""
import chardet
def read_and_detect(file_path: str) -> list[dict]:
with open(file_path, "rb") as f:
rawdata = f.read()
return cast(list[dict], chardet.detect_all(rawdata))
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(read_and_detect, file_path)
try:
encodings = future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
raise TimeoutError(
f"Timeout reached while detecting encoding for {file_path}"
)
if all(encoding["encoding"] is None for encoding in encodings):
raise RuntimeError(f"Could not detect encoding for {file_path}")
return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None]

View File

@@ -0,0 +1,71 @@
"""Abstract interface for document loader implementations."""
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings
from core.rag.models.document import Document
class HtmlExtractor(BaseExtractor):
"""Load html files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
source_column: Optional[str] = None,
csv_args: Optional[dict] = None,
):
"""Initialize with file path."""
self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
self.source_column = source_column
self.csv_args = csv_args or {}
def extract(self) -> list[Document]:
"""Load data into document objects."""
try:
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
docs = self._read_from_file(csvfile)
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(self._file_path)
for encoding in detected_encodings:
try:
with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
docs = self._read_from_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self._file_path}") from e
return docs
def _read_from_file(self, csvfile) -> list[Document]:
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
try:
source = (
row[self.source_column]
if self.source_column is not None
else ''
)
except KeyError:
raise ValueError(
f"Source column '{self.source_column}' not found in CSV file."
)
metadata = {"source": source, "row": i}
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
return docs

View File

@@ -0,0 +1,122 @@
"""Abstract interface for document loader implementations."""
import re
from typing import Optional, cast
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings
from core.rag.models.document import Document
class MarkdownExtractor(BaseExtractor):
"""Load Markdown files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
remove_hyperlinks: bool = True,
remove_images: bool = True,
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
):
"""Initialize with file path."""
self._file_path = file_path
self._remove_hyperlinks = remove_hyperlinks
self._remove_images = remove_images
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def extract(self) -> list[Document]:
"""Load from file path."""
tups = self.parse_tups(self._file_path)
documents = []
for header, value in tups:
value = value.strip()
if header is None:
documents.append(Document(page_content=value))
else:
documents.append(Document(page_content=f"\n\n{header}\n{value}"))
return documents
def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str]]:
"""Convert a markdown file to a dictionary.
The keys are the headers and the values are the text under each header.
"""
markdown_tups: list[tuple[Optional[str], str]] = []
lines = markdown_text.split("\n")
current_header = None
current_text = ""
for line in lines:
header_match = re.match(r"^#+\s", line)
if header_match:
if current_header is not None:
markdown_tups.append((current_header, current_text))
current_header = line
current_text = ""
else:
current_text += line + "\n"
markdown_tups.append((current_header, current_text))
if current_header is not None:
# pass linting, assert keys are defined
markdown_tups = [
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
for key, value in markdown_tups
]
else:
markdown_tups = [
(key, re.sub("\n", "", value)) for key, value in markdown_tups
]
return markdown_tups
def remove_images(self, content: str) -> str:
"""Get a dictionary of a markdown file from its path."""
pattern = r"!{1}\[\[(.*)\]\]"
content = re.sub(pattern, "", content)
return content
def remove_hyperlinks(self, content: str) -> str:
"""Get a dictionary of a markdown file from its path."""
pattern = r"\[(.*?)\]\((.*?)\)"
content = re.sub(pattern, r"\1", content)
return content
def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]:
"""Parse file into tuples."""
content = ""
try:
with open(filepath, encoding=self._encoding) as f:
content = f.read()
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(filepath)
for encoding in detected_encodings:
try:
with open(filepath, encoding=encoding.encoding) as f:
content = f.read()
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {filepath}") from e
except Exception as e:
raise RuntimeError(f"Error loading {filepath}") from e
if self._remove_hyperlinks:
content = self.remove_hyperlinks(content)
if self._remove_images:
content = self.remove_images(content)
return self.markdown_to_tups(content)

View File

@@ -0,0 +1,366 @@
import json
import logging
from typing import Any, Optional
import requests
from flask import current_app
from flask_login import current_user
from langchain.schema import Document
from core.rag.extractor.extractor_base import BaseExtractor
from extensions.ext_database import db
from models.dataset import Document as DocumentModel
from models.source import DataSourceBinding
logger = logging.getLogger(__name__)
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
SEARCH_URL = "https://api.notion.com/v1/search"
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
class NotionExtractor(BaseExtractor):
def __init__(
self,
notion_workspace_id: str,
notion_obj_id: str,
notion_page_type: str,
document_model: Optional[DocumentModel] = None,
notion_access_token: Optional[str] = None
):
self._notion_access_token = None
self._document_model = document_model
self._notion_workspace_id = notion_workspace_id
self._notion_obj_id = notion_obj_id
self._notion_page_type = notion_page_type
if notion_access_token:
self._notion_access_token = notion_access_token
else:
self._notion_access_token = self._get_access_token(current_user.current_tenant_id,
self._notion_workspace_id)
if not self._notion_access_token:
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
if integration_token is None:
raise ValueError(
"Must specify `integration_token` or set environment "
"variable `NOTION_INTEGRATION_TOKEN`."
)
self._notion_access_token = integration_token
def extract(self) -> list[Document]:
self.update_last_edited_time(
self._document_model
)
text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
return text_docs
def _load_data_as_documents(
self, notion_obj_id: str, notion_page_type: str
) -> list[Document]:
docs = []
if notion_page_type == 'database':
# get all the pages in the database
page_text_documents = self._get_notion_database_data(notion_obj_id)
docs.extend(page_text_documents)
elif notion_page_type == 'page':
page_text_list = self._get_notion_block_data(notion_obj_id)
for page_text in page_text_list:
docs.append(Document(page_content=page_text))
else:
raise ValueError("notion page type not supported")
return docs
def _get_notion_database_data(
self, database_id: str, query_dict: dict[str, Any] = {}
) -> list[Document]:
"""Get all the pages from a Notion database."""
res = requests.post(
DATABASE_URL_TMPL.format(database_id=database_id),
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict,
)
data = res.json()
database_content_list = []
if 'results' not in data or data["results"] is None:
return []
for result in data["results"]:
properties = result['properties']
data = {}
for property_name, property_value in properties.items():
type = property_value['type']
if type == 'multi_select':
value = []
multi_select_list = property_value[type]
for multi_select in multi_select_list:
value.append(multi_select['name'])
elif type == 'rich_text' or type == 'title':
if len(property_value[type]) > 0:
value = property_value[type][0]['plain_text']
else:
value = ''
elif type == 'select' or type == 'status':
if property_value[type]:
value = property_value[type]['name']
else:
value = ''
else:
value = property_value[type]
data[property_name] = value
row_dict = {k: v for k, v in data.items() if v}
row_content = ''
for key, value in row_dict.items():
if isinstance(value, dict):
value_dict = {k: v for k, v in value.items() if v}
value_content = ''.join(f'{k}:{v} ' for k, v in value_dict.items())
row_content = row_content + f'{key}:{value_content}\n'
else:
row_content = row_content + f'{key}:{value}\n'
document = Document(page_content=row_content)
database_content_list.append(document)
return database_content_list
def _get_notion_block_data(self, page_id: str) -> list[str]:
result_lines_arr = []
cur_block_id = page_id
while True:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: dict[str, Any] = {}
res = requests.request(
"GET",
block_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
)
data = res.json()
# current block's heading
heading = ''
for result in data["results"]:
result_type = result["type"]
result_obj = result[result_type]
cur_result_text_arr = []
if result_type == 'table':
result_block_id = result["id"]
text = self._read_table_rows(result_block_id)
text += "\n\n"
result_lines_arr.append(text)
else:
if "rich_text" in result_obj:
for rich_text in result_obj["rich_text"]:
# skip if doesn't have text object
if "text" in rich_text:
text = rich_text["text"]["content"]
cur_result_text_arr.append(text)
if result_type in HEADING_TYPE:
heading = text
result_block_id = result["id"]
has_children = result["has_children"]
block_type = result["type"]
if has_children and block_type != 'child_page':
children_text = self._read_block(
result_block_id, num_tabs=1
)
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr)
cur_result_text += "\n\n"
if result_type in HEADING_TYPE:
result_lines_arr.append(cur_result_text)
else:
result_lines_arr.append(f'{heading}\n{cur_result_text}')
if data["next_cursor"] is None:
break
else:
cur_block_id = data["next_cursor"]
return result_lines_arr
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
"""Read a block."""
result_lines_arr = []
cur_block_id = block_id
while True:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: dict[str, Any] = {}
res = requests.request(
"GET",
block_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
)
data = res.json()
if 'results' not in data or data["results"] is None:
break
heading = ''
for result in data["results"]:
result_type = result["type"]
result_obj = result[result_type]
cur_result_text_arr = []
if result_type == 'table':
result_block_id = result["id"]
text = self._read_table_rows(result_block_id)
result_lines_arr.append(text)
else:
if "rich_text" in result_obj:
for rich_text in result_obj["rich_text"]:
# skip if doesn't have text object
if "text" in rich_text:
text = rich_text["text"]["content"]
prefix = "\t" * num_tabs
cur_result_text_arr.append(prefix + text)
if result_type in HEADING_TYPE:
heading = text
result_block_id = result["id"]
has_children = result["has_children"]
block_type = result["type"]
if has_children and block_type != 'child_page':
children_text = self._read_block(
result_block_id, num_tabs=num_tabs + 1
)
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr)
if result_type in HEADING_TYPE:
result_lines_arr.append(cur_result_text)
else:
result_lines_arr.append(f'{heading}\n{cur_result_text}')
if data["next_cursor"] is None:
break
else:
cur_block_id = data["next_cursor"]
result_lines = "\n".join(result_lines_arr)
return result_lines
def _read_table_rows(self, block_id: str) -> str:
"""Read table rows."""
done = False
result_lines_arr = []
cur_block_id = block_id
while not done:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: dict[str, Any] = {}
res = requests.request(
"GET",
block_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
)
data = res.json()
# get table headers text
table_header_cell_texts = []
tabel_header_cells = data["results"][0]['table_row']['cells']
for tabel_header_cell in tabel_header_cells:
if tabel_header_cell:
for table_header_cell_text in tabel_header_cell:
text = table_header_cell_text["text"]["content"]
table_header_cell_texts.append(text)
# get table columns text and format
results = data["results"]
for i in range(len(results) - 1):
column_texts = []
tabel_column_cells = data["results"][i + 1]['table_row']['cells']
for j in range(len(tabel_column_cells)):
if tabel_column_cells[j]:
for table_column_cell_text in tabel_column_cells[j]:
column_text = table_column_cell_text["text"]["content"]
column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
cur_result_text = "\n".join(column_texts)
result_lines_arr.append(cur_result_text)
if data["next_cursor"] is None:
done = True
break
else:
cur_block_id = data["next_cursor"]
result_lines = "\n".join(result_lines_arr)
return result_lines
def update_last_edited_time(self, document_model: DocumentModel):
if not document_model:
return
last_edited_time = self.get_notion_last_edited_time()
data_source_info = document_model.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
DocumentModel.data_source_info: json.dumps(data_source_info)
}
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
db.session.commit()
def get_notion_last_edited_time(self) -> str:
obj_id = self._notion_obj_id
page_type = self._notion_page_type
if page_type == 'database':
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
else:
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
query_dict: dict[str, Any] = {}
res = requests.request(
"GET",
retrieve_page_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
)
data = res.json()
return data["last_edited_time"]
@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
)
).first()
if not data_source_binding:
raise Exception(f'No notion data source binding found for tenant {tenant_id} '
f'and notion workspace {notion_workspace_id}')
return data_source_binding.access_token

View File

@@ -0,0 +1,72 @@
"""Abstract interface for document loader implementations."""
from collections.abc import Iterator
from typing import Optional
from core.rag.extractor.blod.blod import Blob
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_storage import storage
class PdfExtractor(BaseExtractor):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
file_cache_key: Optional[str] = None
):
"""Initialize with file path."""
self._file_path = file_path
self._file_cache_key = file_cache_key
def extract(self) -> list[Document]:
plaintext_file_key = ''
plaintext_file_exists = False
if self._file_cache_key:
try:
text = storage.load(self._file_cache_key).decode('utf-8')
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
pass
documents = list(self.load())
text_list = []
for document in documents:
text_list.append(document.page_content)
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return documents
def load(
self,
) -> Iterator[Document]:
"""Lazy load given path as pages."""
blob = Blob.from_path(self._file_path)
yield from self.parse(blob)
def parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
import pypdfium2
with blob.as_bytes_io() as file_path:
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
try:
for page_number, page in enumerate(pdf_reader):
text_page = page.get_textpage()
content = text_page.get_text_range()
text_page.close()
page.close()
metadata = {"source": blob.source, "page": page_number}
yield Document(page_content=content, metadata=metadata)
finally:
pdf_reader.close()

View File

@@ -0,0 +1,50 @@
"""Abstract interface for document loader implementations."""
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings
from core.rag.models.document import Document
class TextExtractor(BaseExtractor):
"""Load text files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False
):
"""Initialize with file path."""
self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def extract(self) -> list[Document]:
"""Load from file path."""
text = ""
try:
with open(self._file_path, encoding=self._encoding) as f:
text = f.read()
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(self._file_path)
for encoding in detected_encodings:
try:
with open(self._file_path, encoding=encoding.encoding) as f:
text = f.read()
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self._file_path}") from e
except Exception as e:
raise RuntimeError(f"Error loading {self._file_path}") from e
metadata = {"source": self._file_path}
return [Document(page_content=text, metadata=metadata)]

View File

@@ -0,0 +1,61 @@
import logging
import os
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class UnstructuredWordExtractor(BaseExtractor):
"""Loader that uses unstructured to load word documents.
"""
def __init__(
self,
file_path: str,
api_url: str,
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def extract(self) -> list[Document]:
from unstructured.__version__ import __version__ as __unstructured_version__
from unstructured.file_utils.filetype import FileType, detect_filetype
unstructured_version = tuple(
[int(x) for x in __unstructured_version__.split(".")]
)
# check the file extension
try:
import magic # noqa: F401
is_doc = detect_filetype(self._file_path) == FileType.DOC
except ImportError:
_, extension = os.path.splitext(str(self._file_path))
is_doc = extension == ".doc"
if is_doc and unstructured_version < (0, 4, 11):
raise ValueError(
f"You are on unstructured version {__unstructured_version__}. "
"Partitioning .doc files is only supported in unstructured>=0.4.11. "
"Please upgrade the unstructured package and try again."
)
if is_doc:
from unstructured.partition.doc import partition_doc
elements = partition_doc(filename=self._file_path)
else:
from unstructured.partition.docx import partition_docx
elements = partition_docx(filename=self._file_path)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@@ -0,0 +1,51 @@
import base64
import logging
from bs4 import BeautifulSoup
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class UnstructuredEmailExtractor(BaseExtractor):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str,
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def extract(self) -> list[Document]:
from unstructured.partition.email import partition_email
elements = partition_email(filename=self._file_path, api_url=self._api_url)
# noinspection PyBroadException
try:
for element in elements:
element_text = element.text.strip()
padding_needed = 4 - len(element_text) % 4
element_text += '=' * padding_needed
element_decode = base64.b64decode(element_text)
soup = BeautifulSoup(element_decode.decode('utf-8'), 'html.parser')
element.text = soup.get_text()
except Exception:
pass
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@@ -0,0 +1,47 @@
import logging
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class UnstructuredMarkdownExtractor(BaseExtractor):
"""Load md files.
Args:
file_path: Path to the file to load.
remove_hyperlinks: Whether to remove hyperlinks from the text.
remove_images: Whether to remove images from the text.
encoding: File encoding to use. If `None`, the file will be loaded
with the default system encoding.
autodetect_encoding: Whether to try to autodetect the file encoding
if the specified encoding fails.
"""
def __init__(
self,
file_path: str,
api_url: str,
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def extract(self) -> list[Document]:
from unstructured.partition.md import partition_md
elements = partition_md(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@@ -0,0 +1,37 @@
import logging
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class UnstructuredMsgExtractor(BaseExtractor):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def extract(self) -> list[Document]:
from unstructured.partition.msg import partition_msg
elements = partition_msg(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@@ -0,0 +1,44 @@
import logging
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class UnstructuredPPTExtractor(BaseExtractor):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def extract(self) -> list[Document]:
from unstructured.partition.ppt import partition_ppt
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
text_by_page = {}
for element in elements:
page = element.metadata.page_number
text = element.text
if page in text_by_page:
text_by_page[page] += "\n" + text
else:
text_by_page[page] = text
combined_texts = list(text_by_page.values())
documents = []
for combined_text in combined_texts:
text = combined_text.strip()
documents.append(Document(page_content=text))
return documents

View File

@@ -0,0 +1,45 @@
import logging
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class UnstructuredPPTXExtractor(BaseExtractor):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def extract(self) -> list[Document]:
from unstructured.partition.pptx import partition_pptx
elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
text_by_page = {}
for element in elements:
page = element.metadata.page_number
text = element.text
if page in text_by_page:
text_by_page[page] += "\n" + text
else:
text_by_page[page] = text
combined_texts = list(text_by_page.values())
documents = []
for combined_text in combined_texts:
text = combined_text.strip()
documents.append(Document(page_content=text))
return documents

View File

@@ -0,0 +1,37 @@
import logging
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class UnstructuredTextExtractor(BaseExtractor):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def extract(self) -> list[Document]:
from unstructured.partition.text import partition_text
elements = partition_text(filename=self._file_path, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@@ -0,0 +1,37 @@
import logging
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class UnstructuredXmlExtractor(BaseExtractor):
"""Load msg files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def extract(self) -> list[Document]:
from unstructured.partition.xml import partition_xml
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@@ -0,0 +1,62 @@
"""Abstract interface for document loader implementations."""
import os
import tempfile
from urllib.parse import urlparse
import requests
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
class WordExtractor(BaseExtractor):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(self, file_path: str):
"""Initialize with file path."""
self.file_path = file_path
if "~" in self.file_path:
self.file_path = os.path.expanduser(self.file_path)
# If the file is a web path, download it to a temporary file, and use that
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
r = requests.get(self.file_path)
if r.status_code != 200:
raise ValueError(
"Check the url of your file; returned status code %s"
% r.status_code
)
self.web_path = self.file_path
self.temp_file = tempfile.NamedTemporaryFile()
self.temp_file.write(r.content)
self.file_path = self.temp_file.name
elif not os.path.isfile(self.file_path):
raise ValueError("File path %s is not a valid file or url" % self.file_path)
def __del__(self) -> None:
if hasattr(self, "temp_file"):
self.temp_file.close()
def extract(self) -> list[Document]:
"""Load given path as single page."""
import docx2txt
return [
Document(
page_content=docx2txt.process(self.file_path),
metadata={"source": self.file_path},
)
]
@staticmethod
def _is_valid_url(url: str) -> bool:
"""Check if the url is valid."""
parsed = urlparse(url)
return bool(parsed.netloc) and bool(parsed.scheme)

View File

View File

@@ -0,0 +1,8 @@
from enum import Enum
class IndexType(Enum):
PARAGRAPH_INDEX = "text_model"
QA_INDEX = "qa_model"
PARENT_CHILD_INDEX = "parent_child_index"
SUMMARY_INDEX = "summary_index"

View File

@@ -0,0 +1,70 @@
"""Abstract interface for document loader implementations."""
from abc import ABC, abstractmethod
from typing import Optional
from langchain.text_splitter import TextSplitter
from core.model_manager import ModelInstance
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
from models.dataset import Dataset, DatasetProcessRule
class BaseIndexProcessor(ABC):
"""Interface for extract files.
"""
@abstractmethod
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
raise NotImplementedError
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
raise NotImplementedError
@abstractmethod
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict) -> list[Document]:
raise NotImplementedError
def _get_splitter(self, processing_rule: dict,
embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
if processing_rule['mode'] == "custom":
# The user-defined segmentation rule
rules = processing_rule['rules']
segmentation = rules["segmentation"]
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
raise ValueError("Custom segment length should be between 50 and 1000.")
separator = segmentation["separator"]
if separator:
separator = separator.replace('\\n', '\n')
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"],
chunk_overlap=0,
fixed_separator=separator,
separators=["\n\n", "", ".", " ", ""],
embedding_model_instance=embedding_model_instance
)
else:
# Automatic segmentation
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=0,
separators=["\n\n", "", ".", " ", ""],
embedding_model_instance=embedding_model_instance
)
return character_splitter

View File

@@ -0,0 +1,28 @@
"""Abstract interface for document loader implementations."""
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
class IndexProcessorFactory:
"""IndexProcessorInit.
"""
def __init__(self, index_type: str):
self._index_type = index_type
def init_index_processor(self) -> BaseIndexProcessor:
"""Init index processor."""
if not self._index_type:
raise ValueError("Index type must be specified.")
if self._index_type == IndexType.PARAGRAPH_INDEX.value:
return ParagraphIndexProcessor()
elif self._index_type == IndexType.QA_INDEX.value:
return QAIndexProcessor()
else:
raise ValueError(f"Index type {self._index_type} is not supported.")

View File

@@ -0,0 +1,92 @@
"""Paragraph index processor."""
import uuid
from typing import Optional
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from libs import helper
from models.dataset import Dataset
class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
is_automatic=kwargs.get('process_rule_mode') == "automatic")
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
# Split the text documents into nodes.
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
embedding_model_instance=kwargs.get('embedding_model_instance'))
all_documents = []
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
# delete Spliter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
split_documents.append(document_node)
all_documents.extend(split_documents)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset)
vector.create(documents)
if with_keywords:
keyword = Keyword(dataset)
keyword.create(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)
else:
vector.delete()
if with_keywords:
keyword = Keyword(dataset)
if node_ids:
keyword.delete_by_ids(node_ids)
else:
keyword.delete()
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata['score'] = result.score
if result.score > score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs

View File

@@ -0,0 +1,161 @@
"""Paragraph index processor."""
import logging
import re
import threading
import uuid
from typing import Optional
import pandas as pd
from flask import Flask, current_app
from flask_login import current_user
from werkzeug.datastructures import FileStorage
from core.generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from libs import helper
from models.dataset import Dataset
class QAIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
is_automatic=kwargs.get('process_rule_mode') == "automatic")
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
embedding_model_instance=None)
# Split the text documents into nodes.
all_documents = []
all_qa_documents = []
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
# delete Spliter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
split_documents.append(document_node)
all_documents.extend(split_documents)
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i:i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={
'flask_app': current_app._get_current_object(),
'tenant_id': current_user.current_tenant.id,
'document_node': doc,
'all_qa_documents': all_qa_documents,
'document_language': kwargs.get('document_language', 'English')})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
return all_qa_documents
def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
# check file type
if not file.filename.endswith('.csv'):
raise ValueError("Invalid file type. Only CSV files are allowed")
try:
# Skip the first row
df = pd.read_csv(file)
text_docs = []
for index, row in df.iterrows():
data = Document(page_content=row[0], metadata={'answer': row[1]})
text_docs.append(data)
if len(text_docs) == 0:
raise ValueError("The CSV file is empty.")
except Exception as e:
raise ValueError(str(e))
return text_docs
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset)
vector.create(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)
else:
vector.delete()
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict):
# Set search parameters.
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata['score'] = result.score
if result.score > score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return
with flask_app.app_context():
try:
# qa model document
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
document_qa_list = self._format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result['question'])
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
logging.exception(e)
all_qa_documents.extend(format_documents)
def _format_split_text(self, text):
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
matches = re.findall(regex, text, re.UNICODE)
return [
{
"question": q,
"answer": re.sub(r"\n\s*", "\n", a.strip())
}
for q, a in matches if q and a
]

View File

View File

@@ -0,0 +1,16 @@
from typing import Optional
from pydantic import BaseModel, Field
class Document(BaseModel):
"""Class for storing a piece of text and associated metadata."""
page_content: str
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: Optional[dict] = Field(default_factory=dict)