feat(graph_engine): Support pausing workflow graph executions (#26585)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-10-19 21:33:41 +08:00
committed by GitHub
parent 9a5f214623
commit 578247ffbc
112 changed files with 3766 additions and 2415 deletions

View File

@@ -25,7 +25,6 @@ from core.tools.entities.tool_entities import (
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeType,
@@ -44,6 +43,7 @@ from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionMod
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy

View File

@@ -6,7 +6,7 @@ from typing import Any, ClassVar
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams, GraphRuntimeState
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_events import (
GraphNodeEventBase,
@@ -20,6 +20,7 @@ from core.workflow.graph_events import (
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunPauseRequestedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
@@ -37,10 +38,12 @@ from core.workflow.node_events import (
LoopSucceededEvent,
NodeEventBase,
NodeRunResult,
PauseRequestedEvent,
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
)
from core.workflow.runtime import GraphRuntimeState
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
@@ -385,6 +388,16 @@ class Node:
f"Node {self._node_id} does not support status {event.node_run_result.status}"
)
@_dispatch.register
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
return NodeRunPauseRequestedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
reason=event.reason,
)
@_dispatch.register
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
return NodeRunAgentLogEvent(

View File

@@ -19,7 +19,6 @@ from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
@@ -27,6 +26,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.tool.exc import ToolFileError
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from factories import file_factory
from models.model import UploadFile

View File

@@ -15,7 +15,7 @@ from core.file import file_manager
from core.file.enums import FileTransferMethod
from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.entities import VariablePool
from core.workflow.runtime import VariablePool
from .entities import (
HttpRequestNodeAuthorization,

View File

@@ -0,0 +1,3 @@
from .human_input_node import HumanInputNode
__all__ = ["HumanInputNode"]

View File

@@ -0,0 +1,10 @@
from pydantic import Field
from core.workflow.nodes.base import BaseNodeData
class HumanInputNodeData(BaseNodeData):
"""Configuration schema for the HumanInput node."""
required_variables: list[str] = Field(default_factory=list)
pause_reason: str | None = Field(default=None)

View File

@@ -0,0 +1,132 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import HumanInputNodeData
class HumanInputNode(Node):
node_type = NodeType.HUMAN_INPUT
execution_type = NodeExecutionType.BRANCH
_BRANCH_SELECTION_KEYS: tuple[str, ...] = (
"edge_source_handle",
"edgeSourceHandle",
"source_handle",
"selected_branch",
"selectedBranch",
"branch",
"branch_id",
"branchId",
"handle",
)
_node_data: HumanInputNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = HumanInputNodeData(**data)
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def _run(self): # type: ignore[override]
if self._is_completion_ready():
branch_handle = self._resolve_branch_selection()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={},
edge_source_handle=branch_handle or "source",
)
return self._pause_generator()
def _pause_generator(self):
yield PauseRequestedEvent(reason=self._node_data.pause_reason)
def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied."""
if not self._node_data.required_variables:
return False
variable_pool = self.graph_runtime_state.variable_pool
for selector_str in self._node_data.required_variables:
parts = selector_str.split(".")
if len(parts) != 2:
return False
segment = variable_pool.get(parts)
if segment is None:
return False
return True
def _resolve_branch_selection(self) -> str | None:
"""Determine the branch handle selected by human input if available."""
variable_pool = self.graph_runtime_state.variable_pool
for key in self._BRANCH_SELECTION_KEYS:
handle = self._extract_branch_handle(variable_pool.get((self.id, key)))
if handle:
return handle
default_values = self._node_data.default_value_dict
for key in self._BRANCH_SELECTION_KEYS:
handle = self._normalize_branch_value(default_values.get(key))
if handle:
return handle
return None
@staticmethod
def _extract_branch_handle(segment: Any) -> str | None:
if segment is None:
return None
candidate = getattr(segment, "to_object", None)
raw_value = candidate() if callable(candidate) else getattr(segment, "value", None)
if raw_value is None:
return None
return HumanInputNode._normalize_branch_value(raw_value)
@staticmethod
def _normalize_branch_value(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
stripped = value.strip()
return stripped or None
if isinstance(value, Mapping):
for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"):
candidate = value.get(key)
if isinstance(candidate, str) and candidate:
return candidate
return None

View File

@@ -3,12 +3,12 @@ from typing import Any, Literal
from typing_extensions import deprecated
from core.workflow.entities import VariablePool
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.runtime import VariablePool
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor

View File

@@ -12,7 +12,6 @@ from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
@@ -38,6 +37,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
@@ -557,11 +557,12 @@ class IterationNode(Node):
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(

View File

@@ -9,13 +9,13 @@ from sqlalchemy import func, select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment

View File

@@ -67,7 +67,7 @@ from .exc import (
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)

View File

@@ -15,9 +15,9 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.entities import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import Conversation

View File

@@ -52,7 +52,7 @@ from core.variables import (
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
ErrorStrategy,
NodeType,
@@ -71,6 +71,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from . import llm_utils
from .entities import (
@@ -93,7 +94,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)

View File

@@ -406,11 +406,12 @@ class LoopNode(Node):
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(

View File

@@ -10,7 +10,8 @@ from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
@final

View File

@@ -9,6 +9,7 @@ from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
@@ -134,6 +135,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
"2": AgentNode,
"1": AgentNode,
},
NodeType.HUMAN_INPUT: {
LATEST_VERSION: HumanInputNode,
"1": HumanInputNode,
},
NodeType.DATASOURCE: {
LATEST_VERSION: DatasourceNode,
"1": DatasourceNode,

View File

@@ -27,13 +27,13 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.runtime import VariablePool
from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData

View File

@@ -41,7 +41,7 @@ from .template_prompts import (
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
class QuestionClassifierNode(Node):

View File

@@ -36,7 +36,7 @@ from .exc import (
)
if TYPE_CHECKING:
from core.workflow.entities import VariablePool
from core.workflow.runtime import VariablePool
class ToolNode(Node):

View File

@@ -18,7 +18,7 @@ from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
if TYPE_CHECKING:
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]