Feat/support multimodal embedding (#29115)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jyong
2025-12-09 14:41:46 +08:00
committed by GitHub
parent 77cf8f6c27
commit 9affc546c6
78 changed files with 3230 additions and 713 deletions

View File

@@ -0,0 +1,6 @@
from enum import StrEnum
class DocType(StrEnum):
TEXT = "text"
IMAGE = "image"

View File

@@ -1,7 +1,12 @@
from enum import StrEnum
class IndexType(StrEnum):
class IndexStructureType(StrEnum):
PARAGRAPH_INDEX = "text_model"
QA_INDEX = "qa_model"
PARENT_CHILD_INDEX = "hierarchical_model"
class IndexTechniqueType(StrEnum):
ECONOMY = "economy"
HIGH_QUALITY = "high_quality"

View File

@@ -0,0 +1,6 @@
from enum import StrEnum
class QueryType(StrEnum):
TEXT_QUERY = "text_query"
IMAGE_QUERY = "image_query"

View File

@@ -1,20 +1,34 @@
"""Abstract interface for document loader implementations."""
import cgi
import logging
import mimetypes
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Optional
from urllib.parse import unquote, urlparse
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import AttachmentDocument, Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import Account, ToolFile
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
if TYPE_CHECKING:
from core.model_manager import ModelInstance
@@ -28,11 +42,18 @@ class BaseIndexProcessor(ABC):
raise NotImplementedError
@abstractmethod
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
raise NotImplementedError
@abstractmethod
@@ -96,3 +117,178 @@ class BaseIndexProcessor(ABC):
)
return character_splitter # type: ignore
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
"""
Get the content files from the document.
"""
multi_model_documents: list[AttachmentDocument] = []
text = document.page_content
images = self._extract_markdown_images(text)
if not images:
return multi_model_documents
upload_file_id_list = []
for image in images:
# Collect all upload_file_ids including duplicates to preserve occurrence count
# For data before v0.10.0
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For data after v0.10.0
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
# Match URL including any query parameters up to common URL boundaries (space, parenthesis, quotes)
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
match = re.search(pattern, image)
if match:
if current_user:
tool_file_id = match.group(1)
upload_file_id = self._download_tool_file(tool_file_id, current_user)
if upload_file_id:
upload_file_id_list.append(upload_file_id)
continue
if current_user:
upload_file_id = self._download_image(image.split(" ")[0], current_user)
if upload_file_id:
upload_file_id_list.append(upload_file_id)
if not upload_file_id_list:
return multi_model_documents
# Get unique IDs for database query
unique_upload_file_ids = list(set(upload_file_id_list))
upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all()
# Create a mapping from ID to UploadFile for quick lookup
upload_file_map = {upload_file.id: upload_file for upload_file in upload_files}
# Create a Document for each occurrence (including duplicates)
for upload_file_id in upload_file_id_list:
upload_file = upload_file_map.get(upload_file_id)
if upload_file:
multi_model_documents.append(
AttachmentDocument(
page_content=upload_file.name,
metadata={
"doc_id": upload_file.id,
"doc_hash": "",
"document_id": document.metadata.get("document_id"),
"dataset_id": document.metadata.get("dataset_id"),
"doc_type": DocType.IMAGE,
},
)
)
return multi_model_documents
def _extract_markdown_images(self, text: str) -> list[str]:
"""
Extract the markdown images from the text.
"""
pattern = r"!\[.*?\]\((.*?)\)"
return re.findall(pattern, text)
def _download_image(self, image_url: str, current_user: Account) -> str | None:
"""
Download the image from the URL.
Image size must not exceed 2MB.
"""
from services.file_service import FileService
MAX_IMAGE_SIZE = dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
DOWNLOAD_TIMEOUT = dify_config.ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT
try:
# Download with timeout
response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT)
response.raise_for_status()
# Check Content-Length header if available
content_length = response.headers.get("Content-Length")
if content_length and int(content_length) > MAX_IMAGE_SIZE:
logging.warning("Image from %s exceeds 2MB limit (size: %s bytes)", image_url, content_length)
return None
filename = None
content_disposition = response.headers.get("content-disposition")
if content_disposition:
_, params = cgi.parse_header(content_disposition)
if "filename" in params:
filename = params["filename"]
filename = unquote(filename)
if not filename:
parsed_url = urlparse(image_url)
# unquote 处理 URL 中的中文
path = unquote(parsed_url.path)
filename = os.path.basename(path)
if not filename:
filename = "downloaded_image_file"
name, current_ext = os.path.splitext(filename)
content_type = response.headers.get("content-type", "").split(";")[0].strip()
real_ext = mimetypes.guess_extension(content_type)
if not current_ext and real_ext or current_ext in [".php", ".jsp", ".asp", ".html"] and real_ext:
filename = f"{name}{real_ext}"
# Download content with size limit
blob = b""
for chunk in response.iter_bytes(chunk_size=8192):
blob += chunk
if len(blob) > MAX_IMAGE_SIZE:
logging.warning("Image from %s exceeds 2MB limit during download", image_url)
return None
if not blob:
logging.warning("Image from %s is empty", image_url)
return None
upload_file = FileService(db.engine).upload_file(
filename=filename,
content=blob,
mimetype=content_type,
user=current_user,
)
return upload_file.id
except httpx.TimeoutException:
logging.warning("Timeout downloading image from %s after %s seconds", image_url, DOWNLOAD_TIMEOUT)
return None
except httpx.RequestError as e:
logging.warning("Error downloading image from %s: %s", image_url, str(e))
return None
except Exception:
logging.exception("Unexpected error downloading image from %s", image_url)
return None
def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None:
"""
Download the tool file from the ID.
"""
from services.file_service import FileService
tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
upload_file = FileService(db.engine).upload_file(
filename=tool_file.name,
content=blob,
mimetype=tool_file.mimetype,
user=current_user,
)
return upload_file.id

View File

@@ -1,6 +1,6 @@
"""Abstract interface for document loader implementations."""
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
@@ -19,11 +19,11 @@ class IndexProcessorFactory:
if not self._index_type:
raise ValueError("Index type must be specified.")
if self._index_type == IndexType.PARAGRAPH_INDEX:
if self._index_type == IndexStructureType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor()
elif self._index_type == IndexType.QA_INDEX:
elif self._index_type == IndexStructureType.QA_INDEX:
return QAIndexProcessor()
elif self._index_type == IndexType.PARENT_CHILD_INDEX:
elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
else:
raise ValueError(f"Index type {self._index_type} is not supported.")

View File

@@ -11,14 +11,17 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
@@ -33,7 +36,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
@@ -69,6 +72,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if document_node.metadata is not None:
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
multimodal_documents = (
self._get_content_files(document_node, current_user) if document_node.metadata else None
)
if multimodal_documents:
document_node.attachments = multimodal_documents
# delete Splitter character
page_content = remove_leading_symbols(document_node.page_content).strip()
if len(page_content) > 0:
@@ -77,10 +85,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
all_documents.extend(split_documents)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
with_keywords = False
if with_keywords:
keywords_list = kwargs.get("keywords_list")
@@ -134,8 +151,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
documents: list[Any] = []
all_multimodal_documents: list[Any] = []
if isinstance(chunks, list):
documents = []
for content in chunks:
metadata = {
"dataset_id": dataset.id,
@@ -144,26 +162,68 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
"doc_hash": helper.generate_text_hash(content),
}
doc = Document(page_content=content, metadata=metadata)
attachments = self._get_content_files(doc)
if attachments:
doc.attachments = attachments
all_multimodal_documents.extend(attachments)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
elif dataset.indexing_technique == "economy":
keyword = Keyword(dataset)
keyword.add_texts(documents)
else:
raise ValueError("Chunks is not a list")
multimodal_general_structure = MultimodalGeneralStructureChunk.model_validate(chunks)
for general_chunk in multimodal_general_structure.general_chunks:
metadata = {
"dataset_id": dataset.id,
"document_id": document.id,
"doc_id": str(uuid.uuid4()),
"doc_hash": helper.generate_text_hash(general_chunk.content),
}
doc = Document(page_content=general_chunk.content, metadata=metadata)
if general_chunk.files:
attachments = []
for file in general_chunk.files:
file_metadata = {
"doc_id": file.id,
"doc_hash": "",
"document_id": document.id,
"dataset_id": dataset.id,
"doc_type": DocType.IMAGE,
}
file_document = AttachmentDocument(
page_content=file.filename or "image_file", metadata=file_metadata
)
attachments.append(file_document)
all_multimodal_documents.append(file_document)
doc.attachments = attachments
else:
account = AccountService.load_user(document.created_by)
if not account:
raise ValueError("Invalid account")
doc.attachments = self._get_content_files(doc, current_user=account)
if doc.attachments:
all_multimodal_documents.extend(doc.attachments)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if all_multimodal_documents:
vector.create_multimodal(all_multimodal_documents)
elif dataset.indexing_technique == "economy":
keyword = Keyword(dataset)
keyword.add_texts(documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
if isinstance(chunks, list):
preview = []
for content in chunks:
preview.append({"content": content})
return {"chunk_structure": IndexType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks)}
return {
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
"preview": preview,
"total_segments": len(chunks),
}
else:
raise ValueError("Chunks is not a list")

View File

@@ -13,14 +13,17 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from libs import helper
from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
@@ -35,7 +38,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
@@ -77,6 +80,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
page_content = page_content
if len(page_content) > 0:
document_node.page_content = page_content
multimodel_documents = self._get_content_files(document_node, current_user)
if multimodel_documents:
document_node.attachments = multimodel_documents
# parse document to child nodes
child_nodes = self._split_child_nodes(
document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
@@ -87,6 +93,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
elif rules.parent_mode == ParentMode.FULL_DOC:
page_content = "\n".join([document.page_content for document in documents])
document = Document(page_content=page_content, metadata=documents[0].metadata)
multimodel_documents = self._get_content_files(document)
if multimodel_documents:
document.attachments = multimodel_documents
# parse document to child nodes
child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
@@ -104,7 +113,14 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
for document in documents:
@@ -114,6 +130,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
Document.model_validate(child_document.model_dump()) for child_document in child_documents
]
vector.create(formatted_child_documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
@@ -244,6 +262,24 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
}
child_documents.append(ChildDocument(page_content=child, metadata=child_metadata))
doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents)
if parent_child.files and len(parent_child.files) > 0:
attachments = []
for file in parent_child.files:
file_metadata = {
"doc_id": file.id,
"doc_hash": "",
"document_id": document.id,
"dataset_id": dataset.id,
"doc_type": DocType.IMAGE,
}
file_document = AttachmentDocument(page_content=file.filename or "", metadata=file_metadata)
attachments.append(file_document)
doc.attachments = attachments
else:
account = AccountService.load_user(document.created_by)
if not account:
raise ValueError("Invalid account")
doc.attachments = self._get_content_files(doc, current_user=account)
documents.append(doc)
if documents:
# update document parent mode
@@ -267,12 +303,17 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
doc_store.add_documents(docs=documents, save_child=True)
if dataset.indexing_technique == "high_quality":
all_child_documents = []
all_multimodal_documents = []
for doc in documents:
if doc.children:
all_child_documents.extend(doc.children)
if doc.attachments:
all_multimodal_documents.extend(doc.attachments)
vector = Vector(dataset)
if all_child_documents:
vector = Vector(dataset)
vector.create(all_child_documents)
if all_multimodal_documents:
vector.create_multimodal(all_multimodal_documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
parent_childs = ParentChildStructureChunk.model_validate(chunks)
@@ -280,7 +321,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
for parent_child in parent_childs.parent_child_chunks:
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
return {
"chunk_structure": IndexType.PARENT_CHILD_INDEX,
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
"parent_mode": parent_childs.parent_mode,
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks),

View File

@@ -18,12 +18,13 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document, QAStructureChunk
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
@@ -41,7 +42,7 @@ class QAIndexProcessor(BaseIndexProcessor):
)
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
preview = kwargs.get("preview")
process_rule = kwargs.get("process_rule")
if not process_rule:
@@ -116,7 +117,7 @@ class QAIndexProcessor(BaseIndexProcessor):
try:
# Skip the first row
df = pd.read_csv(file)
df = pd.read_csv(file) # type: ignore
text_docs = []
for _, row in df.iterrows():
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
@@ -128,10 +129,19 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError(str(e))
return text_docs
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
vector = Vector(dataset)
@@ -197,7 +207,7 @@ class QAIndexProcessor(BaseIndexProcessor):
for qa_chunk in qa_chunks.qa_chunks:
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
return {
"chunk_structure": IndexType.QA_INDEX,
"chunk_structure": IndexStructureType.QA_INDEX,
"qa_preview": preview,
"total_segments": len(qa_chunks.qa_chunks),
}