refactor: decouple Node and NodeData (#22581)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
-LAN-
2025-07-18 10:08:51 +08:00
committed by GitHub
parent 54c56f2d05
commit 460a825ef1
65 changed files with 2305 additions and 1146 deletions

View File

@@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@@ -19,10 +18,10 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
@@ -37,14 +36,18 @@ from .exc import (
)
class ToolNode(BaseNode[ToolNodeData]):
class ToolNode(BaseNode):
"""
Tool Node
"""
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
_node_data: ToolNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = ToolNodeData.model_validate(data)
@classmethod
def version(cls) -> str:
return "1"
@@ -54,7 +57,7 @@ class ToolNode(BaseNode[ToolNodeData]):
Run the tool node
"""
node_data = cast(ToolNodeData, self.node_data)
node_data = cast(ToolNodeData, self._node_data)
# fetch tool icon
tool_info = {
@@ -67,9 +70,9 @@ class ToolNode(BaseNode[ToolNodeData]):
try:
from core.tools.tool_manager import ToolManager
variable_pool = self.graph_runtime_state.variable_pool if self.node_data.version != "1" else None
variable_pool = self.graph_runtime_state.variable_pool if self._node_data.version != "1" else None
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from, variable_pool
self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool
)
except ToolNodeError as e:
yield RunCompletedEvent(
@@ -88,12 +91,12 @@ class ToolNode(BaseNode[ToolNodeData]):
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
node_data=self._node_data,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
node_data=self._node_data,
for_log=True,
)
# get conversation id
@@ -124,7 +127,14 @@ class ToolNode(BaseNode[ToolNodeData]):
try:
# convert tool messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
yield from self._transform_message(
messages=message_stream,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_id=self.node_id,
)
except (PluginDaemonClientSideError, ToolInvokeError) as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
@@ -191,7 +201,9 @@ class ToolNode(BaseNode[ToolNodeData]):
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
agent_thoughts: Optional[list] = None,
user_id: str,
tenant_id: str,
node_id: str,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
@@ -199,8 +211,8 @@ class ToolNode(BaseNode[ToolNodeData]):
# transform message and handle file storage
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
@@ -208,9 +220,6 @@ class ToolNode(BaseNode[ToolNodeData]):
files: list[File] = []
json: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage: LLMUsage | None = None
variables: dict[str, Any] = {}
for message in message_stream:
@@ -243,7 +252,7 @@ class ToolNode(BaseNode[ToolNodeData]):
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
@@ -266,45 +275,36 @@ class ToolNode(BaseNode[ToolNodeData]):
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
)
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT:
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(msg_metadata)
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
# JSON message handling for tool node
if message.message.json_object is not None:
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
)
else:
variables[variable_name] = variable_value
@@ -319,7 +319,7 @@ class ToolNode(BaseNode[ToolNodeData]):
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(self.tenant_id)
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
@@ -334,8 +334,8 @@ class ToolNode(BaseNode[ToolNodeData]):
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
self.user_id,
self.tenant_id,
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
@@ -347,57 +347,10 @@ class ToolNode(BaseNode[ToolNodeData]):
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
id=message.message.id,
node_execution_id=self.id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=self.node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.id == agent_log.id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
elif message.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
assert isinstance(message.message, ToolInvokeMessage.RetrieverResourceMessage)
yield RunRetrieverResourceEvent(
retriever_resources=message.message.retriever_resources,
context=message.message.context,
)
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json:
json_output.extend(json)
@@ -409,12 +362,9 @@ class ToolNode(BaseNode[ToolNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)
@@ -424,7 +374,7 @@ class ToolNode(BaseNode[ToolNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ToolNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -433,9 +383,12 @@ class ToolNode(BaseNode[ToolNodeData]):
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = ToolNodeData.model_validate(node_data)
result = {}
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
@@ -449,3 +402,29 @@ class ToolNode(BaseNode[ToolNodeData]):
result = {node_id + "." + key: value for key, value in result.items()}
return result
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled