refactor: decouple Node and NodeData (#22581)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
-LAN-
2025-07-18 10:08:51 +08:00
committed by GitHub
parent 54c56f2d05
commit 460a825ef1
65 changed files with 2305 additions and 1146 deletions

View File

@@ -1,4 +1,4 @@
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator
@@ -65,7 +65,7 @@ class LLMNodeData(BaseNodeData):
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: dict | None = None
structured_output: Mapping[str, Any] | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")

View File

@@ -59,7 +59,8 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import (
ModelInvokeCompletedEvent,
NodeEvent,
@@ -90,17 +91,16 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
logger = logging.getLogger(__name__)
class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
class LLMNode(BaseNode):
_node_type = NodeType.LLM
_node_data: LLMNodeData
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
@@ -138,6 +138,27 @@ class LLMNode(BaseNode[LLMNodeData]):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = LLMNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@@ -152,13 +173,13 @@ class LLMNode(BaseNode[LLMNodeData]):
try:
# init messages template
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template)
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data=self.node_data)
inputs = self._fetch_inputs(node_data=self._node_data)
# fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data)
# merge inputs
inputs.update(jinja_inputs)
@@ -169,9 +190,9 @@ class LLMNode(BaseNode[LLMNodeData]):
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=self.node_data.vision.configs.variable_selector,
selector=self._node_data.vision.configs.variable_selector,
)
if self.node_data.vision.enabled
if self._node_data.vision.enabled
else []
)
@@ -179,7 +200,7 @@ class LLMNode(BaseNode[LLMNodeData]):
node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value
generator = self._fetch_context(node_data=self.node_data)
generator = self._fetch_context(node_data=self._node_data)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
@@ -189,44 +210,54 @@ class LLMNode(BaseNode[LLMNodeData]):
node_inputs["#context#"] = context
# fetch model config
model_instance, model_config = self._fetch_model_config(self.node_data.model)
model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self._node_data.model,
tenant_id=self.tenant_id,
)
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=self.node_data.memory,
node_data_memory=self._node_data.memory,
model_instance=model_instance,
)
query = None
if self.node_data.memory:
query = self.node_data.memory.query_prompt_template
if self._node_data.memory:
query = self._node_data.memory.query_prompt_template
if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
query = query_variable.text
prompt_messages, stop = self._fetch_prompt_messages(
prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
prompt_template=self._node_data.prompt_template,
memory_config=self._node_data.memory,
vision_enabled=self._node_data.vision.enabled,
vision_detail=self._node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
jinja2_variables=self._node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
)
# handle invoke result
generator = self._invoke_llm(
node_data_model=self.node_data.model,
generator = LLMNode.invoke_llm(
node_data_model=self._node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=self._node_data.structured_output,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
)
structured_output: LLMStructuredOutput | None = None
@@ -296,12 +327,19 @@ class LLMNode(BaseNode[LLMNodeData]):
)
)
def _invoke_llm(
self,
@staticmethod
def invoke_llm(
*,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
user_id: str,
structured_output_enabled: bool,
structured_output: Optional[Mapping[str, Any]] = None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
@@ -309,8 +347,10 @@ class LLMNode(BaseNode[LLMNodeData]):
if not model_schema:
raise ValueError(f"Model schema not found for {node_data_model.name}")
if self.node_data.structured_output_enabled:
output_schema = self._fetch_structured_output_schema()
if structured_output_enabled:
output_schema = LLMNode.fetch_structured_output_schema(
structured_output=structured_output or {},
)
invoke_result = invoke_llm_with_structured_output(
provider=model_instance.provider,
model_schema=model_schema,
@@ -320,7 +360,7 @@ class LLMNode(BaseNode[LLMNodeData]):
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
stream=True,
user=self.user_id,
user=user_id,
)
else:
invoke_result = model_instance.invoke_llm(
@@ -328,17 +368,31 @@ class LLMNode(BaseNode[LLMNodeData]):
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
stream=True,
user=self.user_id,
user=user_id,
)
return self._handle_invoke_result(invoke_result=invoke_result)
return LLMNode.handle_invoke_result(
invoke_result=invoke_result,
file_saver=file_saver,
file_outputs=file_outputs,
node_id=node_id,
)
def _handle_invoke_result(
self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None]
@staticmethod
def handle_invoke_result(
*,
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
event = self._handle_blocking_result(invoke_result=invoke_result)
event = LLMNode.handle_blocking_result(
invoke_result=invoke_result,
saver=file_saver,
file_outputs=file_outputs,
)
yield event
return
@@ -356,11 +410,13 @@ class LLMNode(BaseNode[LLMNodeData]):
yield result
if isinstance(result, LLMResultChunk):
contents = result.delta.message.content
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
contents=contents,
file_saver=file_saver,
file_outputs=file_outputs,
):
full_text_buffer.write(text_part)
yield RunStreamChunkEvent(
chunk_content=text_part, from_variable_selector=[self.node_id, "text"]
)
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"])
# Update the whole metadata
if not model and result.model:
@@ -378,7 +434,8 @@ class LLMNode(BaseNode[LLMNodeData]):
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
def _image_file_to_markdown(self, file: "File", /):
@staticmethod
def _image_file_to_markdown(file: "File", /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
@@ -539,11 +596,14 @@ class LLMNode(BaseNode[LLMNodeData]):
return None
@staticmethod
def _fetch_model_config(
self, node_data_model: ModelConfig
*,
node_data_model: ModelConfig,
tenant_id: str,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model, model_config_with_cred = llm_utils.fetch_model_config(
tenant_id=self.tenant_id, node_data_model=node_data_model
tenant_id=tenant_id, node_data_model=node_data_model
)
completion_params = model_config_with_cred.parameters
@@ -556,8 +616,8 @@ class LLMNode(BaseNode[LLMNodeData]):
node_data_model.completion_params = completion_params
return model, model_config_with_cred
def _fetch_prompt_messages(
self,
@staticmethod
def fetch_prompt_messages(
*,
sys_query: str | None = None,
sys_files: Sequence["File"],
@@ -570,13 +630,14 @@ class LLMNode(BaseNode[LLMNodeData]):
vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages: list[PromptMessage] = []
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
self._handle_list_messages(
LLMNode.handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
@@ -602,7 +663,7 @@ class LLMNode(BaseNode[LLMNodeData]):
edition_type="basic",
)
prompt_messages.extend(
self._handle_list_messages(
LLMNode.handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
@@ -731,7 +792,7 @@ class LLMNode(BaseNode[LLMNodeData]):
)
model = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,
model=model_config.model,
@@ -750,10 +811,12 @@ class LLMNode(BaseNode[LLMNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LLMNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
prompt_template = node_data.prompt_template
# Create typed NodeData from dict
typed_node_data = LLMNodeData.model_validate(node_data)
prompt_template = typed_node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list) and all(
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
@@ -773,7 +836,7 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
memory = node_data.memory
memory = typed_node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
@@ -781,16 +844,16 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if node_data.context.enabled:
variable_mapping["#context#"] = node_data.context.variable_selector
if typed_node_data.context.enabled:
variable_mapping["#context#"] = typed_node_data.context.variable_selector
if node_data.vision.enabled:
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
if typed_node_data.vision.enabled:
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
if node_data.memory:
if typed_node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
if node_data.prompt_config:
if typed_node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, list):
@@ -803,7 +866,7 @@ class LLMNode(BaseNode[LLMNodeData]):
enable_jinja = True
if enable_jinja:
for variable_selector in node_data.prompt_config.jinja2_variables or []:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
@@ -835,8 +898,8 @@ class LLMNode(BaseNode[LLMNodeData]):
},
}
def _handle_list_messages(
self,
@staticmethod
def handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
@@ -897,9 +960,19 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
@staticmethod
def handle_blocking_result(
*,
invoke_result: LLMResult,
saver: LLMFileSaver,
file_outputs: list["File"],
) -> ModelInvokeCompletedEvent:
buffer = io.StringIO()
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content):
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
contents=invoke_result.message.content,
file_saver=saver,
file_outputs=file_outputs,
):
buffer.write(text_part)
return ModelInvokeCompletedEvent(
@@ -908,7 +981,12 @@ class LLMNode(BaseNode[LLMNodeData]):
finish_reason=None,
)
def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File":
@staticmethod
def save_multimodal_image_output(
*,
content: ImagePromptMessageContent,
file_saver: LLMFileSaver,
) -> "File":
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs:
@@ -918,26 +996,21 @@ class LLMNode(BaseNode[LLMNodeData]):
Currently, only image files are supported.
"""
# Inject the saver somehow...
_saver = self._llm_file_saver
# If this
if content.url != "":
saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE)
else:
saved_file = _saver.save_binary_string(
saved_file = file_saver.save_binary_string(
data=base64.b64decode(content.base64_data),
mime_type=content.mime_type,
file_type=FileType.IMAGE,
)
self._file_outputs.append(saved_file)
return saved_file
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
"""
Fetch model schema
"""
model_name = self.node_data.model.name
model_name = self._node_data.model.name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
@@ -948,16 +1021,20 @@ class LLMNode(BaseNode[LLMNodeData]):
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_schema
def _fetch_structured_output_schema(self) -> dict[str, Any]:
@staticmethod
def fetch_structured_output_schema(
*,
structured_output: Mapping[str, Any],
) -> dict[str, Any]:
"""
Fetch the structured output schema from the node data.
Returns:
dict[str, Any]: The structured output schema
"""
if not self.node_data.structured_output:
if not structured_output:
raise LLMNodeError("Please provide a valid structured output schema")
structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
@@ -969,9 +1046,12 @@ class LLMNode(BaseNode[LLMNodeData]):
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
@staticmethod
def _save_multimodal_output_and_convert_result_to_markdown(
self,
*,
contents: str | list[PromptMessageContentUnionTypes] | None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller.
@@ -994,9 +1074,12 @@ class LLMNode(BaseNode[LLMNodeData]):
if isinstance(item, TextPromptMessageContent):
yield item.data
elif isinstance(item, ImagePromptMessageContent):
file = self._save_multimodal_image_output(item)
self._file_outputs.append(file)
yield self._image_file_to_markdown(file)
file = LLMNode.save_multimodal_image_output(
content=item,
file_saver=file_saver,
)
file_outputs.append(file)
yield LLMNode._image_file_to_markdown(file)
else:
logger.warning("unknown item type encountered, type=%s", type(item))
yield str(item)
@@ -1004,6 +1087,14 @@ class LLMNode(BaseNode[LLMNodeData]):
logger.warning("unknown contents type encountered, type=%s", type(contents))
yield str(contents)
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole