Enhanced GraphEngine Pause Handling (#28196)
This commit: 1. Convert `pause_reason` to `pause_reasons` in `GraphExecution` and relevant classes. Change the field from a scalar value to a list that can contain multiple `PauseReason` objects, ensuring all pause events are properly captured. 2. Introduce a new `WorkflowPauseReason` model to record reasons associated with a specific `WorkflowPause`. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -31,7 +31,7 @@ class TestDataFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
|
||||
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
|
||||
return GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")], outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_started_event() -> GraphRunStartedEvent:
|
||||
@@ -255,15 +255,17 @@ class TestPauseStatePersistenceLayer:
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_called_once_with(session_factory)
|
||||
mock_repo.create_workflow_pause.assert_called_once_with(
|
||||
workflow_run_id="run-123",
|
||||
state_owner_user_id="owner-123",
|
||||
state=mock_repo.create_workflow_pause.call_args.kwargs["state"],
|
||||
)
|
||||
serialized_state = mock_repo.create_workflow_pause.call_args.kwargs["state"]
|
||||
assert mock_repo.create_workflow_pause.call_count == 1
|
||||
call_kwargs = mock_repo.create_workflow_pause.call_args.kwargs
|
||||
assert call_kwargs["workflow_run_id"] == "run-123"
|
||||
assert call_kwargs["state_owner_user_id"] == "owner-123"
|
||||
serialized_state = call_kwargs["state"]
|
||||
resumption_context = WorkflowResumptionContext.loads(serialized_state)
|
||||
assert resumption_context.serialized_graph_runtime_state == expected_state
|
||||
assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump()
|
||||
pause_reasons = call_kwargs["pause_reasons"]
|
||||
|
||||
assert isinstance(pause_reasons, list)
|
||||
|
||||
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
|
||||
@@ -19,38 +19,18 @@ class TestPrivateWorkflowPauseEntity:
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
# Create entity
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
# Verify initialization
|
||||
assert entity._pause_model is mock_pause_model
|
||||
assert entity._cached_state is None
|
||||
|
||||
def test_from_models_classmethod(self):
|
||||
"""Test from_models class method."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
# Create entity using from_models
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(
|
||||
workflow_pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify entity creation
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model is mock_pause_model
|
||||
|
||||
def test_id_property(self):
|
||||
"""Test id property returns pause model ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.id == "pause-123"
|
||||
|
||||
@@ -59,9 +39,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.workflow_execution_id == "execution-456"
|
||||
|
||||
@@ -72,9 +50,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = resumed_at
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.resumed_at == resumed_at
|
||||
|
||||
@@ -83,9 +59,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.resumed_at is None
|
||||
|
||||
@@ -98,9 +72,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
# First call should load from storage
|
||||
result = entity.get_state()
|
||||
@@ -118,9 +90,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
# First call
|
||||
result1 = entity.get_state()
|
||||
@@ -139,9 +109,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
# Pre-cache data
|
||||
entity._cached_state = state_data
|
||||
@@ -162,9 +130,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=mock_pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
result = entity.get_state()
|
||||
|
||||
|
||||
@@ -8,12 +8,13 @@ from typing import Any
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph.validation import GraphValidationError
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@@ -178,8 +178,7 @@ def test_pause_command():
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
|
||||
assert len(pause_events) == 1
|
||||
assert pause_events[0].reason == SchedulingPause(message="User requested pause")
|
||||
assert pause_events[0].reasons == [SchedulingPause(message="User requested pause")]
|
||||
|
||||
graph_execution = engine.graph_runtime_state.graph_execution
|
||||
assert graph_execution.paused
|
||||
assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")
|
||||
assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")]
|
||||
|
||||
@@ -6,10 +6,10 @@ from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
@@ -129,12 +129,14 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
state=state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||
assert result.id == "pause-123"
|
||||
assert result.workflow_execution_id == workflow_run_id
|
||||
assert result.get_pause_reasons() == []
|
||||
|
||||
# Verify database interactions
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
|
||||
@@ -156,6 +158,7 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
|
||||
@@ -174,6 +177,7 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -316,19 +320,10 @@ class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test _PrivateWorkflowPauseEntity class."""
|
||||
|
||||
def test_from_models(self, sample_workflow_pause: Mock):
|
||||
"""Test creating _PrivateWorkflowPauseEntity from models."""
|
||||
# Act
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
|
||||
# Assert
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model == sample_workflow_pause
|
||||
|
||||
def test_properties(self, sample_workflow_pause: Mock):
|
||||
"""Test entity properties."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
|
||||
# Act & Assert
|
||||
assert entity.id == sample_workflow_pause.id
|
||||
@@ -338,7 +333,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
|
||||
def test_get_state(self, sample_workflow_pause: Mock):
|
||||
"""Test getting state from storage."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
@@ -354,7 +349,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
|
||||
def test_get_state_caching(self, sample_workflow_pause: Mock):
|
||||
"""Test state caching in get_state method."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from models.workflow import WorkflowPause
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||
from services.workflow_run_service import (
|
||||
@@ -63,7 +64,7 @@ class TestDataFactory:
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowPauseModel object."""
|
||||
mock_pause = MagicMock()
|
||||
mock_pause = MagicMock(spec=WorkflowPause)
|
||||
mock_pause.id = id
|
||||
mock_pause.tenant_id = tenant_id
|
||||
mock_pause.app_id = app_id
|
||||
@@ -77,38 +78,15 @@ class TestDataFactory:
|
||||
|
||||
return mock_pause
|
||||
|
||||
@staticmethod
|
||||
def create_upload_file_mock(
|
||||
id: str = "file-456",
|
||||
key: str = "upload_files/test/state.json",
|
||||
name: str = "state.json",
|
||||
tenant_id: str = "tenant-456",
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock UploadFile object."""
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = id
|
||||
mock_file.key = key
|
||||
mock_file.name = name
|
||||
mock_file.tenant_id = tenant_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_file, key, value)
|
||||
|
||||
return mock_file
|
||||
|
||||
@staticmethod
|
||||
def create_pause_entity_mock(
|
||||
pause_model: MagicMock | None = None,
|
||||
upload_file: MagicMock | None = None,
|
||||
) -> _PrivateWorkflowPauseEntity:
|
||||
"""Create a mock _PrivateWorkflowPauseEntity object."""
|
||||
if pause_model is None:
|
||||
pause_model = TestDataFactory.create_workflow_pause_mock()
|
||||
if upload_file is None:
|
||||
upload_file = TestDataFactory.create_upload_file_mock()
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file)
|
||||
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=[], human_input_form=[])
|
||||
|
||||
|
||||
class TestWorkflowRunService:
|
||||
|
||||
Reference in New Issue
Block a user