Refactor workflow nodes to use generic node_data (#28782)

This commit is contained in:
-LAN-
2025-11-27 20:46:56 +08:00
committed by GitHub
parent 002d8769b0
commit 8b761319f6
28 changed files with 121 additions and 170 deletions

View File

@@ -65,7 +65,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
node_type = NodeType.ITERATION
execution_type = NodeExecutionType.CONTAINER
_node_data: IterationNodeData
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@@ -136,10 +135,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
)
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
@@ -174,7 +173,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return cast(list[object], iterator_list_value)
def _validate_start_node(self) -> None:
if not self._node_data.start_node_id:
if not self.node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
def _execute_iterations(
@@ -184,7 +183,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
if self._node_data.is_parallel:
if self.node_data.is_parallel:
# Parallel mode execution
yield from self._execute_parallel_iterations(
iterator_list_value=iterator_list_value,
@@ -231,7 +230,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
outputs.extend([None] * len(iterator_list_value))
# Determine the number of parallel workers
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
max_workers = min(self.node_data.parallel_nums, len(iterator_list_value))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks
@@ -287,7 +286,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
except Exception as e:
# Handle errors based on error_handle_mode
match self._node_data.error_handle_mode:
match self.node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED:
# Cancel remaining futures and re-raise
for f in future_to_index:
@@ -300,7 +299,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
outputs[index] = None # Will be filtered later
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs[:] = [output for output in outputs if output is not None]
def _execute_single_iteration_parallel(
@@ -389,7 +388,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
If flatten_output is True (default), flattens the list if all elements are lists.
"""
# If flatten_output is disabled, return outputs as-is
if not self._node_data.flatten_output:
if not self.node_data.flatten_output:
return outputs
if not outputs:
@@ -569,14 +568,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
yield event
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
result = variable_pool.get(self._node_data.output_selector)
result = variable_pool.get(self.node_data.output_selector)
if result is None:
outputs.append(None)
else:
outputs.append(result.to_object())
return
elif isinstance(event, GraphRunFailedEvent):
match self._node_data.error_handle_mode:
match self.node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED:
raise IterationNodeError(event.error)
case ErrorHandleMode.CONTINUE_ON_ERROR:
@@ -627,7 +626,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
# Initialize the iteration graph with the new node factory
iteration_graph = Graph.init(
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id
)
if not iteration_graph:

View File

@@ -11,8 +11,6 @@ class IterationStartNode(Node[IterationStartNodeData]):
node_type = NodeType.ITERATION_START
_node_data: IterationStartNodeData
@classmethod
def version(cls) -> str:
return "1"