feat(api): Introduce workflow pause state management (#27298)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
QuantumGhost
2025-10-30 14:41:09 +08:00
committed by GitHub
parent fd7c4e8a6d
commit a1c0bd7a1c
43 changed files with 3834 additions and 44 deletions

View File

@@ -0,0 +1 @@
# Core integration tests package

View File

@@ -0,0 +1 @@
# App integration tests package

View File

@@ -0,0 +1 @@
# Layers integration tests package

View File

@@ -0,0 +1,520 @@
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class.
This test suite covers complete integration scenarios including:
- Real database interactions using containerized PostgreSQL
- Real storage operations using test storage backend
- Complete workflow: event -> state serialization -> database save -> storage save
- Testing with actual WorkflowRunService (not mocked)
- Real Workflow and WorkflowRun instances in database
- Database transactions and rollback behavior
- Actual file upload and retrieval through storage
- Workflow status transitions in database
- Error handling with real database constraints
- Multiple pause events in sequence
- Integration with real ReadOnlyGraphRuntimeState implementations
These tests use TestContainers to spin up real services for integration testing,
providing more reliable and realistic test scenarios than mocks.
"""
import json
import uuid
from time import time
import pytest
from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
from core.workflow.graph_events.graph import GraphRunPausedEvent
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
from core.workflow.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper
from core.workflow.runtime.variable_pool import SystemVariable, VariablePool
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models import Account
from models import WorkflowPause as WorkflowPauseModel
from models.model import UploadFile
from models.workflow import Workflow, WorkflowRun
from services.file_service import FileService
from services.workflow_run_service import WorkflowRunService
class _TestCommandChannelImpl:
"""Real implementation of CommandChannel for testing."""
def __init__(self):
self._commands: list[GraphEngineCommand] = []
def fetch_commands(self) -> list[GraphEngineCommand]:
"""Fetch pending commands for this GraphEngine instance."""
return self._commands.copy()
def send_command(self, command: GraphEngineCommand) -> None:
"""Send a command to be processed by this GraphEngine instance."""
self._commands.append(command)
class TestPauseStatePersistenceLayerTestContainers:
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class."""
@pytest.fixture
def engine(self, db_session_with_containers: Session):
"""Get database engine from TestContainers session."""
bind = db_session_with_containers.get_bind()
assert isinstance(bind, Engine)
return bind
@pytest.fixture
def file_service(self, engine: Engine):
"""Create FileService instance with TestContainers engine."""
return FileService(engine)
@pytest.fixture
def workflow_run_service(self, engine: Engine, file_service: FileService):
"""Create WorkflowRunService instance with TestContainers engine and FileService."""
return WorkflowRunService(engine)
@pytest.fixture(autouse=True)
def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service):
"""Set up test data for each test method using TestContainers."""
# Create test tenant and account
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
tenant = Tenant(
name="Test Tenant",
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
account = Account(
email="test@example.com",
name="Test User",
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant-account join
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(tenant_join)
db_session_with_containers.commit()
# Set test data
self.test_tenant_id = tenant.id
self.test_user_id = account.id
self.test_app_id = str(uuid.uuid4())
self.test_workflow_id = str(uuid.uuid4())
self.test_workflow_run_id = str(uuid.uuid4())
# Create test workflow
self.test_workflow = Workflow(
id=self.test_workflow_id,
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features='{"file_upload": {"enabled": false}}',
created_by=self.test_user_id,
created_at=naive_utc_now(),
)
# Create test workflow run
self.test_workflow_run = WorkflowRun(
id=self.test_workflow_run_id,
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
workflow_id=self.test_workflow_id,
type="workflow",
triggered_from="debugging",
version="draft",
status=WorkflowExecutionStatus.RUNNING,
created_by=self.test_user_id,
created_by_role="account",
created_at=naive_utc_now(),
)
# Store session and service instances
self.session = db_session_with_containers
self.file_service = file_service
self.workflow_run_service = workflow_run_service
# Save test data to database
self.session.add(self.test_workflow)
self.session.add(self.test_workflow_run)
self.session.commit()
yield
# Cleanup
self._cleanup_test_data()
def _cleanup_test_data(self):
"""Clean up test data after each test method."""
try:
# Clean up workflow pauses
self.session.execute(delete(WorkflowPauseModel))
# Clean up upload files
self.session.execute(
delete(UploadFile).where(
UploadFile.tenant_id == self.test_tenant_id,
)
)
# Clean up workflow runs
self.session.execute(
delete(WorkflowRun).where(
WorkflowRun.tenant_id == self.test_tenant_id,
WorkflowRun.app_id == self.test_app_id,
)
)
# Clean up workflows
self.session.execute(
delete(Workflow).where(
Workflow.tenant_id == self.test_tenant_id,
Workflow.app_id == self.test_app_id,
)
)
self.session.commit()
except Exception as e:
self.session.rollback()
raise e
def _create_graph_runtime_state(
self,
outputs: dict[str, object] | None = None,
total_tokens: int = 0,
node_run_steps: int = 0,
variables: dict[tuple[str, str], object] | None = None,
workflow_run_id: str | None = None,
) -> ReadOnlyGraphRuntimeState:
"""Create a real GraphRuntimeState for testing."""
start_at = time()
execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4())
# Create variable pool
variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id))
if variables:
for (node_id, var_key), value in variables.items():
variable_pool.add([node_id, var_key], value)
# Create LLM usage
llm_usage = LLMUsage.empty_usage()
# Create graph runtime state
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=start_at,
total_tokens=total_tokens,
llm_usage=llm_usage,
outputs=outputs or {},
node_run_steps=node_run_steps,
)
return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state)
def _create_pause_state_persistence_layer(
self,
workflow_run: WorkflowRun | None = None,
workflow: Workflow | None = None,
state_owner_user_id: str | None = None,
) -> PauseStatePersistenceLayer:
"""Create PauseStatePersistenceLayer with real dependencies."""
owner_id = state_owner_user_id
if owner_id is None:
if workflow is not None and workflow.created_by:
owner_id = workflow.created_by
elif workflow_run is not None and workflow_run.created_by:
owner_id = workflow_run.created_by
else:
owner_id = getattr(self, "test_user_id", None)
assert owner_id is not None
owner_id = str(owner_id)
return PauseStatePersistenceLayer(
session_factory=self.session.get_bind(),
state_owner_user_id=owner_id,
)
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
"""Test complete pause flow: event -> state serialization -> database save -> storage save."""
# Arrange
layer = self._create_pause_state_persistence_layer()
# Create real graph runtime state with test data
test_outputs = {"result": "test_output", "step": "intermediate"}
test_variables = {
("node1", "var1"): "string_value",
("node2", "var2"): {"complex": "object"},
}
graph_runtime_state = self._create_graph_runtime_state(
outputs=test_outputs,
total_tokens=100,
node_run_steps=5,
variables=test_variables,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
# Create pause event
event = GraphRunPausedEvent(
reason=SchedulingPause(message="test pause"),
outputs={"intermediate": "result"},
)
# Act
layer.on_event(event)
# Assert - Verify pause state was saved to database
self.session.refresh(self.test_workflow_run)
workflow_run = self.session.get(WorkflowRun, self.test_workflow_run_id)
assert workflow_run is not None
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
# Verify pause state exists in database
pause_model = self.session.scalars(
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
).first()
assert pause_model is not None
assert pause_model.workflow_id == self.test_workflow_id
assert pause_model.workflow_run_id == self.test_workflow_run_id
assert pause_model.state_object_key != ""
assert pause_model.resumed_at is None
storage_content = storage.load(pause_model.state_object_key).decode()
expected_state = json.loads(graph_runtime_state.dumps())
actual_state = json.loads(storage_content)
assert actual_state == expected_state
def test_state_persistence_and_retrieval(self, db_session_with_containers):
"""Test that pause state can be persisted and retrieved correctly."""
# Arrange
layer = self._create_pause_state_persistence_layer()
# Create complex test data
complex_outputs = {
"nested": {"key": "value", "number": 42},
"list": [1, 2, 3, {"nested": "item"}],
"boolean": True,
"null_value": None,
}
complex_variables = {
("node1", "var1"): "string_value",
("node2", "var2"): {"complex": "object"},
("node3", "var3"): [1, 2, 3],
}
graph_runtime_state = self._create_graph_runtime_state(
outputs=complex_outputs,
total_tokens=250,
node_run_steps=10,
variables=complex_variables,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act - Save pause state
layer.on_event(event)
# Assert - Retrieve and verify
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id)
assert pause_entity is not None
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
state_bytes = pause_entity.get_state()
retrieved_state = json.loads(state_bytes.decode())
expected_state = json.loads(graph_runtime_state.dumps())
assert retrieved_state == expected_state
assert retrieved_state["outputs"] == complex_outputs
assert retrieved_state["total_tokens"] == 250
assert retrieved_state["node_run_steps"] == 10
def test_database_transaction_handling(self, db_session_with_containers):
"""Test that database transactions are handled correctly."""
# Arrange
layer = self._create_pause_state_persistence_layer()
graph_runtime_state = self._create_graph_runtime_state(
outputs={"test": "transaction"},
total_tokens=50,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act
layer.on_event(event)
# Assert - Verify data is committed and accessible in new session
with Session(bind=self.session.get_bind(), expire_on_commit=False) as new_session:
workflow_run = new_session.get(WorkflowRun, self.test_workflow_run_id)
assert workflow_run is not None
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
pause_model = new_session.scalars(
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
).first()
assert pause_model is not None
assert pause_model.workflow_run_id == self.test_workflow_run_id
assert pause_model.resumed_at is None
assert pause_model.state_object_key != ""
def test_file_storage_integration(self, db_session_with_containers):
"""Test integration with file storage system."""
# Arrange
layer = self._create_pause_state_persistence_layer()
# Create large state data to test storage
large_outputs = {"data": "x" * 10000} # 10KB of data
graph_runtime_state = self._create_graph_runtime_state(
outputs=large_outputs,
total_tokens=1000,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act
layer.on_event(event)
# Assert - Verify file was uploaded to storage
self.session.refresh(self.test_workflow_run)
pause_model = self.session.scalars(
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run.id)
).first()
assert pause_model is not None
assert pause_model.state_object_key != ""
# Verify content in storage
storage_content = storage.load(pause_model.state_object_key).decode()
assert storage_content == graph_runtime_state.dumps()
def test_workflow_with_different_creators(self, db_session_with_containers):
"""Test pause state with workflows created by different users."""
# Arrange - Create workflow with different creator
different_user_id = str(uuid.uuid4())
different_workflow = Workflow(
id=str(uuid.uuid4()),
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features='{"file_upload": {"enabled": false}}',
created_by=different_user_id,
created_at=naive_utc_now(),
)
different_workflow_run = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
workflow_id=different_workflow.id,
type="workflow",
triggered_from="debugging",
version="draft",
status=WorkflowExecutionStatus.RUNNING,
created_by=self.test_user_id, # Run created by different user
created_by_role="account",
created_at=naive_utc_now(),
)
self.session.add(different_workflow)
self.session.add(different_workflow_run)
self.session.commit()
layer = self._create_pause_state_persistence_layer(
workflow_run=different_workflow_run,
workflow=different_workflow,
)
graph_runtime_state = self._create_graph_runtime_state(
outputs={"creator_test": "different_creator"},
workflow_run_id=different_workflow_run.id,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act
layer.on_event(event)
# Assert - Should use workflow creator (not run creator)
self.session.refresh(different_workflow_run)
pause_model = self.session.scalars(
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == different_workflow_run.id)
).first()
assert pause_model is not None
# Verify the state owner is the workflow creator
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id)
assert pause_entity is not None
def test_layer_ignores_non_pause_events(self, db_session_with_containers):
"""Test that layer ignores non-pause events."""
# Arrange
layer = self._create_pause_state_persistence_layer()
graph_runtime_state = self._create_graph_runtime_state()
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
# Import other event types
from core.workflow.graph_events.graph import (
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
# Act - Send non-pause events
layer.on_event(GraphRunStartedEvent())
layer.on_event(GraphRunSucceededEvent(outputs={"result": "success"}))
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
# Assert - No pause state should be created
self.session.refresh(self.test_workflow_run)
assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING
pause_states = (
self.session.query(WorkflowPauseModel)
.filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id)
.all()
)
assert len(pause_states) == 0
def test_layer_requires_initialization(self, db_session_with_containers):
"""Test that layer requires proper initialization before handling events."""
# Arrange
layer = self._create_pause_state_persistence_layer()
# Don't initialize - graph_runtime_state should not be set
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act & Assert - Should raise AttributeError
with pytest.raises(AttributeError):
layer.on_event(event)

View File

@@ -0,0 +1,948 @@
"""Comprehensive integration tests for workflow pause functionality.
This test suite covers complete workflow pause functionality including:
- Real database interactions using containerized PostgreSQL
- Real storage operations using the test storage backend
- Complete workflow: create -> pause -> resume -> delete
- Testing with actual FileService (not mocked)
- Database transactions and rollback behavior
- Actual file upload and retrieval through storage
- Workflow status transitions in the database
- Error handling with real database constraints
- Concurrent access scenarios
- Multi-tenant isolation
- Prune functionality
- File storage integration
These tests use TestContainers to spin up real services for integration testing,
providing more reliable and realistic test scenarios than mocks.
"""
import json
import uuid
from dataclasses import dataclass
from datetime import timedelta
import pytest
from sqlalchemy import delete, select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.entities import WorkflowExecution
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models import Account
from models import WorkflowPause as WorkflowPauseModel
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
from models.model import UploadFile
from models.workflow import Workflow, WorkflowRun
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_WorkflowRunError,
)
@dataclass
class PauseWorkflowSuccessCase:
"""Test case for successful pause workflow operations."""
name: str
initial_status: WorkflowExecutionStatus
description: str = ""
@dataclass
class PauseWorkflowFailureCase:
"""Test case for pause workflow failure scenarios."""
name: str
initial_status: WorkflowExecutionStatus
description: str = ""
@dataclass
class ResumeWorkflowSuccessCase:
"""Test case for successful resume workflow operations."""
name: str
initial_status: WorkflowExecutionStatus
description: str = ""
@dataclass
class ResumeWorkflowFailureCase:
"""Test case for resume workflow failure scenarios."""
name: str
initial_status: WorkflowExecutionStatus
pause_resumed: bool
set_running_status: bool = False
description: str = ""
@dataclass
class PrunePausesTestCase:
"""Test case for prune pauses operations."""
name: str
pause_age: timedelta
resume_age: timedelta | None
expected_pruned_count: int
description: str = ""
def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]:
"""Create test cases for pause workflow failure scenarios."""
return [
PauseWorkflowFailureCase(
name="pause_already_paused_workflow",
initial_status=WorkflowExecutionStatus.PAUSED,
description="Should fail to pause an already paused workflow",
),
PauseWorkflowFailureCase(
name="pause_completed_workflow",
initial_status=WorkflowExecutionStatus.SUCCEEDED,
description="Should fail to pause a completed workflow",
),
PauseWorkflowFailureCase(
name="pause_failed_workflow",
initial_status=WorkflowExecutionStatus.FAILED,
description="Should fail to pause a failed workflow",
),
]
def resume_workflow_success_cases() -> list[ResumeWorkflowSuccessCase]:
"""Create test cases for successful resume workflow operations."""
return [
ResumeWorkflowSuccessCase(
name="resume_paused_workflow",
initial_status=WorkflowExecutionStatus.PAUSED,
description="Should successfully resume a paused workflow",
),
]
def resume_workflow_failure_cases() -> list[ResumeWorkflowFailureCase]:
"""Create test cases for resume workflow failure scenarios."""
return [
ResumeWorkflowFailureCase(
name="resume_already_resumed_workflow",
initial_status=WorkflowExecutionStatus.PAUSED,
pause_resumed=True,
description="Should fail to resume an already resumed workflow",
),
ResumeWorkflowFailureCase(
name="resume_running_workflow",
initial_status=WorkflowExecutionStatus.RUNNING,
pause_resumed=False,
set_running_status=True,
description="Should fail to resume a running workflow",
),
]
def prune_pauses_test_cases() -> list[PrunePausesTestCase]:
"""Create test cases for prune pauses operations."""
return [
PrunePausesTestCase(
name="prune_old_active_pauses",
pause_age=timedelta(days=7),
resume_age=None,
expected_pruned_count=1,
description="Should prune old active pauses",
),
PrunePausesTestCase(
name="prune_old_resumed_pauses",
pause_age=timedelta(hours=12), # Created 12 hours ago (recent)
resume_age=timedelta(days=7),
expected_pruned_count=1,
description="Should prune old resumed pauses",
),
PrunePausesTestCase(
name="keep_recent_active_pauses",
pause_age=timedelta(hours=1),
resume_age=None,
expected_pruned_count=0,
description="Should keep recent active pauses",
),
PrunePausesTestCase(
name="keep_recent_resumed_pauses",
pause_age=timedelta(days=1),
resume_age=timedelta(hours=1),
expected_pruned_count=0,
description="Should keep recent resumed pauses",
),
]
class TestWorkflowPauseIntegration:
"""Comprehensive integration tests for workflow pause functionality."""
@pytest.fixture(autouse=True)
def setup_test_data(self, db_session_with_containers):
"""Set up test data for each test method using TestContainers."""
# Create test tenant and account
tenant = Tenant(
name="Test Tenant",
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
account = Account(
email="test@example.com",
name="Test User",
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant-account join
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(tenant_join)
db_session_with_containers.commit()
# Set test data
self.test_tenant_id = tenant.id
self.test_user_id = account.id
self.test_app_id = str(uuid.uuid4())
self.test_workflow_id = str(uuid.uuid4())
# Create test workflow
self.test_workflow = Workflow(
id=self.test_workflow_id,
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features='{"file_upload": {"enabled": false}}',
created_by=self.test_user_id,
created_at=naive_utc_now(),
)
# Store session instance
self.session = db_session_with_containers
# Save test data to database
self.session.add(self.test_workflow)
self.session.commit()
yield
# Cleanup
self._cleanup_test_data()
def _cleanup_test_data(self):
"""Clean up test data after each test method."""
# Clean up workflow pauses
self.session.execute(delete(WorkflowPauseModel))
# Clean up upload files
self.session.execute(
delete(UploadFile).where(
UploadFile.tenant_id == self.test_tenant_id,
)
)
# Clean up workflow runs
self.session.execute(
delete(WorkflowRun).where(
WorkflowRun.tenant_id == self.test_tenant_id,
WorkflowRun.app_id == self.test_app_id,
)
)
# Clean up workflows
self.session.execute(
delete(Workflow).where(
Workflow.tenant_id == self.test_tenant_id,
Workflow.app_id == self.test_app_id,
)
)
self.session.commit()
def _create_test_workflow_run(
self, status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
) -> WorkflowRun:
"""Create a test workflow run with specified status."""
workflow_run = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
workflow_id=self.test_workflow_id,
type="workflow",
triggered_from="debugging",
version="draft",
status=status,
created_by=self.test_user_id,
created_by_role="account",
created_at=naive_utc_now(),
)
self.session.add(workflow_run)
self.session.commit()
return workflow_run
def _create_test_state(self) -> str:
"""Create a test state string."""
return json.dumps(
{
"node_id": "test-node",
"node_type": "llm",
"status": "paused",
"data": {"key": "value"},
"timestamp": naive_utc_now().isoformat(),
}
)
def _get_workflow_run_repository(self):
"""Get workflow run repository instance for testing."""
# Create session factory from the test session
engine = self.session.get_bind()
session_factory = sessionmaker(bind=engine, expire_on_commit=False)
# Create a test-specific repository that implements the missing save method
class TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
"""Test-specific repository that implements the missing save method."""
def save(self, execution: WorkflowExecution):
"""Implement the missing save method for testing."""
# For testing purposes, we don't need to implement this method
# as it's not used in the pause functionality tests
pass
# Create and return repository instance
repository = TestWorkflowRunRepository(session_maker=session_factory)
return repository
# ==================== Complete Pause Workflow Tests ====================
def test_complete_pause_resume_workflow(self):
"""Test complete workflow: create -> pause -> resume -> delete."""
# Arrange
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
# Act - Create pause state
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Assert - Pause state created
assert pause_entity is not None
assert pause_entity.id is not None
assert pause_entity.workflow_execution_id == workflow_run.id
# Convert both to strings for comparison
retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state
# Verify database state
query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
pause_model = self.session.scalars(query).first()
assert pause_model is not None
assert pause_model.resumed_at is None
assert pause_model.id == pause_entity.id
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
# Act - Get pause state
retrieved_entity = repository.get_workflow_pause(workflow_run.id)
# Assert - Pause state retrieved
assert retrieved_entity is not None
assert retrieved_entity.id == pause_entity.id
retrieved_state = retrieved_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state
# Act - Resume workflow
resumed_entity = repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
# Assert - Workflow resumed
assert resumed_entity is not None
assert resumed_entity.id == pause_entity.id
assert resumed_entity.resumed_at is not None
# Verify database state
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
self.session.refresh(pause_model)
assert pause_model.resumed_at is not None
# Act - Delete pause state
repository.delete_workflow_pause(pause_entity)
# Assert - Pause state deleted
with Session(bind=self.session.get_bind()) as session:
deleted_pause = session.get(WorkflowPauseModel, pause_entity.id)
assert deleted_pause is None
def test_pause_workflow_success(self):
"""Test successful pause workflow scenarios."""
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
assert pause_entity is not None
assert pause_entity.workflow_execution_id == workflow_run.id
retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
pause_model = self.session.scalars(pause_query).first()
assert pause_model is not None
assert pause_model.id == pause_entity.id
assert pause_model.resumed_at is None
@pytest.mark.parametrize("test_case", pause_workflow_failure_cases(), ids=lambda tc: tc.name)
def test_pause_workflow_failure(self, test_case: PauseWorkflowFailureCase):
"""Test pause workflow failure scenarios."""
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
with pytest.raises(_WorkflowRunError):
repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
@pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name)
def test_resume_workflow_success(self, test_case: ResumeWorkflowSuccessCase):
"""Test successful resume workflow scenarios."""
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
workflow_run.status = WorkflowExecutionStatus.RUNNING
self.session.commit()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
resumed_entity = repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
assert resumed_entity is not None
assert resumed_entity.id == pause_entity.id
assert resumed_entity.resumed_at is not None
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
pause_model = self.session.scalars(pause_query).first()
assert pause_model is not None
assert pause_model.id == pause_entity.id
assert pause_model.resumed_at is not None
def test_resume_running_workflow(self):
"""Test resume workflow failure scenarios."""
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
self.session.refresh(workflow_run)
workflow_run.status = WorkflowExecutionStatus.RUNNING
self.session.add(workflow_run)
self.session.commit()
with pytest.raises(_WorkflowRunError):
repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
def test_resume_resumed_pause(self):
"""Test resume workflow failure scenarios."""
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
pause_model.resumed_at = naive_utc_now()
self.session.add(pause_model)
self.session.commit()
with pytest.raises(_WorkflowRunError):
repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
# ==================== Error Scenario Tests ====================
def test_pause_nonexistent_workflow_run(self):
"""Test pausing a non-existent workflow run."""
# Arrange
nonexistent_id = str(uuid.uuid4())
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
# Act & Assert
with pytest.raises(ValueError, match="WorkflowRun not found"):
repository.create_workflow_pause(
workflow_run_id=nonexistent_id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
def test_resume_nonexistent_workflow_run(self):
"""Test resuming a non-existent workflow run."""
# Arrange
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
nonexistent_id = str(uuid.uuid4())
# Act & Assert
with pytest.raises(ValueError, match="WorkflowRun not found"):
repository.resume_workflow_pause(
workflow_run_id=nonexistent_id,
pause_entity=pause_entity,
)
# ==================== Prune Functionality Tests ====================
@pytest.mark.parametrize("test_case", prune_pauses_test_cases(), ids=lambda tc: tc.name)
def test_prune_pauses_scenarios(self, test_case: PrunePausesTestCase):
"""Test various prune pauses scenarios."""
now = naive_utc_now()
# Create pause state
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Manually adjust timestamps for testing
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
pause_model.created_at = now - test_case.pause_age
if test_case.resume_age is not None:
# Resume pause and adjust resume time
repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
# Need to refresh to get the updated model
self.session.refresh(pause_model)
# Manually set the resumed_at to an older time for testing
pause_model.resumed_at = now - test_case.resume_age
self.session.commit() # Commit the resumed_at change
# Refresh again to ensure the change is persisted
self.session.refresh(pause_model)
self.session.commit()
# Act - Prune pauses
expiration_time = now - timedelta(days=1, seconds=1) # Expire pauses older than 1 day (plus 1 second)
resumption_time = now - timedelta(
days=7, seconds=1
) # Clean up pauses resumed more than 7 days ago (plus 1 second)
# Debug: Check pause state before pruning
self.session.refresh(pause_model)
print(f"Pause created_at: {pause_model.created_at}")
print(f"Pause resumed_at: {pause_model.resumed_at}")
print(f"Expiration time: {expiration_time}")
print(f"Resumption time: {resumption_time}")
# Force commit to ensure timestamps are saved
self.session.commit()
# Determine if the pause should be pruned based on timestamps
should_be_pruned = False
if test_case.resume_age is not None:
# If resumed, check if resumed_at is older than resumption_time
should_be_pruned = pause_model.resumed_at < resumption_time
else:
# If not resumed, check if created_at is older than expiration_time
should_be_pruned = pause_model.created_at < expiration_time
# Act - Prune pauses
pruned_ids = repository.prune_pauses(
expiration=expiration_time,
resumption_expiration=resumption_time,
)
# Assert - Check pruning results
if should_be_pruned:
assert len(pruned_ids) == test_case.expected_pruned_count
# Verify pause was actually deleted
# The pause should be in the pruned_ids list if it was pruned
assert pause_entity.id in pruned_ids
else:
assert len(pruned_ids) == 0
def test_prune_pauses_with_limit(self):
"""Test prune pauses with limit parameter."""
now = naive_utc_now()
# Create multiple pause states
pause_entities = []
repository = self._get_workflow_run_repository()
for i in range(5):
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
pause_entities.append(pause_entity)
# Make all pauses old enough to be pruned
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
pause_model.created_at = now - timedelta(days=7)
self.session.commit()
# Act - Prune with limit
expiration_time = now - timedelta(days=1)
resumption_time = now - timedelta(days=7)
pruned_ids = repository.prune_pauses(
expiration=expiration_time,
resumption_expiration=resumption_time,
limit=3,
)
# Assert
assert len(pruned_ids) == 3
# Verify only 3 were deleted
remaining_count = (
self.session.query(WorkflowPauseModel)
.filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]))
.count()
)
assert remaining_count == 2
# ==================== Multi-tenant Isolation Tests ====================
def test_multi_tenant_pause_isolation(self):
"""Test that pause states are properly isolated by tenant."""
# Arrange - Create second tenant
tenant2 = Tenant(
name="Test Tenant 2",
status="normal",
)
self.session.add(tenant2)
self.session.commit()
account2 = Account(
email="test2@example.com",
name="Test User 2",
interface_language="en-US",
status="active",
)
self.session.add(account2)
self.session.commit()
tenant2_join = TenantAccountJoin(
tenant_id=tenant2.id,
account_id=account2.id,
role=TenantAccountRole.OWNER,
current=True,
)
self.session.add(tenant2_join)
self.session.commit()
# Create workflow for tenant 2
workflow2 = Workflow(
id=str(uuid.uuid4()),
tenant_id=tenant2.id,
app_id=str(uuid.uuid4()),
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features='{"file_upload": {"enabled": false}}',
created_by=account2.id,
created_at=naive_utc_now(),
)
self.session.add(workflow2)
self.session.commit()
# Create workflow runs for both tenants
workflow_run1 = self._create_test_workflow_run()
workflow_run2 = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=tenant2.id,
app_id=workflow2.app_id,
workflow_id=workflow2.id,
type="workflow",
triggered_from="debugging",
version="draft",
status=WorkflowExecutionStatus.RUNNING,
created_by=account2.id,
created_by_role="account",
created_at=naive_utc_now(),
)
self.session.add(workflow_run2)
self.session.commit()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
# Act - Create pause for tenant 1
pause_entity1 = repository.create_workflow_pause(
workflow_run_id=workflow_run1.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Try to access pause from tenant 2 using tenant 1's repository
# This should work because we're using the same repository
pause_entity2 = repository.get_workflow_pause(workflow_run2.id)
assert pause_entity2 is None # No pause for tenant 2 yet
# Create pause for tenant 2
pause_entity2 = repository.create_workflow_pause(
workflow_run_id=workflow_run2.id,
state_owner_user_id=account2.id,
state=test_state,
)
# Assert - Both pauses should exist and be separate
assert pause_entity1 is not None
assert pause_entity2 is not None
assert pause_entity1.id != pause_entity2.id
assert pause_entity1.workflow_execution_id != pause_entity2.workflow_execution_id
def test_cross_tenant_access_restriction(self):
"""Test that cross-tenant access is properly restricted."""
# This test would require tenant-specific repositories
# For now, we test that pause entities are properly scoped by tenant_id
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Verify pause is properly scoped
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
assert pause_model.workflow_id == self.test_workflow_id
# ==================== File Storage Integration Tests ====================
def test_file_storage_integration(self):
"""Test that state files are properly stored and retrieved."""
# Arrange
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
# Act - Create pause state
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Assert - Verify file was uploaded to storage
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
assert pause_model.state_object_key != ""
# Verify file content in storage
file_key = pause_model.state_object_key
storage_content = storage.load(file_key).decode()
assert storage_content == test_state
# Verify retrieval through entity
retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state
def test_file_cleanup_on_pause_deletion(self):
"""Test that files are properly handled on pause deletion."""
# Arrange
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Get file info before deletion
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
file_key = pause_model.state_object_key
# Act - Delete pause state
repository.delete_workflow_pause(pause_entity)
# Assert - Pause record should be deleted
self.session.expire_all() # Clear session to ensure fresh query
deleted_pause = self.session.get(WorkflowPauseModel, pause_entity.id)
assert deleted_pause is None
try:
content = storage.load(file_key).decode()
pytest.fail("File should be deleted from storage after pause deletion")
except FileNotFoundError:
# This is expected - file should be deleted from storage
pass
except Exception as e:
pytest.fail(f"Unexpected error when checking file deletion: {e}")
def test_large_state_file_handling(self):
"""Test handling of large state files."""
# Arrange - Create a large state (1MB)
large_state = "x" * (1024 * 1024) # 1MB of data
large_state_json = json.dumps({"large_data": large_state})
workflow_run = self._create_test_workflow_run()
repository = self._get_workflow_run_repository()
# Act
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=large_state_json,
)
# Assert
assert pause_entity is not None
retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == large_state_json
# Verify file size in database
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
assert pause_model.state_object_key != ""
loaded_state = storage.load(pause_model.state_object_key)
assert loaded_state.decode() == large_state_json
def test_multiple_pause_resume_cycles(self):
"""Test multiple pause/resume cycles on the same workflow run."""
# Arrange
workflow_run = self._create_test_workflow_run()
repository = self._get_workflow_run_repository()
# Act & Assert - Multiple cycles
for i in range(3):
state = json.dumps({"cycle": i, "data": f"state_{i}"})
# Reset workflow run status to RUNNING before each pause (after first cycle)
if i > 0:
self.session.refresh(workflow_run) # Refresh to get latest state from session
workflow_run.status = WorkflowExecutionStatus.RUNNING
self.session.commit()
self.session.refresh(workflow_run) # Refresh again after commit
# Pause
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=state,
)
assert pause_entity is not None
# Verify pause
self.session.expire_all() # Clear session to ensure fresh query
self.session.refresh(workflow_run)
# Use the test session directly to verify the pause
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run.id)
workflow_run_with_pause = self.session.scalar(stmt)
pause_model = workflow_run_with_pause.pause
# Verify pause using test session directly
assert pause_model is not None
assert pause_model.id == pause_entity.id
assert pause_model.state_object_key != ""
# Load file content using storage directly
file_content = storage.load(pause_model.state_object_key)
if isinstance(file_content, bytes):
file_content = file_content.decode()
assert file_content == state
# Resume
resumed_entity = repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
assert resumed_entity is not None
assert resumed_entity.resumed_at is not None
# Verify resume - check that pause is marked as resumed
self.session.expire_all() # Clear session to ensure fresh query
stmt = select(WorkflowPauseModel).where(WorkflowPauseModel.id == pause_entity.id)
resumed_pause_model = self.session.scalar(stmt)
assert resumed_pause_model is not None
assert resumed_pause_model.resumed_at is not None
# Verify workflow run status
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.RUNNING