feat: knowledge pipeline (#25360)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: jyong <718720800@qq.com>
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: quicksand <quicksandzn@gmail.com>
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: Hanqing Zhao <sherry9277@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry <xh001x@hotmail.com>
This commit is contained in:
-LAN-
2025-09-18 12:49:10 +08:00
committed by GitHub
parent 7dadb33003
commit 85cda47c70
1772 changed files with 102407 additions and 31710 deletions

View File

@@ -39,7 +39,7 @@ class IterationState(BaseIterationState):
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Any | None = None
current_output: Any = None
class MetaData(BaseIterationState.MetaData):
"""
@@ -48,7 +48,7 @@ class IterationState(BaseIterationState):
iterator_length: int
def get_last_output(self) -> Any | None:
def get_last_output(self) -> Any:
"""
Get last output.
"""
@@ -56,7 +56,7 @@ class IterationState(BaseIterationState):
return self.outputs[-1]
return None
def get_current_output(self) -> Any | None:
def get_current_output(self) -> Any:
"""
Get current output.
"""

View File

@@ -1,48 +1,39 @@
import contextvars
import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait
from datetime import datetime
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, cast
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, NewType, cast
from flask import Flask, current_app
from typing_extensions import TypeIs
from configs import dify_config
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import (
NodeRunResult,
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
BaseParallelBranchEvent,
from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunFailedEvent,
InNodeEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
GraphRunSucceededEvent,
)
from core.workflow.node_events import (
IterationFailedEvent,
IterationNextEvent,
IterationStartedEvent,
IterationSucceededEvent,
NodeEventBase,
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from factories.variable_factory import build_segment
from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
from .exc import (
InvalidIteratorValueError,
@@ -54,17 +45,20 @@ from .exc import (
)
if TYPE_CHECKING:
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
class IterationNode(BaseNode):
class IterationNode(Node):
"""
Iteration Node.
"""
_node_type = NodeType.ITERATION
node_type = NodeType.ITERATION
execution_type = NodeExecutionType.CONTAINER
_node_data: IterationNodeData
def init_node_data(self, data: Mapping[str, Any]):
@@ -89,7 +83,7 @@ class IterationNode(BaseNode):
return self._node_data
@classmethod
def get_default_config(cls, filters: dict | None = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "iteration",
"config": {
@@ -103,10 +97,53 @@ class IterationNode(BaseNode):
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
Run the node.
"""
def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore
variable = self._get_iterator_variable()
if self._is_empty_iteration(variable):
yield from self._handle_empty_iteration(variable)
return
iterator_list_value = self._validate_and_get_iterator_list(variable)
inputs = {"iterator_selector": iterator_list_value}
self._validate_start_node()
started_at = naive_utc_now()
iter_run_map: dict[str, float] = {}
outputs: list[object] = []
yield IterationStartedEvent(
start_at=started_at,
inputs=inputs,
metadata={"iteration_length": len(iterator_list_value)},
)
try:
yield from self._execute_iterations(
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
)
yield from self._handle_iteration_success(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
)
except IterationNodeError as e:
yield from self._handle_iteration_failure(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
error=e,
)
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
if not variable:
@@ -115,213 +152,211 @@ class IterationNode(BaseNode):
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneSegment) or len(variable.value) == 0:
# Try our best to preserve the type informat.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
# TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
)
)
return
return variable
def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]:
return isinstance(variable, NoneSegment) or len(variable.value) == 0
def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]:
# Try our best to preserve the type information.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
# TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
)
)
def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]:
iterator_list_value = variable.to_object()
if not isinstance(iterator_list_value, list):
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
inputs = {"iterator_selector": iterator_list_value}
graph_config = self.graph_config
return cast(list[object], iterator_list_value)
def _validate_start_node(self) -> None:
if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
root_node_id = self._node_data.start_node_id
def _execute_iterations(
self,
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
if self._node_data.is_parallel:
# Parallel mode execution
yield from self._execute_parallel_iterations(
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
)
else:
# Sequential mode execution
for index, item in enumerate(iterator_list_value):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
yield IterationNextEvent(index=index)
# init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
graph_engine = self._create_graph_engine(index, item)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
variable_pool = self.graph_runtime_state.variable_pool
# append iteration variable (item, index) to variable pool
variable_pool.add([self.node_id, "index"], 0)
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=iteration_graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id,
)
start_at = naive_utc_now()
yield IterationRunStartedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
metadata={"iterator_length": len(iterator_list_value)},
predecessor_node_id=self.previous_node_id,
)
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=0,
pre_iteration_output=None,
duration=None,
)
iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value)
try:
if self._node_data.is_parallel:
futures: list[Future] = []
q: Queue = Queue()
thread_pool = GraphEngineThreadPool(
max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
# Run the iteration
yield from self._run_single_iter(
variable_pool=graph_engine.graph_runtime_state.variable_pool,
outputs=outputs,
graph_engine=graph_engine,
)
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
flask_app=current_app._get_current_object(), # type: ignore
q=q,
context=contextvars.copy_context(),
iterator_list_value=iterator_list_value,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine,
iteration_graph=iteration_graph,
index=index,
item=item,
iter_run_map=iter_run_map,
)
future.add_done_callback(thread_pool.task_done_callback)
futures.append(future)
succeeded_count = 0
while True:
try:
event = q.get(timeout=1)
if event is None:
break
if isinstance(event, IterationRunNextEvent):
succeeded_count += 1
if succeeded_count == len(futures):
q.put(None)
yield event
if isinstance(event, RunCompletedEvent):
q.put(None)
for f in futures:
if not f.done():
# Update the total tokens from this iteration
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
def _execute_parallel_iterations(
self,
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
# Initialize outputs list with None values to maintain order
outputs.extend([None] * len(iterator_list_value))
# Determine the number of parallel workers
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks
future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {}
for index, item in enumerate(iterator_list_value):
yield IterationNextEvent(index=index)
future = executor.submit(
self._execute_single_iteration_parallel,
index=index,
item=item,
)
future_to_index[future] = index
# Process completed iterations as they finish
for future in as_completed(future_to_index):
index = future_to_index[future]
try:
result = future.result()
iter_start_at, events, output_value, tokens_used = result
# Update outputs at the correct index
outputs[index] = output_value
# Yield all events from this iteration
yield from events
# Update tokens and timing
self.graph_runtime_state.total_tokens += tokens_used
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
except Exception as e:
# Handle errors based on error_handle_mode
match self._node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED:
# Cancel remaining futures and re-raise
for f in future_to_index:
if f != future:
f.cancel()
yield event
if isinstance(event, IterationRunFailedEvent):
q.put(None)
yield event
except Empty:
continue
raise IterationNodeError(str(e))
case ErrorHandleMode.CONTINUE_ON_ERROR:
outputs[index] = None
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs[index] = None # Will be filtered later
# wait all threads
wait(futures)
else:
for _ in range(len(iterator_list_value)):
yield from self._run_single_iter(
iterator_list_value=iterator_list_value,
variable_pool=variable_pool,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine,
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
)
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None]
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs[:] = [output for output in outputs if output is not None]
# Flatten the list of lists
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
outputs = [item for sublist in outputs for item in sublist]
output_segment = build_segment(outputs)
def _execute_single_iteration_parallel(
self,
index: int,
item: object,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
"""Execute a single iteration in parallel mode and return results."""
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
graph_engine = self._create_graph_engine(index, item)
# Collect events instead of yielding them directly
for event in self._run_single_iter(
variable_pool=graph_engine.graph_runtime_state.variable_pool,
outputs=outputs_temp,
graph_engine=graph_engine,
):
events.append(event)
# Get the output value from the temporary outputs list
output_value = outputs_temp[0] if outputs_temp else None
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
def _handle_iteration_success(
self,
started_at: datetime,
inputs: dict[str, Sequence[object]],
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
) -> Generator[NodeEventBase, None, None]:
yield IterationSucceededEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
)
# Yield final success event
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
},
)
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": output_segment},
metadata={
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
},
)
def _handle_iteration_failure(
self,
started_at: datetime,
inputs: dict[str, Sequence[object]],
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
error: IterationNodeError,
) -> Generator[NodeEventBase, None, None]:
yield IterationFailedEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
error=str(error),
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(error),
)
except IterationNodeError as e:
# iteration run failed
logger.warning("Iteration run failed")
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
finally:
# remove iteration variable (item, index) from variable pool after iteration run completed
variable_pool.remove([self.node_id, "index"])
variable_pool.remove([self.node_id, "item"])
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
@@ -339,12 +374,45 @@ class IterationNode(BaseNode):
}
# init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
# Create minimal GraphInitParams for static analysis
graph_init_params = GraphInitParams(
tenant_id="",
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)
# Create minimal GraphRuntimeState for static analysis
from core.workflow.entities import VariablePool
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)
# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
iteration_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
# Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("iteration_id") != node_id:
continue
@@ -382,297 +450,111 @@ class IterationNode(BaseNode):
return variable_mapping
def _handle_event_metadata(
def _append_iteration_info_to_event(
self,
*,
event: BaseNodeEvent | InNodeEvent,
event: GraphNodeEventBase,
iter_run_index: int,
parallel_mode_run_id: str | None,
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
"""
add iteration metadata to event.
ensures iteration context (ID, index/parallel_run_id) is added to metadata,
"""
if not isinstance(event, BaseNodeEvent):
return event
if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id
):
event.in_iteration_id = self._node_id
iter_metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self.node_id,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self._node_id,
WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index,
}
if parallel_mode_run_id:
# for parallel, the specific branch ID is more important than the sequential index
iter_metadata[WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
if event.route_node_state.node_run_result:
current_metadata = event.route_node_state.node_run_result.metadata or {}
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata}
return event
current_metadata = event.node_run_result.metadata
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
event.node_run_result.metadata = {**current_metadata, **iter_metadata}
def _run_single_iter(
self,
*,
iterator_list_value: Sequence[str],
variable_pool: VariablePool,
inputs: Mapping[str, list],
outputs: list,
start_at: datetime,
outputs: list[object],
graph_engine: "GraphEngine",
iteration_graph: Graph,
iter_run_map: dict[str, float],
parallel_mode_run_id: str | None = None,
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration
"""
iter_start_at = naive_utc_now()
) -> Generator[GraphNodeEventBase, None, None]:
rst = graph_engine.run()
# get current iteration index
index_variable = variable_pool.get([self._node_id, "index"])
if not isinstance(index_variable, IntegerVariable):
raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found")
current_index = index_variable.value
for event in rst:
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.ITERATION_START:
continue
try:
rst = graph_engine.run()
# get current iteration index
index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(index_variable, IntegerVariable):
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
current_index = index_variable.value
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
next_index = int(current_index) + 1
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: # ty: ignore [unresolved-attribute]
event.in_iteration_id = self.node_id # ty: ignore [unresolved-attribute]
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.ITERATION_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
continue
if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
if self._node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
else:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
if isinstance(event, GraphNodeEventBase):
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
yield event
elif isinstance(event, GraphRunSucceededEvent):
result = variable_pool.get(self._node_data.output_selector)
if result is None:
outputs.append(None)
else:
outputs.append(result.to_object())
return
elif isinstance(event, GraphRunFailedEvent):
match self._node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED:
raise IterationNodeError(event.error)
case ErrorHandleMode.CONTINUE_ON_ERROR:
outputs.append(None)
return
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
return
elif isinstance(event, InNodeEvent):
# event = cast(InNodeEvent, event)
metadata_event = self._handle_event_metadata(
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
if isinstance(event, NodeRunFailedEvent):
if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
outputs[current_index] = None
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (naive_utc_now() - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (naive_utc_now() - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
outputs[current_index] = None
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
# clean nodes resources
for node_id in iteration_graph.node_ids:
variable_pool.remove([node_id])
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
graph_config=self.graph_config,
user_id=self.user_id,
user_from=self.user_from.value,
invoke_from=self.invoke_from.value,
call_depth=self.workflow_call_depth,
)
# Create a deep copy of the variable pool for each iteration
variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True)
# iteration run failed
if self._node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
else:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
# append iteration variable (item, index) to variable pool
variable_pool_copy.add([self._node_id, "index"], index)
variable_pool_copy.add([self._node_id, "item"], item)
# stop the iterator
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
return
yield metadata_event
# Create a new GraphRuntimeState for this iteration
graph_runtime_state_copy = GraphRuntimeState(
variable_pool=variable_pool_copy,
start_at=self.graph_runtime_state.start_at,
total_tokens=0,
node_run_steps=0,
)
current_output_segment = variable_pool.get(self._node_data.output_selector)
if current_output_segment is None:
raise IterationNodeError("iteration output selector not found")
current_iteration_output = current_output_segment.value
outputs[current_index] = current_iteration_output
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove([node_id])
# Create a new node factory with the new GraphRuntimeState
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
)
# move to next iteration
variable_pool.add([self.node_id, "index"], next_index)
# Initialize the iteration graph with the new node factory
iteration_graph = Graph.init(
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id
)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (naive_utc_now() - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=current_iteration_output or None,
duration=duration,
)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
except IterationNodeError as e:
logger.warning("Iteration run failed:%s", str(e))
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": None},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
workflow_id=self.workflow_id,
graph=iteration_graph,
graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
)
def _run_single_iter_parallel(
self,
*,
flask_app: Flask,
context: contextvars.Context,
q: Queue,
iterator_list_value: Sequence[str],
inputs: Mapping[str, list],
outputs: list,
start_at: datetime,
graph_engine: "GraphEngine",
iteration_graph: Graph,
index: int,
item: Any,
iter_run_map: dict[str, float],
):
"""
run single iteration in parallel mode
"""
with preserve_flask_contexts(flask_app, context_vars=context):
parallel_mode_run_id = uuid.uuid4().hex
graph_engine_copy = graph_engine.create_copy()
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
variable_pool_copy.add([self.node_id, "index"], index)
variable_pool_copy.add([self.node_id, "item"], item)
for event in self._run_single_iter(
iterator_list_value=iterator_list_value,
variable_pool=variable_pool_copy,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine_copy,
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
parallel_mode_run_id=parallel_mode_run_id,
):
q.put(event)
graph_engine.graph_runtime_state.total_tokens += graph_engine_copy.graph_runtime_state.total_tokens
return graph_engine

View File

@@ -1,20 +1,19 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import IterationStartNodeData
class IterationStartNode(BaseNode):
class IterationStartNode(Node):
"""
Iteration Start Node.
"""
_node_type = NodeType.ITERATION_START
node_type = NodeType.ITERATION_START
_node_data: IterationStartNodeData