feat(api): Introduce workflow pause state management (#27298)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,278 @@
|
||||
import json
|
||||
from time import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||
from core.workflow.graph_events.graph import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class TestDataFactory:
|
||||
"""Factory helpers for constructing graph events used in tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
|
||||
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_started_event() -> GraphRunStartedEvent:
|
||||
return GraphRunStartedEvent()
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_succeeded_event(outputs: dict[str, object] | None = None) -> GraphRunSucceededEvent:
|
||||
return GraphRunSucceededEvent(outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_failed_event(
|
||||
error: str = "Test error",
|
||||
exceptions_count: int = 1,
|
||||
) -> GraphRunFailedEvent:
|
||||
return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count)
|
||||
|
||||
|
||||
class MockSystemVariableReadOnlyView:
|
||||
"""Minimal read-only system variable view for testing."""
|
||||
|
||||
def __init__(self, workflow_execution_id: str | None = None) -> None:
|
||||
self._workflow_execution_id = workflow_execution_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str | None:
|
||||
return self._workflow_execution_id
|
||||
|
||||
|
||||
class MockReadOnlyVariablePool:
|
||||
"""Mock implementation of ReadOnlyVariablePool for testing."""
|
||||
|
||||
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
|
||||
self._variables = variables or {}
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
value = self._variables.get((node_id, variable_key))
|
||||
if value is None:
|
||||
return None
|
||||
mock_segment = Mock(spec=Segment)
|
||||
mock_segment.value = value
|
||||
return mock_segment
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> dict[str, object]:
|
||||
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> dict[str, object]:
|
||||
return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)}
|
||||
|
||||
|
||||
class MockReadOnlyGraphRuntimeState:
|
||||
"""Mock implementation of ReadOnlyGraphRuntimeState for testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_at: float | None = None,
|
||||
total_tokens: int = 0,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue_size: int = 0,
|
||||
exceptions_count: int = 0,
|
||||
outputs: dict[str, object] | None = None,
|
||||
variables: dict[tuple[str, str], object] | None = None,
|
||||
workflow_execution_id: str | None = None,
|
||||
):
|
||||
self._start_at = start_at or time()
|
||||
self._total_tokens = total_tokens
|
||||
self._node_run_steps = node_run_steps
|
||||
self._ready_queue_size = ready_queue_size
|
||||
self._exceptions_count = exceptions_count
|
||||
self._outputs = outputs or {}
|
||||
self._variable_pool = MockReadOnlyVariablePool(variables)
|
||||
self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id)
|
||||
|
||||
@property
|
||||
def system_variable(self) -> MockSystemVariableReadOnlyView:
|
||||
return self._system_variable
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||
return self._variable_pool
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
return self._start_at
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self._total_tokens
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
return self._node_run_steps
|
||||
|
||||
@property
|
||||
def ready_queue_size(self) -> int:
|
||||
return self._ready_queue_size
|
||||
|
||||
@property
|
||||
def exceptions_count(self) -> int:
|
||||
return self._exceptions_count
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, object]:
|
||||
return self._outputs.copy()
|
||||
|
||||
@property
|
||||
def llm_usage(self):
|
||||
mock_usage = Mock()
|
||||
mock_usage.prompt_tokens = 10
|
||||
mock_usage.completion_tokens = 20
|
||||
mock_usage.total_tokens = 30
|
||||
return mock_usage
|
||||
|
||||
def get_output(self, key: str, default: object = None) -> object:
|
||||
return self._outputs.get(key, default)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"start_at": self._start_at,
|
||||
"total_tokens": self._total_tokens,
|
||||
"node_run_steps": self._node_run_steps,
|
||||
"ready_queue_size": self._ready_queue_size,
|
||||
"exceptions_count": self._exceptions_count,
|
||||
"outputs": self._outputs,
|
||||
"variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()},
|
||||
"workflow_execution_id": self._system_variable.workflow_execution_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class MockCommandChannel:
|
||||
"""Mock implementation of CommandChannel for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._commands: list[GraphEngineCommand] = []
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
return self._commands.copy()
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
self._commands.append(command)
|
||||
|
||||
|
||||
class TestPauseStatePersistenceLayer:
|
||||
"""Unit tests for PauseStatePersistenceLayer."""
|
||||
|
||||
def test_init_with_dependency_injection(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
state_owner_user_id = "user-123"
|
||||
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
)
|
||||
|
||||
assert layer._session_maker is session_factory
|
||||
assert layer._state_owner_user_id == state_owner_user_id
|
||||
assert not hasattr(layer, "graph_runtime_state")
|
||||
assert not hasattr(layer, "command_channel")
|
||||
|
||||
def test_initialize_sets_dependencies(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner")
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||
command_channel = MockCommandChannel()
|
||||
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
assert layer.graph_runtime_state is graph_runtime_state
|
||||
assert layer.command_channel is command_channel
|
||||
|
||||
def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState(
|
||||
outputs={"result": "test_output"},
|
||||
total_tokens=100,
|
||||
workflow_execution_id="run-123",
|
||||
)
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event(outputs={"intermediate": "result"})
|
||||
expected_state = graph_runtime_state.dumps()
|
||||
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_called_once_with(session_factory)
|
||||
mock_repo.create_workflow_pause.assert_called_once_with(
|
||||
workflow_run_id="run-123",
|
||||
state_owner_user_id="owner-123",
|
||||
state=expected_state,
|
||||
)
|
||||
|
||||
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
events = [
|
||||
TestDataFactory.create_graph_run_started_event(),
|
||||
TestDataFactory.create_graph_run_succeeded_event(),
|
||||
TestDataFactory.create_graph_run_failed_event(),
|
||||
]
|
||||
|
||||
for event in events:
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_not_called()
|
||||
mock_repo.create_workflow_pause.assert_not_called()
|
||||
|
||||
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event()
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
layer.on_event(event)
|
||||
|
||||
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState(workflow_execution_id=None)
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_not_called()
|
||||
mock_repo.create_workflow_pause.assert_not_called()
|
||||
@@ -0,0 +1,171 @@
|
||||
"""Tests for _PrivateWorkflowPauseEntity implementation."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity:
|
||||
"""Test _PrivateWorkflowPauseEntity implementation."""
|
||||
|
||||
def test_entity_initialization(self):
|
||||
"""Test entity initialization with required parameters."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
# Create entity
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert entity._pause_model is mock_pause_model
|
||||
assert entity._cached_state is None
|
||||
|
||||
def test_from_models_classmethod(self):
|
||||
"""Test from_models class method."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
# Create entity using from_models
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(
|
||||
workflow_pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify entity creation
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model is mock_pause_model
|
||||
|
||||
def test_id_property(self):
|
||||
"""Test id property returns pause model ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.id == "pause-123"
|
||||
|
||||
def test_workflow_execution_id_property(self):
|
||||
"""Test workflow_execution_id property returns workflow run ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.workflow_execution_id == "execution-456"
|
||||
|
||||
def test_resumed_at_property(self):
|
||||
"""Test resumed_at property returns pause model resumed_at."""
|
||||
resumed_at = datetime(2023, 12, 25, 15, 30, 45)
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = resumed_at
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.resumed_at == resumed_at
|
||||
|
||||
def test_resumed_at_property_none(self):
|
||||
"""Test resumed_at property returns None when not set."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.resumed_at is None
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_first_call(self, mock_storage):
|
||||
"""Test get_state loads from storage on first call."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
mock_storage.load.return_value = state_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# First call should load from storage
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == state_data
|
||||
mock_storage.load.assert_called_once_with("test-state-key")
|
||||
assert entity._cached_state == state_data
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_cached_call(self, mock_storage):
|
||||
"""Test get_state returns cached data on subsequent calls."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
mock_storage.load.return_value = state_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# First call
|
||||
result1 = entity.get_state()
|
||||
# Second call should use cache
|
||||
result2 = entity.get_state()
|
||||
|
||||
assert result1 == state_data
|
||||
assert result2 == state_data
|
||||
# Storage should only be called once
|
||||
mock_storage.load.assert_called_once_with("test-state-key")
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_with_pre_cached_data(self, mock_storage):
|
||||
"""Test get_state returns pre-cached data."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Pre-cache data
|
||||
entity._cached_state = state_data
|
||||
|
||||
# Should return cached data without calling storage
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == state_data
|
||||
mock_storage.load.assert_not_called()
|
||||
|
||||
def test_entity_with_binary_state_data(self):
|
||||
"""Test entity with binary state data."""
|
||||
# Test with binary data that's not valid JSON
|
||||
binary_data = b"\x00\x01\x02\x03\x04\x05\xff\xfe"
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = binary_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == binary_data
|
||||
@@ -3,6 +3,7 @@
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
@@ -149,8 +150,8 @@ def test_pause_command():
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
|
||||
assert len(pause_events) == 1
|
||||
assert pause_events[0].reason == "User requested pause"
|
||||
assert pause_events[0].reason == SchedulingPause(message="User requested pause")
|
||||
|
||||
graph_execution = engine.graph_runtime_state.graph_execution
|
||||
assert graph_execution.is_paused
|
||||
assert graph_execution.pause_reason == "User requested pause"
|
||||
assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")
|
||||
|
||||
32
api/tests/unit_tests/core/workflow/test_enums.py
Normal file
32
api/tests/unit_tests/core/workflow/test_enums.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Tests for workflow pause related enums and constants."""
|
||||
|
||||
from core.workflow.enums import (
|
||||
WorkflowExecutionStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowExecutionStatus:
|
||||
"""Test WorkflowExecutionStatus enum."""
|
||||
|
||||
def test_is_ended_method(self):
|
||||
"""Test is_ended method for different statuses."""
|
||||
# Test ended statuses
|
||||
ended_statuses = [
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
WorkflowExecutionStatus.STOPPED,
|
||||
]
|
||||
|
||||
for status in ended_statuses:
|
||||
assert status.is_ended(), f"{status} should be considered ended"
|
||||
|
||||
# Test non-ended statuses
|
||||
non_ended_statuses = [
|
||||
WorkflowExecutionStatus.SCHEDULED,
|
||||
WorkflowExecutionStatus.RUNNING,
|
||||
WorkflowExecutionStatus.PAUSED,
|
||||
]
|
||||
|
||||
for status in non_ended_statuses:
|
||||
assert not status.is_ended(), f"{status} should not be considered ended"
|
||||
@@ -0,0 +1,202 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file.models import File, FileTransferMethod, FileType
|
||||
from core.workflow.system_variable import SystemVariable, SystemVariableReadOnlyView
|
||||
|
||||
|
||||
class TestSystemVariableReadOnlyView:
|
||||
"""Test cases for SystemVariableReadOnlyView class."""
|
||||
|
||||
def test_read_only_property_access(self):
|
||||
"""Test that all properties return correct values from wrapped instance."""
|
||||
# Create test data
|
||||
test_file = File(
|
||||
id="file-123",
|
||||
tenant_id="tenant-123",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related-123",
|
||||
)
|
||||
|
||||
datasource_info = {"key": "value", "nested": {"data": 42}}
|
||||
|
||||
# Create SystemVariable with all fields
|
||||
system_var = SystemVariable(
|
||||
user_id="user-123",
|
||||
app_id="app-123",
|
||||
workflow_id="workflow-123",
|
||||
files=[test_file],
|
||||
workflow_execution_id="exec-123",
|
||||
query="test query",
|
||||
conversation_id="conv-123",
|
||||
dialogue_count=5,
|
||||
document_id="doc-123",
|
||||
original_document_id="orig-doc-123",
|
||||
dataset_id="dataset-123",
|
||||
batch="batch-123",
|
||||
datasource_type="type-123",
|
||||
datasource_info=datasource_info,
|
||||
invoke_from="invoke-123",
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test all properties
|
||||
assert read_only_view.user_id == "user-123"
|
||||
assert read_only_view.app_id == "app-123"
|
||||
assert read_only_view.workflow_id == "workflow-123"
|
||||
assert read_only_view.workflow_execution_id == "exec-123"
|
||||
assert read_only_view.query == "test query"
|
||||
assert read_only_view.conversation_id == "conv-123"
|
||||
assert read_only_view.dialogue_count == 5
|
||||
assert read_only_view.document_id == "doc-123"
|
||||
assert read_only_view.original_document_id == "orig-doc-123"
|
||||
assert read_only_view.dataset_id == "dataset-123"
|
||||
assert read_only_view.batch == "batch-123"
|
||||
assert read_only_view.datasource_type == "type-123"
|
||||
assert read_only_view.invoke_from == "invoke-123"
|
||||
|
||||
def test_defensive_copying_of_mutable_objects(self):
|
||||
"""Test that mutable objects are defensively copied."""
|
||||
# Create test data
|
||||
test_file = File(
|
||||
id="file-123",
|
||||
tenant_id="tenant-123",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related-123",
|
||||
)
|
||||
|
||||
datasource_info = {"key": "original_value"}
|
||||
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(
|
||||
files=[test_file], datasource_info=datasource_info, workflow_execution_id="exec-123"
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test files defensive copying
|
||||
files_copy = read_only_view.files
|
||||
assert isinstance(files_copy, tuple) # Should be immutable tuple
|
||||
assert len(files_copy) == 1
|
||||
assert files_copy[0].id == "file-123"
|
||||
|
||||
# Verify it's a copy (can't modify original through view)
|
||||
assert isinstance(files_copy, tuple)
|
||||
# tuples don't have append method, so they're immutable
|
||||
|
||||
# Test datasource_info defensive copying
|
||||
datasource_copy = read_only_view.datasource_info
|
||||
assert datasource_copy is not None
|
||||
assert datasource_copy["key"] == "original_value"
|
||||
|
||||
datasource_copy = cast(dict, datasource_copy)
|
||||
with pytest.raises(TypeError):
|
||||
datasource_copy["key"] = "modified value"
|
||||
|
||||
# Verify original is unchanged
|
||||
assert system_var.datasource_info is not None
|
||||
assert system_var.datasource_info["key"] == "original_value"
|
||||
assert read_only_view.datasource_info is not None
|
||||
assert read_only_view.datasource_info["key"] == "original_value"
|
||||
|
||||
def test_always_accesses_latest_data(self):
|
||||
"""Test that properties always return the latest data from wrapped instance."""
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(user_id="original-user", workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Verify initial value
|
||||
assert read_only_view.user_id == "original-user"
|
||||
|
||||
# Modify the wrapped instance
|
||||
system_var.user_id = "modified-user"
|
||||
|
||||
# Verify view returns the new value
|
||||
assert read_only_view.user_id == "modified-user"
|
||||
|
||||
def test_repr_method(self):
|
||||
"""Test the __repr__ method."""
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test repr
|
||||
repr_str = repr(read_only_view)
|
||||
assert "SystemVariableReadOnlyView" in repr_str
|
||||
assert "system_variable=" in repr_str
|
||||
|
||||
def test_none_value_handling(self):
|
||||
"""Test that None values are properly handled."""
|
||||
# Create SystemVariable with all None values except workflow_execution_id
|
||||
system_var = SystemVariable(
|
||||
user_id=None,
|
||||
app_id=None,
|
||||
workflow_id=None,
|
||||
workflow_execution_id="exec-123",
|
||||
query=None,
|
||||
conversation_id=None,
|
||||
dialogue_count=None,
|
||||
document_id=None,
|
||||
original_document_id=None,
|
||||
dataset_id=None,
|
||||
batch=None,
|
||||
datasource_type=None,
|
||||
datasource_info=None,
|
||||
invoke_from=None,
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test all None values
|
||||
assert read_only_view.user_id is None
|
||||
assert read_only_view.app_id is None
|
||||
assert read_only_view.workflow_id is None
|
||||
assert read_only_view.query is None
|
||||
assert read_only_view.conversation_id is None
|
||||
assert read_only_view.dialogue_count is None
|
||||
assert read_only_view.document_id is None
|
||||
assert read_only_view.original_document_id is None
|
||||
assert read_only_view.dataset_id is None
|
||||
assert read_only_view.batch is None
|
||||
assert read_only_view.datasource_type is None
|
||||
assert read_only_view.datasource_info is None
|
||||
assert read_only_view.invoke_from is None
|
||||
|
||||
# files should be empty tuple even when default list is empty
|
||||
assert read_only_view.files == ()
|
||||
|
||||
def test_empty_files_handling(self):
|
||||
"""Test that empty files list is handled correctly."""
|
||||
# Create SystemVariable with empty files
|
||||
system_var = SystemVariable(files=[], workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test files handling
|
||||
assert read_only_view.files == ()
|
||||
assert isinstance(read_only_view.files, tuple)
|
||||
|
||||
def test_empty_datasource_info_handling(self):
|
||||
"""Test that empty datasource_info is handled correctly."""
|
||||
# Create SystemVariable with empty datasource_info
|
||||
system_var = SystemVariable(datasource_info={}, workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test datasource_info handling
|
||||
assert read_only_view.datasource_info == {}
|
||||
# Should be a copy, not the same object
|
||||
assert read_only_view.datasource_info is not system_var.datasource_info
|
||||
11
api/tests/unit_tests/models/test_base.py
Normal file
11
api/tests/unit_tests/models/test_base.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from models.base import DefaultFieldsMixin
|
||||
|
||||
|
||||
class FooModel(DefaultFieldsMixin):
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
|
||||
|
||||
def test_repr():
|
||||
foo_model = FooModel(id="test-id")
|
||||
assert repr(foo_model) == "<FooModel(id=test-id)>"
|
||||
@@ -0,0 +1,370 @@
|
||||
"""Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
_WorkflowRunError,
|
||||
)
|
||||
|
||||
|
||||
class TestDifyAPISQLAlchemyWorkflowRunRepository:
|
||||
"""Test DifyAPISQLAlchemyWorkflowRunRepository implementation."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock session."""
|
||||
return Mock(spec=Session)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(self, mock_session):
|
||||
"""Create a mock sessionmaker."""
|
||||
session_maker = Mock(spec=sessionmaker)
|
||||
|
||||
# Create a context manager mock
|
||||
context_manager = Mock()
|
||||
context_manager.__enter__ = Mock(return_value=mock_session)
|
||||
context_manager.__exit__ = Mock(return_value=None)
|
||||
session_maker.return_value = context_manager
|
||||
|
||||
# Mock session.begin() context manager
|
||||
begin_context_manager = Mock()
|
||||
begin_context_manager.__enter__ = Mock(return_value=None)
|
||||
begin_context_manager.__exit__ = Mock(return_value=None)
|
||||
mock_session.begin = Mock(return_value=begin_context_manager)
|
||||
|
||||
# Add missing session methods
|
||||
mock_session.commit = Mock()
|
||||
mock_session.rollback = Mock()
|
||||
mock_session.add = Mock()
|
||||
mock_session.delete = Mock()
|
||||
mock_session.get = Mock()
|
||||
mock_session.scalar = Mock()
|
||||
mock_session.scalars = Mock()
|
||||
|
||||
# Also support expire_on_commit parameter
|
||||
def make_session(expire_on_commit=None):
|
||||
cm = Mock()
|
||||
cm.__enter__ = Mock(return_value=mock_session)
|
||||
cm.__exit__ = Mock(return_value=None)
|
||||
return cm
|
||||
|
||||
session_maker.side_effect = make_session
|
||||
return session_maker
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self, mock_session_maker):
|
||||
"""Create repository instance with mocked dependencies."""
|
||||
|
||||
# Create a testable subclass that implements the save method
|
||||
class TestableDifyAPISQLAlchemyWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
|
||||
def __init__(self, session_maker):
|
||||
# Initialize without calling parent __init__ to avoid any instantiation issues
|
||||
self._session_maker = session_maker
|
||||
|
||||
def save(self, execution):
|
||||
"""Mock implementation of save method."""
|
||||
return None
|
||||
|
||||
# Create repository instance
|
||||
repo = TestableDifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
|
||||
|
||||
return repo
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_run(self):
|
||||
"""Create a sample WorkflowRun model."""
|
||||
workflow_run = Mock(spec=WorkflowRun)
|
||||
workflow_run.id = "workflow-run-123"
|
||||
workflow_run.tenant_id = "tenant-123"
|
||||
workflow_run.app_id = "app-123"
|
||||
workflow_run.workflow_id = "workflow-123"
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
return workflow_run
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_pause(self):
|
||||
"""Create a sample WorkflowPauseModel."""
|
||||
pause = Mock(spec=WorkflowPauseModel)
|
||||
pause.id = "pause-123"
|
||||
pause.workflow_id = "workflow-123"
|
||||
pause.workflow_run_id = "workflow-run-123"
|
||||
pause.state_object_key = "workflow-state-123.json"
|
||||
pause.resumed_at = None
|
||||
pause.created_at = datetime.now(UTC)
|
||||
return pause
|
||||
|
||||
|
||||
class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test create_workflow_pause method."""
|
||||
|
||||
def test_create_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
):
|
||||
"""Test successful workflow pause creation."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
state_owner_user_id = "user-123"
|
||||
state = '{"test": "state"}'
|
||||
|
||||
mock_session.get.return_value = sample_workflow_run
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.uuidv7") as mock_uuidv7:
|
||||
mock_uuidv7.side_effect = ["pause-123"]
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
# Act
|
||||
result = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||
assert result.id == "pause-123"
|
||||
assert result.workflow_execution_id == workflow_run_id
|
||||
|
||||
# Verify database interactions
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
|
||||
mock_storage.save.assert_called_once()
|
||||
mock_session.add.assert_called()
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_create_workflow_pause_not_found(
|
||||
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock
|
||||
):
|
||||
"""Test workflow pause creation when workflow run not found."""
|
||||
# Arrange
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="WorkflowRun not found: workflow-run-123"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
)
|
||||
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
|
||||
|
||||
def test_create_workflow_pause_invalid_status(
|
||||
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock, sample_workflow_run: Mock
|
||||
):
|
||||
"""Test workflow pause creation when workflow not in RUNNING status."""
|
||||
# Arrange
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
mock_session.get.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
)
|
||||
|
||||
|
||||
class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test resume_workflow_pause method."""
|
||||
|
||||
def test_resume_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test successful workflow pause resume."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
# Setup workflow run and pause
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
sample_workflow_run.pause = sample_workflow_pause
|
||||
sample_workflow_pause.resumed_at = None
|
||||
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now:
|
||||
mock_now.return_value = datetime.now(UTC)
|
||||
|
||||
# Act
|
||||
result = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||
assert result.id == "pause-123"
|
||||
|
||||
# Verify state transitions
|
||||
assert sample_workflow_pause.resumed_at is not None
|
||||
assert sample_workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
|
||||
# Verify database interactions
|
||||
mock_session.add.assert_called()
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_resume_workflow_pause_not_paused(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
):
|
||||
"""Test resume when workflow is not paused."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
def test_resume_workflow_pause_id_mismatch(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test resume when pause ID doesn't match."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-456" # Different ID
|
||||
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
sample_workflow_pause.id = "pause-123"
|
||||
sample_workflow_run.pause = sample_workflow_pause
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test delete_workflow_pause method."""
|
||||
|
||||
def test_delete_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test successful workflow pause deletion."""
|
||||
# Arrange
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
mock_session.get.return_value = sample_workflow_pause
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
# Act
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
|
||||
# Assert
|
||||
mock_storage.delete.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||
mock_session.delete.assert_called_once_with(sample_workflow_pause)
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_delete_workflow_pause_not_found(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
):
|
||||
"""Test delete when pause not found."""
|
||||
# Arrange
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found: pause-123"):
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test _PrivateWorkflowPauseEntity class."""
|
||||
|
||||
def test_from_models(self, sample_workflow_pause: Mock):
|
||||
"""Test creating _PrivateWorkflowPauseEntity from models."""
|
||||
# Act
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
|
||||
# Assert
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model == sample_workflow_pause
|
||||
|
||||
def test_properties(self, sample_workflow_pause: Mock):
|
||||
"""Test entity properties."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
|
||||
# Act & Assert
|
||||
assert entity.id == sample_workflow_pause.id
|
||||
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
|
||||
assert entity.resumed_at == sample_workflow_pause.resumed_at
|
||||
|
||||
def test_get_state(self, sample_workflow_pause: Mock):
|
||||
"""Test getting state from storage."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = expected_state
|
||||
|
||||
# Act
|
||||
result = entity.get_state()
|
||||
|
||||
# Assert
|
||||
assert result == expected_state
|
||||
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||
|
||||
def test_get_state_caching(self, sample_workflow_pause: Mock):
|
||||
"""Test state caching in get_state method."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = expected_state
|
||||
|
||||
# Act
|
||||
result1 = entity.get_state()
|
||||
result2 = entity.get_state() # Should use cache
|
||||
|
||||
# Assert
|
||||
assert result1 == expected_state
|
||||
assert result2 == expected_state
|
||||
mock_storage.load.assert_called_once() # Only called once due to caching
|
||||
200
api/tests/unit_tests/services/test_workflow_run_service_pause.py
Normal file
200
api/tests/unit_tests/services/test_workflow_run_service_pause.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Comprehensive unit tests for WorkflowRunService class.
|
||||
|
||||
This test suite covers all pause state management operations including:
|
||||
- Retrieving pause state for workflow runs
|
||||
- Saving pause state with file uploads
|
||||
- Marking paused workflows as resumed
|
||||
- Error handling and edge cases
|
||||
- Database transaction management
|
||||
- Repository-based approach testing
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||
from services.workflow_run_service import (
|
||||
WorkflowRunService,
|
||||
)
|
||||
|
||||
|
||||
class TestDataFactory:
|
||||
"""Factory class for creating test data objects."""
|
||||
|
||||
@staticmethod
|
||||
def create_workflow_run_mock(
|
||||
id: str = "workflow-run-123",
|
||||
tenant_id: str = "tenant-456",
|
||||
app_id: str = "app-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
status: str | WorkflowExecutionStatus = "paused",
|
||||
pause_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowRun object."""
|
||||
mock_run = MagicMock()
|
||||
mock_run.id = id
|
||||
mock_run.tenant_id = tenant_id
|
||||
mock_run.app_id = app_id
|
||||
mock_run.workflow_id = workflow_id
|
||||
mock_run.status = status
|
||||
mock_run.pause_id = pause_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_run, key, value)
|
||||
|
||||
return mock_run
|
||||
|
||||
@staticmethod
|
||||
def create_workflow_pause_mock(
|
||||
id: str = "pause-123",
|
||||
tenant_id: str = "tenant-456",
|
||||
app_id: str = "app-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
workflow_execution_id: str = "workflow-execution-123",
|
||||
state_file_id: str = "file-456",
|
||||
resumed_at: datetime | None = None,
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowPauseModel object."""
|
||||
mock_pause = MagicMock()
|
||||
mock_pause.id = id
|
||||
mock_pause.tenant_id = tenant_id
|
||||
mock_pause.app_id = app_id
|
||||
mock_pause.workflow_id = workflow_id
|
||||
mock_pause.workflow_execution_id = workflow_execution_id
|
||||
mock_pause.state_file_id = state_file_id
|
||||
mock_pause.resumed_at = resumed_at
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_pause, key, value)
|
||||
|
||||
return mock_pause
|
||||
|
||||
@staticmethod
|
||||
def create_upload_file_mock(
|
||||
id: str = "file-456",
|
||||
key: str = "upload_files/test/state.json",
|
||||
name: str = "state.json",
|
||||
tenant_id: str = "tenant-456",
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock UploadFile object."""
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = id
|
||||
mock_file.key = key
|
||||
mock_file.name = name
|
||||
mock_file.tenant_id = tenant_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_file, key, value)
|
||||
|
||||
return mock_file
|
||||
|
||||
@staticmethod
|
||||
def create_pause_entity_mock(
|
||||
pause_model: MagicMock | None = None,
|
||||
upload_file: MagicMock | None = None,
|
||||
) -> _PrivateWorkflowPauseEntity:
|
||||
"""Create a mock _PrivateWorkflowPauseEntity object."""
|
||||
if pause_model is None:
|
||||
pause_model = TestDataFactory.create_workflow_pause_mock()
|
||||
if upload_file is None:
|
||||
upload_file = TestDataFactory.create_upload_file_mock()
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file)
|
||||
|
||||
|
||||
class TestWorkflowRunService:
|
||||
"""Comprehensive unit tests for WorkflowRunService class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory(self):
|
||||
"""Create a mock session factory with proper session management."""
|
||||
mock_session = create_autospec(Session)
|
||||
|
||||
# Create a mock context manager for the session
|
||||
mock_session_cm = MagicMock()
|
||||
mock_session_cm.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session_cm.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Create a mock context manager for the transaction
|
||||
mock_transaction_cm = MagicMock()
|
||||
mock_transaction_cm.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_transaction_cm.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
mock_session.begin = MagicMock(return_value=mock_transaction_cm)
|
||||
|
||||
# Create mock factory that returns the context manager
|
||||
mock_factory = MagicMock(spec=sessionmaker)
|
||||
mock_factory.return_value = mock_session_cm
|
||||
|
||||
return mock_factory, mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_run_repository(self):
|
||||
"""Create a mock APIWorkflowRunRepository."""
|
||||
mock_repo = create_autospec(APIWorkflowRunRepository)
|
||||
return mock_repo
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_run_service(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Create WorkflowRunService instance with mocked dependencies."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(session_factory)
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_run_service_with_engine(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Create WorkflowRunService instance with Engine input."""
|
||||
mock_engine = create_autospec(Engine)
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(mock_engine)
|
||||
return service
|
||||
|
||||
# ==================== Initialization Tests ====================
|
||||
|
||||
def test_init_with_session_factory(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Test WorkflowRunService initialization with session_factory."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
assert service._session_factory == session_factory
|
||||
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
def test_init_with_engine(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Test WorkflowRunService initialization with Engine (should convert to sessionmaker)."""
|
||||
mock_engine = create_autospec(Engine)
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker:
|
||||
service = WorkflowRunService(mock_engine)
|
||||
|
||||
mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
|
||||
assert service._session_factory == session_factory
|
||||
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
def test_init_with_default_dependencies(self, mock_session_factory):
|
||||
"""Test WorkflowRunService initialization with default dependencies."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
assert service._session_factory == session_factory
|
||||
Reference in New Issue
Block a user