Refactor workflow nodes to use generic node_data (#28782)

This commit is contained in:
-LAN-
2025-11-27 20:46:56 +08:00
committed by GitHub
parent 002d8769b0
commit 8b761319f6
28 changed files with 121 additions and 170 deletions

View File

@@ -11,8 +11,6 @@ class LoopEndNode(Node[LoopEndNodeData]):
node_type = NodeType.LOOP_END
_node_data: LoopEndNodeData
@classmethod
def version(cls) -> str:
return "1"

View File

@@ -46,7 +46,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
"""
node_type = NodeType.LOOP
_node_data: LoopNodeData
execution_type = NodeExecutionType.CONTAINER
@classmethod
@@ -56,27 +55,27 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _run(self) -> Generator:
"""Run the node."""
# Get inputs
loop_count = self._node_data.loop_count
break_conditions = self._node_data.break_conditions
logical_operator = self._node_data.logical_operator
loop_count = self.node_data.loop_count
break_conditions = self.node_data.break_conditions
logical_operator = self.node_data.logical_operator
inputs = {"loop_count": loop_count}
if not self._node_data.start_node_id:
if not self.node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
root_node_id = self._node_data.start_node_id
root_node_id = self.node_data.start_node_id
# Initialize loop variables in the original variable pool
loop_variable_selectors = {}
if self._node_data.loop_variables:
if self.node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
if isinstance(var.value, list)
else None,
}
for loop_variable in self._node_data.loop_variables:
for loop_variable in self.node_data.loop_variables:
if loop_variable.value_type not in value_processor:
raise ValueError(
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
@@ -164,7 +163,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
yield LoopNextEvent(
index=i + 1,
pre_loop_output=self._node_data.outputs,
pre_loop_output=self.node_data.outputs,
)
self._accumulate_usage(loop_usage)
@@ -172,7 +171,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
yield LoopSucceededEvent(
start_at=start_at,
inputs=inputs,
outputs=self._node_data.outputs,
outputs=self.node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
@@ -194,7 +193,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
outputs=self._node_data.outputs,
outputs=self.node_data.outputs,
inputs=inputs,
llm_usage=loop_usage,
)
@@ -252,11 +251,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
if isinstance(event, GraphRunFailedEvent):
raise Exception(event.error)
for loop_var in self._node_data.loop_variables or []:
for loop_var in self.node_data.loop_variables or []:
key, sel = loop_var.label, [self._node_id, loop_var.label]
segment = self.graph_runtime_state.variable_pool.get(sel)
self._node_data.outputs[key] = segment.value if segment else None
self._node_data.outputs["loop_round"] = current_index + 1
self.node_data.outputs[key] = segment.value if segment else None
self.node_data.outputs["loop_round"] = current_index + 1
return reach_break_node

View File

@@ -11,8 +11,6 @@ class LoopStartNode(Node[LoopStartNodeData]):
node_type = NodeType.LOOP_START
_node_data: LoopStartNodeData
@classmethod
def version(cls) -> str:
return "1"