|
|
|
|
@@ -1,58 +1,52 @@
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
import time
|
|
|
|
|
from collections.abc import Generator, Mapping, Sequence
|
|
|
|
|
from collections.abc import Callable, Generator, Mapping, Sequence
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
|
|
|
|
|
|
|
|
|
from configs import dify_config
|
|
|
|
|
from core.variables import (
|
|
|
|
|
IntegerSegment,
|
|
|
|
|
Segment,
|
|
|
|
|
SegmentType,
|
|
|
|
|
from core.variables import Segment, SegmentType
|
|
|
|
|
from core.workflow.enums import (
|
|
|
|
|
ErrorStrategy,
|
|
|
|
|
NodeExecutionType,
|
|
|
|
|
NodeType,
|
|
|
|
|
WorkflowNodeExecutionMetadataKey,
|
|
|
|
|
WorkflowNodeExecutionStatus,
|
|
|
|
|
)
|
|
|
|
|
from core.workflow.entities.node_entities import NodeRunResult
|
|
|
|
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
|
|
|
|
from core.workflow.graph_engine.entities.event import (
|
|
|
|
|
BaseGraphEvent,
|
|
|
|
|
BaseNodeEvent,
|
|
|
|
|
BaseParallelBranchEvent,
|
|
|
|
|
from core.workflow.graph_events import (
|
|
|
|
|
GraphNodeEventBase,
|
|
|
|
|
GraphRunFailedEvent,
|
|
|
|
|
InNodeEvent,
|
|
|
|
|
LoopRunFailedEvent,
|
|
|
|
|
LoopRunNextEvent,
|
|
|
|
|
LoopRunStartedEvent,
|
|
|
|
|
LoopRunSucceededEvent,
|
|
|
|
|
NodeRunFailedEvent,
|
|
|
|
|
NodeRunStartedEvent,
|
|
|
|
|
NodeRunStreamChunkEvent,
|
|
|
|
|
NodeRunSucceededEvent,
|
|
|
|
|
)
|
|
|
|
|
from core.workflow.graph_engine.entities.graph import Graph
|
|
|
|
|
from core.workflow.nodes.base import BaseNode
|
|
|
|
|
from core.workflow.node_events import (
|
|
|
|
|
LoopFailedEvent,
|
|
|
|
|
LoopNextEvent,
|
|
|
|
|
LoopStartedEvent,
|
|
|
|
|
LoopSucceededEvent,
|
|
|
|
|
NodeEventBase,
|
|
|
|
|
NodeRunResult,
|
|
|
|
|
StreamCompletedEvent,
|
|
|
|
|
)
|
|
|
|
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
|
|
|
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
|
|
|
|
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
|
|
|
|
from core.workflow.nodes.loop.entities import LoopNodeData
|
|
|
|
|
from core.workflow.nodes.base.node import Node
|
|
|
|
|
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
|
|
|
|
from core.workflow.utils.condition.processor import ConditionProcessor
|
|
|
|
|
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
|
|
|
|
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
|
|
|
|
from libs.datetime_utils import naive_utc_now
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from core.workflow.entities.variable_pool import VariablePool
|
|
|
|
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
|
|
|
|
from core.workflow.graph_engine import GraphEngine
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoopNode(BaseNode):
|
|
|
|
|
class LoopNode(Node):
|
|
|
|
|
"""
|
|
|
|
|
Loop Node.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
_node_type = NodeType.LOOP
|
|
|
|
|
|
|
|
|
|
node_type = NodeType.LOOP
|
|
|
|
|
_node_data: LoopNodeData
|
|
|
|
|
execution_type = NodeExecutionType.CONTAINER
|
|
|
|
|
|
|
|
|
|
def init_node_data(self, data: Mapping[str, Any]):
|
|
|
|
|
self._node_data = LoopNodeData.model_validate(data)
|
|
|
|
|
@@ -79,7 +73,7 @@ class LoopNode(BaseNode):
|
|
|
|
|
def version(cls) -> str:
|
|
|
|
|
return "1"
|
|
|
|
|
|
|
|
|
|
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
|
|
|
|
def _run(self) -> Generator:
|
|
|
|
|
"""Run the node."""
|
|
|
|
|
# Get inputs
|
|
|
|
|
loop_count = self._node_data.loop_count
|
|
|
|
|
@@ -89,144 +83,128 @@ class LoopNode(BaseNode):
|
|
|
|
|
inputs = {"loop_count": loop_count}
|
|
|
|
|
|
|
|
|
|
if not self._node_data.start_node_id:
|
|
|
|
|
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
|
|
|
|
|
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
|
|
|
|
|
|
|
|
|
|
# Initialize graph
|
|
|
|
|
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
|
|
|
|
|
if not loop_graph:
|
|
|
|
|
raise ValueError("loop graph not found")
|
|
|
|
|
root_node_id = self._node_data.start_node_id
|
|
|
|
|
|
|
|
|
|
# Initialize variable pool
|
|
|
|
|
variable_pool = self.graph_runtime_state.variable_pool
|
|
|
|
|
variable_pool.add([self.node_id, "index"], 0)
|
|
|
|
|
|
|
|
|
|
# Initialize loop variables
|
|
|
|
|
# Initialize loop variables in the original variable pool
|
|
|
|
|
loop_variable_selectors = {}
|
|
|
|
|
if self._node_data.loop_variables:
|
|
|
|
|
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
|
|
|
|
|
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
|
|
|
|
|
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
|
|
|
|
|
if isinstance(var.value, list)
|
|
|
|
|
else None,
|
|
|
|
|
}
|
|
|
|
|
for loop_variable in self._node_data.loop_variables:
|
|
|
|
|
value_processor = {
|
|
|
|
|
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
|
|
|
|
|
"variable": lambda var=loop_variable: variable_pool.get(var.value),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if loop_variable.value_type not in value_processor:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
processed_segment = value_processor[loop_variable.value_type]()
|
|
|
|
|
processed_segment = value_processor[loop_variable.value_type](loop_variable)
|
|
|
|
|
if not processed_segment:
|
|
|
|
|
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
|
|
|
|
|
variable_selector = [self.node_id, loop_variable.label]
|
|
|
|
|
variable_pool.add(variable_selector, processed_segment.value)
|
|
|
|
|
variable_selector = [self._node_id, loop_variable.label]
|
|
|
|
|
variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
|
|
|
|
|
self.graph_runtime_state.variable_pool.add(variable_selector, variable)
|
|
|
|
|
loop_variable_selectors[loop_variable.label] = variable_selector
|
|
|
|
|
inputs[loop_variable.label] = processed_segment.value
|
|
|
|
|
|
|
|
|
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
|
|
|
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
|
|
|
|
|
|
|
|
|
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
|
|
|
|
|
|
|
|
|
graph_engine = GraphEngine(
|
|
|
|
|
tenant_id=self.tenant_id,
|
|
|
|
|
app_id=self.app_id,
|
|
|
|
|
workflow_type=self.workflow_type,
|
|
|
|
|
workflow_id=self.workflow_id,
|
|
|
|
|
user_id=self.user_id,
|
|
|
|
|
user_from=self.user_from,
|
|
|
|
|
invoke_from=self.invoke_from,
|
|
|
|
|
call_depth=self.workflow_call_depth,
|
|
|
|
|
graph=loop_graph,
|
|
|
|
|
graph_config=self.graph_config,
|
|
|
|
|
graph_runtime_state=graph_runtime_state,
|
|
|
|
|
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
|
|
|
|
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
|
|
|
|
thread_pool_id=self.thread_pool_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
start_at = naive_utc_now()
|
|
|
|
|
condition_processor = ConditionProcessor()
|
|
|
|
|
|
|
|
|
|
loop_duration_map: dict[str, float] = {}
|
|
|
|
|
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
|
|
|
|
|
|
|
|
|
# Start Loop event
|
|
|
|
|
yield LoopRunStartedEvent(
|
|
|
|
|
loop_id=self.id,
|
|
|
|
|
loop_node_id=self.node_id,
|
|
|
|
|
loop_node_type=self.type_,
|
|
|
|
|
loop_node_data=self._node_data,
|
|
|
|
|
yield LoopStartedEvent(
|
|
|
|
|
start_at=start_at,
|
|
|
|
|
inputs=inputs,
|
|
|
|
|
metadata={"loop_length": loop_count},
|
|
|
|
|
predecessor_node_id=self.previous_node_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# yield LoopRunNextEvent(
|
|
|
|
|
# loop_id=self.id,
|
|
|
|
|
# loop_node_id=self.node_id,
|
|
|
|
|
# loop_node_type=self.node_type,
|
|
|
|
|
# loop_node_data=self.node_data,
|
|
|
|
|
# index=0,
|
|
|
|
|
# pre_loop_output=None,
|
|
|
|
|
# )
|
|
|
|
|
loop_duration_map = {}
|
|
|
|
|
single_loop_variable_map = {} # single loop variable output
|
|
|
|
|
try:
|
|
|
|
|
check_break_result = False
|
|
|
|
|
for i in range(loop_count):
|
|
|
|
|
loop_start_time = naive_utc_now()
|
|
|
|
|
# run single loop
|
|
|
|
|
loop_result = yield from self._run_single_loop(
|
|
|
|
|
graph_engine=graph_engine,
|
|
|
|
|
loop_graph=loop_graph,
|
|
|
|
|
variable_pool=variable_pool,
|
|
|
|
|
loop_variable_selectors=loop_variable_selectors,
|
|
|
|
|
break_conditions=break_conditions,
|
|
|
|
|
logical_operator=logical_operator,
|
|
|
|
|
condition_processor=condition_processor,
|
|
|
|
|
current_index=i,
|
|
|
|
|
start_at=start_at,
|
|
|
|
|
inputs=inputs,
|
|
|
|
|
reach_break_condition = False
|
|
|
|
|
if break_conditions:
|
|
|
|
|
_, _, reach_break_condition = condition_processor.process_conditions(
|
|
|
|
|
variable_pool=self.graph_runtime_state.variable_pool,
|
|
|
|
|
conditions=break_conditions,
|
|
|
|
|
operator=logical_operator,
|
|
|
|
|
)
|
|
|
|
|
loop_end_time = naive_utc_now()
|
|
|
|
|
if reach_break_condition:
|
|
|
|
|
loop_count = 0
|
|
|
|
|
cost_tokens = 0
|
|
|
|
|
|
|
|
|
|
for i in range(loop_count):
|
|
|
|
|
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
|
|
|
|
|
|
|
|
|
|
loop_start_time = naive_utc_now()
|
|
|
|
|
reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i)
|
|
|
|
|
# Track loop duration
|
|
|
|
|
loop_duration_map[str(i)] = (naive_utc_now() - loop_start_time).total_seconds()
|
|
|
|
|
|
|
|
|
|
# Accumulate outputs from the sub-graph's response nodes
|
|
|
|
|
for key, value in graph_engine.graph_runtime_state.outputs.items():
|
|
|
|
|
if key == "answer":
|
|
|
|
|
# Concatenate answer outputs with newline
|
|
|
|
|
existing_answer = self.graph_runtime_state.get_output("answer", "")
|
|
|
|
|
if existing_answer:
|
|
|
|
|
self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}")
|
|
|
|
|
else:
|
|
|
|
|
self.graph_runtime_state.set_output("answer", value)
|
|
|
|
|
else:
|
|
|
|
|
# For other outputs, just update
|
|
|
|
|
self.graph_runtime_state.set_output(key, value)
|
|
|
|
|
|
|
|
|
|
# Update the total tokens from this iteration
|
|
|
|
|
cost_tokens += graph_engine.graph_runtime_state.total_tokens
|
|
|
|
|
|
|
|
|
|
# Collect loop variable values after iteration
|
|
|
|
|
single_loop_variable = {}
|
|
|
|
|
for key, selector in loop_variable_selectors.items():
|
|
|
|
|
item = variable_pool.get(selector)
|
|
|
|
|
if item:
|
|
|
|
|
single_loop_variable[key] = item.value
|
|
|
|
|
else:
|
|
|
|
|
single_loop_variable[key] = None
|
|
|
|
|
segment = self.graph_runtime_state.variable_pool.get(selector)
|
|
|
|
|
single_loop_variable[key] = segment.value if segment else None
|
|
|
|
|
|
|
|
|
|
loop_duration_map[str(i)] = (loop_end_time - loop_start_time).total_seconds()
|
|
|
|
|
single_loop_variable_map[str(i)] = single_loop_variable
|
|
|
|
|
|
|
|
|
|
check_break_result = loop_result.get("check_break_result", False)
|
|
|
|
|
|
|
|
|
|
if check_break_result:
|
|
|
|
|
if reach_break_node:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if break_conditions:
|
|
|
|
|
_, _, reach_break_condition = condition_processor.process_conditions(
|
|
|
|
|
variable_pool=self.graph_runtime_state.variable_pool,
|
|
|
|
|
conditions=break_conditions,
|
|
|
|
|
operator=logical_operator,
|
|
|
|
|
)
|
|
|
|
|
if reach_break_condition:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
yield LoopNextEvent(
|
|
|
|
|
index=i + 1,
|
|
|
|
|
pre_loop_output=self._node_data.outputs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.graph_runtime_state.total_tokens += cost_tokens
|
|
|
|
|
# Loop completed successfully
|
|
|
|
|
yield LoopRunSucceededEvent(
|
|
|
|
|
loop_id=self.id,
|
|
|
|
|
loop_node_id=self.node_id,
|
|
|
|
|
loop_node_type=self.type_,
|
|
|
|
|
loop_node_data=self._node_data,
|
|
|
|
|
yield LoopSucceededEvent(
|
|
|
|
|
start_at=start_at,
|
|
|
|
|
inputs=inputs,
|
|
|
|
|
outputs=self._node_data.outputs,
|
|
|
|
|
steps=loop_count,
|
|
|
|
|
metadata={
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
|
|
|
|
"completed_reason": "loop_break" if check_break_result else "loop_completed",
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens,
|
|
|
|
|
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
yield RunCompletedEvent(
|
|
|
|
|
run_result=NodeRunResult(
|
|
|
|
|
yield StreamCompletedEvent(
|
|
|
|
|
node_run_result=NodeRunResult(
|
|
|
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
|
|
|
metadata={
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
|
|
|
|
},
|
|
|
|
|
@@ -236,18 +214,12 @@ class LoopNode(BaseNode):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# Loop failed
|
|
|
|
|
logger.exception("Loop run failed")
|
|
|
|
|
yield LoopRunFailedEvent(
|
|
|
|
|
loop_id=self.id,
|
|
|
|
|
loop_node_id=self.node_id,
|
|
|
|
|
loop_node_type=self.type_,
|
|
|
|
|
loop_node_data=self._node_data,
|
|
|
|
|
yield LoopFailedEvent(
|
|
|
|
|
start_at=start_at,
|
|
|
|
|
inputs=inputs,
|
|
|
|
|
steps=loop_count,
|
|
|
|
|
metadata={
|
|
|
|
|
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
|
|
|
|
"completed_reason": "error",
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
|
|
|
|
@@ -255,215 +227,60 @@ class LoopNode(BaseNode):
|
|
|
|
|
error=str(e),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
yield RunCompletedEvent(
|
|
|
|
|
run_result=NodeRunResult(
|
|
|
|
|
yield StreamCompletedEvent(
|
|
|
|
|
node_run_result=NodeRunResult(
|
|
|
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
|
|
error=str(e),
|
|
|
|
|
metadata={
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
# Clean up
|
|
|
|
|
variable_pool.remove([self.node_id, "index"])
|
|
|
|
|
|
|
|
|
|
def _run_single_loop(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
graph_engine: "GraphEngine",
|
|
|
|
|
loop_graph: Graph,
|
|
|
|
|
variable_pool: "VariablePool",
|
|
|
|
|
loop_variable_selectors: dict,
|
|
|
|
|
break_conditions: list,
|
|
|
|
|
logical_operator: Literal["and", "or"],
|
|
|
|
|
condition_processor: ConditionProcessor,
|
|
|
|
|
current_index: int,
|
|
|
|
|
start_at: datetime,
|
|
|
|
|
inputs: dict,
|
|
|
|
|
) -> Generator[NodeEvent | InNodeEvent, None, dict]:
|
|
|
|
|
"""Run a single loop iteration.
|
|
|
|
|
Returns:
|
|
|
|
|
dict: {'check_break_result': bool}
|
|
|
|
|
"""
|
|
|
|
|
condition_selectors = self._extract_selectors_from_conditions(break_conditions)
|
|
|
|
|
extended_selectors = {**loop_variable_selectors, **condition_selectors}
|
|
|
|
|
# Run workflow
|
|
|
|
|
rst = graph_engine.run()
|
|
|
|
|
current_index_variable = variable_pool.get([self.node_id, "index"])
|
|
|
|
|
if not isinstance(current_index_variable, IntegerSegment):
|
|
|
|
|
raise ValueError(f"loop {self.node_id} current index not found")
|
|
|
|
|
current_index = current_index_variable.value
|
|
|
|
|
) -> Generator[NodeEventBase | GraphNodeEventBase, None, bool]:
|
|
|
|
|
reach_break_node = False
|
|
|
|
|
for event in graph_engine.run():
|
|
|
|
|
if isinstance(event, GraphNodeEventBase):
|
|
|
|
|
self._append_loop_info_to_event(event=event, loop_run_index=current_index)
|
|
|
|
|
|
|
|
|
|
check_break_result = False
|
|
|
|
|
|
|
|
|
|
for event in rst:
|
|
|
|
|
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id: # ty: ignore [unresolved-attribute]
|
|
|
|
|
event.in_loop_id = self.node_id # ty: ignore [unresolved-attribute]
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
isinstance(event, BaseNodeEvent)
|
|
|
|
|
and event.node_type == NodeType.LOOP_START
|
|
|
|
|
and not isinstance(event, NodeRunStreamChunkEvent)
|
|
|
|
|
):
|
|
|
|
|
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.LOOP_START:
|
|
|
|
|
continue
|
|
|
|
|
if isinstance(event, GraphNodeEventBase):
|
|
|
|
|
yield event
|
|
|
|
|
if isinstance(event, NodeRunSucceededEvent) and event.node_type == NodeType.LOOP_END:
|
|
|
|
|
reach_break_node = True
|
|
|
|
|
if isinstance(event, GraphRunFailedEvent):
|
|
|
|
|
raise Exception(event.error)
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
isinstance(event, NodeRunSucceededEvent)
|
|
|
|
|
and event.node_type == NodeType.LOOP_END
|
|
|
|
|
and not isinstance(event, NodeRunStreamChunkEvent)
|
|
|
|
|
):
|
|
|
|
|
check_break_result = True
|
|
|
|
|
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
|
|
|
|
break
|
|
|
|
|
for loop_var in self._node_data.loop_variables or []:
|
|
|
|
|
key, sel = loop_var.label, [self._node_id, loop_var.label]
|
|
|
|
|
segment = self.graph_runtime_state.variable_pool.get(sel)
|
|
|
|
|
self._node_data.outputs[key] = segment.value if segment else None
|
|
|
|
|
self._node_data.outputs["loop_round"] = current_index + 1
|
|
|
|
|
|
|
|
|
|
if isinstance(event, NodeRunSucceededEvent):
|
|
|
|
|
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
|
|
|
|
return reach_break_node
|
|
|
|
|
|
|
|
|
|
# Check if all variables in break conditions exist
|
|
|
|
|
exists_variable = False
|
|
|
|
|
for condition in break_conditions:
|
|
|
|
|
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector):
|
|
|
|
|
exists_variable = False
|
|
|
|
|
break
|
|
|
|
|
else:
|
|
|
|
|
exists_variable = True
|
|
|
|
|
if exists_variable:
|
|
|
|
|
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
|
|
|
|
|
variable_pool=self.graph_runtime_state.variable_pool,
|
|
|
|
|
conditions=break_conditions,
|
|
|
|
|
operator=logical_operator,
|
|
|
|
|
)
|
|
|
|
|
if check_break_result:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
elif isinstance(event, BaseGraphEvent):
|
|
|
|
|
if isinstance(event, GraphRunFailedEvent):
|
|
|
|
|
# Loop run failed
|
|
|
|
|
yield LoopRunFailedEvent(
|
|
|
|
|
loop_id=self.id,
|
|
|
|
|
loop_node_id=self.node_id,
|
|
|
|
|
loop_node_type=self.type_,
|
|
|
|
|
loop_node_data=self._node_data,
|
|
|
|
|
start_at=start_at,
|
|
|
|
|
inputs=inputs,
|
|
|
|
|
steps=current_index,
|
|
|
|
|
metadata={
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
|
|
|
|
|
graph_engine.graph_runtime_state.total_tokens
|
|
|
|
|
),
|
|
|
|
|
"completed_reason": "error",
|
|
|
|
|
},
|
|
|
|
|
error=event.error,
|
|
|
|
|
)
|
|
|
|
|
yield RunCompletedEvent(
|
|
|
|
|
run_result=NodeRunResult(
|
|
|
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
|
|
error=event.error,
|
|
|
|
|
metadata={
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
|
|
|
|
|
graph_engine.graph_runtime_state.total_tokens
|
|
|
|
|
)
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return {"check_break_result": True}
|
|
|
|
|
elif isinstance(event, NodeRunFailedEvent):
|
|
|
|
|
# Loop run failed
|
|
|
|
|
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
|
|
|
|
yield LoopRunFailedEvent(
|
|
|
|
|
loop_id=self.id,
|
|
|
|
|
loop_node_id=self.node_id,
|
|
|
|
|
loop_node_type=self.type_,
|
|
|
|
|
loop_node_data=self._node_data,
|
|
|
|
|
start_at=start_at,
|
|
|
|
|
inputs=inputs,
|
|
|
|
|
steps=current_index,
|
|
|
|
|
metadata={
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
|
|
|
|
"completed_reason": "error",
|
|
|
|
|
},
|
|
|
|
|
error=event.error,
|
|
|
|
|
)
|
|
|
|
|
yield RunCompletedEvent(
|
|
|
|
|
run_result=NodeRunResult(
|
|
|
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
|
|
error=event.error,
|
|
|
|
|
metadata={
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return {"check_break_result": True}
|
|
|
|
|
else:
|
|
|
|
|
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
|
|
|
|
|
|
|
|
|
|
_outputs: dict[str, Segment | int | None] = {}
|
|
|
|
|
for loop_variable_key, loop_variable_selector in extended_selectors.items():
|
|
|
|
|
_loop_variable_segment = variable_pool.get(loop_variable_selector)
|
|
|
|
|
if _loop_variable_segment:
|
|
|
|
|
_outputs[loop_variable_key] = _loop_variable_segment
|
|
|
|
|
else:
|
|
|
|
|
_outputs[loop_variable_key] = None
|
|
|
|
|
|
|
|
|
|
_outputs["loop_round"] = current_index + 1
|
|
|
|
|
self._node_data.outputs = _outputs
|
|
|
|
|
|
|
|
|
|
# Remove all nodes outputs from variable pool
|
|
|
|
|
for node_id in loop_graph.node_ids:
|
|
|
|
|
variable_pool.remove([node_id])
|
|
|
|
|
|
|
|
|
|
if check_break_result:
|
|
|
|
|
return {"check_break_result": True}
|
|
|
|
|
|
|
|
|
|
# Move to next loop
|
|
|
|
|
next_index = current_index + 1
|
|
|
|
|
variable_pool.add([self.node_id, "index"], next_index)
|
|
|
|
|
|
|
|
|
|
yield LoopRunNextEvent(
|
|
|
|
|
loop_id=self.id,
|
|
|
|
|
loop_node_id=self.node_id,
|
|
|
|
|
loop_node_type=self.type_,
|
|
|
|
|
loop_node_data=self._node_data,
|
|
|
|
|
index=next_index,
|
|
|
|
|
pre_loop_output=self._node_data.outputs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return {"check_break_result": False}
|
|
|
|
|
|
|
|
|
|
def _extract_selectors_from_conditions(self, conditions: list) -> dict[str, list[str]]:
|
|
|
|
|
return {
|
|
|
|
|
condition.variable_selector[1]: condition.variable_selector
|
|
|
|
|
for condition in conditions
|
|
|
|
|
if condition.variable_selector and len(condition.variable_selector) >= 2
|
|
|
|
|
def _append_loop_info_to_event(
|
|
|
|
|
self,
|
|
|
|
|
event: GraphNodeEventBase,
|
|
|
|
|
loop_run_index: int,
|
|
|
|
|
):
|
|
|
|
|
event.in_loop_id = self._node_id
|
|
|
|
|
loop_metadata = {
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.LOOP_ID: self._node_id,
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: loop_run_index,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _handle_event_metadata(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
event: BaseNodeEvent | InNodeEvent,
|
|
|
|
|
iter_run_index: int,
|
|
|
|
|
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
|
|
|
|
|
"""
|
|
|
|
|
add iteration metadata to event.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(event, BaseNodeEvent):
|
|
|
|
|
return event
|
|
|
|
|
if event.route_node_state.node_run_result:
|
|
|
|
|
metadata = event.route_node_state.node_run_result.metadata
|
|
|
|
|
if not metadata:
|
|
|
|
|
metadata = {}
|
|
|
|
|
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in metadata:
|
|
|
|
|
metadata = {
|
|
|
|
|
**metadata,
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.LOOP_ID: self.node_id,
|
|
|
|
|
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: iter_run_index,
|
|
|
|
|
}
|
|
|
|
|
event.route_node_state.node_run_result.metadata = metadata
|
|
|
|
|
return event
|
|
|
|
|
current_metadata = event.node_run_result.metadata
|
|
|
|
|
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
|
|
|
|
|
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _extract_variable_selector_to_variable_mapping(
|
|
|
|
|
@@ -479,12 +296,43 @@ class LoopNode(BaseNode):
|
|
|
|
|
variable_mapping = {}
|
|
|
|
|
|
|
|
|
|
# init graph
|
|
|
|
|
loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
|
|
|
|
|
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
|
|
|
|
from core.workflow.graph import Graph
|
|
|
|
|
from core.workflow.nodes.node_factory import DifyNodeFactory
|
|
|
|
|
|
|
|
|
|
# Create minimal GraphInitParams for static analysis
|
|
|
|
|
graph_init_params = GraphInitParams(
|
|
|
|
|
tenant_id="",
|
|
|
|
|
app_id="",
|
|
|
|
|
workflow_id="",
|
|
|
|
|
graph_config=graph_config,
|
|
|
|
|
user_id="",
|
|
|
|
|
user_from="",
|
|
|
|
|
invoke_from="",
|
|
|
|
|
call_depth=0,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create minimal GraphRuntimeState for static analysis
|
|
|
|
|
graph_runtime_state = GraphRuntimeState(
|
|
|
|
|
variable_pool=VariablePool(),
|
|
|
|
|
start_at=0,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create node factory for static analysis
|
|
|
|
|
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
|
|
|
|
|
|
|
|
|
loop_graph = Graph.init(
|
|
|
|
|
graph_config=graph_config,
|
|
|
|
|
node_factory=node_factory,
|
|
|
|
|
root_node_id=typed_node_data.start_node_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not loop_graph:
|
|
|
|
|
raise ValueError("loop graph not found")
|
|
|
|
|
|
|
|
|
|
for sub_node_id, sub_node_config in loop_graph.node_id_config_mapping.items():
|
|
|
|
|
# Get node configs from graph_config instead of non-existent node_id_config_mapping
|
|
|
|
|
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
|
|
|
|
for sub_node_id, sub_node_config in node_configs.items():
|
|
|
|
|
if sub_node_config.get("data", {}).get("loop_id") != node_id:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
@@ -560,3 +408,47 @@ class LoopNode(BaseNode):
|
|
|
|
|
except ValueError:
|
|
|
|
|
raise type_exc
|
|
|
|
|
return build_segment_with_type(var_type, value)
|
|
|
|
|
|
|
|
|
|
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
|
|
|
|
# Import dependencies
|
|
|
|
|
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
|
|
|
|
from core.workflow.graph import Graph
|
|
|
|
|
from core.workflow.graph_engine import GraphEngine
|
|
|
|
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
|
|
|
|
from core.workflow.nodes.node_factory import DifyNodeFactory
|
|
|
|
|
|
|
|
|
|
# Create GraphInitParams from node attributes
|
|
|
|
|
graph_init_params = GraphInitParams(
|
|
|
|
|
tenant_id=self.tenant_id,
|
|
|
|
|
app_id=self.app_id,
|
|
|
|
|
workflow_id=self.workflow_id,
|
|
|
|
|
graph_config=self.graph_config,
|
|
|
|
|
user_id=self.user_id,
|
|
|
|
|
user_from=self.user_from.value,
|
|
|
|
|
invoke_from=self.invoke_from.value,
|
|
|
|
|
call_depth=self.workflow_call_depth,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create a new GraphRuntimeState for this iteration
|
|
|
|
|
graph_runtime_state_copy = GraphRuntimeState(
|
|
|
|
|
variable_pool=self.graph_runtime_state.variable_pool,
|
|
|
|
|
start_at=start_at.timestamp(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create a new node factory with the new GraphRuntimeState
|
|
|
|
|
node_factory = DifyNodeFactory(
|
|
|
|
|
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Initialize the loop graph with the new node factory
|
|
|
|
|
loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
|
|
|
|
|
|
|
|
|
# Create a new GraphEngine for this iteration
|
|
|
|
|
graph_engine = GraphEngine(
|
|
|
|
|
workflow_id=self.workflow_id,
|
|
|
|
|
graph=loop_graph,
|
|
|
|
|
graph_runtime_state=graph_runtime_state_copy,
|
|
|
|
|
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return graph_engine
|
|
|
|
|
|