Enhanced GraphEngine Pause Handling (#28196)
This commit: 1. Convert `pause_reason` to `pause_reasons` in `GraphExecution` and relevant classes. Change the field from a scalar value to a list that can contain multiple `PauseReason` objects, ensuring all pause events are properly captured. 2. Introduce a new `WorkflowPauseReason` model to record reasons associated with a specific `WorkflowPause`. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -118,6 +118,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=self._state_owner_user_id,
|
||||
state=state.dumps(),
|
||||
pause_reasons=event.reasons,
|
||||
)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
from ..runtime.graph_runtime_state import GraphRuntimeState
|
||||
from ..runtime.variable_pool import VariablePool
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
from .workflow_pause import WorkflowPauseEntity
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"GraphRuntimeState",
|
||||
"VariablePool",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
"WorkflowPauseEntity",
|
||||
]
|
||||
|
||||
@@ -1,49 +1,26 @@
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, ClassVar, TypeAlias
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Discriminator, Tag
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class _PauseReasonType(StrEnum):
|
||||
class PauseReasonType(StrEnum):
|
||||
HUMAN_INPUT_REQUIRED = auto()
|
||||
SCHEDULED_PAUSE = auto()
|
||||
|
||||
|
||||
class _PauseReasonBase(BaseModel):
|
||||
TYPE: ClassVar[_PauseReasonType]
|
||||
class HumanInputRequired(BaseModel):
|
||||
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
form_id: str
|
||||
# The identifier of the human input node causing the pause.
|
||||
node_id: str
|
||||
|
||||
|
||||
class HumanInputRequired(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
|
||||
class SchedulingPause(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.SCHEDULED_PAUSE
|
||||
class SchedulingPause(BaseModel):
|
||||
TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
|
||||
if isinstance(v, _PauseReasonBase):
|
||||
return v.TYPE
|
||||
elif isinstance(v, dict):
|
||||
reason_type_str = v.get("TYPE")
|
||||
if reason_type_str is None:
|
||||
return None
|
||||
try:
|
||||
reason_type = _PauseReasonType(reason_type_str)
|
||||
except ValueError:
|
||||
return None
|
||||
return reason_type
|
||||
else:
|
||||
# return None if the discriminator value isn't found
|
||||
return None
|
||||
|
||||
|
||||
PauseReason: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
|
||||
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
|
||||
),
|
||||
Discriminator(_get_pause_reason_discriminator),
|
||||
]
|
||||
PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")]
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
"""
|
||||
Domain entities for workflow pause management.
|
||||
|
||||
This module contains the domain model for workflow pause, which is used
|
||||
by the core workflow module. These models are independent of the storage mechanism
|
||||
and don't contain implementation details like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class WorkflowPauseEntity(ABC):
|
||||
"""
|
||||
Abstract base class for workflow pause entities.
|
||||
|
||||
This domain model represents a paused workflow execution state,
|
||||
without implementation details like tenant_id, app_id, etc.
|
||||
It provides the interface for managing workflow pause/resume operations
|
||||
and state persistence through file storage.
|
||||
|
||||
The `WorkflowPauseEntity` is never reused. If a workflow execution pauses multiple times,
|
||||
it will generate multiple `WorkflowPauseEntity` records.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def id(self) -> str:
|
||||
"""The identifier of current WorkflowPauseEntity"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def workflow_execution_id(self) -> str:
|
||||
"""The identifier of the workflow execution record the pause associated with.
|
||||
Correspond to `WorkflowExecution.id`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_state(self) -> bytes:
|
||||
"""
|
||||
Retrieve the serialized workflow state from storage.
|
||||
|
||||
This method should load and return the workflow execution state
|
||||
that was saved when the workflow was paused. The state contains
|
||||
all necessary information to resume the workflow execution.
|
||||
|
||||
Returns:
|
||||
bytes: The serialized workflow state containing
|
||||
execution context, variable values, node states, etc.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def resumed_at(self) -> datetime | None:
|
||||
"""`resumed_at` return the resumption time of the current pause, or `None` if
|
||||
the pause is not resumed yet.
|
||||
"""
|
||||
pass
|
||||
@@ -42,7 +42,7 @@ class GraphExecutionState(BaseModel):
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
paused: bool = Field(default=False)
|
||||
pause_reason: PauseReason | None = Field(default=None)
|
||||
pause_reasons: list[PauseReason] = Field(default_factory=list)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
@@ -107,7 +107,7 @@ class GraphExecution:
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
paused: bool = False
|
||||
pause_reason: PauseReason | None = None
|
||||
pause_reasons: list[PauseReason] = field(default_factory=list)
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
@@ -137,10 +137,8 @@ class GraphExecution:
|
||||
raise RuntimeError("Cannot pause execution that has completed")
|
||||
if self.aborted:
|
||||
raise RuntimeError("Cannot pause execution that has been aborted")
|
||||
if self.paused:
|
||||
return
|
||||
self.paused = True
|
||||
self.pause_reason = reason
|
||||
self.pause_reasons.append(reason)
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Mark the graph execution as failed."""
|
||||
@@ -195,7 +193,7 @@ class GraphExecution:
|
||||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
paused=self.paused,
|
||||
pause_reason=self.pause_reason,
|
||||
pause_reasons=self.pause_reasons,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
@@ -221,7 +219,7 @@ class GraphExecution:
|
||||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.paused = state.paused
|
||||
self.pause_reason = state.pause_reason
|
||||
self.pause_reasons = state.pause_reasons
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
|
||||
@@ -110,7 +110,13 @@ class EventManager:
|
||||
"""
|
||||
with self._lock.write_lock():
|
||||
self._events.append(event)
|
||||
self._notify_layers(event)
|
||||
|
||||
# NOTE: `_notify_layers` is intentionally called outside the critical section
|
||||
# to minimize lock contention and avoid blocking other readers or writers.
|
||||
#
|
||||
# The public `notify_layers` method also does not use a write lock,
|
||||
# so protecting `_notify_layers` with a lock here is unnecessary.
|
||||
self._notify_layers(event)
|
||||
|
||||
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
|
||||
"""
|
||||
|
||||
@@ -232,7 +232,7 @@ class GraphEngine:
|
||||
self._graph_execution.start()
|
||||
else:
|
||||
self._graph_execution.paused = False
|
||||
self._graph_execution.pause_reason = None
|
||||
self._graph_execution.pause_reasons = []
|
||||
|
||||
start_event = GraphRunStartedEvent()
|
||||
self._event_manager.notify_layers(start_event)
|
||||
@@ -246,11 +246,11 @@ class GraphEngine:
|
||||
|
||||
# Handle completion
|
||||
if self._graph_execution.is_paused:
|
||||
pause_reason = self._graph_execution.pause_reason
|
||||
assert pause_reason is not None, "pause_reason should not be None when execution is paused."
|
||||
pause_reasons = self._graph_execution.pause_reasons
|
||||
assert pause_reasons, "pause_reasons should not be empty when execution is paused."
|
||||
# Ensure we have a valid PauseReason for the event
|
||||
paused_event = GraphRunPausedEvent(
|
||||
reason=pause_reason,
|
||||
reasons=pause_reasons,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(paused_event)
|
||||
|
||||
@@ -45,8 +45,7 @@ class GraphRunAbortedEvent(BaseGraphEvent):
|
||||
class GraphRunPausedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is paused by user command."""
|
||||
|
||||
# reason: str | None = Field(default=None, description="reason for pause")
|
||||
reason: PauseReason = Field(..., description="reason for pause")
|
||||
reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list)
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs available to the client while the run is paused.",
|
||||
|
||||
@@ -65,7 +65,8 @@ class HumanInputNode(Node):
|
||||
return self._pause_generator()
|
||||
|
||||
def _pause_generator(self):
|
||||
yield PauseRequestedEvent(reason=HumanInputRequired())
|
||||
# TODO(QuantumGhost): yield a real form id.
|
||||
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
|
||||
|
||||
def _is_completion_ready(self) -> bool:
|
||||
"""Determine whether all required inputs are satisfied."""
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any, Protocol
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
|
||||
@@ -46,7 +47,11 @@ class ReadyQueueProtocol(Protocol):
|
||||
|
||||
|
||||
class GraphExecutionProtocol(Protocol):
|
||||
"""Structural interface for graph execution aggregate."""
|
||||
"""Structural interface for graph execution aggregate.
|
||||
|
||||
Defines the minimal set of attributes and methods required from a GraphExecution entity
|
||||
for runtime orchestration and state management.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
started: bool
|
||||
@@ -54,6 +59,7 @@ class GraphExecutionProtocol(Protocol):
|
||||
aborted: bool
|
||||
error: Exception | None
|
||||
exceptions_count: int
|
||||
pause_reasons: list[PauseReason]
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
|
||||
Reference in New Issue
Block a user