Feat/support multimodal embedding (#29115)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jyong
2025-12-09 14:41:46 +08:00
committed by GitHub
parent 77cf8f6c27
commit 9affc546c6
78 changed files with 3230 additions and 713 deletions

View File

@@ -8,6 +8,7 @@ from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import and_, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
DatasetEntity,
@@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.file import File, FileTransferMethod, FileType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
@@ -37,7 +39,9 @@ from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@@ -52,10 +56,12 @@ from core.rag.retrieval.template_prompts import (
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
@@ -99,7 +105,8 @@ class DatasetRetrieval:
message_id: str,
memory: TokenBufferMemory | None = None,
inputs: Mapping[str, Any] | None = None,
) -> str | None:
vision_enabled: bool = False,
) -> tuple[str | None, list[File] | None]:
"""
Retrieve dataset.
:param app_id: app_id
@@ -118,7 +125,7 @@ class DatasetRetrieval:
"""
dataset_ids = config.dataset_ids
if len(dataset_ids) == 0:
return None
return None, []
retrieve_config = config.retrieve_config
# check model is support tool calling
@@ -136,7 +143,7 @@ class DatasetRetrieval:
)
if not model_schema:
return None
return None, []
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
@@ -182,8 +189,8 @@ class DatasetRetrieval:
tenant_id,
user_id,
user_from,
available_datasets,
query,
available_datasets,
model_instance,
model_config,
planning_strategy,
@@ -213,6 +220,7 @@ class DatasetRetrieval:
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
document_context_list: list[DocumentContext] = []
context_files: list[File] = []
retrieval_resource_list: list[RetrievalSourceMetadata] = []
# deal with external documents
for item in external_documents:
@@ -248,6 +256,31 @@ class DatasetRetrieval:
score=record.score,
)
)
if vision_enabled:
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == segment.id,
)
).all()
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attchment_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=segment.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
)
context_files.append(attchment_info)
if show_retrieve_source:
for record in records:
segment = record.segment
@@ -288,8 +321,10 @@ class DatasetRetrieval:
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
return str("\n".join([document_context.content for document_context in document_context_list]))
return ""
return str(
"\n".join([document_context.content for document_context in document_context_list])
), context_files
return "", context_files
def single_retrieve(
self,
@@ -297,8 +332,8 @@ class DatasetRetrieval:
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
available_datasets: list,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
@@ -336,7 +371,7 @@ class DatasetRetrieval:
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
self._record_usage(router_usage)
timer = None
if dataset_id:
# get retrieval model config
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
@@ -406,10 +441,19 @@ class DatasetRetrieval:
weights=retrieval_model_config.get("weights", None),
document_ids_filter=document_ids_filter,
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
self._on_query(query, None, [dataset_id], app_id, user_from, user_id)
if results:
self._on_retrieval_end(results, message_id, timer)
thread = threading.Thread(
target=self._on_retrieval_end,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"documents": results,
"message_id": message_id,
"timer": timer,
},
)
thread.start()
return results
return []
@@ -421,7 +465,7 @@ class DatasetRetrieval:
user_id: str,
user_from: str,
available_datasets: list,
query: str,
query: str | None,
top_k: int,
score_threshold: float,
reranking_mode: str,
@@ -431,10 +475,11 @@ class DatasetRetrieval:
message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None,
metadata_condition: MetadataCondition | None = None,
attachment_ids: list[str] | None = None,
):
if not available_datasets:
return []
threads = []
all_threads = []
all_documents: list[Document] = []
dataset_ids = [dataset.id for dataset in available_datasets]
index_type_check = all(
@@ -467,131 +512,226 @@ class DatasetRetrieval:
0
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
with measure_time() as timer:
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
all_documents = data_post_processor.invoke(
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
if query:
query_thread = threading.Thread(
target=self._multiple_retrieve_thread,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"available_datasets": available_datasets,
"metadata_condition": metadata_condition,
"metadata_filter_document_ids": metadata_filter_document_ids,
"all_documents": all_documents,
"tenant_id": tenant_id,
"reranking_enable": reranking_enable,
"reranking_mode": reranking_mode,
"reranking_model": reranking_model,
"weights": weights,
"top_k": top_k,
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
},
)
else:
if index_type == "economy":
all_documents = self.calculate_keyword_score(query, all_documents, top_k)
elif index_type == "high_quality":
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
else:
all_documents = all_documents[:top_k] if top_k else all_documents
self._on_query(query, dataset_ids, app_id, user_from, user_id)
all_threads.append(query_thread)
query_thread.start()
if attachment_ids:
for attachment_id in attachment_ids:
attachment_thread = threading.Thread(
target=self._multiple_retrieve_thread,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"available_datasets": available_datasets,
"metadata_condition": metadata_condition,
"metadata_filter_document_ids": metadata_filter_document_ids,
"all_documents": all_documents,
"tenant_id": tenant_id,
"reranking_enable": reranking_enable,
"reranking_mode": reranking_mode,
"reranking_model": reranking_model,
"weights": weights,
"top_k": top_k,
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
},
)
all_threads.append(attachment_thread)
attachment_thread.start()
for thread in all_threads:
thread.join()
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
if all_documents:
self._on_retrieval_end(all_documents, message_id, timer)
return all_documents
def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None):
"""Handle retrieval end."""
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
if document.metadata is not None:
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
)
dataset_document = db.session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
_ = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
else:
query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
db.session.commit()
# get tracing instance
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
)
# add thread to call _on_retrieval_end
retrieval_end_thread = threading.Thread(
target=self._on_retrieval_end,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"documents": all_documents,
"message_id": message_id,
"timer": timer,
},
)
retrieval_end_thread.start()
retrieval_resource_list = []
doc_ids_filter = []
for document in all_documents:
if document.provider == "dify":
doc_id = document.metadata.get("doc_id")
if doc_id and doc_id not in doc_ids_filter:
doc_ids_filter.append(doc_id)
retrieval_resource_list.append(document)
elif document.provider == "external":
retrieval_resource_list.append(document)
return retrieval_resource_list
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str):
def _on_retrieval_end(
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None
):
"""Handle retrieval end."""
with flask_app.app_context():
dify_documents = [document for document in documents if document.provider == "dify"]
segment_ids = []
segment_index_node_ids = []
with Session(db.engine) as session:
for document in dify_documents:
if document.metadata is not None:
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
)
dataset_document = session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
segment_id = None
if (
"doc_type" not in document.metadata
or document.metadata.get("doc_type") == DocType.TEXT
):
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = session.scalar(child_chunk_stmt)
if child_chunk:
segment_id = child_chunk.segment_id
elif (
"doc_type" in document.metadata
and document.metadata.get("doc_type") == DocType.IMAGE
):
attachment_info_dict = RetrievalService.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
segment_id = attachment_info_dict["segment_id"]
if segment_id:
if segment_id not in segment_ids:
segment_ids.append(segment_id)
_ = (
session.query(DocumentSegment)
.where(DocumentSegment.id == segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
else:
query = None
if (
"doc_type" not in document.metadata
or document.metadata.get("doc_type") == DocType.TEXT
):
if document.metadata["doc_id"] not in segment_index_node_ids:
segment = (
session.query(DocumentSegment)
.where(DocumentSegment.index_node_id == document.metadata["doc_id"])
.first()
)
if segment:
segment_index_node_ids.append(document.metadata["doc_id"])
segment_ids.append(segment.id)
query = session.query(DocumentSegment).where(
DocumentSegment.id == segment.id
)
elif (
"doc_type" in document.metadata
and document.metadata.get("doc_type") == DocType.IMAGE
):
attachment_info_dict = RetrievalService.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
segment_id = attachment_info_dict["segment_id"]
if segment_id not in segment_ids:
segment_ids.append(segment_id)
query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id)
if query:
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.where(
DocumentSegment.dataset_id == document.metadata["dataset_id"]
)
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
db.session.commit()
# get tracing instance
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
)
)
def _on_query(
self,
query: str | None,
attachment_ids: list[str] | None,
dataset_ids: list[str],
app_id: str,
user_from: str,
user_id: str,
):
"""
Handle query.
"""
if not query:
if not query and not attachment_ids:
return
dataset_queries = []
for dataset_id in dataset_ids:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source="app",
source_app_id=app_id,
created_by_role=user_from,
created_by=user_id,
)
dataset_queries.append(dataset_query)
if dataset_queries:
db.session.add_all(dataset_queries)
contents = []
if query:
contents.append({"content_type": QueryType.TEXT_QUERY, "content": query})
if attachment_ids:
for attachment_id in attachment_ids:
contents.append({"content_type": QueryType.IMAGE_QUERY, "content": attachment_id})
if contents:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=json.dumps(contents),
source="app",
source_app_id=app_id,
created_by_role=user_from,
created_by=user_id,
)
dataset_queries.append(dataset_query)
if dataset_queries:
db.session.add_all(dataset_queries)
db.session.commit()
def _retriever(
@@ -603,6 +743,7 @@ class DatasetRetrieval:
all_documents: list,
document_ids_filter: list[str] | None = None,
metadata_condition: MetadataCondition | None = None,
attachment_ids: list[str] | None = None,
):
with flask_app.app_context():
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
@@ -611,7 +752,7 @@ class DatasetRetrieval:
if not dataset:
return []
if dataset.provider == "external":
if dataset.provider == "external" and query:
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
@@ -663,6 +804,7 @@ class DatasetRetrieval:
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
attachment_ids=attachment_ids,
)
all_documents.extend(documents)
@@ -1222,3 +1364,86 @@ class DatasetRetrieval:
usage = LLMUsage.empty_usage()
return full_text, usage
def _multiple_retrieve_thread(
self,
flask_app: Flask,
available_datasets: list,
metadata_condition: MetadataCondition | None,
metadata_filter_document_ids: dict[str, list[str]] | None,
all_documents: list[Document],
tenant_id: str,
reranking_enable: bool,
reranking_mode: str,
reranking_model: dict | None,
weights: dict[str, Any] | None,
top_k: int,
score_threshold: float,
query: str | None,
attachment_id: str | None,
):
with flask_app.app_context():
threads = []
all_documents_item: list[Document] = []
index_type = None
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": flask_app,
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents_item,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
"attachment_ids": [attachment_id] if attachment_id else None,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)