feat(api): Introduce workflow pause state management (#27298)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,278 @@
|
||||
import json
|
||||
from time import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||
from core.workflow.graph_events.graph import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class TestDataFactory:
|
||||
"""Factory helpers for constructing graph events used in tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
|
||||
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_started_event() -> GraphRunStartedEvent:
|
||||
return GraphRunStartedEvent()
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_succeeded_event(outputs: dict[str, object] | None = None) -> GraphRunSucceededEvent:
|
||||
return GraphRunSucceededEvent(outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_failed_event(
|
||||
error: str = "Test error",
|
||||
exceptions_count: int = 1,
|
||||
) -> GraphRunFailedEvent:
|
||||
return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count)
|
||||
|
||||
|
||||
class MockSystemVariableReadOnlyView:
|
||||
"""Minimal read-only system variable view for testing."""
|
||||
|
||||
def __init__(self, workflow_execution_id: str | None = None) -> None:
|
||||
self._workflow_execution_id = workflow_execution_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str | None:
|
||||
return self._workflow_execution_id
|
||||
|
||||
|
||||
class MockReadOnlyVariablePool:
|
||||
"""Mock implementation of ReadOnlyVariablePool for testing."""
|
||||
|
||||
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
|
||||
self._variables = variables or {}
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
value = self._variables.get((node_id, variable_key))
|
||||
if value is None:
|
||||
return None
|
||||
mock_segment = Mock(spec=Segment)
|
||||
mock_segment.value = value
|
||||
return mock_segment
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> dict[str, object]:
|
||||
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> dict[str, object]:
|
||||
return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)}
|
||||
|
||||
|
||||
class MockReadOnlyGraphRuntimeState:
|
||||
"""Mock implementation of ReadOnlyGraphRuntimeState for testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_at: float | None = None,
|
||||
total_tokens: int = 0,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue_size: int = 0,
|
||||
exceptions_count: int = 0,
|
||||
outputs: dict[str, object] | None = None,
|
||||
variables: dict[tuple[str, str], object] | None = None,
|
||||
workflow_execution_id: str | None = None,
|
||||
):
|
||||
self._start_at = start_at or time()
|
||||
self._total_tokens = total_tokens
|
||||
self._node_run_steps = node_run_steps
|
||||
self._ready_queue_size = ready_queue_size
|
||||
self._exceptions_count = exceptions_count
|
||||
self._outputs = outputs or {}
|
||||
self._variable_pool = MockReadOnlyVariablePool(variables)
|
||||
self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id)
|
||||
|
||||
@property
|
||||
def system_variable(self) -> MockSystemVariableReadOnlyView:
|
||||
return self._system_variable
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||
return self._variable_pool
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
return self._start_at
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self._total_tokens
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
return self._node_run_steps
|
||||
|
||||
@property
|
||||
def ready_queue_size(self) -> int:
|
||||
return self._ready_queue_size
|
||||
|
||||
@property
|
||||
def exceptions_count(self) -> int:
|
||||
return self._exceptions_count
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, object]:
|
||||
return self._outputs.copy()
|
||||
|
||||
@property
|
||||
def llm_usage(self):
|
||||
mock_usage = Mock()
|
||||
mock_usage.prompt_tokens = 10
|
||||
mock_usage.completion_tokens = 20
|
||||
mock_usage.total_tokens = 30
|
||||
return mock_usage
|
||||
|
||||
def get_output(self, key: str, default: object = None) -> object:
|
||||
return self._outputs.get(key, default)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"start_at": self._start_at,
|
||||
"total_tokens": self._total_tokens,
|
||||
"node_run_steps": self._node_run_steps,
|
||||
"ready_queue_size": self._ready_queue_size,
|
||||
"exceptions_count": self._exceptions_count,
|
||||
"outputs": self._outputs,
|
||||
"variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()},
|
||||
"workflow_execution_id": self._system_variable.workflow_execution_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class MockCommandChannel:
|
||||
"""Mock implementation of CommandChannel for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._commands: list[GraphEngineCommand] = []
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
return self._commands.copy()
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
self._commands.append(command)
|
||||
|
||||
|
||||
class TestPauseStatePersistenceLayer:
|
||||
"""Unit tests for PauseStatePersistenceLayer."""
|
||||
|
||||
def test_init_with_dependency_injection(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
state_owner_user_id = "user-123"
|
||||
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
)
|
||||
|
||||
assert layer._session_maker is session_factory
|
||||
assert layer._state_owner_user_id == state_owner_user_id
|
||||
assert not hasattr(layer, "graph_runtime_state")
|
||||
assert not hasattr(layer, "command_channel")
|
||||
|
||||
def test_initialize_sets_dependencies(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner")
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||
command_channel = MockCommandChannel()
|
||||
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
assert layer.graph_runtime_state is graph_runtime_state
|
||||
assert layer.command_channel is command_channel
|
||||
|
||||
def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState(
|
||||
outputs={"result": "test_output"},
|
||||
total_tokens=100,
|
||||
workflow_execution_id="run-123",
|
||||
)
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event(outputs={"intermediate": "result"})
|
||||
expected_state = graph_runtime_state.dumps()
|
||||
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_called_once_with(session_factory)
|
||||
mock_repo.create_workflow_pause.assert_called_once_with(
|
||||
workflow_run_id="run-123",
|
||||
state_owner_user_id="owner-123",
|
||||
state=expected_state,
|
||||
)
|
||||
|
||||
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
events = [
|
||||
TestDataFactory.create_graph_run_started_event(),
|
||||
TestDataFactory.create_graph_run_succeeded_event(),
|
||||
TestDataFactory.create_graph_run_failed_event(),
|
||||
]
|
||||
|
||||
for event in events:
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_not_called()
|
||||
mock_repo.create_workflow_pause.assert_not_called()
|
||||
|
||||
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event()
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
layer.on_event(event)
|
||||
|
||||
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState(workflow_execution_id=None)
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_not_called()
|
||||
mock_repo.create_workflow_pause.assert_not_called()
|
||||
@@ -0,0 +1,171 @@
|
||||
"""Tests for _PrivateWorkflowPauseEntity implementation."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity:
|
||||
"""Test _PrivateWorkflowPauseEntity implementation."""
|
||||
|
||||
def test_entity_initialization(self):
|
||||
"""Test entity initialization with required parameters."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
# Create entity
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert entity._pause_model is mock_pause_model
|
||||
assert entity._cached_state is None
|
||||
|
||||
def test_from_models_classmethod(self):
|
||||
"""Test from_models class method."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
# Create entity using from_models
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(
|
||||
workflow_pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify entity creation
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model is mock_pause_model
|
||||
|
||||
def test_id_property(self):
|
||||
"""Test id property returns pause model ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.id == "pause-123"
|
||||
|
||||
def test_workflow_execution_id_property(self):
|
||||
"""Test workflow_execution_id property returns workflow run ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.workflow_execution_id == "execution-456"
|
||||
|
||||
def test_resumed_at_property(self):
|
||||
"""Test resumed_at property returns pause model resumed_at."""
|
||||
resumed_at = datetime(2023, 12, 25, 15, 30, 45)
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = resumed_at
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.resumed_at == resumed_at
|
||||
|
||||
def test_resumed_at_property_none(self):
|
||||
"""Test resumed_at property returns None when not set."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.resumed_at is None
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_first_call(self, mock_storage):
|
||||
"""Test get_state loads from storage on first call."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
mock_storage.load.return_value = state_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# First call should load from storage
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == state_data
|
||||
mock_storage.load.assert_called_once_with("test-state-key")
|
||||
assert entity._cached_state == state_data
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_cached_call(self, mock_storage):
|
||||
"""Test get_state returns cached data on subsequent calls."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
mock_storage.load.return_value = state_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# First call
|
||||
result1 = entity.get_state()
|
||||
# Second call should use cache
|
||||
result2 = entity.get_state()
|
||||
|
||||
assert result1 == state_data
|
||||
assert result2 == state_data
|
||||
# Storage should only be called once
|
||||
mock_storage.load.assert_called_once_with("test-state-key")
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_with_pre_cached_data(self, mock_storage):
|
||||
"""Test get_state returns pre-cached data."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Pre-cache data
|
||||
entity._cached_state = state_data
|
||||
|
||||
# Should return cached data without calling storage
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == state_data
|
||||
mock_storage.load.assert_not_called()
|
||||
|
||||
def test_entity_with_binary_state_data(self):
|
||||
"""Test entity with binary state data."""
|
||||
# Test with binary data that's not valid JSON
|
||||
binary_data = b"\x00\x01\x02\x03\x04\x05\xff\xfe"
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = binary_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == binary_data
|
||||
@@ -3,6 +3,7 @@
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
@@ -149,8 +150,8 @@ def test_pause_command():
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
|
||||
assert len(pause_events) == 1
|
||||
assert pause_events[0].reason == "User requested pause"
|
||||
assert pause_events[0].reason == SchedulingPause(message="User requested pause")
|
||||
|
||||
graph_execution = engine.graph_runtime_state.graph_execution
|
||||
assert graph_execution.is_paused
|
||||
assert graph_execution.pause_reason == "User requested pause"
|
||||
assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")
|
||||
|
||||
32
api/tests/unit_tests/core/workflow/test_enums.py
Normal file
32
api/tests/unit_tests/core/workflow/test_enums.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Tests for workflow pause related enums and constants."""
|
||||
|
||||
from core.workflow.enums import (
|
||||
WorkflowExecutionStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowExecutionStatus:
|
||||
"""Test WorkflowExecutionStatus enum."""
|
||||
|
||||
def test_is_ended_method(self):
|
||||
"""Test is_ended method for different statuses."""
|
||||
# Test ended statuses
|
||||
ended_statuses = [
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
WorkflowExecutionStatus.STOPPED,
|
||||
]
|
||||
|
||||
for status in ended_statuses:
|
||||
assert status.is_ended(), f"{status} should be considered ended"
|
||||
|
||||
# Test non-ended statuses
|
||||
non_ended_statuses = [
|
||||
WorkflowExecutionStatus.SCHEDULED,
|
||||
WorkflowExecutionStatus.RUNNING,
|
||||
WorkflowExecutionStatus.PAUSED,
|
||||
]
|
||||
|
||||
for status in non_ended_statuses:
|
||||
assert not status.is_ended(), f"{status} should not be considered ended"
|
||||
@@ -0,0 +1,202 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file.models import File, FileTransferMethod, FileType
|
||||
from core.workflow.system_variable import SystemVariable, SystemVariableReadOnlyView
|
||||
|
||||
|
||||
class TestSystemVariableReadOnlyView:
|
||||
"""Test cases for SystemVariableReadOnlyView class."""
|
||||
|
||||
def test_read_only_property_access(self):
|
||||
"""Test that all properties return correct values from wrapped instance."""
|
||||
# Create test data
|
||||
test_file = File(
|
||||
id="file-123",
|
||||
tenant_id="tenant-123",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related-123",
|
||||
)
|
||||
|
||||
datasource_info = {"key": "value", "nested": {"data": 42}}
|
||||
|
||||
# Create SystemVariable with all fields
|
||||
system_var = SystemVariable(
|
||||
user_id="user-123",
|
||||
app_id="app-123",
|
||||
workflow_id="workflow-123",
|
||||
files=[test_file],
|
||||
workflow_execution_id="exec-123",
|
||||
query="test query",
|
||||
conversation_id="conv-123",
|
||||
dialogue_count=5,
|
||||
document_id="doc-123",
|
||||
original_document_id="orig-doc-123",
|
||||
dataset_id="dataset-123",
|
||||
batch="batch-123",
|
||||
datasource_type="type-123",
|
||||
datasource_info=datasource_info,
|
||||
invoke_from="invoke-123",
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test all properties
|
||||
assert read_only_view.user_id == "user-123"
|
||||
assert read_only_view.app_id == "app-123"
|
||||
assert read_only_view.workflow_id == "workflow-123"
|
||||
assert read_only_view.workflow_execution_id == "exec-123"
|
||||
assert read_only_view.query == "test query"
|
||||
assert read_only_view.conversation_id == "conv-123"
|
||||
assert read_only_view.dialogue_count == 5
|
||||
assert read_only_view.document_id == "doc-123"
|
||||
assert read_only_view.original_document_id == "orig-doc-123"
|
||||
assert read_only_view.dataset_id == "dataset-123"
|
||||
assert read_only_view.batch == "batch-123"
|
||||
assert read_only_view.datasource_type == "type-123"
|
||||
assert read_only_view.invoke_from == "invoke-123"
|
||||
|
||||
def test_defensive_copying_of_mutable_objects(self):
|
||||
"""Test that mutable objects are defensively copied."""
|
||||
# Create test data
|
||||
test_file = File(
|
||||
id="file-123",
|
||||
tenant_id="tenant-123",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related-123",
|
||||
)
|
||||
|
||||
datasource_info = {"key": "original_value"}
|
||||
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(
|
||||
files=[test_file], datasource_info=datasource_info, workflow_execution_id="exec-123"
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test files defensive copying
|
||||
files_copy = read_only_view.files
|
||||
assert isinstance(files_copy, tuple) # Should be immutable tuple
|
||||
assert len(files_copy) == 1
|
||||
assert files_copy[0].id == "file-123"
|
||||
|
||||
# Verify it's a copy (can't modify original through view)
|
||||
assert isinstance(files_copy, tuple)
|
||||
# tuples don't have append method, so they're immutable
|
||||
|
||||
# Test datasource_info defensive copying
|
||||
datasource_copy = read_only_view.datasource_info
|
||||
assert datasource_copy is not None
|
||||
assert datasource_copy["key"] == "original_value"
|
||||
|
||||
datasource_copy = cast(dict, datasource_copy)
|
||||
with pytest.raises(TypeError):
|
||||
datasource_copy["key"] = "modified value"
|
||||
|
||||
# Verify original is unchanged
|
||||
assert system_var.datasource_info is not None
|
||||
assert system_var.datasource_info["key"] == "original_value"
|
||||
assert read_only_view.datasource_info is not None
|
||||
assert read_only_view.datasource_info["key"] == "original_value"
|
||||
|
||||
def test_always_accesses_latest_data(self):
|
||||
"""Test that properties always return the latest data from wrapped instance."""
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(user_id="original-user", workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Verify initial value
|
||||
assert read_only_view.user_id == "original-user"
|
||||
|
||||
# Modify the wrapped instance
|
||||
system_var.user_id = "modified-user"
|
||||
|
||||
# Verify view returns the new value
|
||||
assert read_only_view.user_id == "modified-user"
|
||||
|
||||
def test_repr_method(self):
|
||||
"""Test the __repr__ method."""
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test repr
|
||||
repr_str = repr(read_only_view)
|
||||
assert "SystemVariableReadOnlyView" in repr_str
|
||||
assert "system_variable=" in repr_str
|
||||
|
||||
def test_none_value_handling(self):
|
||||
"""Test that None values are properly handled."""
|
||||
# Create SystemVariable with all None values except workflow_execution_id
|
||||
system_var = SystemVariable(
|
||||
user_id=None,
|
||||
app_id=None,
|
||||
workflow_id=None,
|
||||
workflow_execution_id="exec-123",
|
||||
query=None,
|
||||
conversation_id=None,
|
||||
dialogue_count=None,
|
||||
document_id=None,
|
||||
original_document_id=None,
|
||||
dataset_id=None,
|
||||
batch=None,
|
||||
datasource_type=None,
|
||||
datasource_info=None,
|
||||
invoke_from=None,
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test all None values
|
||||
assert read_only_view.user_id is None
|
||||
assert read_only_view.app_id is None
|
||||
assert read_only_view.workflow_id is None
|
||||
assert read_only_view.query is None
|
||||
assert read_only_view.conversation_id is None
|
||||
assert read_only_view.dialogue_count is None
|
||||
assert read_only_view.document_id is None
|
||||
assert read_only_view.original_document_id is None
|
||||
assert read_only_view.dataset_id is None
|
||||
assert read_only_view.batch is None
|
||||
assert read_only_view.datasource_type is None
|
||||
assert read_only_view.datasource_info is None
|
||||
assert read_only_view.invoke_from is None
|
||||
|
||||
# files should be empty tuple even when default list is empty
|
||||
assert read_only_view.files == ()
|
||||
|
||||
def test_empty_files_handling(self):
|
||||
"""Test that empty files list is handled correctly."""
|
||||
# Create SystemVariable with empty files
|
||||
system_var = SystemVariable(files=[], workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test files handling
|
||||
assert read_only_view.files == ()
|
||||
assert isinstance(read_only_view.files, tuple)
|
||||
|
||||
def test_empty_datasource_info_handling(self):
|
||||
"""Test that empty datasource_info is handled correctly."""
|
||||
# Create SystemVariable with empty datasource_info
|
||||
system_var = SystemVariable(datasource_info={}, workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test datasource_info handling
|
||||
assert read_only_view.datasource_info == {}
|
||||
# Should be a copy, not the same object
|
||||
assert read_only_view.datasource_info is not system_var.datasource_info
|
||||
Reference in New Issue
Block a user