feat: knowledge pipeline (#25360)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: jyong <718720800@qq.com>
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: quicksand <quicksandzn@gmail.com>
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: Hanqing Zhao <sherry9277@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry <xh001x@hotmail.com>
This commit is contained in:
-LAN-
2025-09-18 12:49:10 +08:00
committed by GitHub
parent 7dadb33003
commit 85cda47c70
1772 changed files with 102407 additions and 31710 deletions

View File

@@ -1,7 +1,6 @@
from collections.abc import Mapping
from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field
from pydantic import AfterValidator, BaseModel, Field, field_validator
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
@@ -39,15 +38,18 @@ class LoopVariableData(BaseModel):
class LoopNodeData(BaseLoopNodeData):
"""
Loop Node Data.
"""
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
outputs: Mapping[str, Any] | None = None
outputs: dict[str, Any] = Field(default_factory=dict)
@field_validator("outputs", mode="before")
@classmethod
def validate_outputs(cls, v):
if v is None:
return {}
return v
class LoopStartNodeData(BaseNodeData):
@@ -72,7 +74,7 @@ class LoopState(BaseLoopState):
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Any | None = None
current_output: Any = None
class MetaData(BaseLoopState.MetaData):
"""
@@ -81,7 +83,7 @@ class LoopState(BaseLoopState):
loop_length: int
def get_last_output(self) -> Any | None:
def get_last_output(self) -> Any:
"""
Get last output.
"""
@@ -89,7 +91,7 @@ class LoopState(BaseLoopState):
return self.outputs[-1]
return None
def get_current_output(self) -> Any | None:
def get_current_output(self) -> Any:
"""
Get current output.
"""

View File

@@ -1,20 +1,19 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopEndNodeData
class LoopEndNode(BaseNode):
class LoopEndNode(Node):
"""
Loop End Node.
"""
_node_type = NodeType.LOOP_END
node_type = NodeType.LOOP_END
_node_data: LoopEndNodeData

View File

@@ -1,58 +1,52 @@
import json
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from configs import dify_config
from core.variables import (
IntegerSegment,
Segment,
SegmentType,
from core.variables import Segment, SegmentType
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
BaseParallelBranchEvent,
from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunFailedEvent,
InNodeEvent,
LoopRunFailedEvent,
LoopRunNextEvent,
LoopRunStartedEvent,
LoopRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
from core.workflow.node_events import (
LoopFailedEvent,
LoopNextEvent,
LoopStartedEvent,
LoopSucceededEvent,
NodeEventBase,
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
from core.workflow.utils.condition.processor import ConditionProcessor
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
from libs.datetime_utils import naive_utc_now
if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
class LoopNode(BaseNode):
class LoopNode(Node):
"""
Loop Node.
"""
_node_type = NodeType.LOOP
node_type = NodeType.LOOP
_node_data: LoopNodeData
execution_type = NodeExecutionType.CONTAINER
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopNodeData.model_validate(data)
@@ -79,7 +73,7 @@ class LoopNode(BaseNode):
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def _run(self) -> Generator:
"""Run the node."""
# Get inputs
loop_count = self._node_data.loop_count
@@ -89,144 +83,128 @@ class LoopNode(BaseNode):
inputs = {"loop_count": loop_count}
if not self._node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
# Initialize graph
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
root_node_id = self._node_data.start_node_id
# Initialize variable pool
variable_pool = self.graph_runtime_state.variable_pool
variable_pool.add([self.node_id, "index"], 0)
# Initialize loop variables
# Initialize loop variables in the original variable pool
loop_variable_selectors = {}
if self._node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
if isinstance(var.value, list)
else None,
}
for loop_variable in self._node_data.loop_variables:
value_processor = {
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var=loop_variable: variable_pool.get(var.value),
}
if loop_variable.value_type not in value_processor:
raise ValueError(
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
)
processed_segment = value_processor[loop_variable.value_type]()
processed_segment = value_processor[loop_variable.value_type](loop_variable)
if not processed_segment:
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
variable_selector = [self.node_id, loop_variable.label]
variable_pool.add(variable_selector, processed_segment.value)
variable_selector = [self._node_id, loop_variable.label]
variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
self.graph_runtime_state.variable_pool.add(variable_selector, variable)
loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=loop_graph,
graph_config=self.graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id,
)
start_at = naive_utc_now()
condition_processor = ConditionProcessor()
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
# Start Loop event
yield LoopRunStartedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
yield LoopStartedEvent(
start_at=start_at,
inputs=inputs,
metadata={"loop_length": loop_count},
predecessor_node_id=self.previous_node_id,
)
# yield LoopRunNextEvent(
# loop_id=self.id,
# loop_node_id=self.node_id,
# loop_node_type=self.node_type,
# loop_node_data=self.node_data,
# index=0,
# pre_loop_output=None,
# )
loop_duration_map = {}
single_loop_variable_map = {} # single loop variable output
try:
check_break_result = False
for i in range(loop_count):
loop_start_time = naive_utc_now()
# run single loop
loop_result = yield from self._run_single_loop(
graph_engine=graph_engine,
loop_graph=loop_graph,
variable_pool=variable_pool,
loop_variable_selectors=loop_variable_selectors,
break_conditions=break_conditions,
logical_operator=logical_operator,
condition_processor=condition_processor,
current_index=i,
start_at=start_at,
inputs=inputs,
reach_break_condition = False
if break_conditions:
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
loop_end_time = naive_utc_now()
if reach_break_condition:
loop_count = 0
cost_tokens = 0
for i in range(loop_count):
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
loop_start_time = naive_utc_now()
reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i)
# Track loop duration
loop_duration_map[str(i)] = (naive_utc_now() - loop_start_time).total_seconds()
# Accumulate outputs from the sub-graph's response nodes
for key, value in graph_engine.graph_runtime_state.outputs.items():
if key == "answer":
# Concatenate answer outputs with newline
existing_answer = self.graph_runtime_state.get_output("answer", "")
if existing_answer:
self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}")
else:
self.graph_runtime_state.set_output("answer", value)
else:
# For other outputs, just update
self.graph_runtime_state.set_output(key, value)
# Update the total tokens from this iteration
cost_tokens += graph_engine.graph_runtime_state.total_tokens
# Collect loop variable values after iteration
single_loop_variable = {}
for key, selector in loop_variable_selectors.items():
item = variable_pool.get(selector)
if item:
single_loop_variable[key] = item.value
else:
single_loop_variable[key] = None
segment = self.graph_runtime_state.variable_pool.get(selector)
single_loop_variable[key] = segment.value if segment else None
loop_duration_map[str(i)] = (loop_end_time - loop_start_time).total_seconds()
single_loop_variable_map[str(i)] = single_loop_variable
check_break_result = loop_result.get("check_break_result", False)
if check_break_result:
if reach_break_node:
break
if break_conditions:
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
if reach_break_condition:
break
yield LoopNextEvent(
index=i + 1,
pre_loop_output=self._node_data.outputs,
)
self.graph_runtime_state.total_tokens += cost_tokens
# Loop completed successfully
yield LoopRunSucceededEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
yield LoopSucceededEvent(
start_at=start_at,
inputs=inputs,
outputs=self._node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "loop_break" if check_break_result else "loop_completed",
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens,
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
@@ -236,18 +214,12 @@ class LoopNode(BaseNode):
)
except Exception as e:
# Loop failed
logger.exception("Loop run failed")
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
yield LoopFailedEvent(
start_at=start_at,
inputs=inputs,
steps=loop_count,
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
"completed_reason": "error",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
@@ -255,215 +227,60 @@ class LoopNode(BaseNode):
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)
)
finally:
# Clean up
variable_pool.remove([self.node_id, "index"])
def _run_single_loop(
self,
*,
graph_engine: "GraphEngine",
loop_graph: Graph,
variable_pool: "VariablePool",
loop_variable_selectors: dict,
break_conditions: list,
logical_operator: Literal["and", "or"],
condition_processor: ConditionProcessor,
current_index: int,
start_at: datetime,
inputs: dict,
) -> Generator[NodeEvent | InNodeEvent, None, dict]:
"""Run a single loop iteration.
Returns:
dict: {'check_break_result': bool}
"""
condition_selectors = self._extract_selectors_from_conditions(break_conditions)
extended_selectors = {**loop_variable_selectors, **condition_selectors}
# Run workflow
rst = graph_engine.run()
current_index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(current_index_variable, IntegerSegment):
raise ValueError(f"loop {self.node_id} current index not found")
current_index = current_index_variable.value
) -> Generator[NodeEventBase | GraphNodeEventBase, None, bool]:
reach_break_node = False
for event in graph_engine.run():
if isinstance(event, GraphNodeEventBase):
self._append_loop_info_to_event(event=event, loop_run_index=current_index)
check_break_result = False
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id: # ty: ignore [unresolved-attribute]
event.in_loop_id = self.node_id # ty: ignore [unresolved-attribute]
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.LOOP_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.LOOP_START:
continue
if isinstance(event, GraphNodeEventBase):
yield event
if isinstance(event, NodeRunSucceededEvent) and event.node_type == NodeType.LOOP_END:
reach_break_node = True
if isinstance(event, GraphRunFailedEvent):
raise Exception(event.error)
if (
isinstance(event, NodeRunSucceededEvent)
and event.node_type == NodeType.LOOP_END
and not isinstance(event, NodeRunStreamChunkEvent)
):
check_break_result = True
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
break
for loop_var in self._node_data.loop_variables or []:
key, sel = loop_var.label, [self._node_id, loop_var.label]
segment = self.graph_runtime_state.variable_pool.get(sel)
self._node_data.outputs[key] = segment.value if segment else None
self._node_data.outputs["loop_round"] = current_index + 1
if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
return reach_break_node
# Check if all variables in break conditions exist
exists_variable = False
for condition in break_conditions:
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector):
exists_variable = False
break
else:
exists_variable = True
if exists_variable:
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
if check_break_result:
break
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# Loop run failed
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
graph_engine.graph_runtime_state.total_tokens
),
"completed_reason": "error",
},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
graph_engine.graph_runtime_state.total_tokens
)
},
)
)
return {"check_break_result": True}
elif isinstance(event, NodeRunFailedEvent):
# Loop run failed
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
},
)
)
return {"check_break_result": True}
else:
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
_outputs: dict[str, Segment | int | None] = {}
for loop_variable_key, loop_variable_selector in extended_selectors.items():
_loop_variable_segment = variable_pool.get(loop_variable_selector)
if _loop_variable_segment:
_outputs[loop_variable_key] = _loop_variable_segment
else:
_outputs[loop_variable_key] = None
_outputs["loop_round"] = current_index + 1
self._node_data.outputs = _outputs
# Remove all nodes outputs from variable pool
for node_id in loop_graph.node_ids:
variable_pool.remove([node_id])
if check_break_result:
return {"check_break_result": True}
# Move to next loop
next_index = current_index + 1
variable_pool.add([self.node_id, "index"], next_index)
yield LoopRunNextEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
index=next_index,
pre_loop_output=self._node_data.outputs,
)
return {"check_break_result": False}
def _extract_selectors_from_conditions(self, conditions: list) -> dict[str, list[str]]:
return {
condition.variable_selector[1]: condition.variable_selector
for condition in conditions
if condition.variable_selector and len(condition.variable_selector) >= 2
def _append_loop_info_to_event(
self,
event: GraphNodeEventBase,
loop_run_index: int,
):
event.in_loop_id = self._node_id
loop_metadata = {
WorkflowNodeExecutionMetadataKey.LOOP_ID: self._node_id,
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: loop_run_index,
}
def _handle_event_metadata(
self,
*,
event: BaseNodeEvent | InNodeEvent,
iter_run_index: int,
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
"""
add iteration metadata to event.
"""
if not isinstance(event, BaseNodeEvent):
return event
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in metadata:
metadata = {
**metadata,
WorkflowNodeExecutionMetadataKey.LOOP_ID: self.node_id,
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: iter_run_index,
}
event.route_node_state.node_run_result.metadata = metadata
return event
current_metadata = event.node_run_result.metadata
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
@classmethod
def _extract_variable_selector_to_variable_mapping(
@@ -479,12 +296,43 @@ class LoopNode(BaseNode):
variable_mapping = {}
# init graph
loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
# Create minimal GraphInitParams for static analysis
graph_init_params = GraphInitParams(
tenant_id="",
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)
# Create minimal GraphRuntimeState for static analysis
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)
# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
loop_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)
if not loop_graph:
raise ValueError("loop graph not found")
for sub_node_id, sub_node_config in loop_graph.node_id_config_mapping.items():
# Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id:
continue
@@ -560,3 +408,47 @@ class LoopNode(BaseNode):
except ValueError:
raise type_exc
return build_segment_with_type(var_type, value)
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
graph_config=self.graph_config,
user_id=self.user_id,
user_from=self.user_from.value,
invoke_from=self.invoke_from.value,
call_depth=self.workflow_call_depth,
)
# Create a new GraphRuntimeState for this iteration
graph_runtime_state_copy = GraphRuntimeState(
variable_pool=self.graph_runtime_state.variable_pool,
start_at=start_at.timestamp(),
)
# Create a new node factory with the new GraphRuntimeState
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
)
# Initialize the loop graph with the new node factory
loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id)
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
workflow_id=self.workflow_id,
graph=loop_graph,
graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
)
return graph_engine

View File

@@ -1,20 +1,19 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopStartNodeData
class LoopStartNode(BaseNode):
class LoopStartNode(Node):
"""
Loop Start Node.
"""
_node_type = NodeType.LOOP_START
node_type = NodeType.LOOP_START
_node_data: LoopStartNodeData