Feat/add retriever rerank (#1560)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2023-11-17 22:13:37 +08:00
committed by GitHub
parent a4f37220a0
commit 4588831bff
44 changed files with 1899 additions and 164 deletions

View File

@@ -1,5 +1,6 @@
import json
from typing import Type, Optional
import threading
from typing import Type, Optional, List
from flask import current_app
from langchain.tools import BaseTool
@@ -14,6 +15,18 @@ from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitE
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment, Document
from services.retrieval_service import RetrievalService
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class DatasetRetrieverToolInput(BaseModel):
@@ -56,7 +69,9 @@ class DatasetRetrieverTool(BaseTool):
).first()
if not dataset:
return f'[{self.name} failed to find dataset with id {self.dataset_id}.]'
return ''
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
@@ -83,28 +98,62 @@ class DatasetRetrieverTool(BaseTool):
return ''
embeddings = CacheEmbedding(embedding_model)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = []
threads = []
if self.top_k > 0:
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': self.top_k,
'score_threshold': self.score_threshold,
'filter': {
'group_id': [dataset.id]
}
}
)
# retrieval source with semantic
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'top_k': self.top_k,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enable'] else None,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval_model source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enable'] else None,
'top_k': self.top_k,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
# hybrid search: rerank after all documents have been searched
if retrieval_model['search_method'] == 'hybrid_search':
hybrid_rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
model_name=retrieval_model['reranking_model']['reranking_model_name']
)
documents = hybrid_rerank.rerank(query, documents,
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
self.top_k)
else:
documents = []
hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task)
hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":