feat(api): Introduce workflow pause state management (#27298)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
QuantumGhost
2025-10-30 14:41:09 +08:00
committed by GitHub
parent fd7c4e8a6d
commit a1c0bd7a1c
43 changed files with 3834 additions and 44 deletions

View File

@@ -1,6 +1,6 @@
import logging
import time
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any, cast
from sqlalchemy import select
@@ -25,6 +25,7 @@ from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@@ -61,11 +62,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app: App,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
):
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
graph_engine_layers=graph_engine_layers,
)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
@@ -195,6 +198,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)
generator = workflow_entry.run()

View File

@@ -135,6 +135,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)
generator = workflow_entry.run()

View File

@@ -1,5 +1,5 @@
import time
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -27,6 +27,7 @@ from core.app.entities.queue_entities import (
)
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunFailedEvent,
@@ -69,10 +70,12 @@ class WorkflowBasedAppRunner:
queue_manager: AppQueueManager,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
app_id: str,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
):
self._queue_manager = queue_manager
self._variable_loader = variable_loader
self._app_id = app_id
self._graph_engine_layers = graph_engine_layers
def _init_graph(
self,

View File

@@ -0,0 +1,71 @@
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunPausedEvent
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
class PauseStatePersistenceLayer(GraphEngineLayer):
def __init__(self, session_factory: Engine | sessionmaker, state_owner_user_id: str):
"""Create a PauseStatePersistenceLayer.
The `state_owner_user_id` is used when creating state file for pause.
It generally should id of the creator of workflow.
"""
if isinstance(session_factory, Engine):
session_factory = sessionmaker(session_factory)
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
def _get_repo(self) -> APIWorkflowRunRepository:
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
def on_graph_start(self) -> None:
"""
Called when graph execution starts.
This is called after the engine has been initialized but before any nodes
are executed. Layers can use this to set up resources or log start information.
"""
pass
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
This method receives all events generated during graph execution, including:
- Graph lifecycle events (start, success, failure)
- Node execution events (start, success, failure, retry)
- Stream events for response nodes
- Container events (iteration, loop)
Args:
event: The event emitted by the engine
"""
if not isinstance(event, GraphRunPausedEvent):
return
assert self.graph_runtime_state is not None
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
assert workflow_run_id is not None
repo = self._get_repo()
repo.create_workflow_pause(
workflow_run_id=workflow_run_id,
state_owner_user_id=self._state_owner_user_id,
state=self.graph_runtime_state.dumps(),
)
def on_graph_end(self, error: Exception | None) -> None:
"""
Called when graph execution ends.
This is called after all nodes have been executed or when execution is
aborted. Layers can use this to clean up resources or log final state.
Args:
error: The exception that caused execution to fail, or None if successful
"""
pass

View File

@@ -4,6 +4,7 @@ from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
from .workflow_pause import WorkflowPauseEntity
__all__ = [
"AgentNodeStrategyInit",
@@ -12,4 +13,5 @@ __all__ = [
"VariablePool",
"WorkflowExecution",
"WorkflowNodeExecution",
"WorkflowPauseEntity",
]

View File

@@ -0,0 +1,49 @@
from enum import StrEnum, auto
from typing import Annotated, Any, ClassVar, TypeAlias
from pydantic import BaseModel, Discriminator, Tag
class _PauseReasonType(StrEnum):
HUMAN_INPUT_REQUIRED = auto()
SCHEDULED_PAUSE = auto()
class _PauseReasonBase(BaseModel):
TYPE: ClassVar[_PauseReasonType]
class HumanInputRequired(_PauseReasonBase):
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
class SchedulingPause(_PauseReasonBase):
TYPE = _PauseReasonType.SCHEDULED_PAUSE
message: str
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
if isinstance(v, _PauseReasonBase):
return v.TYPE
elif isinstance(v, dict):
reason_type_str = v.get("TYPE")
if reason_type_str is None:
return None
try:
reason_type = _PauseReasonType(reason_type_str)
except ValueError:
return None
return reason_type
else:
# return None if the discriminator value isn't found
return None
PauseReason: TypeAlias = Annotated[
(
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
),
Discriminator(_get_pause_reason_discriminator),
]

View File

@@ -0,0 +1,61 @@
"""
Domain entities for workflow pause management.
This module contains the domain model for workflow pause, which is used
by the core workflow module. These models are independent of the storage mechanism
and don't contain implementation details like tenant_id, app_id, etc.
"""
from abc import ABC, abstractmethod
from datetime import datetime
class WorkflowPauseEntity(ABC):
"""
Abstract base class for workflow pause entities.
This domain model represents a paused workflow execution state,
without implementation details like tenant_id, app_id, etc.
It provides the interface for managing workflow pause/resume operations
and state persistence through file storage.
The `WorkflowPauseEntity` is never reused. If a workflow execution pauses multiple times,
it will generate multiple `WorkflowPauseEntity` records.
"""
@property
@abstractmethod
def id(self) -> str:
"""The identifier of current WorkflowPauseEntity"""
pass
@property
@abstractmethod
def workflow_execution_id(self) -> str:
"""The identifier of the workflow execution record the pause associated with.
Correspond to `WorkflowExecution.id`.
"""
@abstractmethod
def get_state(self) -> bytes:
"""
Retrieve the serialized workflow state from storage.
This method should load and return the workflow execution state
that was saved when the workflow was paused. The state contains
all necessary information to resume the workflow execution.
Returns:
bytes: The serialized workflow state containing
execution context, variable values, node states, etc.
"""
...
@property
@abstractmethod
def resumed_at(self) -> datetime | None:
"""`resumed_at` return the resumption time of the current pause, or `None` if
the pause is not resumed yet.
"""
pass

View File

@@ -92,13 +92,111 @@ class WorkflowType(StrEnum):
class WorkflowExecutionStatus(StrEnum):
# State diagram for the workflw status:
# (@) means start, (*) means end
#
# ┌------------------>------------------------->------------------->--------------┐
# | |
# | ┌-----------------------<--------------------┐ |
# ^ | | |
# | | ^ |
# | V | |
# ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V
# | Scheduled |------->| Running |---------------------->| paused | |
# └-----------┘ └-----------------------┘ └-----------┘ |
# | | | | | | |
# | | | | | | |
# ^ | | | V V |
# | | | | | ┌---------┐ |
# (@) | | | └------------------------>| Stopped |<----┘
# | | | └---------┘
# | | | |
# | | V V
# | | ┌-----------┐ |
# | | | Succeeded |------------->--------------┤
# | | └-----------┘ |
# | V V
# | +--------┐ |
# | | Failed |---------------------->----------------┤
# | └--------┘ |
# V V
# ┌---------------------┐ |
# | Partially Succeeded |---------------------->-----------------┘--------> (*)
# └---------------------┘
#
# Mermaid diagram:
#
# ---
# title: State diagram for Workflow run state
# ---
# stateDiagram-v2
# scheduled: Scheduled
# running: Running
# succeeded: Succeeded
# failed: Failed
# partial_succeeded: Partial Succeeded
# paused: Paused
# stopped: Stopped
#
# [*] --> scheduled:
# scheduled --> running: Start Execution
# running --> paused: Human input required
# paused --> running: human input added
# paused --> stopped: User stops execution
# running --> succeeded: Execution finishes without any error
# running --> failed: Execution finishes with errors
# running --> stopped: User stops execution
# running --> partial_succeeded: some execution occurred and handled during execution
#
# scheduled --> stopped: User stops execution
#
# succeeded --> [*]
# failed --> [*]
# partial_succeeded --> [*]
# stopped --> [*]
# `SCHEDULED` means that the workflow is scheduled to run, but has not
# started running yet. (maybe due to possible worker saturation.)
#
# This enum value is currently unused.
SCHEDULED = "scheduled"
# `RUNNING` means the workflow is exeuting.
RUNNING = "running"
# `SUCCEEDED` means the execution of workflow succeed without any error.
SUCCEEDED = "succeeded"
# `FAILED` means the execution of workflow failed without some errors.
FAILED = "failed"
# `STOPPED` means the execution of workflow was stopped, either manually
# by the user, or automatically by the Dify application (E.G. the moderation
# mechanism.)
STOPPED = "stopped"
# `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow
# execution, but they were successfully handled (e.g., by using an error
# strategy such as "fail branch" or "default value").
PARTIAL_SUCCEEDED = "partial-succeeded"
# `PAUSED` indicates that the workflow execution is temporarily paused
# (e.g., awaiting human input) and is expected to resume later.
PAUSED = "paused"
def is_ended(self) -> bool:
return self in _END_STATE
_END_STATE = frozenset(
[
WorkflowExecutionStatus.SUCCEEDED,
WorkflowExecutionStatus.FAILED,
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
WorkflowExecutionStatus.STOPPED,
]
)
class WorkflowNodeExecutionMetadataKey(StrEnum):
"""

View File

@@ -3,6 +3,8 @@ from typing import final
from typing_extensions import override
from core.workflow.entities.pause_reason import SchedulingPause
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from .command_processor import CommandHandler
@@ -25,4 +27,7 @@ class PauseCommandHandler(CommandHandler):
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
assert isinstance(command, PauseCommand)
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
execution.pause(command.reason)
# Convert string reason to PauseReason if needed
reason = command.reason
pause_reason = SchedulingPause(message=reason)
execution.pause(pause_reason)

View File

@@ -8,6 +8,7 @@ from typing import Literal
from pydantic import BaseModel, Field
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.enums import NodeState
from .node_execution import NodeExecution
@@ -41,7 +42,7 @@ class GraphExecutionState(BaseModel):
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
paused: bool = Field(default=False)
pause_reason: str | None = Field(default=None)
pause_reason: PauseReason | None = Field(default=None)
error: GraphExecutionErrorState | None = Field(default=None)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
@@ -106,7 +107,7 @@ class GraphExecution:
completed: bool = False
aborted: bool = False
paused: bool = False
pause_reason: str | None = None
pause_reason: PauseReason | None = None
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
@@ -130,7 +131,7 @@ class GraphExecution:
self.aborted = True
self.error = RuntimeError(f"Aborted: {reason}")
def pause(self, reason: str | None = None) -> None:
def pause(self, reason: PauseReason) -> None:
"""Pause the graph execution without marking it complete."""
if self.completed:
raise RuntimeError("Cannot pause execution that has completed")

View File

@@ -36,4 +36,4 @@ class PauseCommand(GraphEngineCommand):
"""Command to pause a running workflow execution."""
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
reason: str | None = Field(default=None, description="Optional reason for pause")
reason: str = Field(default="unknown reason", description="reason for pause")

View File

@@ -210,7 +210,7 @@ class EventHandler:
def _(self, event: NodeRunPauseRequestedEvent) -> None:
"""Handle pause requests emitted by nodes."""
pause_reason = event.reason or "Awaiting human input"
pause_reason = event.reason
self._graph_execution.pause(pause_reason)
self._state_manager.finish_execution(event.node_id)
if event.node_id in self._graph.nodes:

View File

@@ -247,8 +247,11 @@ class GraphEngine:
# Handle completion
if self._graph_execution.is_paused:
pause_reason = self._graph_execution.pause_reason
assert pause_reason is not None, "pause_reason should not be None when execution is paused."
# Ensure we have a valid PauseReason for the event
paused_event = GraphRunPausedEvent(
reason=self._graph_execution.pause_reason,
reason=pause_reason,
outputs=self._graph_runtime_state.outputs,
)
self._event_manager.notify_layers(paused_event)

View File

@@ -216,7 +216,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
execution = self._get_workflow_execution()
execution.status = WorkflowExecutionStatus.PAUSED
execution.error_message = event.reason or "Workflow execution paused"
execution.outputs = event.outputs
self._populate_completion_statistics(execution, update_finished=False)
@@ -296,7 +295,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.PAUSED,
error=event.reason,
error="",
update_outputs=False,
)

View File

@@ -1,5 +1,6 @@
from pydantic import Field
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.graph_events import BaseGraphEvent
@@ -44,7 +45,8 @@ class GraphRunAbortedEvent(BaseGraphEvent):
class GraphRunPausedEvent(BaseGraphEvent):
"""Event emitted when a graph run is paused by user command."""
reason: str | None = Field(default=None, description="reason for pause")
# reason: str | None = Field(default=None, description="reason for pause")
reason: PauseReason = Field(..., description="reason for pause")
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs available to the client while the run is paused.",

View File

@@ -5,6 +5,7 @@ from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities.pause_reason import PauseReason
from .base import GraphNodeEventBase
@@ -54,4 +55,4 @@ class NodeRunRetryEvent(NodeRunStartedEvent):
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
reason: str | None = Field(default=None, description="Optional pause reason")
reason: PauseReason = Field(..., description="pause reason")

View File

@@ -5,6 +5,7 @@ from pydantic import Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.node_events import NodeRunResult
from .base import NodeEventBase
@@ -43,4 +44,4 @@ class StreamCompletedEvent(NodeEventBase):
class PauseRequestedEvent(NodeEventBase):
reason: str | None = Field(default=None, description="Optional pause reason")
reason: PauseReason = Field(..., description="pause reason")

View File

@@ -1,6 +1,7 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
@@ -64,7 +65,7 @@ class HumanInputNode(Node):
return self._pause_generator()
def _pause_generator(self):
yield PauseRequestedEvent(reason=self._node_data.pause_reason)
yield PauseRequestedEvent(reason=HumanInputRequired())
def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied."""

View File

@@ -3,6 +3,7 @@ from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
from core.workflow.system_variable import SystemVariableReadOnlyView
class ReadOnlyVariablePool(Protocol):
@@ -30,6 +31,9 @@ class ReadOnlyGraphRuntimeState(Protocol):
All methods return defensive copies to ensure immutability.
"""
@property
def system_variable(self) -> SystemVariableReadOnlyView: ...
@property
def variable_pool(self) -> ReadOnlyVariablePool:
"""Get read-only access to the variable pool."""

View File

@@ -6,6 +6,7 @@ from typing import Any
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
from core.workflow.system_variable import SystemVariableReadOnlyView
from .graph_runtime_state import GraphRuntimeState
from .variable_pool import VariablePool
@@ -42,6 +43,10 @@ class ReadOnlyGraphRuntimeStateWrapper:
self._state = state
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
@property
def system_variable(self) -> SystemVariableReadOnlyView:
return self._state.variable_pool.system_variables.as_view()
@property
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
return self._variable_pool_wrapper

View File

@@ -1,4 +1,5 @@
from collections.abc import Mapping, Sequence
from types import MappingProxyType
from typing import Any
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
@@ -108,3 +109,102 @@ class SystemVariable(BaseModel):
if self.invoke_from is not None:
d[SystemVariableKey.INVOKE_FROM] = self.invoke_from
return d
def as_view(self) -> "SystemVariableReadOnlyView":
return SystemVariableReadOnlyView(self)
class SystemVariableReadOnlyView:
"""
A read-only view of a SystemVariable that implements the ReadOnlySystemVariable protocol.
This class wraps a SystemVariable instance and provides read-only access to all its fields.
It always reads the latest data from the wrapped instance and prevents any write operations.
"""
def __init__(self, system_variable: SystemVariable) -> None:
"""
Initialize the read-only view with a SystemVariable instance.
Args:
system_variable: The SystemVariable instance to wrap
"""
self._system_variable = system_variable
@property
def user_id(self) -> str | None:
return self._system_variable.user_id
@property
def app_id(self) -> str | None:
return self._system_variable.app_id
@property
def workflow_id(self) -> str | None:
return self._system_variable.workflow_id
@property
def workflow_execution_id(self) -> str | None:
return self._system_variable.workflow_execution_id
@property
def query(self) -> str | None:
return self._system_variable.query
@property
def conversation_id(self) -> str | None:
return self._system_variable.conversation_id
@property
def dialogue_count(self) -> int | None:
return self._system_variable.dialogue_count
@property
def document_id(self) -> str | None:
return self._system_variable.document_id
@property
def original_document_id(self) -> str | None:
return self._system_variable.original_document_id
@property
def dataset_id(self) -> str | None:
return self._system_variable.dataset_id
@property
def batch(self) -> str | None:
return self._system_variable.batch
@property
def datasource_type(self) -> str | None:
return self._system_variable.datasource_type
@property
def invoke_from(self) -> str | None:
return self._system_variable.invoke_from
@property
def files(self) -> Sequence[File]:
"""
Get a copy of the files from the wrapped SystemVariable.
Returns:
A defensive copy of the files sequence to prevent modification
"""
return tuple(self._system_variable.files) # Convert to immutable tuple
@property
def datasource_info(self) -> Mapping[str, Any] | None:
"""
Get a copy of the datasource info from the wrapped SystemVariable.
Returns:
A view of the datasource info mapping to prevent modification
"""
if self._system_variable.datasource_info is None:
return None
return MappingProxyType(self._system_variable.datasource_info)
def __repr__(self) -> str:
"""Return a string representation of the read-only view."""
return f"SystemVariableReadOnlyView(system_variable={self._system_variable!r})"