feat(workflow_cycle_manager): Removes redundant repository methods and adds caching (#22597)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -80,15 +80,12 @@ def real_workflow_system_variables():
|
||||
@pytest.fixture
|
||||
def mock_node_execution_repository():
|
||||
repo = MagicMock(spec=WorkflowNodeExecutionRepository)
|
||||
repo.get_by_node_execution_id.return_value = None
|
||||
repo.get_running_executions.return_value = []
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_execution_repository():
|
||||
repo = MagicMock(spec=WorkflowExecutionRepository)
|
||||
repo.get.return_value = None
|
||||
return repo
|
||||
|
||||
|
||||
@@ -217,8 +214,8 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
|
||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_run_success(
|
||||
@@ -251,11 +248,10 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
|
||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||
|
||||
# Mock get_running_executions to return an empty list
|
||||
workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = []
|
||||
# No running node executions in cache (empty cache)
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_run_failed(
|
||||
@@ -289,8 +285,8 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
|
||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||
|
||||
# Create a mock event
|
||||
event = MagicMock(spec=QueueNodeStartedEvent)
|
||||
@@ -342,8 +338,8 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Mock the repository get method to return the real execution
|
||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache["test-workflow-run-id"] = workflow_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id")
|
||||
@@ -351,11 +347,13 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
|
||||
# Verify the result
|
||||
assert result == workflow_execution
|
||||
|
||||
# Test error case
|
||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = None
|
||||
# Test error case - clear cache
|
||||
workflow_cycle_manager._workflow_execution_cache.clear()
|
||||
|
||||
# Expect an error when execution is not found
|
||||
with pytest.raises(ValueError):
|
||||
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
|
||||
|
||||
with pytest.raises(WorkflowRunNotFoundError):
|
||||
workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id")
|
||||
|
||||
|
||||
@@ -384,8 +382,8 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Mock the repository to return the node execution
|
||||
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
|
||||
# Pre-populate the cache with the node execution
|
||||
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_node_execution_success(
|
||||
@@ -414,8 +412,8 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
|
||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_run_partial_success(
|
||||
@@ -462,8 +460,8 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Mock the repository to return the node execution
|
||||
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
|
||||
# Pre-populate the cache with the node execution
|
||||
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_node_execution_failed(
|
||||
|
||||
@@ -137,37 +137,6 @@ def test_save_with_existing_tenant_id(repository, session):
|
||||
session_obj.merge.assert_called_once_with(modified_execution)
|
||||
|
||||
|
||||
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
|
||||
"""Test get_by_node_execution_id method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
|
||||
# Create a properly configured mock execution
|
||||
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
configure_mock_execution(mock_execution)
|
||||
session_obj.scalar.return_value = mock_execution
|
||||
|
||||
# Create a mock domain model to be returned by _to_domain_model
|
||||
mock_domain_model = mocker.MagicMock()
|
||||
# Mock the _to_domain_model method to return our mock domain model
|
||||
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
|
||||
|
||||
# Call method
|
||||
result = repository.get_by_node_execution_id("test-node-execution-id")
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalar.assert_called_once_with(mock_stmt)
|
||||
# Assert _to_domain_model was called with the mock execution
|
||||
repository._to_domain_model.assert_called_once_with(mock_execution)
|
||||
# Assert the result is our mock domain model
|
||||
assert result is mock_domain_model
|
||||
|
||||
|
||||
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||
"""Test get_by_workflow_run method."""
|
||||
session_obj, _ = session
|
||||
@@ -202,88 +171,6 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||
assert result[0] is mock_domain_model
|
||||
|
||||
|
||||
def test_get_running_executions(repository, session, mocker: MockerFixture):
|
||||
"""Test get_running_executions method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
|
||||
# Create a properly configured mock execution
|
||||
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
configure_mock_execution(mock_execution)
|
||||
session_obj.scalars.return_value.all.return_value = [mock_execution]
|
||||
|
||||
# Create a mock domain model to be returned by _to_domain_model
|
||||
mock_domain_model = mocker.MagicMock()
|
||||
# Mock the _to_domain_model method to return our mock domain model
|
||||
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
|
||||
|
||||
# Call method
|
||||
result = repository.get_running_executions("test-workflow-run-id")
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||
# Assert _to_domain_model was called with the mock execution
|
||||
repository._to_domain_model.assert_called_once_with(mock_execution)
|
||||
# Assert the result contains our mock domain model
|
||||
assert len(result) == 1
|
||||
assert result[0] is mock_domain_model
|
||||
|
||||
|
||||
def test_update_via_save(repository, session):
|
||||
"""Test updating an existing record via save method."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution
|
||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
execution.tenant_id = None
|
||||
execution.app_id = None
|
||||
execution.inputs = None
|
||||
execution.process_data = None
|
||||
execution.outputs = None
|
||||
execution.metadata = None
|
||||
|
||||
# Mock the to_db_model method to return the execution itself
|
||||
# This simulates the behavior of setting tenant_id and app_id
|
||||
repository.to_db_model = MagicMock(return_value=execution)
|
||||
|
||||
# Call save method to update an existing record
|
||||
repository.save(execution)
|
||||
|
||||
# Assert to_db_model was called with the execution
|
||||
repository.to_db_model.assert_called_once_with(execution)
|
||||
|
||||
# Assert session.merge was called (for updates)
|
||||
session_obj.merge.assert_called_once_with(execution)
|
||||
|
||||
|
||||
def test_clear(repository, session, mocker: MockerFixture):
|
||||
"""Test clear method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_delete = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.delete")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_delete.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
|
||||
# Mock the execute result with rowcount
|
||||
mock_result = mocker.MagicMock()
|
||||
mock_result.rowcount = 5 # Simulate 5 records deleted
|
||||
session_obj.execute.return_value = mock_result
|
||||
|
||||
# Call method
|
||||
repository.clear()
|
||||
|
||||
# Assert delete was called with correct parameters
|
||||
mock_delete.assert_called_once_with(WorkflowNodeExecutionModel)
|
||||
mock_stmt.where.assert_called()
|
||||
session_obj.execute.assert_called_once_with(mock_stmt)
|
||||
session_obj.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_to_db_model(repository):
|
||||
"""Test to_db_model method."""
|
||||
# Create a domain model
|
||||
|
||||
Reference in New Issue
Block a user