Refactor workflow nodes to use generic node_data (#28782)
This commit is contained in:
@@ -11,8 +11,6 @@ class LoopEndNode(Node[LoopEndNodeData]):
|
||||
|
||||
node_type = NodeType.LOOP_END
|
||||
|
||||
_node_data: LoopEndNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@@ -46,7 +46,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
"""
|
||||
|
||||
node_type = NodeType.LOOP
|
||||
_node_data: LoopNodeData
|
||||
execution_type = NodeExecutionType.CONTAINER
|
||||
|
||||
@classmethod
|
||||
@@ -56,27 +55,27 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
def _run(self) -> Generator:
|
||||
"""Run the node."""
|
||||
# Get inputs
|
||||
loop_count = self._node_data.loop_count
|
||||
break_conditions = self._node_data.break_conditions
|
||||
logical_operator = self._node_data.logical_operator
|
||||
loop_count = self.node_data.loop_count
|
||||
break_conditions = self.node_data.break_conditions
|
||||
logical_operator = self.node_data.logical_operator
|
||||
|
||||
inputs = {"loop_count": loop_count}
|
||||
|
||||
if not self._node_data.start_node_id:
|
||||
if not self.node_data.start_node_id:
|
||||
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
|
||||
|
||||
root_node_id = self._node_data.start_node_id
|
||||
root_node_id = self.node_data.start_node_id
|
||||
|
||||
# Initialize loop variables in the original variable pool
|
||||
loop_variable_selectors = {}
|
||||
if self._node_data.loop_variables:
|
||||
if self.node_data.loop_variables:
|
||||
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
|
||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
|
||||
if isinstance(var.value, list)
|
||||
else None,
|
||||
}
|
||||
for loop_variable in self._node_data.loop_variables:
|
||||
for loop_variable in self.node_data.loop_variables:
|
||||
if loop_variable.value_type not in value_processor:
|
||||
raise ValueError(
|
||||
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
|
||||
@@ -164,7 +163,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
|
||||
yield LoopNextEvent(
|
||||
index=i + 1,
|
||||
pre_loop_output=self._node_data.outputs,
|
||||
pre_loop_output=self.node_data.outputs,
|
||||
)
|
||||
|
||||
self._accumulate_usage(loop_usage)
|
||||
@@ -172,7 +171,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
yield LoopSucceededEvent(
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs=self._node_data.outputs,
|
||||
outputs=self.node_data.outputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
@@ -194,7 +193,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
outputs=self._node_data.outputs,
|
||||
outputs=self.node_data.outputs,
|
||||
inputs=inputs,
|
||||
llm_usage=loop_usage,
|
||||
)
|
||||
@@ -252,11 +251,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
raise Exception(event.error)
|
||||
|
||||
for loop_var in self._node_data.loop_variables or []:
|
||||
for loop_var in self.node_data.loop_variables or []:
|
||||
key, sel = loop_var.label, [self._node_id, loop_var.label]
|
||||
segment = self.graph_runtime_state.variable_pool.get(sel)
|
||||
self._node_data.outputs[key] = segment.value if segment else None
|
||||
self._node_data.outputs["loop_round"] = current_index + 1
|
||||
self.node_data.outputs[key] = segment.value if segment else None
|
||||
self.node_data.outputs["loop_round"] = current_index + 1
|
||||
|
||||
return reach_break_node
|
||||
|
||||
|
||||
@@ -11,8 +11,6 @@ class LoopStartNode(Node[LoopStartNodeData]):
|
||||
|
||||
node_type = NodeType.LOOP_START
|
||||
|
||||
_node_data: LoopStartNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
Reference in New Issue
Block a user