feat(api): Introduce workflow pause state management (#27298)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1 @@
|
||||
# Core integration tests package
|
||||
@@ -0,0 +1 @@
|
||||
# App integration tests package
|
||||
@@ -0,0 +1 @@
|
||||
# Layers integration tests package
|
||||
@@ -0,0 +1,520 @@
|
||||
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class.
|
||||
|
||||
This test suite covers complete integration scenarios including:
|
||||
- Real database interactions using containerized PostgreSQL
|
||||
- Real storage operations using test storage backend
|
||||
- Complete workflow: event -> state serialization -> database save -> storage save
|
||||
- Testing with actual WorkflowRunService (not mocked)
|
||||
- Real Workflow and WorkflowRun instances in database
|
||||
- Database transactions and rollback behavior
|
||||
- Actual file upload and retrieval through storage
|
||||
- Workflow status transitions in database
|
||||
- Error handling with real database constraints
|
||||
- Multiple pause events in sequence
|
||||
- Integration with real ReadOnlyGraphRuntimeState implementations
|
||||
|
||||
These tests use TestContainers to spin up real services for integration testing,
|
||||
providing more reliable and realistic test scenarios than mocks.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from time import time
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||
from core.workflow.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper
|
||||
from core.workflow.runtime.variable_pool import SystemVariable, VariablePool
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models import WorkflowPause as WorkflowPauseModel
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from services.file_service import FileService
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
class _TestCommandChannelImpl:
|
||||
"""Real implementation of CommandChannel for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._commands: list[GraphEngineCommand] = []
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""Fetch pending commands for this GraphEngine instance."""
|
||||
return self._commands.copy()
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to be processed by this GraphEngine instance."""
|
||||
self._commands.append(command)
|
||||
|
||||
|
||||
class TestPauseStatePersistenceLayerTestContainers:
|
||||
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, db_session_with_containers: Session):
|
||||
"""Get database engine from TestContainers session."""
|
||||
bind = db_session_with_containers.get_bind()
|
||||
assert isinstance(bind, Engine)
|
||||
return bind
|
||||
|
||||
@pytest.fixture
|
||||
def file_service(self, engine: Engine):
|
||||
"""Create FileService instance with TestContainers engine."""
|
||||
return FileService(engine)
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_run_service(self, engine: Engine, file_service: FileService):
|
||||
"""Create WorkflowRunService instance with TestContainers engine and FileService."""
|
||||
return WorkflowRunService(engine)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service):
|
||||
"""Set up test data for each test method using TestContainers."""
|
||||
# Create test tenant and account
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
tenant = Tenant(
|
||||
name="Test Tenant",
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account = Account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set test data
|
||||
self.test_tenant_id = tenant.id
|
||||
self.test_user_id = account.id
|
||||
self.test_app_id = str(uuid.uuid4())
|
||||
self.test_workflow_id = str(uuid.uuid4())
|
||||
self.test_workflow_run_id = str(uuid.uuid4())
|
||||
|
||||
# Create test workflow
|
||||
self.test_workflow = Workflow(
|
||||
id=self.test_workflow_id,
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=self.test_user_id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Create test workflow run
|
||||
self.test_workflow_run = WorkflowRun(
|
||||
id=self.test_workflow_run_id,
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
workflow_id=self.test_workflow_id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by=self.test_user_id,
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store session and service instances
|
||||
self.session = db_session_with_containers
|
||||
self.file_service = file_service
|
||||
self.workflow_run_service = workflow_run_service
|
||||
|
||||
# Save test data to database
|
||||
self.session.add(self.test_workflow)
|
||||
self.session.add(self.test_workflow_run)
|
||||
self.session.commit()
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
self._cleanup_test_data()
|
||||
|
||||
def _cleanup_test_data(self):
|
||||
"""Clean up test data after each test method."""
|
||||
try:
|
||||
# Clean up workflow pauses
|
||||
self.session.execute(delete(WorkflowPauseModel))
|
||||
# Clean up upload files
|
||||
self.session.execute(
|
||||
delete(UploadFile).where(
|
||||
UploadFile.tenant_id == self.test_tenant_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflow runs
|
||||
self.session.execute(
|
||||
delete(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == self.test_tenant_id,
|
||||
WorkflowRun.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflows
|
||||
self.session.execute(
|
||||
delete(Workflow).where(
|
||||
Workflow.tenant_id == self.test_tenant_id,
|
||||
Workflow.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
raise e
|
||||
|
||||
def _create_graph_runtime_state(
|
||||
self,
|
||||
outputs: dict[str, object] | None = None,
|
||||
total_tokens: int = 0,
|
||||
node_run_steps: int = 0,
|
||||
variables: dict[tuple[str, str], object] | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
) -> ReadOnlyGraphRuntimeState:
|
||||
"""Create a real GraphRuntimeState for testing."""
|
||||
start_at = time()
|
||||
|
||||
execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4())
|
||||
|
||||
# Create variable pool
|
||||
variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id))
|
||||
if variables:
|
||||
for (node_id, var_key), value in variables.items():
|
||||
variable_pool.add([node_id, var_key], value)
|
||||
|
||||
# Create LLM usage
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
# Create graph runtime state
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=start_at,
|
||||
total_tokens=total_tokens,
|
||||
llm_usage=llm_usage,
|
||||
outputs=outputs or {},
|
||||
node_run_steps=node_run_steps,
|
||||
)
|
||||
|
||||
return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state)
|
||||
|
||||
def _create_pause_state_persistence_layer(
|
||||
self,
|
||||
workflow_run: WorkflowRun | None = None,
|
||||
workflow: Workflow | None = None,
|
||||
state_owner_user_id: str | None = None,
|
||||
) -> PauseStatePersistenceLayer:
|
||||
"""Create PauseStatePersistenceLayer with real dependencies."""
|
||||
owner_id = state_owner_user_id
|
||||
if owner_id is None:
|
||||
if workflow is not None and workflow.created_by:
|
||||
owner_id = workflow.created_by
|
||||
elif workflow_run is not None and workflow_run.created_by:
|
||||
owner_id = workflow_run.created_by
|
||||
else:
|
||||
owner_id = getattr(self, "test_user_id", None)
|
||||
|
||||
assert owner_id is not None
|
||||
owner_id = str(owner_id)
|
||||
|
||||
return PauseStatePersistenceLayer(
|
||||
session_factory=self.session.get_bind(),
|
||||
state_owner_user_id=owner_id,
|
||||
)
|
||||
|
||||
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
|
||||
"""Test complete pause flow: event -> state serialization -> database save -> storage save."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
|
||||
# Create real graph runtime state with test data
|
||||
test_outputs = {"result": "test_output", "step": "intermediate"}
|
||||
test_variables = {
|
||||
("node1", "var1"): "string_value",
|
||||
("node2", "var2"): {"complex": "object"},
|
||||
}
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs=test_outputs,
|
||||
total_tokens=100,
|
||||
node_run_steps=5,
|
||||
variables=test_variables,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
# Create pause event
|
||||
event = GraphRunPausedEvent(
|
||||
reason=SchedulingPause(message="test pause"),
|
||||
outputs={"intermediate": "result"},
|
||||
)
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Verify pause state was saved to database
|
||||
self.session.refresh(self.test_workflow_run)
|
||||
workflow_run = self.session.get(WorkflowRun, self.test_workflow_run_id)
|
||||
assert workflow_run is not None
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
# Verify pause state exists in database
|
||||
pause_model = self.session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.workflow_id == self.test_workflow_id
|
||||
assert pause_model.workflow_run_id == self.test_workflow_run_id
|
||||
assert pause_model.state_object_key != ""
|
||||
assert pause_model.resumed_at is None
|
||||
|
||||
storage_content = storage.load(pause_model.state_object_key).decode()
|
||||
expected_state = json.loads(graph_runtime_state.dumps())
|
||||
actual_state = json.loads(storage_content)
|
||||
|
||||
assert actual_state == expected_state
|
||||
|
||||
def test_state_persistence_and_retrieval(self, db_session_with_containers):
|
||||
"""Test that pause state can be persisted and retrieved correctly."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
|
||||
# Create complex test data
|
||||
complex_outputs = {
|
||||
"nested": {"key": "value", "number": 42},
|
||||
"list": [1, 2, 3, {"nested": "item"}],
|
||||
"boolean": True,
|
||||
"null_value": None,
|
||||
}
|
||||
complex_variables = {
|
||||
("node1", "var1"): "string_value",
|
||||
("node2", "var2"): {"complex": "object"},
|
||||
("node3", "var3"): [1, 2, 3],
|
||||
}
|
||||
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs=complex_outputs,
|
||||
total_tokens=250,
|
||||
node_run_steps=10,
|
||||
variables=complex_variables,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act - Save pause state
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Retrieve and verify
|
||||
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id)
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
|
||||
|
||||
state_bytes = pause_entity.get_state()
|
||||
retrieved_state = json.loads(state_bytes.decode())
|
||||
expected_state = json.loads(graph_runtime_state.dumps())
|
||||
|
||||
assert retrieved_state == expected_state
|
||||
assert retrieved_state["outputs"] == complex_outputs
|
||||
assert retrieved_state["total_tokens"] == 250
|
||||
assert retrieved_state["node_run_steps"] == 10
|
||||
|
||||
def test_database_transaction_handling(self, db_session_with_containers):
|
||||
"""Test that database transactions are handled correctly."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs={"test": "transaction"},
|
||||
total_tokens=50,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Verify data is committed and accessible in new session
|
||||
with Session(bind=self.session.get_bind(), expire_on_commit=False) as new_session:
|
||||
workflow_run = new_session.get(WorkflowRun, self.test_workflow_run_id)
|
||||
assert workflow_run is not None
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
pause_model = new_session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.workflow_run_id == self.test_workflow_run_id
|
||||
assert pause_model.resumed_at is None
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
def test_file_storage_integration(self, db_session_with_containers):
|
||||
"""Test integration with file storage system."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
|
||||
# Create large state data to test storage
|
||||
large_outputs = {"data": "x" * 10000} # 10KB of data
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs=large_outputs,
|
||||
total_tokens=1000,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Verify file was uploaded to storage
|
||||
self.session.refresh(self.test_workflow_run)
|
||||
pause_model = self.session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
# Verify content in storage
|
||||
storage_content = storage.load(pause_model.state_object_key).decode()
|
||||
assert storage_content == graph_runtime_state.dumps()
|
||||
|
||||
def test_workflow_with_different_creators(self, db_session_with_containers):
|
||||
"""Test pause state with workflows created by different users."""
|
||||
# Arrange - Create workflow with different creator
|
||||
different_user_id = str(uuid.uuid4())
|
||||
different_workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=different_user_id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
different_workflow_run = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
workflow_id=different_workflow.id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by=self.test_user_id, # Run created by different user
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
self.session.add(different_workflow)
|
||||
self.session.add(different_workflow_run)
|
||||
self.session.commit()
|
||||
|
||||
layer = self._create_pause_state_persistence_layer(
|
||||
workflow_run=different_workflow_run,
|
||||
workflow=different_workflow,
|
||||
)
|
||||
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs={"creator_test": "different_creator"},
|
||||
workflow_run_id=different_workflow_run.id,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Should use workflow creator (not run creator)
|
||||
self.session.refresh(different_workflow_run)
|
||||
pause_model = self.session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == different_workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
|
||||
# Verify the state owner is the workflow creator
|
||||
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id)
|
||||
assert pause_entity is not None
|
||||
|
||||
def test_layer_ignores_non_pause_events(self, db_session_with_containers):
|
||||
"""Test that layer ignores non-pause events."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
graph_runtime_state = self._create_graph_runtime_state()
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
# Import other event types
|
||||
from core.workflow.graph_events.graph import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
|
||||
# Act - Send non-pause events
|
||||
layer.on_event(GraphRunStartedEvent())
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={"result": "success"}))
|
||||
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
|
||||
|
||||
# Assert - No pause state should be created
|
||||
self.session.refresh(self.test_workflow_run)
|
||||
assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
|
||||
pause_states = (
|
||||
self.session.query(WorkflowPauseModel)
|
||||
.filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
assert len(pause_states) == 0
|
||||
|
||||
def test_layer_requires_initialization(self, db_session_with_containers):
|
||||
"""Test that layer requires proper initialization before handling events."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
# Don't initialize - graph_runtime_state should not be set
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act & Assert - Should raise AttributeError
|
||||
with pytest.raises(AttributeError):
|
||||
layer.on_event(event)
|
||||
@@ -0,0 +1,948 @@
|
||||
"""Comprehensive integration tests for workflow pause functionality.
|
||||
|
||||
This test suite covers complete workflow pause functionality including:
|
||||
- Real database interactions using containerized PostgreSQL
|
||||
- Real storage operations using the test storage backend
|
||||
- Complete workflow: create -> pause -> resume -> delete
|
||||
- Testing with actual FileService (not mocked)
|
||||
- Database transactions and rollback behavior
|
||||
- Actual file upload and retrieval through storage
|
||||
- Workflow status transitions in the database
|
||||
- Error handling with real database constraints
|
||||
- Concurrent access scenarios
|
||||
- Multi-tenant isolation
|
||||
- Prune functionality
|
||||
- File storage integration
|
||||
|
||||
These tests use TestContainers to spin up real services for integration testing,
|
||||
providing more reliable and realistic test scenarios than mocks.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.entities import WorkflowExecution
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models import WorkflowPause as WorkflowPauseModel
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_WorkflowRunError,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PauseWorkflowSuccessCase:
|
||||
"""Test case for successful pause workflow operations."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PauseWorkflowFailureCase:
|
||||
"""Test case for pause workflow failure scenarios."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumeWorkflowSuccessCase:
|
||||
"""Test case for successful resume workflow operations."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumeWorkflowFailureCase:
|
||||
"""Test case for resume workflow failure scenarios."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
pause_resumed: bool
|
||||
set_running_status: bool = False
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrunePausesTestCase:
|
||||
"""Test case for prune pauses operations."""
|
||||
|
||||
name: str
|
||||
pause_age: timedelta
|
||||
resume_age: timedelta | None
|
||||
expected_pruned_count: int
|
||||
description: str = ""
|
||||
|
||||
|
||||
def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]:
|
||||
"""Create test cases for pause workflow failure scenarios."""
|
||||
return [
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_already_paused_workflow",
|
||||
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||
description="Should fail to pause an already paused workflow",
|
||||
),
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_completed_workflow",
|
||||
initial_status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
description="Should fail to pause a completed workflow",
|
||||
),
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_failed_workflow",
|
||||
initial_status=WorkflowExecutionStatus.FAILED,
|
||||
description="Should fail to pause a failed workflow",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def resume_workflow_success_cases() -> list[ResumeWorkflowSuccessCase]:
|
||||
"""Create test cases for successful resume workflow operations."""
|
||||
return [
|
||||
ResumeWorkflowSuccessCase(
|
||||
name="resume_paused_workflow",
|
||||
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||
description="Should successfully resume a paused workflow",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def resume_workflow_failure_cases() -> list[ResumeWorkflowFailureCase]:
|
||||
"""Create test cases for resume workflow failure scenarios."""
|
||||
return [
|
||||
ResumeWorkflowFailureCase(
|
||||
name="resume_already_resumed_workflow",
|
||||
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||
pause_resumed=True,
|
||||
description="Should fail to resume an already resumed workflow",
|
||||
),
|
||||
ResumeWorkflowFailureCase(
|
||||
name="resume_running_workflow",
|
||||
initial_status=WorkflowExecutionStatus.RUNNING,
|
||||
pause_resumed=False,
|
||||
set_running_status=True,
|
||||
description="Should fail to resume a running workflow",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def prune_pauses_test_cases() -> list[PrunePausesTestCase]:
|
||||
"""Create test cases for prune pauses operations."""
|
||||
return [
|
||||
PrunePausesTestCase(
|
||||
name="prune_old_active_pauses",
|
||||
pause_age=timedelta(days=7),
|
||||
resume_age=None,
|
||||
expected_pruned_count=1,
|
||||
description="Should prune old active pauses",
|
||||
),
|
||||
PrunePausesTestCase(
|
||||
name="prune_old_resumed_pauses",
|
||||
pause_age=timedelta(hours=12), # Created 12 hours ago (recent)
|
||||
resume_age=timedelta(days=7),
|
||||
expected_pruned_count=1,
|
||||
description="Should prune old resumed pauses",
|
||||
),
|
||||
PrunePausesTestCase(
|
||||
name="keep_recent_active_pauses",
|
||||
pause_age=timedelta(hours=1),
|
||||
resume_age=None,
|
||||
expected_pruned_count=0,
|
||||
description="Should keep recent active pauses",
|
||||
),
|
||||
PrunePausesTestCase(
|
||||
name="keep_recent_resumed_pauses",
|
||||
pause_age=timedelta(days=1),
|
||||
resume_age=timedelta(hours=1),
|
||||
expected_pruned_count=0,
|
||||
description="Should keep recent resumed pauses",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestWorkflowPauseIntegration:
|
||||
"""Comprehensive integration tests for workflow pause functionality."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_data(self, db_session_with_containers):
|
||||
"""Set up test data for each test method using TestContainers."""
|
||||
# Create test tenant and account
|
||||
|
||||
tenant = Tenant(
|
||||
name="Test Tenant",
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account = Account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set test data
|
||||
self.test_tenant_id = tenant.id
|
||||
self.test_user_id = account.id
|
||||
self.test_app_id = str(uuid.uuid4())
|
||||
self.test_workflow_id = str(uuid.uuid4())
|
||||
|
||||
# Create test workflow
|
||||
self.test_workflow = Workflow(
|
||||
id=self.test_workflow_id,
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=self.test_user_id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store session instance
|
||||
self.session = db_session_with_containers
|
||||
|
||||
# Save test data to database
|
||||
self.session.add(self.test_workflow)
|
||||
self.session.commit()
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
self._cleanup_test_data()
|
||||
|
||||
def _cleanup_test_data(self):
|
||||
"""Clean up test data after each test method."""
|
||||
# Clean up workflow pauses
|
||||
self.session.execute(delete(WorkflowPauseModel))
|
||||
# Clean up upload files
|
||||
self.session.execute(
|
||||
delete(UploadFile).where(
|
||||
UploadFile.tenant_id == self.test_tenant_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflow runs
|
||||
self.session.execute(
|
||||
delete(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == self.test_tenant_id,
|
||||
WorkflowRun.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflows
|
||||
self.session.execute(
|
||||
delete(Workflow).where(
|
||||
Workflow.tenant_id == self.test_tenant_id,
|
||||
Workflow.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
|
||||
def _create_test_workflow_run(
|
||||
self, status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
|
||||
) -> WorkflowRun:
|
||||
"""Create a test workflow run with specified status."""
|
||||
workflow_run = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
workflow_id=self.test_workflow_id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=status,
|
||||
created_by=self.test_user_id,
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
self.session.add(workflow_run)
|
||||
self.session.commit()
|
||||
return workflow_run
|
||||
|
||||
def _create_test_state(self) -> str:
|
||||
"""Create a test state string."""
|
||||
return json.dumps(
|
||||
{
|
||||
"node_id": "test-node",
|
||||
"node_type": "llm",
|
||||
"status": "paused",
|
||||
"data": {"key": "value"},
|
||||
"timestamp": naive_utc_now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
def _get_workflow_run_repository(self):
|
||||
"""Get workflow run repository instance for testing."""
|
||||
# Create session factory from the test session
|
||||
engine = self.session.get_bind()
|
||||
session_factory = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
# Create a test-specific repository that implements the missing save method
|
||||
class TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test-specific repository that implements the missing save method."""
|
||||
|
||||
def save(self, execution: WorkflowExecution):
|
||||
"""Implement the missing save method for testing."""
|
||||
# For testing purposes, we don't need to implement this method
|
||||
# as it's not used in the pause functionality tests
|
||||
pass
|
||||
|
||||
# Create and return repository instance
|
||||
repository = TestWorkflowRunRepository(session_maker=session_factory)
|
||||
return repository
|
||||
|
||||
# ==================== Complete Pause Workflow Tests ====================
|
||||
|
||||
def test_complete_pause_resume_workflow(self):
|
||||
"""Test complete workflow: create -> pause -> resume -> delete."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act - Create pause state
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Assert - Pause state created
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.id is not None
|
||||
assert pause_entity.workflow_execution_id == workflow_run.id
|
||||
# Convert both to strings for comparison
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
# Verify database state
|
||||
query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
pause_model = self.session.scalars(query).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.resumed_at is None
|
||||
assert pause_model.id == pause_entity.id
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
# Act - Get pause state
|
||||
retrieved_entity = repository.get_workflow_pause(workflow_run.id)
|
||||
|
||||
# Assert - Pause state retrieved
|
||||
assert retrieved_entity is not None
|
||||
assert retrieved_entity.id == pause_entity.id
|
||||
retrieved_state = retrieved_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
# Act - Resume workflow
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# Assert - Workflow resumed
|
||||
assert resumed_entity is not None
|
||||
assert resumed_entity.id == pause_entity.id
|
||||
assert resumed_entity.resumed_at is not None
|
||||
|
||||
# Verify database state
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
self.session.refresh(pause_model)
|
||||
assert pause_model.resumed_at is not None
|
||||
|
||||
# Act - Delete pause state
|
||||
repository.delete_workflow_pause(pause_entity)
|
||||
|
||||
# Assert - Pause state deleted
|
||||
with Session(bind=self.session.get_bind()) as session:
|
||||
deleted_pause = session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert deleted_pause is None
|
||||
|
||||
def test_pause_workflow_success(self):
|
||||
"""Test successful pause workflow scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.workflow_execution_id == workflow_run.id
|
||||
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
pause_model = self.session.scalars(pause_query).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.id == pause_entity.id
|
||||
assert pause_model.resumed_at is None
|
||||
|
||||
@pytest.mark.parametrize("test_case", pause_workflow_failure_cases(), ids=lambda tc: tc.name)
|
||||
def test_pause_workflow_failure(self, test_case: PauseWorkflowFailureCase):
|
||||
"""Test pause workflow failure scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
with pytest.raises(_WorkflowRunError):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name)
|
||||
def test_resume_workflow_success(self, test_case: ResumeWorkflowSuccessCase):
|
||||
"""Test successful resume workflow scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
self.session.commit()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
assert resumed_entity is not None
|
||||
assert resumed_entity.id == pause_entity.id
|
||||
assert resumed_entity.resumed_at is not None
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
pause_model = self.session.scalars(pause_query).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.id == pause_entity.id
|
||||
assert pause_model.resumed_at is not None
|
||||
|
||||
def test_resume_running_workflow(self):
|
||||
"""Test resume workflow failure scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
self.session.add(workflow_run)
|
||||
self.session.commit()
|
||||
|
||||
with pytest.raises(_WorkflowRunError):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
def test_resume_resumed_pause(self):
|
||||
"""Test resume workflow failure scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model.resumed_at = naive_utc_now()
|
||||
self.session.add(pause_model)
|
||||
self.session.commit()
|
||||
|
||||
with pytest.raises(_WorkflowRunError):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# ==================== Error Scenario Tests ====================
|
||||
|
||||
def test_pause_nonexistent_workflow_run(self):
|
||||
"""Test pausing a non-existent workflow run."""
|
||||
# Arrange
|
||||
nonexistent_id = str(uuid.uuid4())
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="WorkflowRun not found"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id=nonexistent_id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
def test_resume_nonexistent_workflow_run(self):
|
||||
"""Test resuming a non-existent workflow run."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
nonexistent_id = str(uuid.uuid4())
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="WorkflowRun not found"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=nonexistent_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# ==================== Prune Functionality Tests ====================
|
||||
|
||||
@pytest.mark.parametrize("test_case", prune_pauses_test_cases(), ids=lambda tc: tc.name)
|
||||
def test_prune_pauses_scenarios(self, test_case: PrunePausesTestCase):
|
||||
"""Test various prune pauses scenarios."""
|
||||
now = naive_utc_now()
|
||||
|
||||
# Create pause state
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Manually adjust timestamps for testing
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model.created_at = now - test_case.pause_age
|
||||
|
||||
if test_case.resume_age is not None:
|
||||
# Resume pause and adjust resume time
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
# Need to refresh to get the updated model
|
||||
self.session.refresh(pause_model)
|
||||
# Manually set the resumed_at to an older time for testing
|
||||
pause_model.resumed_at = now - test_case.resume_age
|
||||
self.session.commit() # Commit the resumed_at change
|
||||
# Refresh again to ensure the change is persisted
|
||||
self.session.refresh(pause_model)
|
||||
|
||||
self.session.commit()
|
||||
|
||||
# Act - Prune pauses
|
||||
expiration_time = now - timedelta(days=1, seconds=1) # Expire pauses older than 1 day (plus 1 second)
|
||||
resumption_time = now - timedelta(
|
||||
days=7, seconds=1
|
||||
) # Clean up pauses resumed more than 7 days ago (plus 1 second)
|
||||
|
||||
# Debug: Check pause state before pruning
|
||||
self.session.refresh(pause_model)
|
||||
print(f"Pause created_at: {pause_model.created_at}")
|
||||
print(f"Pause resumed_at: {pause_model.resumed_at}")
|
||||
print(f"Expiration time: {expiration_time}")
|
||||
print(f"Resumption time: {resumption_time}")
|
||||
|
||||
# Force commit to ensure timestamps are saved
|
||||
self.session.commit()
|
||||
|
||||
# Determine if the pause should be pruned based on timestamps
|
||||
should_be_pruned = False
|
||||
if test_case.resume_age is not None:
|
||||
# If resumed, check if resumed_at is older than resumption_time
|
||||
should_be_pruned = pause_model.resumed_at < resumption_time
|
||||
else:
|
||||
# If not resumed, check if created_at is older than expiration_time
|
||||
should_be_pruned = pause_model.created_at < expiration_time
|
||||
|
||||
# Act - Prune pauses
|
||||
pruned_ids = repository.prune_pauses(
|
||||
expiration=expiration_time,
|
||||
resumption_expiration=resumption_time,
|
||||
)
|
||||
|
||||
# Assert - Check pruning results
|
||||
if should_be_pruned:
|
||||
assert len(pruned_ids) == test_case.expected_pruned_count
|
||||
# Verify pause was actually deleted
|
||||
# The pause should be in the pruned_ids list if it was pruned
|
||||
assert pause_entity.id in pruned_ids
|
||||
else:
|
||||
assert len(pruned_ids) == 0
|
||||
|
||||
def test_prune_pauses_with_limit(self):
|
||||
"""Test prune pauses with limit parameter."""
|
||||
now = naive_utc_now()
|
||||
|
||||
# Create multiple pause states
|
||||
pause_entities = []
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
for i in range(5):
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
pause_entities.append(pause_entity)
|
||||
|
||||
# Make all pauses old enough to be pruned
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model.created_at = now - timedelta(days=7)
|
||||
|
||||
self.session.commit()
|
||||
|
||||
# Act - Prune with limit
|
||||
expiration_time = now - timedelta(days=1)
|
||||
resumption_time = now - timedelta(days=7)
|
||||
|
||||
pruned_ids = repository.prune_pauses(
|
||||
expiration=expiration_time,
|
||||
resumption_expiration=resumption_time,
|
||||
limit=3,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(pruned_ids) == 3
|
||||
|
||||
# Verify only 3 were deleted
|
||||
remaining_count = (
|
||||
self.session.query(WorkflowPauseModel)
|
||||
.filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]))
|
||||
.count()
|
||||
)
|
||||
assert remaining_count == 2
|
||||
|
||||
# ==================== Multi-tenant Isolation Tests ====================
|
||||
|
||||
def test_multi_tenant_pause_isolation(self):
|
||||
"""Test that pause states are properly isolated by tenant."""
|
||||
# Arrange - Create second tenant
|
||||
|
||||
tenant2 = Tenant(
|
||||
name="Test Tenant 2",
|
||||
status="normal",
|
||||
)
|
||||
self.session.add(tenant2)
|
||||
self.session.commit()
|
||||
|
||||
account2 = Account(
|
||||
email="test2@example.com",
|
||||
name="Test User 2",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
self.session.add(account2)
|
||||
self.session.commit()
|
||||
|
||||
tenant2_join = TenantAccountJoin(
|
||||
tenant_id=tenant2.id,
|
||||
account_id=account2.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
self.session.add(tenant2_join)
|
||||
self.session.commit()
|
||||
|
||||
# Create workflow for tenant 2
|
||||
workflow2 = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant2.id,
|
||||
app_id=str(uuid.uuid4()),
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=account2.id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
self.session.add(workflow2)
|
||||
self.session.commit()
|
||||
|
||||
# Create workflow runs for both tenants
|
||||
workflow_run1 = self._create_test_workflow_run()
|
||||
workflow_run2 = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant2.id,
|
||||
app_id=workflow2.app_id,
|
||||
workflow_id=workflow2.id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by=account2.id,
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
self.session.add(workflow_run2)
|
||||
self.session.commit()
|
||||
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act - Create pause for tenant 1
|
||||
pause_entity1 = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run1.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Try to access pause from tenant 2 using tenant 1's repository
|
||||
# This should work because we're using the same repository
|
||||
pause_entity2 = repository.get_workflow_pause(workflow_run2.id)
|
||||
assert pause_entity2 is None # No pause for tenant 2 yet
|
||||
|
||||
# Create pause for tenant 2
|
||||
pause_entity2 = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run2.id,
|
||||
state_owner_user_id=account2.id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Assert - Both pauses should exist and be separate
|
||||
assert pause_entity1 is not None
|
||||
assert pause_entity2 is not None
|
||||
assert pause_entity1.id != pause_entity2.id
|
||||
assert pause_entity1.workflow_execution_id != pause_entity2.workflow_execution_id
|
||||
|
||||
def test_cross_tenant_access_restriction(self):
|
||||
"""Test that cross-tenant access is properly restricted."""
|
||||
# This test would require tenant-specific repositories
|
||||
# For now, we test that pause entities are properly scoped by tenant_id
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Verify pause is properly scoped
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert pause_model.workflow_id == self.test_workflow_id
|
||||
|
||||
# ==================== File Storage Integration Tests ====================
|
||||
|
||||
def test_file_storage_integration(self):
|
||||
"""Test that state files are properly stored and retrieved."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act - Create pause state
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Assert - Verify file was uploaded to storage
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
# Verify file content in storage
|
||||
|
||||
file_key = pause_model.state_object_key
|
||||
storage_content = storage.load(file_key).decode()
|
||||
assert storage_content == test_state
|
||||
|
||||
# Verify retrieval through entity
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
def test_file_cleanup_on_pause_deletion(self):
|
||||
"""Test that files are properly handled on pause deletion."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Get file info before deletion
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
file_key = pause_model.state_object_key
|
||||
|
||||
# Act - Delete pause state
|
||||
repository.delete_workflow_pause(pause_entity)
|
||||
|
||||
# Assert - Pause record should be deleted
|
||||
self.session.expire_all() # Clear session to ensure fresh query
|
||||
deleted_pause = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert deleted_pause is None
|
||||
|
||||
try:
|
||||
content = storage.load(file_key).decode()
|
||||
pytest.fail("File should be deleted from storage after pause deletion")
|
||||
except FileNotFoundError:
|
||||
# This is expected - file should be deleted from storage
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Unexpected error when checking file deletion: {e}")
|
||||
|
||||
def test_large_state_file_handling(self):
|
||||
"""Test handling of large state files."""
|
||||
# Arrange - Create a large state (1MB)
|
||||
large_state = "x" * (1024 * 1024) # 1MB of data
|
||||
large_state_json = json.dumps({"large_data": large_state})
|
||||
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=large_state_json,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert pause_entity is not None
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == large_state_json
|
||||
|
||||
# Verify file size in database
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert pause_model.state_object_key != ""
|
||||
loaded_state = storage.load(pause_model.state_object_key)
|
||||
assert loaded_state.decode() == large_state_json
|
||||
|
||||
def test_multiple_pause_resume_cycles(self):
|
||||
"""Test multiple pause/resume cycles on the same workflow run."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act & Assert - Multiple cycles
|
||||
for i in range(3):
|
||||
state = json.dumps({"cycle": i, "data": f"state_{i}"})
|
||||
|
||||
# Reset workflow run status to RUNNING before each pause (after first cycle)
|
||||
if i > 0:
|
||||
self.session.refresh(workflow_run) # Refresh to get latest state from session
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
self.session.commit()
|
||||
self.session.refresh(workflow_run) # Refresh again after commit
|
||||
|
||||
# Pause
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=state,
|
||||
)
|
||||
assert pause_entity is not None
|
||||
|
||||
# Verify pause
|
||||
self.session.expire_all() # Clear session to ensure fresh query
|
||||
self.session.refresh(workflow_run)
|
||||
|
||||
# Use the test session directly to verify the pause
|
||||
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run.id)
|
||||
workflow_run_with_pause = self.session.scalar(stmt)
|
||||
pause_model = workflow_run_with_pause.pause
|
||||
|
||||
# Verify pause using test session directly
|
||||
assert pause_model is not None
|
||||
assert pause_model.id == pause_entity.id
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
# Load file content using storage directly
|
||||
file_content = storage.load(pause_model.state_object_key)
|
||||
if isinstance(file_content, bytes):
|
||||
file_content = file_content.decode()
|
||||
assert file_content == state
|
||||
|
||||
# Resume
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
assert resumed_entity is not None
|
||||
assert resumed_entity.resumed_at is not None
|
||||
|
||||
# Verify resume - check that pause is marked as resumed
|
||||
self.session.expire_all() # Clear session to ensure fresh query
|
||||
stmt = select(WorkflowPauseModel).where(WorkflowPauseModel.id == pause_entity.id)
|
||||
resumed_pause_model = self.session.scalar(stmt)
|
||||
assert resumed_pause_model is not None
|
||||
assert resumed_pause_model.resumed_at is not None
|
||||
|
||||
# Verify workflow run status
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
Reference in New Issue
Block a user