Feat/support multimodal embedding (#29115)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jyong
2025-12-09 14:41:46 +08:00
committed by GitHub
parent 77cf8f6c27
commit 9affc546c6
78 changed files with 3230 additions and 713 deletions

View File

@@ -7,7 +7,7 @@ import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal
from typing import Any, Literal, cast
import sqlalchemy as sa
from redis.exceptions import LockNotOwnedError
@@ -19,9 +19,10 @@ from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.helper.name_generator import generate_incremental_name
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from enums.cloud_plan import CloudPlan
from events.dataset_event import dataset_was_deleted
@@ -46,6 +47,7 @@ from models.dataset import (
DocumentSegment,
ExternalKnowledgeBindings,
Pipeline,
SegmentAttachmentBinding,
)
from models.model import UploadFile
from models.provider_ids import ModelProviderID
@@ -363,6 +365,27 @@ class DatasetService:
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
@staticmethod
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str):
try:
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=model,
)
text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance)
model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials)
if not model_schema:
raise ValueError("Model schema not found")
if model_schema.features and ModelFeature.VISION in model_schema.features:
return True
else:
return False
except LLMBadRequestError:
raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.")
@staticmethod
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
try:
@@ -402,13 +425,13 @@ class DatasetService:
if not dataset:
raise ValueError("Dataset not found")
# check if dataset name is exists
if DatasetService._has_dataset_same_name(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
name=data.get("name", dataset.name),
):
raise ValueError("Dataset name already exists")
if data.get("name") and data.get("name") != dataset.name:
if DatasetService._has_dataset_same_name(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
name=data.get("name", dataset.name),
):
raise ValueError("Dataset name already exists")
# Verify user has permission to update this dataset
DatasetService.check_dataset_permission(dataset, user)
@@ -844,6 +867,12 @@ class DatasetService:
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_configuration.embedding_model or "",
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
@@ -880,6 +909,12 @@ class DatasetService:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
dataset.collection_binding_id = dataset_collection_binding.id
dataset.indexing_technique = knowledge_configuration.indexing_technique
except LLMBadRequestError:
@@ -937,6 +972,12 @@ class DatasetService:
)
)
dataset.collection_binding_id = dataset_collection_binding.id
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
@@ -2305,6 +2346,7 @@ class DocumentService:
embedding_model_provider=knowledge_config.embedding_model_provider,
collection_binding_id=dataset_collection_binding_id,
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
is_multimodal=knowledge_config.is_multimodal,
)
db.session.add(dataset)
@@ -2685,6 +2727,13 @@ class SegmentService:
if "content" not in args or not args["content"] or not args["content"].strip():
raise ValueError("Content is empty")
if args.get("attachment_ids"):
if not isinstance(args["attachment_ids"], list):
raise ValueError("Attachment IDs is invalid")
single_chunk_attachment_limit = dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT
if len(args["attachment_ids"]) > single_chunk_attachment_limit:
raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}")
@classmethod
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
@@ -2731,11 +2780,23 @@ class SegmentService:
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.commit()
if args["attachment_ids"]:
for attachment_id in args["attachment_ids"]:
binding = SegmentAttachmentBinding(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
segment_id=segment_document.id,
attachment_id=attachment_id,
)
db.session.add(binding)
db.session.commit()
# save vector index
@@ -2899,7 +2960,7 @@ class SegmentService:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
# update segment index task
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# regenerate child chunks
# get embedding model instance
if dataset.indexing_technique == "high_quality":
@@ -2926,12 +2987,11 @@ class SegmentService:
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
if processing_rule:
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
if args.enabled or keyword_changed:
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
@@ -2976,7 +3036,7 @@ class SegmentService:
db.session.add(document)
db.session.add(segment)
db.session.commit()
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
@@ -3002,15 +3062,15 @@ class SegmentService:
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
if processing_rule:
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
# update multimodel vector index
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
except Exception as e:
logger.exception("update segment index failed")
segment.enabled = False
@@ -3048,7 +3108,9 @@ class SegmentService:
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids)
delete_segment_from_index_task.delay(
[segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids
)
db.session.delete(segment)
# update document word count
@@ -3097,7 +3159,9 @@ class SegmentService:
# Start async cleanup with both parent and child node IDs
if index_node_ids or child_node_ids:
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
delete_segment_from_index_task.delay(
index_node_ids, dataset.id, document.id, segment_db_ids, child_node_ids
)
if document.word_count is None:
document.word_count = 0