Feat/support multimodal embedding (#29115)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import Any, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
@@ -37,7 +39,9 @@ from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@@ -52,10 +56,12 @@ from core.rag.retrieval.template_prompts import (
|
||||
METADATA_FILTER_USER_PROMPT_2,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
|
||||
from models import UploadFile
|
||||
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
@@ -99,7 +105,8 @@ class DatasetRetrieval:
|
||||
message_id: str,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
inputs: Mapping[str, Any] | None = None,
|
||||
) -> str | None:
|
||||
vision_enabled: bool = False,
|
||||
) -> tuple[str | None, list[File] | None]:
|
||||
"""
|
||||
Retrieve dataset.
|
||||
:param app_id: app_id
|
||||
@@ -118,7 +125,7 @@ class DatasetRetrieval:
|
||||
"""
|
||||
dataset_ids = config.dataset_ids
|
||||
if len(dataset_ids) == 0:
|
||||
return None
|
||||
return None, []
|
||||
retrieve_config = config.retrieve_config
|
||||
|
||||
# check model is support tool calling
|
||||
@@ -136,7 +143,7 @@ class DatasetRetrieval:
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
return None
|
||||
return None, []
|
||||
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
@@ -182,8 +189,8 @@ class DatasetRetrieval:
|
||||
tenant_id,
|
||||
user_id,
|
||||
user_from,
|
||||
available_datasets,
|
||||
query,
|
||||
available_datasets,
|
||||
model_instance,
|
||||
model_config,
|
||||
planning_strategy,
|
||||
@@ -213,6 +220,7 @@ class DatasetRetrieval:
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
document_context_list: list[DocumentContext] = []
|
||||
context_files: list[File] = []
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
@@ -248,6 +256,31 @@ class DatasetRetrieval:
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
if vision_enabled:
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.segment_id == segment.id,
|
||||
)
|
||||
).all()
|
||||
if attachments_with_bindings:
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
attchment_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=segment.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=sign_upload_file(upload_file.id, upload_file.extension),
|
||||
)
|
||||
context_files.append(attchment_info)
|
||||
if show_retrieve_source:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
@@ -288,8 +321,10 @@ class DatasetRetrieval:
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
return str("\n".join([document_context.content for document_context in document_context_list]))
|
||||
return ""
|
||||
return str(
|
||||
"\n".join([document_context.content for document_context in document_context_list])
|
||||
), context_files
|
||||
return "", context_files
|
||||
|
||||
def single_retrieve(
|
||||
self,
|
||||
@@ -297,8 +332,8 @@ class DatasetRetrieval:
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
available_datasets: list,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
@@ -336,7 +371,7 @@ class DatasetRetrieval:
|
||||
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
|
||||
|
||||
self._record_usage(router_usage)
|
||||
|
||||
timer = None
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
@@ -406,10 +441,19 @@ class DatasetRetrieval:
|
||||
weights=retrieval_model_config.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||
self._on_query(query, None, [dataset_id], app_id, user_from, user_id)
|
||||
|
||||
if results:
|
||||
self._on_retrieval_end(results, message_id, timer)
|
||||
thread = threading.Thread(
|
||||
target=self._on_retrieval_end,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"documents": results,
|
||||
"message_id": message_id,
|
||||
"timer": timer,
|
||||
},
|
||||
)
|
||||
thread.start()
|
||||
|
||||
return results
|
||||
return []
|
||||
@@ -421,7 +465,7 @@ class DatasetRetrieval:
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
query: str | None,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_mode: str,
|
||||
@@ -431,10 +475,11 @@ class DatasetRetrieval:
|
||||
message_id: str | None = None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
):
|
||||
if not available_datasets:
|
||||
return []
|
||||
threads = []
|
||||
all_threads = []
|
||||
all_documents: list[Document] = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
index_type_check = all(
|
||||
@@ -467,131 +512,226 @@ class DatasetRetrieval:
|
||||
0
|
||||
].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
document_ids_filter = None
|
||||
if dataset.provider != "external":
|
||||
if metadata_condition and not metadata_filter_document_ids:
|
||||
continue
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
else:
|
||||
continue
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset.id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents,
|
||||
"document_ids_filter": document_ids_filter,
|
||||
"metadata_condition": metadata_condition,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
with measure_time() as timer:
|
||||
if reranking_enable:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
|
||||
if query:
|
||||
query_thread = threading.Thread(
|
||||
target=self._multiple_retrieve_thread,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"available_datasets": available_datasets,
|
||||
"metadata_condition": metadata_condition,
|
||||
"metadata_filter_document_ids": metadata_filter_document_ids,
|
||||
"all_documents": all_documents,
|
||||
"tenant_id": tenant_id,
|
||||
"reranking_enable": reranking_enable,
|
||||
"reranking_mode": reranking_mode,
|
||||
"reranking_model": reranking_model,
|
||||
"weights": weights,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"query": query,
|
||||
"attachment_id": None,
|
||||
},
|
||||
)
|
||||
else:
|
||||
if index_type == "economy":
|
||||
all_documents = self.calculate_keyword_score(query, all_documents, top_k)
|
||||
elif index_type == "high_quality":
|
||||
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
|
||||
else:
|
||||
all_documents = all_documents[:top_k] if top_k else all_documents
|
||||
|
||||
self._on_query(query, dataset_ids, app_id, user_from, user_id)
|
||||
all_threads.append(query_thread)
|
||||
query_thread.start()
|
||||
if attachment_ids:
|
||||
for attachment_id in attachment_ids:
|
||||
attachment_thread = threading.Thread(
|
||||
target=self._multiple_retrieve_thread,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"available_datasets": available_datasets,
|
||||
"metadata_condition": metadata_condition,
|
||||
"metadata_filter_document_ids": metadata_filter_document_ids,
|
||||
"all_documents": all_documents,
|
||||
"tenant_id": tenant_id,
|
||||
"reranking_enable": reranking_enable,
|
||||
"reranking_mode": reranking_mode,
|
||||
"reranking_model": reranking_model,
|
||||
"weights": weights,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"query": None,
|
||||
"attachment_id": attachment_id,
|
||||
},
|
||||
)
|
||||
all_threads.append(attachment_thread)
|
||||
attachment_thread.start()
|
||||
for thread in all_threads:
|
||||
thread.join()
|
||||
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
|
||||
|
||||
if all_documents:
|
||||
self._on_retrieval_end(all_documents, message_id, timer)
|
||||
|
||||
return all_documents
|
||||
|
||||
def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None):
|
||||
"""Handle retrieval end."""
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
for document in dify_documents:
|
||||
if document.metadata is not None:
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
)
|
||||
dataset_document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk_stmt = select(ChildChunk).where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
_ = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# get tracing instance
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
||||
)
|
||||
# add thread to call _on_retrieval_end
|
||||
retrieval_end_thread = threading.Thread(
|
||||
target=self._on_retrieval_end,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"documents": all_documents,
|
||||
"message_id": message_id,
|
||||
"timer": timer,
|
||||
},
|
||||
)
|
||||
retrieval_end_thread.start()
|
||||
retrieval_resource_list = []
|
||||
doc_ids_filter = []
|
||||
for document in all_documents:
|
||||
if document.provider == "dify":
|
||||
doc_id = document.metadata.get("doc_id")
|
||||
if doc_id and doc_id not in doc_ids_filter:
|
||||
doc_ids_filter.append(doc_id)
|
||||
retrieval_resource_list.append(document)
|
||||
elif document.provider == "external":
|
||||
retrieval_resource_list.append(document)
|
||||
return retrieval_resource_list
|
||||
|
||||
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str):
|
||||
def _on_retrieval_end(
|
||||
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None
|
||||
):
|
||||
"""Handle retrieval end."""
|
||||
with flask_app.app_context():
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
segment_ids = []
|
||||
segment_index_node_ids = []
|
||||
with Session(db.engine) as session:
|
||||
for document in dify_documents:
|
||||
if document.metadata is not None:
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
)
|
||||
dataset_document = session.scalar(dataset_document_stmt)
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
segment_id = None
|
||||
if (
|
||||
"doc_type" not in document.metadata
|
||||
or document.metadata.get("doc_type") == DocType.TEXT
|
||||
):
|
||||
child_chunk_stmt = select(ChildChunk).where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
child_chunk = session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
segment_id = child_chunk.segment_id
|
||||
elif (
|
||||
"doc_type" in document.metadata
|
||||
and document.metadata.get("doc_type") == DocType.IMAGE
|
||||
):
|
||||
attachment_info_dict = RetrievalService.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
if segment_id:
|
||||
if segment_id not in segment_ids:
|
||||
segment_ids.append(segment_id)
|
||||
_ = (
|
||||
session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id)
|
||||
.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = None
|
||||
if (
|
||||
"doc_type" not in document.metadata
|
||||
or document.metadata.get("doc_type") == DocType.TEXT
|
||||
):
|
||||
if document.metadata["doc_id"] not in segment_index_node_ids:
|
||||
segment = (
|
||||
session.query(DocumentSegment)
|
||||
.where(DocumentSegment.index_node_id == document.metadata["doc_id"])
|
||||
.first()
|
||||
)
|
||||
if segment:
|
||||
segment_index_node_ids.append(document.metadata["doc_id"])
|
||||
segment_ids.append(segment.id)
|
||||
query = session.query(DocumentSegment).where(
|
||||
DocumentSegment.id == segment.id
|
||||
)
|
||||
elif (
|
||||
"doc_type" in document.metadata
|
||||
and document.metadata.get("doc_type") == DocType.IMAGE
|
||||
):
|
||||
attachment_info_dict = RetrievalService.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
if segment_id not in segment_ids:
|
||||
segment_ids.append(segment_id)
|
||||
query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id)
|
||||
if query:
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.where(
|
||||
DocumentSegment.dataset_id == document.metadata["dataset_id"]
|
||||
)
|
||||
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# get tracing instance
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
||||
)
|
||||
)
|
||||
|
||||
def _on_query(
|
||||
self,
|
||||
query: str | None,
|
||||
attachment_ids: list[str] | None,
|
||||
dataset_ids: list[str],
|
||||
app_id: str,
|
||||
user_from: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Handle query.
|
||||
"""
|
||||
if not query:
|
||||
if not query and not attachment_ids:
|
||||
return
|
||||
dataset_queries = []
|
||||
for dataset_id in dataset_ids:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source="app",
|
||||
source_app_id=app_id,
|
||||
created_by_role=user_from,
|
||||
created_by=user_id,
|
||||
)
|
||||
dataset_queries.append(dataset_query)
|
||||
if dataset_queries:
|
||||
db.session.add_all(dataset_queries)
|
||||
contents = []
|
||||
if query:
|
||||
contents.append({"content_type": QueryType.TEXT_QUERY, "content": query})
|
||||
if attachment_ids:
|
||||
for attachment_id in attachment_ids:
|
||||
contents.append({"content_type": QueryType.IMAGE_QUERY, "content": attachment_id})
|
||||
if contents:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=json.dumps(contents),
|
||||
source="app",
|
||||
source_app_id=app_id,
|
||||
created_by_role=user_from,
|
||||
created_by=user_id,
|
||||
)
|
||||
dataset_queries.append(dataset_query)
|
||||
if dataset_queries:
|
||||
db.session.add_all(dataset_queries)
|
||||
db.session.commit()
|
||||
|
||||
def _retriever(
|
||||
@@ -603,6 +743,7 @@ class DatasetRetrieval:
|
||||
all_documents: list,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
@@ -611,7 +752,7 @@ class DatasetRetrieval:
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
if dataset.provider == "external":
|
||||
if dataset.provider == "external" and query:
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
@@ -663,6 +804,7 @@ class DatasetRetrieval:
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
@@ -1222,3 +1364,86 @@ class DatasetRetrieval:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
return full_text, usage
|
||||
|
||||
def _multiple_retrieve_thread(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
available_datasets: list,
|
||||
metadata_condition: MetadataCondition | None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None,
|
||||
all_documents: list[Document],
|
||||
tenant_id: str,
|
||||
reranking_enable: bool,
|
||||
reranking_mode: str,
|
||||
reranking_model: dict | None,
|
||||
weights: dict[str, Any] | None,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
query: str | None,
|
||||
attachment_id: str | None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
threads = []
|
||||
all_documents_item: list[Document] = []
|
||||
index_type = None
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
document_ids_filter = None
|
||||
if dataset.provider != "external":
|
||||
if metadata_condition and not metadata_filter_document_ids:
|
||||
continue
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
else:
|
||||
continue
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": flask_app,
|
||||
"dataset_id": dataset.id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents_item,
|
||||
"document_ids_filter": document_ids_filter,
|
||||
"metadata_condition": metadata_condition,
|
||||
"attachment_ids": [attachment_id] if attachment_id else None,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if reranking_enable:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
if query:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
if attachment_id:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
query=attachment_id,
|
||||
)
|
||||
else:
|
||||
if index_type == IndexTechniqueType.ECONOMY:
|
||||
if not query:
|
||||
all_documents_item = []
|
||||
else:
|
||||
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
|
||||
elif index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
|
||||
else:
|
||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||
if all_documents_item:
|
||||
all_documents.extend(all_documents_item)
|
||||
|
||||
Reference in New Issue
Block a user