Enhanced GraphEngine Pause Handling (#28196)

This commit: 

1. Convert `pause_reason` to `pause_reasons` in `GraphExecution` and relevant classes. Change the field from a scalar value to a list that can contain multiple `PauseReason` objects, ensuring all pause events are properly captured.
2. Introduce a new `WorkflowPauseReason` model to record reasons associated with a specific `WorkflowPause`.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
QuantumGhost
2025-11-26 19:59:34 +08:00
committed by GitHub
parent b353a126d8
commit 1c1f124891
24 changed files with 275 additions and 185 deletions

View File

@@ -118,6 +118,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
workflow_run_id=workflow_run_id,
state_owner_user_id=self._state_owner_user_id,
state=state.dumps(),
pause_reasons=event.reasons,
)
def on_graph_end(self, error: Exception | None) -> None:

View File

@@ -1,17 +1,11 @@
from ..runtime.graph_runtime_state import GraphRuntimeState
from ..runtime.variable_pool import VariablePool
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
from .workflow_pause import WorkflowPauseEntity
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"GraphRuntimeState",
"VariablePool",
"WorkflowExecution",
"WorkflowNodeExecution",
"WorkflowPauseEntity",
]

View File

@@ -1,49 +1,26 @@
from enum import StrEnum, auto
from typing import Annotated, Any, ClassVar, TypeAlias
from typing import Annotated, Literal, TypeAlias
from pydantic import BaseModel, Discriminator, Tag
from pydantic import BaseModel, Field
class _PauseReasonType(StrEnum):
class PauseReasonType(StrEnum):
HUMAN_INPUT_REQUIRED = auto()
SCHEDULED_PAUSE = auto()
class _PauseReasonBase(BaseModel):
TYPE: ClassVar[_PauseReasonType]
class HumanInputRequired(BaseModel):
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
# The identifier of the human input node causing the pause.
node_id: str
class HumanInputRequired(_PauseReasonBase):
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
class SchedulingPause(_PauseReasonBase):
TYPE = _PauseReasonType.SCHEDULED_PAUSE
class SchedulingPause(BaseModel):
TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE
message: str
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
if isinstance(v, _PauseReasonBase):
return v.TYPE
elif isinstance(v, dict):
reason_type_str = v.get("TYPE")
if reason_type_str is None:
return None
try:
reason_type = _PauseReasonType(reason_type_str)
except ValueError:
return None
return reason_type
else:
# return None if the discriminator value isn't found
return None
PauseReason: TypeAlias = Annotated[
(
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
),
Discriminator(_get_pause_reason_discriminator),
]
PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")]

View File

@@ -1,61 +0,0 @@
"""
Domain entities for workflow pause management.
This module contains the domain model for workflow pause, which is used
by the core workflow module. These models are independent of the storage mechanism
and don't contain implementation details like tenant_id, app_id, etc.
"""
from abc import ABC, abstractmethod
from datetime import datetime
class WorkflowPauseEntity(ABC):
"""
Abstract base class for workflow pause entities.
This domain model represents a paused workflow execution state,
without implementation details like tenant_id, app_id, etc.
It provides the interface for managing workflow pause/resume operations
and state persistence through file storage.
The `WorkflowPauseEntity` is never reused. If a workflow execution pauses multiple times,
it will generate multiple `WorkflowPauseEntity` records.
"""
@property
@abstractmethod
def id(self) -> str:
"""The identifier of current WorkflowPauseEntity"""
pass
@property
@abstractmethod
def workflow_execution_id(self) -> str:
"""The identifier of the workflow execution record the pause associated with.
Correspond to `WorkflowExecution.id`.
"""
@abstractmethod
def get_state(self) -> bytes:
"""
Retrieve the serialized workflow state from storage.
This method should load and return the workflow execution state
that was saved when the workflow was paused. The state contains
all necessary information to resume the workflow execution.
Returns:
bytes: The serialized workflow state containing
execution context, variable values, node states, etc.
"""
...
@property
@abstractmethod
def resumed_at(self) -> datetime | None:
"""`resumed_at` return the resumption time of the current pause, or `None` if
the pause is not resumed yet.
"""
pass

View File

@@ -42,7 +42,7 @@ class GraphExecutionState(BaseModel):
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
paused: bool = Field(default=False)
pause_reason: PauseReason | None = Field(default=None)
pause_reasons: list[PauseReason] = Field(default_factory=list)
error: GraphExecutionErrorState | None = Field(default=None)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
@@ -107,7 +107,7 @@ class GraphExecution:
completed: bool = False
aborted: bool = False
paused: bool = False
pause_reason: PauseReason | None = None
pause_reasons: list[PauseReason] = field(default_factory=list)
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
@@ -137,10 +137,8 @@ class GraphExecution:
raise RuntimeError("Cannot pause execution that has completed")
if self.aborted:
raise RuntimeError("Cannot pause execution that has been aborted")
if self.paused:
return
self.paused = True
self.pause_reason = reason
self.pause_reasons.append(reason)
def fail(self, error: Exception) -> None:
"""Mark the graph execution as failed."""
@@ -195,7 +193,7 @@ class GraphExecution:
completed=self.completed,
aborted=self.aborted,
paused=self.paused,
pause_reason=self.pause_reason,
pause_reasons=self.pause_reasons,
error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states,
@@ -221,7 +219,7 @@ class GraphExecution:
self.completed = state.completed
self.aborted = state.aborted
self.paused = state.paused
self.pause_reason = state.pause_reason
self.pause_reasons = state.pause_reasons
self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = {

View File

@@ -110,7 +110,13 @@ class EventManager:
"""
with self._lock.write_lock():
self._events.append(event)
self._notify_layers(event)
# NOTE: `_notify_layers` is intentionally called outside the critical section
# to minimize lock contention and avoid blocking other readers or writers.
#
# The public `notify_layers` method also does not use a write lock,
# so protecting `_notify_layers` with a lock here is unnecessary.
self._notify_layers(event)
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
"""

View File

@@ -232,7 +232,7 @@ class GraphEngine:
self._graph_execution.start()
else:
self._graph_execution.paused = False
self._graph_execution.pause_reason = None
self._graph_execution.pause_reasons = []
start_event = GraphRunStartedEvent()
self._event_manager.notify_layers(start_event)
@@ -246,11 +246,11 @@ class GraphEngine:
# Handle completion
if self._graph_execution.is_paused:
pause_reason = self._graph_execution.pause_reason
assert pause_reason is not None, "pause_reason should not be None when execution is paused."
pause_reasons = self._graph_execution.pause_reasons
assert pause_reasons, "pause_reasons should not be empty when execution is paused."
# Ensure we have a valid PauseReason for the event
paused_event = GraphRunPausedEvent(
reason=pause_reason,
reasons=pause_reasons,
outputs=self._graph_runtime_state.outputs,
)
self._event_manager.notify_layers(paused_event)

View File

@@ -45,8 +45,7 @@ class GraphRunAbortedEvent(BaseGraphEvent):
class GraphRunPausedEvent(BaseGraphEvent):
"""Event emitted when a graph run is paused by user command."""
# reason: str | None = Field(default=None, description="reason for pause")
reason: PauseReason = Field(..., description="reason for pause")
reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list)
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs available to the client while the run is paused.",

View File

@@ -65,7 +65,8 @@ class HumanInputNode(Node):
return self._pause_generator()
def _pause_generator(self):
yield PauseRequestedEvent(reason=HumanInputRequired())
# TODO(QuantumGhost): yield a real form id.
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied."""

View File

@@ -10,6 +10,7 @@ from typing import Any, Protocol
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.runtime.variable_pool import VariablePool
@@ -46,7 +47,11 @@ class ReadyQueueProtocol(Protocol):
class GraphExecutionProtocol(Protocol):
"""Structural interface for graph execution aggregate."""
"""Structural interface for graph execution aggregate.
Defines the minimal set of attributes and methods required from a GraphExecution entity
for runtime orchestration and state management.
"""
workflow_id: str
started: bool
@@ -54,6 +59,7 @@ class GraphExecutionProtocol(Protocol):
aborted: bool
error: Exception | None
exceptions_count: int
pause_reasons: list[PauseReason]
def start(self) -> None:
"""Transition execution into the running state."""