Feat/support multimodal embedding (#29115)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jyong
2025-12-09 14:41:46 +08:00
committed by GitHub
parent 77cf8f6c27
commit 9affc546c6
78 changed files with 3230 additions and 713 deletions

View File

@@ -7,8 +7,10 @@ import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import select
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
from core.file import File, FileTransferMethod, FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
@@ -44,6 +46,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from core.variables import (
ArrayFileSegment,
ArraySegment,
@@ -72,6 +75,9 @@ from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from . import llm_utils
from .entities import (
@@ -179,12 +185,17 @@ class LLMNode(Node[LLMNodeData]):
# fetch context value
generator = self._fetch_context(node_data=self.node_data)
context = None
context_files: list[File] = []
for event in generator:
context = event.context
context_files = event.context_files or []
yield event
if context:
node_inputs["#context#"] = context
if context_files:
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
# fetch model config
model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self.node_data.model,
@@ -220,6 +231,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
# handle invoke result
@@ -654,10 +666,13 @@ class LLMNode(Node[LLMNodeData]):
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
if context_value_variable:
if isinstance(context_value_variable, StringSegment):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
yield RunRetrieverResourceEvent(
retriever_resources=[], context=context_value_variable.value, context_files=[]
)
elif isinstance(context_value_variable, ArraySegment):
context_str = ""
original_retriever_resource: list[RetrievalSourceMetadata] = []
context_files: list[File] = []
for item in context_value_variable.value:
if isinstance(item, str):
context_str += item + "\n"
@@ -670,9 +685,34 @@ class LLMNode(Node[LLMNodeData]):
retriever_resource = self._convert_to_original_retriever_resource(item)
if retriever_resource:
original_retriever_resource.append(retriever_resource)
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == retriever_resource.segment_id,
)
).all()
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attchment_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
)
context_files.append(attchment_info)
yield RunRetrieverResourceEvent(
retriever_resources=original_retriever_resource, context=context_str.strip()
retriever_resources=original_retriever_resource,
context=context_str.strip(),
context_files=context_files,
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None:
@@ -700,6 +740,7 @@ class LLMNode(Node[LLMNodeData]):
content=context_dict.get("content"),
page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"),
files=context_dict.get("files"),
)
return source
@@ -741,6 +782,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
context_files: list["File"] | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
@@ -853,6 +895,23 @@ class LLMNode(Node[LLMNodeData]):
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# The context_files
if vision_enabled and context_files:
file_prompts = []
for file in context_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages: