feat(api): Introduce workflow pause state management (#27298)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -25,6 +25,7 @@ from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
@@ -61,11 +62,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
app: App,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
@@ -195,6 +198,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
|
||||
@@ -135,6 +135,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@@ -27,6 +27,7 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
@@ -69,10 +70,12 @@ class WorkflowBasedAppRunner:
|
||||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
app_id: str,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
):
|
||||
self._queue_manager = queue_manager
|
||||
self._variable_loader = variable_loader
|
||||
self._app_id = app_id
|
||||
self._graph_engine_layers = graph_engine_layers
|
||||
|
||||
def _init_graph(
|
||||
self,
|
||||
|
||||
71
api/core/app/layers/pause_state_persist_layer.py
Normal file
71
api/core/app/layers/pause_state_persist_layer.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(self, session_factory: Engine | sessionmaker, state_owner_user_id: str):
|
||||
"""Create a PauseStatePersistenceLayer.
|
||||
|
||||
The `state_owner_user_id` is used when creating state file for pause.
|
||||
It generally should id of the creator of workflow.
|
||||
"""
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(session_factory)
|
||||
self._session_maker = session_factory
|
||||
self._state_owner_user_id = state_owner_user_id
|
||||
|
||||
def _get_repo(self) -> APIWorkflowRunRepository:
|
||||
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Called when graph execution starts.
|
||||
|
||||
This is called after the engine has been initialized but before any nodes
|
||||
are executed. Layers can use this to set up resources or log start information.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
This method receives all events generated during graph execution, including:
|
||||
- Graph lifecycle events (start, success, failure)
|
||||
- Node execution events (start, success, failure, retry)
|
||||
- Stream events for response nodes
|
||||
- Container events (iteration, loop)
|
||||
|
||||
Args:
|
||||
event: The event emitted by the engine
|
||||
"""
|
||||
if not isinstance(event, GraphRunPausedEvent):
|
||||
return
|
||||
|
||||
assert self.graph_runtime_state is not None
|
||||
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
|
||||
assert workflow_run_id is not None
|
||||
repo = self._get_repo()
|
||||
repo.create_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=self._state_owner_user_id,
|
||||
state=self.graph_runtime_state.dumps(),
|
||||
)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
This is called after all nodes have been executed or when execution is
|
||||
aborted. Layers can use this to clean up resources or log final state.
|
||||
|
||||
Args:
|
||||
error: The exception that caused execution to fail, or None if successful
|
||||
"""
|
||||
pass
|
||||
@@ -4,6 +4,7 @@ from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
from .workflow_pause import WorkflowPauseEntity
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
@@ -12,4 +13,5 @@ __all__ = [
|
||||
"VariablePool",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
"WorkflowPauseEntity",
|
||||
]
|
||||
|
||||
49
api/core/workflow/entities/pause_reason.py
Normal file
49
api/core/workflow/entities/pause_reason.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, ClassVar, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Discriminator, Tag
|
||||
|
||||
|
||||
class _PauseReasonType(StrEnum):
|
||||
HUMAN_INPUT_REQUIRED = auto()
|
||||
SCHEDULED_PAUSE = auto()
|
||||
|
||||
|
||||
class _PauseReasonBase(BaseModel):
|
||||
TYPE: ClassVar[_PauseReasonType]
|
||||
|
||||
|
||||
class HumanInputRequired(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
|
||||
class SchedulingPause(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.SCHEDULED_PAUSE
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
|
||||
if isinstance(v, _PauseReasonBase):
|
||||
return v.TYPE
|
||||
elif isinstance(v, dict):
|
||||
reason_type_str = v.get("TYPE")
|
||||
if reason_type_str is None:
|
||||
return None
|
||||
try:
|
||||
reason_type = _PauseReasonType(reason_type_str)
|
||||
except ValueError:
|
||||
return None
|
||||
return reason_type
|
||||
else:
|
||||
# return None if the discriminator value isn't found
|
||||
return None
|
||||
|
||||
|
||||
PauseReason: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
|
||||
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
|
||||
),
|
||||
Discriminator(_get_pause_reason_discriminator),
|
||||
]
|
||||
61
api/core/workflow/entities/workflow_pause.py
Normal file
61
api/core/workflow/entities/workflow_pause.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Domain entities for workflow pause management.
|
||||
|
||||
This module contains the domain model for workflow pause, which is used
|
||||
by the core workflow module. These models are independent of the storage mechanism
|
||||
and don't contain implementation details like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class WorkflowPauseEntity(ABC):
|
||||
"""
|
||||
Abstract base class for workflow pause entities.
|
||||
|
||||
This domain model represents a paused workflow execution state,
|
||||
without implementation details like tenant_id, app_id, etc.
|
||||
It provides the interface for managing workflow pause/resume operations
|
||||
and state persistence through file storage.
|
||||
|
||||
The `WorkflowPauseEntity` is never reused. If a workflow execution pauses multiple times,
|
||||
it will generate multiple `WorkflowPauseEntity` records.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def id(self) -> str:
|
||||
"""The identifier of current WorkflowPauseEntity"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def workflow_execution_id(self) -> str:
|
||||
"""The identifier of the workflow execution record the pause associated with.
|
||||
Correspond to `WorkflowExecution.id`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_state(self) -> bytes:
|
||||
"""
|
||||
Retrieve the serialized workflow state from storage.
|
||||
|
||||
This method should load and return the workflow execution state
|
||||
that was saved when the workflow was paused. The state contains
|
||||
all necessary information to resume the workflow execution.
|
||||
|
||||
Returns:
|
||||
bytes: The serialized workflow state containing
|
||||
execution context, variable values, node states, etc.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def resumed_at(self) -> datetime | None:
|
||||
"""`resumed_at` return the resumption time of the current pause, or `None` if
|
||||
the pause is not resumed yet.
|
||||
"""
|
||||
pass
|
||||
@@ -92,13 +92,111 @@ class WorkflowType(StrEnum):
|
||||
|
||||
|
||||
class WorkflowExecutionStatus(StrEnum):
|
||||
# State diagram for the workflw status:
|
||||
# (@) means start, (*) means end
|
||||
#
|
||||
# ┌------------------>------------------------->------------------->--------------┐
|
||||
# | |
|
||||
# | ┌-----------------------<--------------------┐ |
|
||||
# ^ | | |
|
||||
# | | ^ |
|
||||
# | V | |
|
||||
# ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V
|
||||
# | Scheduled |------->| Running |---------------------->| paused | |
|
||||
# └-----------┘ └-----------------------┘ └-----------┘ |
|
||||
# | | | | | | |
|
||||
# | | | | | | |
|
||||
# ^ | | | V V |
|
||||
# | | | | | ┌---------┐ |
|
||||
# (@) | | | └------------------------>| Stopped |<----┘
|
||||
# | | | └---------┘
|
||||
# | | | |
|
||||
# | | V V
|
||||
# | | ┌-----------┐ |
|
||||
# | | | Succeeded |------------->--------------┤
|
||||
# | | └-----------┘ |
|
||||
# | V V
|
||||
# | +--------┐ |
|
||||
# | | Failed |---------------------->----------------┤
|
||||
# | └--------┘ |
|
||||
# V V
|
||||
# ┌---------------------┐ |
|
||||
# | Partially Succeeded |---------------------->-----------------┘--------> (*)
|
||||
# └---------------------┘
|
||||
#
|
||||
# Mermaid diagram:
|
||||
#
|
||||
# ---
|
||||
# title: State diagram for Workflow run state
|
||||
# ---
|
||||
# stateDiagram-v2
|
||||
# scheduled: Scheduled
|
||||
# running: Running
|
||||
# succeeded: Succeeded
|
||||
# failed: Failed
|
||||
# partial_succeeded: Partial Succeeded
|
||||
# paused: Paused
|
||||
# stopped: Stopped
|
||||
#
|
||||
# [*] --> scheduled:
|
||||
# scheduled --> running: Start Execution
|
||||
# running --> paused: Human input required
|
||||
# paused --> running: human input added
|
||||
# paused --> stopped: User stops execution
|
||||
# running --> succeeded: Execution finishes without any error
|
||||
# running --> failed: Execution finishes with errors
|
||||
# running --> stopped: User stops execution
|
||||
# running --> partial_succeeded: some execution occurred and handled during execution
|
||||
#
|
||||
# scheduled --> stopped: User stops execution
|
||||
#
|
||||
# succeeded --> [*]
|
||||
# failed --> [*]
|
||||
# partial_succeeded --> [*]
|
||||
# stopped --> [*]
|
||||
|
||||
# `SCHEDULED` means that the workflow is scheduled to run, but has not
|
||||
# started running yet. (maybe due to possible worker saturation.)
|
||||
#
|
||||
# This enum value is currently unused.
|
||||
SCHEDULED = "scheduled"
|
||||
|
||||
# `RUNNING` means the workflow is exeuting.
|
||||
RUNNING = "running"
|
||||
|
||||
# `SUCCEEDED` means the execution of workflow succeed without any error.
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
# `FAILED` means the execution of workflow failed without some errors.
|
||||
FAILED = "failed"
|
||||
|
||||
# `STOPPED` means the execution of workflow was stopped, either manually
|
||||
# by the user, or automatically by the Dify application (E.G. the moderation
|
||||
# mechanism.)
|
||||
STOPPED = "stopped"
|
||||
|
||||
# `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow
|
||||
# execution, but they were successfully handled (e.g., by using an error
|
||||
# strategy such as "fail branch" or "default value").
|
||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||
|
||||
# `PAUSED` indicates that the workflow execution is temporarily paused
|
||||
# (e.g., awaiting human input) and is expected to resume later.
|
||||
PAUSED = "paused"
|
||||
|
||||
def is_ended(self) -> bool:
|
||||
return self in _END_STATE
|
||||
|
||||
|
||||
_END_STATE = frozenset(
|
||||
[
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
WorkflowExecutionStatus.STOPPED,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from .command_processor import CommandHandler
|
||||
@@ -25,4 +27,7 @@ class PauseCommandHandler(CommandHandler):
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, PauseCommand)
|
||||
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.pause(command.reason)
|
||||
# Convert string reason to PauseReason if needed
|
||||
reason = command.reason
|
||||
pause_reason = SchedulingPause(message=reason)
|
||||
execution.pause(pause_reason)
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
from .node_execution import NodeExecution
|
||||
@@ -41,7 +42,7 @@ class GraphExecutionState(BaseModel):
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
paused: bool = Field(default=False)
|
||||
pause_reason: str | None = Field(default=None)
|
||||
pause_reason: PauseReason | None = Field(default=None)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
@@ -106,7 +107,7 @@ class GraphExecution:
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
paused: bool = False
|
||||
pause_reason: str | None = None
|
||||
pause_reason: PauseReason | None = None
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
@@ -130,7 +131,7 @@ class GraphExecution:
|
||||
self.aborted = True
|
||||
self.error = RuntimeError(f"Aborted: {reason}")
|
||||
|
||||
def pause(self, reason: str | None = None) -> None:
|
||||
def pause(self, reason: PauseReason) -> None:
|
||||
"""Pause the graph execution without marking it complete."""
|
||||
if self.completed:
|
||||
raise RuntimeError("Cannot pause execution that has completed")
|
||||
|
||||
@@ -36,4 +36,4 @@ class PauseCommand(GraphEngineCommand):
|
||||
"""Command to pause a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
|
||||
reason: str | None = Field(default=None, description="Optional reason for pause")
|
||||
reason: str = Field(default="unknown reason", description="reason for pause")
|
||||
|
||||
@@ -210,7 +210,7 @@ class EventHandler:
|
||||
def _(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
"""Handle pause requests emitted by nodes."""
|
||||
|
||||
pause_reason = event.reason or "Awaiting human input"
|
||||
pause_reason = event.reason
|
||||
self._graph_execution.pause(pause_reason)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
if event.node_id in self._graph.nodes:
|
||||
|
||||
@@ -247,8 +247,11 @@ class GraphEngine:
|
||||
|
||||
# Handle completion
|
||||
if self._graph_execution.is_paused:
|
||||
pause_reason = self._graph_execution.pause_reason
|
||||
assert pause_reason is not None, "pause_reason should not be None when execution is paused."
|
||||
# Ensure we have a valid PauseReason for the event
|
||||
paused_event = GraphRunPausedEvent(
|
||||
reason=self._graph_execution.pause_reason,
|
||||
reason=pause_reason,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(paused_event)
|
||||
|
||||
@@ -216,7 +216,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.PAUSED
|
||||
execution.error_message = event.reason or "Workflow execution paused"
|
||||
execution.outputs = event.outputs
|
||||
self._populate_completion_statistics(execution, update_finished=False)
|
||||
|
||||
@@ -296,7 +295,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.PAUSED,
|
||||
error=event.reason,
|
||||
error="",
|
||||
update_outputs=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from pydantic import Field
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.graph_events import BaseGraphEvent
|
||||
|
||||
|
||||
@@ -44,7 +45,8 @@ class GraphRunAbortedEvent(BaseGraphEvent):
|
||||
class GraphRunPausedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is paused by user command."""
|
||||
|
||||
reason: str | None = Field(default=None, description="reason for pause")
|
||||
# reason: str | None = Field(default=None, description="reason for pause")
|
||||
reason: PauseReason = Field(..., description="reason for pause")
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs available to the client while the run is paused.",
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
|
||||
@@ -54,4 +55,4 @@ class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
|
||||
|
||||
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
|
||||
reason: str | None = Field(default=None, description="Optional pause reason")
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
from .base import NodeEventBase
|
||||
@@ -43,4 +44,4 @@ class StreamCompletedEvent(NodeEventBase):
|
||||
|
||||
|
||||
class PauseRequestedEvent(NodeEventBase):
|
||||
reason: str | None = Field(default=None, description="Optional pause reason")
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
@@ -64,7 +65,7 @@ class HumanInputNode(Node):
|
||||
return self._pause_generator()
|
||||
|
||||
def _pause_generator(self):
|
||||
yield PauseRequestedEvent(reason=self._node_data.pause_reason)
|
||||
yield PauseRequestedEvent(reason=HumanInputRequired())
|
||||
|
||||
def _is_completion_ready(self) -> bool:
|
||||
"""Determine whether all required inputs are satisfied."""
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Protocol
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||
|
||||
|
||||
class ReadOnlyVariablePool(Protocol):
|
||||
@@ -30,6 +31,9 @@ class ReadOnlyGraphRuntimeState(Protocol):
|
||||
All methods return defensive copies to ensure immutability.
|
||||
"""
|
||||
|
||||
@property
|
||||
def system_variable(self) -> SystemVariableReadOnlyView: ...
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||
"""Get read-only access to the variable pool."""
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .variable_pool import VariablePool
|
||||
@@ -42,6 +43,10 @@ class ReadOnlyGraphRuntimeStateWrapper:
|
||||
self._state = state
|
||||
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
|
||||
|
||||
@property
|
||||
def system_variable(self) -> SystemVariableReadOnlyView:
|
||||
return self._state.variable_pool.system_variables.as_view()
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
|
||||
return self._variable_pool_wrapper
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
|
||||
@@ -108,3 +109,102 @@ class SystemVariable(BaseModel):
|
||||
if self.invoke_from is not None:
|
||||
d[SystemVariableKey.INVOKE_FROM] = self.invoke_from
|
||||
return d
|
||||
|
||||
def as_view(self) -> "SystemVariableReadOnlyView":
|
||||
return SystemVariableReadOnlyView(self)
|
||||
|
||||
|
||||
class SystemVariableReadOnlyView:
|
||||
"""
|
||||
A read-only view of a SystemVariable that implements the ReadOnlySystemVariable protocol.
|
||||
|
||||
This class wraps a SystemVariable instance and provides read-only access to all its fields.
|
||||
It always reads the latest data from the wrapped instance and prevents any write operations.
|
||||
"""
|
||||
|
||||
def __init__(self, system_variable: SystemVariable) -> None:
|
||||
"""
|
||||
Initialize the read-only view with a SystemVariable instance.
|
||||
|
||||
Args:
|
||||
system_variable: The SystemVariable instance to wrap
|
||||
"""
|
||||
self._system_variable = system_variable
|
||||
|
||||
@property
|
||||
def user_id(self) -> str | None:
|
||||
return self._system_variable.user_id
|
||||
|
||||
@property
|
||||
def app_id(self) -> str | None:
|
||||
return self._system_variable.app_id
|
||||
|
||||
@property
|
||||
def workflow_id(self) -> str | None:
|
||||
return self._system_variable.workflow_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str | None:
|
||||
return self._system_variable.workflow_execution_id
|
||||
|
||||
@property
|
||||
def query(self) -> str | None:
|
||||
return self._system_variable.query
|
||||
|
||||
@property
|
||||
def conversation_id(self) -> str | None:
|
||||
return self._system_variable.conversation_id
|
||||
|
||||
@property
|
||||
def dialogue_count(self) -> int | None:
|
||||
return self._system_variable.dialogue_count
|
||||
|
||||
@property
|
||||
def document_id(self) -> str | None:
|
||||
return self._system_variable.document_id
|
||||
|
||||
@property
|
||||
def original_document_id(self) -> str | None:
|
||||
return self._system_variable.original_document_id
|
||||
|
||||
@property
|
||||
def dataset_id(self) -> str | None:
|
||||
return self._system_variable.dataset_id
|
||||
|
||||
@property
|
||||
def batch(self) -> str | None:
|
||||
return self._system_variable.batch
|
||||
|
||||
@property
|
||||
def datasource_type(self) -> str | None:
|
||||
return self._system_variable.datasource_type
|
||||
|
||||
@property
|
||||
def invoke_from(self) -> str | None:
|
||||
return self._system_variable.invoke_from
|
||||
|
||||
@property
|
||||
def files(self) -> Sequence[File]:
|
||||
"""
|
||||
Get a copy of the files from the wrapped SystemVariable.
|
||||
|
||||
Returns:
|
||||
A defensive copy of the files sequence to prevent modification
|
||||
"""
|
||||
return tuple(self._system_variable.files) # Convert to immutable tuple
|
||||
|
||||
@property
|
||||
def datasource_info(self) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get a copy of the datasource info from the wrapped SystemVariable.
|
||||
|
||||
Returns:
|
||||
A view of the datasource info mapping to prevent modification
|
||||
"""
|
||||
if self._system_variable.datasource_info is None:
|
||||
return None
|
||||
return MappingProxyType(self._system_variable.datasource_info)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the read-only view."""
|
||||
return f"SystemVariableReadOnlyView(system_variable={self._system_variable!r})"
|
||||
|
||||
Reference in New Issue
Block a user