Refactor workflow nodes to use generic node_data (#28782)
This commit is contained in:
@@ -65,7 +65,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
|
||||
node_type = NodeType.ITERATION
|
||||
execution_type = NodeExecutionType.CONTAINER
|
||||
_node_data: IterationNodeData
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@@ -136,10 +135,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
)
|
||||
|
||||
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
||||
|
||||
if not variable:
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
|
||||
|
||||
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
@@ -174,7 +173,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
return cast(list[object], iterator_list_value)
|
||||
|
||||
def _validate_start_node(self) -> None:
|
||||
if not self._node_data.start_node_id:
|
||||
if not self.node_data.start_node_id:
|
||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
||||
|
||||
def _execute_iterations(
|
||||
@@ -184,7 +183,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
iter_run_map: dict[str, float],
|
||||
usage_accumulator: list[LLMUsage],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
if self._node_data.is_parallel:
|
||||
if self.node_data.is_parallel:
|
||||
# Parallel mode execution
|
||||
yield from self._execute_parallel_iterations(
|
||||
iterator_list_value=iterator_list_value,
|
||||
@@ -231,7 +230,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
outputs.extend([None] * len(iterator_list_value))
|
||||
|
||||
# Determine the number of parallel workers
|
||||
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
|
||||
max_workers = min(self.node_data.parallel_nums, len(iterator_list_value))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all iteration tasks
|
||||
@@ -287,7 +286,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors based on error_handle_mode
|
||||
match self._node_data.error_handle_mode:
|
||||
match self.node_data.error_handle_mode:
|
||||
case ErrorHandleMode.TERMINATED:
|
||||
# Cancel remaining futures and re-raise
|
||||
for f in future_to_index:
|
||||
@@ -300,7 +299,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
outputs[index] = None # Will be filtered later
|
||||
|
||||
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
|
||||
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs[:] = [output for output in outputs if output is not None]
|
||||
|
||||
def _execute_single_iteration_parallel(
|
||||
@@ -389,7 +388,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
If flatten_output is True (default), flattens the list if all elements are lists.
|
||||
"""
|
||||
# If flatten_output is disabled, return outputs as-is
|
||||
if not self._node_data.flatten_output:
|
||||
if not self.node_data.flatten_output:
|
||||
return outputs
|
||||
|
||||
if not outputs:
|
||||
@@ -569,14 +568,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
|
||||
yield event
|
||||
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
|
||||
result = variable_pool.get(self._node_data.output_selector)
|
||||
result = variable_pool.get(self.node_data.output_selector)
|
||||
if result is None:
|
||||
outputs.append(None)
|
||||
else:
|
||||
outputs.append(result.to_object())
|
||||
return
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
match self._node_data.error_handle_mode:
|
||||
match self.node_data.error_handle_mode:
|
||||
case ErrorHandleMode.TERMINATED:
|
||||
raise IterationNodeError(event.error)
|
||||
case ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
@@ -627,7 +626,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
|
||||
# Initialize the iteration graph with the new node factory
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id
|
||||
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
|
||||
@@ -11,8 +11,6 @@ class IterationStartNode(Node[IterationStartNodeData]):
|
||||
|
||||
node_type = NodeType.ITERATION_START
|
||||
|
||||
_node_data: IterationStartNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
Reference in New Issue
Block a user