Feat/support multimodal embedding (#29115)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -7,7 +7,7 @@ import time
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
@@ -19,9 +19,10 @@ from configs import dify_config
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from events.dataset_event import dataset_was_deleted
|
||||
@@ -46,6 +47,7 @@ from models.dataset import (
|
||||
DocumentSegment,
|
||||
ExternalKnowledgeBindings,
|
||||
Pipeline,
|
||||
SegmentAttachmentBinding,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.provider_ids import ModelProviderID
|
||||
@@ -363,6 +365,27 @@ class DatasetService:
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
|
||||
@staticmethod
|
||||
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=model,
|
||||
)
|
||||
text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance)
|
||||
model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
if model_schema.features and ModelFeature.VISION in model_schema.features:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except LLMBadRequestError:
|
||||
raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.")
|
||||
|
||||
@staticmethod
|
||||
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
|
||||
try:
|
||||
@@ -402,13 +425,13 @@ class DatasetService:
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
# check if dataset name is exists
|
||||
|
||||
if DatasetService._has_dataset_same_name(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
name=data.get("name", dataset.name),
|
||||
):
|
||||
raise ValueError("Dataset name already exists")
|
||||
if data.get("name") and data.get("name") != dataset.name:
|
||||
if DatasetService._has_dataset_same_name(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
name=data.get("name", dataset.name),
|
||||
):
|
||||
raise ValueError("Dataset name already exists")
|
||||
|
||||
# Verify user has permission to update this dataset
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
@@ -844,6 +867,12 @@ class DatasetService:
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_configuration.embedding_model or "",
|
||||
)
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_user.current_tenant_id,
|
||||
knowledge_configuration.embedding_model_provider,
|
||||
knowledge_configuration.embedding_model,
|
||||
)
|
||||
dataset.is_multimodal = is_multimodal
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
@@ -880,6 +909,12 @@ class DatasetService:
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_user.current_tenant_id,
|
||||
knowledge_configuration.embedding_model_provider,
|
||||
knowledge_configuration.embedding_model,
|
||||
)
|
||||
dataset.is_multimodal = is_multimodal
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
dataset.indexing_technique = knowledge_configuration.indexing_technique
|
||||
except LLMBadRequestError:
|
||||
@@ -937,6 +972,12 @@ class DatasetService:
|
||||
)
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_user.current_tenant_id,
|
||||
knowledge_configuration.embedding_model_provider,
|
||||
knowledge_configuration.embedding_model,
|
||||
)
|
||||
dataset.is_multimodal = is_multimodal
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
@@ -2305,6 +2346,7 @@ class DocumentService:
|
||||
embedding_model_provider=knowledge_config.embedding_model_provider,
|
||||
collection_binding_id=dataset_collection_binding_id,
|
||||
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
|
||||
is_multimodal=knowledge_config.is_multimodal,
|
||||
)
|
||||
|
||||
db.session.add(dataset)
|
||||
@@ -2685,6 +2727,13 @@ class SegmentService:
|
||||
if "content" not in args or not args["content"] or not args["content"].strip():
|
||||
raise ValueError("Content is empty")
|
||||
|
||||
if args.get("attachment_ids"):
|
||||
if not isinstance(args["attachment_ids"], list):
|
||||
raise ValueError("Attachment IDs is invalid")
|
||||
single_chunk_attachment_limit = dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT
|
||||
if len(args["attachment_ids"]) > single_chunk_attachment_limit:
|
||||
raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}")
|
||||
|
||||
@classmethod
|
||||
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
|
||||
assert isinstance(current_user, Account)
|
||||
@@ -2731,11 +2780,23 @@ class SegmentService:
|
||||
segment_document.word_count += len(args["answer"])
|
||||
segment_document.answer = args["answer"]
|
||||
|
||||
db.session.add(segment_document)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
document.word_count += segment_document.word_count
|
||||
db.session.add(document)
|
||||
db.session.add(segment_document)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
document.word_count += segment_document.word_count
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
if args["attachment_ids"]:
|
||||
for attachment_id in args["attachment_ids"]:
|
||||
binding = SegmentAttachmentBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
document_id=document.id,
|
||||
segment_id=segment_document.id,
|
||||
attachment_id=attachment_id,
|
||||
)
|
||||
db.session.add(binding)
|
||||
db.session.commit()
|
||||
|
||||
# save vector index
|
||||
@@ -2899,7 +2960,7 @@ class SegmentService:
|
||||
document.word_count = max(0, document.word_count + word_count_change)
|
||||
db.session.add(document)
|
||||
# update segment index task
|
||||
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
# regenerate child chunks
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
@@ -2926,12 +2987,11 @@ class SegmentService:
|
||||
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
raise ValueError("No processing rule found.")
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
|
||||
if processing_rule:
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
|
||||
if args.enabled or keyword_changed:
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
@@ -2976,7 +3036,7 @@ class SegmentService:
|
||||
db.session.add(document)
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# check embedding model setting
|
||||
@@ -3002,15 +3062,15 @@ class SegmentService:
|
||||
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
raise ValueError("No processing rule found.")
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
|
||||
if processing_rule:
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
|
||||
# update multimodel vector index
|
||||
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
|
||||
except Exception as e:
|
||||
logger.exception("update segment index failed")
|
||||
segment.enabled = False
|
||||
@@ -3048,7 +3108,9 @@ class SegmentService:
|
||||
)
|
||||
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
|
||||
|
||||
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids)
|
||||
delete_segment_from_index_task.delay(
|
||||
[segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids
|
||||
)
|
||||
|
||||
db.session.delete(segment)
|
||||
# update document word count
|
||||
@@ -3097,7 +3159,9 @@ class SegmentService:
|
||||
|
||||
# Start async cleanup with both parent and child node IDs
|
||||
if index_node_ids or child_node_ids:
|
||||
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
|
||||
delete_segment_from_index_task.delay(
|
||||
index_node_ids, dataset.id, document.id, segment_db_ids, child_node_ids
|
||||
)
|
||||
|
||||
if document.word_count is None:
|
||||
document.word_count = 0
|
||||
|
||||
Reference in New Issue
Block a user