Feat/dify rag (#2528)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2024-02-22 23:31:57 +08:00
committed by GitHub
parent 97fe817186
commit 6c4e6bf1d6
119 changed files with 3181 additions and 5892 deletions

View File

@@ -11,10 +11,11 @@ from flask_login import current_user
from sqlalchemy import func
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.index.index import IndexBuilder
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.models.document import Document as RAGDocument
from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted
from extensions.ext_database import db
@@ -402,7 +403,7 @@ class DocumentService:
@staticmethod
def delete_document(document):
# trigger document_was_deleted signal
document_was_deleted.send(document.id, dataset_id=document.dataset_id)
document_was_deleted.send(document.id, dataset_id=document.dataset_id, doc_form=document.doc_form)
db.session.delete(document)
db.session.commit()
@@ -1060,7 +1061,7 @@ class SegmentService:
# save vector index
try:
VectorService.create_segment_vector(args['keywords'], segment_document, dataset)
VectorService.create_segments_vector([args['keywords']], [segment_document], dataset)
except Exception as e:
logging.exception("create segment index failed")
segment_document.enabled = False
@@ -1087,6 +1088,7 @@ class SegmentService:
).scalar()
pre_segment_data_list = []
segment_data_list = []
keywords_list = []
for segment_item in segments:
content = segment_item['content']
doc_id = str(uuid.uuid4())
@@ -1119,15 +1121,13 @@ class SegmentService:
segment_document.answer = segment_item['answer']
db.session.add(segment_document)
segment_data_list.append(segment_document)
pre_segment_data = {
'segment': segment_document,
'keywords': segment_item['keywords']
}
pre_segment_data_list.append(pre_segment_data)
pre_segment_data_list.append(segment_document)
keywords_list.append(segment_item['keywords'])
try:
# save vector index
VectorService.multi_create_segment_vector(pre_segment_data_list, dataset)
VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset)
except Exception as e:
logging.exception("create segment index failed")
for segment_document in segment_data_list:
@@ -1157,11 +1157,18 @@ class SegmentService:
db.session.commit()
# update segment index task
if args['keywords']:
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from keyword index
kw_index.delete_by_ids([segment.index_node_id])
# save keyword index
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
keyword = Keyword(dataset)
keyword.delete_by_ids([segment.index_node_id])
document = RAGDocument(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
keyword.add_texts([document], keywords_list=[args['keywords']])
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0

View File

@@ -9,8 +9,8 @@ from flask_login import current_user
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
from core.data_loader.file_extractor import FileExtractor
from core.file.upload_file_parser import UploadFileParser
from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Account
@@ -32,7 +32,8 @@ class FileService:
def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
extension = file.filename.split('.')[-1]
etl_type = current_app.config['ETL_TYPE']
allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \
else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
if extension.lower() not in allowed_extensions:
raise UnsupportedFileTypeError()
elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
@@ -136,7 +137,7 @@ class FileService:
if extension.lower() not in allowed_extensions:
raise UnsupportedFileTypeError()
text = FileExtractor.load(upload_file, return_text=True)
text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return text
@@ -164,7 +165,7 @@ class FileService:
return generator, upload_file.mime_type
@staticmethod
def get_public_image_preview(file_id: str) -> str:
def get_public_image_preview(file_id: str) -> tuple[Generator, str]:
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()

View File

@@ -1,21 +1,18 @@
import logging
import threading
import time
import numpy as np
from flask import current_app
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from sklearn.manifold import TSNE
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rerank.rerank import RerankRunner
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
from extensions.ext_database import db
from models.account import Account
from models.dataset import Dataset, DatasetQuery, DocumentSegment
from services.retrieval_service import RetrievalService
default_retrieval_model = {
'search_method': 'semantic_search',
@@ -28,6 +25,7 @@ default_retrieval_model = {
'score_threshold_enabled': False
}
class HitTestingService:
@classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
@@ -57,61 +55,15 @@ class HitTestingService:
embeddings = CacheEmbedding(embedding_model)
all_documents = []
threads = []
# retrieval_model source with semantic
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': str(dataset.id),
'query': query,
'top_k': retrieval_model['top_k'],
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
'all_documents': all_documents,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': str(dataset.id),
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
'top_k': retrieval_model['top_k'],
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
'all_documents': all_documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
if retrieval_model['search_method'] == 'hybrid_search':
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=retrieval_model['reranking_model']['reranking_provider_name'],
model_type=ModelType.RERANK,
model=retrieval_model['reranking_model']['reranking_model_name']
)
rerank_runner = RerankRunner(rerank_model_instance)
all_documents = rerank_runner.run(
query=query,
documents=all_documents,
score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
top_n=retrieval_model['top_k'],
user=f"account-{account.id}"
)
all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model['top_k'],
score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model['reranking_model']
if retrieval_model['reranking_enable'] else None
)
end = time.perf_counter()
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
@@ -203,4 +155,3 @@ class HitTestingService:
if not query or len(query) > 250:
raise ValueError('Query is required and cannot exceed 250 characters')

View File

@@ -1,119 +0,0 @@
from typing import Optional
from flask import Flask, current_app
from langchain.embeddings.base import Embeddings
from core.index.vector_index.vector_index import VectorIndex
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rerank.rerank import RerankRunner
from extensions.ext_database import db
from models.dataset import Dataset
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
class RetrievalService:
@classmethod
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': top_k,
'score_threshold': score_threshold,
'filter': {
'group_id': [dataset.id]
}
}
)
if documents:
if reranking_model and search_method == 'semantic_search':
try:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return
rerank_runner = RerankRunner(rerank_model_instance)
all_documents.extend(rerank_runner.run(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents)
))
else:
all_documents.extend(documents)
@classmethod
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search_by_full_text_index(
query,
search_type='similarity_score_threshold',
top_k=top_k
)
if documents:
if reranking_model and search_method == 'full_text_search':
try:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return
rerank_runner = RerankRunner(rerank_model_instance)
all_documents.extend(rerank_runner.run(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents)
))
else:
all_documents.extend(documents)

View File

@@ -1,44 +1,18 @@
from typing import Optional
from langchain.schema import Document
from core.index.index import IndexBuilder
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from models.dataset import Dataset, DocumentSegment
class VectorService:
@classmethod
def create_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts([document], duplicate_check=True)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
if keywords and len(keywords) > 0:
index.create_segment_keywords(segment.index_node_id, keywords)
else:
index.add_texts([document])
@classmethod
def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset):
def create_segments_vector(cls, keywords_list: Optional[list[list[str]]],
segments: list[DocumentSegment], dataset: Dataset):
documents = []
for pre_segment_data in pre_segment_data_list:
segment = pre_segment_data['segment']
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
@@ -49,30 +23,26 @@ class VectorService:
}
)
documents.append(document)
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts(documents, duplicate_check=True)
if dataset.indexing_technique == 'high_quality':
# save vector index
vector = Vector(
dataset=dataset
)
vector.add_texts(documents, duplicate_check=True)
# save keyword index
keyword_index = IndexBuilder.get_index(dataset, 'economy')
if keyword_index:
keyword_index.multi_create_segment_keywords(pre_segment_data_list)
keyword = Keyword(dataset)
if keywords_list and len(keywords_list) > 0:
keyword.add_texts(documents, keyword_list=keywords_list)
else:
keyword.add_texts(documents)
@classmethod
def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
# update segment index task
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
vector_index.delete_by_ids([segment.index_node_id])
# delete from keyword index
kw_index.delete_by_ids([segment.index_node_id])
# add new index
# format new index
document = Document(
page_content=segment.content,
metadata={
@@ -82,13 +52,20 @@ class VectorService:
"dataset_id": segment.dataset_id,
}
)
if dataset.indexing_technique == 'high_quality':
# update vector index
vector = Vector(
dataset=dataset
)
vector.delete_by_ids([segment.index_node_id])
vector.add_texts([document], duplicate_check=True)
# save vector index
if vector_index:
vector_index.add_texts([document], duplicate_check=True)
# update keyword index
keyword = Keyword(dataset)
keyword.delete_by_ids([segment.index_node_id])
# save keyword index
if keywords and len(keywords) > 0:
kw_index.create_segment_keywords(segment.index_node_id, keywords)
keyword.add_texts([document], keywords_list=[keywords])
else:
kw_index.add_texts([document])
keyword.add_texts([document])