feat(api/repo): Allow to config repository implementation (#21458)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
-LAN-
2025-07-14 14:54:38 +08:00
committed by GitHub
parent b27c540379
commit 6eb155ae69
38 changed files with 2361 additions and 329 deletions

View File

@@ -0,0 +1 @@
# Unit tests for core repositories module

View File

@@ -0,0 +1,455 @@
"""
Unit tests for the RepositoryFactory.
This module tests the factory pattern implementation for creating repository instances
based on configuration, including error handling and validation.
"""
from unittest.mock import MagicMock, patch
import pytest
from pytest_mock import MockerFixture
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowNodeExecutionTriggeredFrom
class TestRepositoryFactory:
"""Test cases for RepositoryFactory."""
def test_import_class_success(self):
"""Test successful class import."""
# Test importing a real class
class_path = "unittest.mock.MagicMock"
result = DifyCoreRepositoryFactory._import_class(class_path)
assert result is MagicMock
def test_import_class_invalid_path(self):
"""Test import with invalid module path."""
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._import_class("invalid.module.path")
assert "Cannot import repository class" in str(exc_info.value)
def test_import_class_invalid_class_name(self):
"""Test import with invalid class name."""
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass")
assert "Cannot import repository class" in str(exc_info.value)
def test_import_class_malformed_path(self):
"""Test import with malformed path (no dots)."""
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._import_class("invalidpath")
assert "Cannot import repository class" in str(exc_info.value)
def test_validate_repository_interface_success(self):
"""Test successful interface validation."""
# Create a mock class that implements the required methods
class MockRepository:
def save(self):
pass
def get_by_id(self):
pass
# Create a mock interface with the same methods
class MockInterface:
def save(self):
pass
def get_by_id(self):
pass
# Should not raise an exception
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_repository_interface_missing_methods(self):
"""Test interface validation with missing methods."""
# Create a mock class that doesn't implement all required methods
class IncompleteRepository:
def save(self):
pass
# Missing get_by_id method
# Create a mock interface with required methods
class MockInterface:
def save(self):
pass
def get_by_id(self):
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
assert "does not implement required methods" in str(exc_info.value)
assert "get_by_id" in str(exc_info.value)
def test_validate_constructor_signature_success(self):
"""Test successful constructor signature validation."""
class MockRepository:
def __init__(self, session_factory, user, app_id, triggered_from):
pass
# Should not raise an exception
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
def test_validate_constructor_signature_missing_params(self):
"""Test constructor validation with missing parameters."""
class IncompleteRepository:
def __init__(self, session_factory, user):
# Missing app_id and triggered_from parameters
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(
IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
assert "does not accept required parameters" in str(exc_info.value)
assert "app_id" in str(exc_info.value)
assert "triggered_from" in str(exc_info.value)
def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture):
"""Test constructor validation when inspection fails."""
# Mock inspect.signature to raise an exception
mocker.patch("inspect.signature", side_effect=Exception("Inspection failed"))
class MockRepository:
def __init__(self, session_factory):
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"])
assert "Failed to validate constructor signature" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture):
"""Test successful creation of WorkflowExecutionRepository."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
# Mock the imported class to be a valid repository
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
# Mock the validation methods
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id=app_id,
triggered_from=triggered_from,
)
# Verify the repository was created with correct parameters
mock_repository_class.assert_called_once_with(
session_factory=mock_session_factory,
user=mock_user,
app_id=app_id,
triggered_from=triggered_from,
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_import_error(self, mock_config):
"""Test WorkflowExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Cannot import repository class" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
"""Test WorkflowExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
# Mock import to succeed but validation to fail
mock_repository_class = MagicMock()
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Interface validation failed" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture):
"""Test WorkflowExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
# Mock import and validation to succeed but instantiation to fail
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture):
"""Test successful creation of WorkflowNodeExecutionRepository."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
app_id = "test-app-id"
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
# Mock the imported class to be a valid repository
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
# Mock the validation methods
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id=app_id,
triggered_from=triggered_from,
)
# Verify the repository was created with correct parameters
mock_repository_class.assert_called_once_with(
session_factory=mock_session_factory,
user=mock_user,
app_id=app_id,
triggered_from=triggered_from,
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert "Cannot import repository class" in str(exc_info.value)
def test_repository_import_error_exception(self):
"""Test RepositoryImportError exception."""
error_message = "Test error message"
exception = RepositoryImportError(error_message)
assert str(exception) == error_message
assert isinstance(exception, Exception)
@patch("core.repositories.factory.dify_config")
def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture):
"""Test repository creation with Engine instead of sessionmaker."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies with Engine instead of sessionmaker
mock_engine = MagicMock(spec=Engine)
mock_user = MagicMock(spec=Account)
# Mock the imported class to be a valid repository
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
# Mock the validation methods
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_engine, # Using Engine instead of sessionmaker
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Verify the repository was created with the Engine
mock_repository_class.assert_called_once_with(
session_factory=mock_engine,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_validation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Mock import to succeed but validation to fail
mock_repository_class = MagicMock()
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert "Interface validation failed" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Mock import and validation to succeed but instantiation to fail
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
def test_validate_repository_interface_with_private_methods(self):
"""Test interface validation ignores private methods."""
# Create a mock class with private methods
class MockRepository:
def save(self):
pass
def get_by_id(self):
pass
def _private_method(self):
pass
# Create a mock interface with private methods
class MockInterface:
def save(self):
pass
def get_by_id(self):
pass
def _private_method(self):
pass
# Should not raise an exception (private methods are ignored)
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_constructor_signature_with_extra_params(self):
"""Test constructor validation with extra parameters (should pass)."""
class MockRepository:
def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None):
pass
# Should not raise an exception (extra parameters are allowed)
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
def test_validate_constructor_signature_with_kwargs(self):
"""Test constructor validation with **kwargs (current implementation doesn't support this)."""
class MockRepository:
def __init__(self, session_factory, user, **kwargs):
pass
# Current implementation doesn't handle **kwargs, so this should raise an exception
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
assert "does not accept required parameters" in str(exc_info.value)
assert "app_id" in str(exc_info.value)
assert "triggered_from" in str(exc_info.value)

View File

@@ -10,7 +10,8 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
@pytest.fixture
def workflow_setup():
workflow_service = WorkflowService()
mock_session_maker = MagicMock()
workflow_service = WorkflowService(mock_session_maker)
session = MagicMock(spec=Session)
tenant_id = "test-tenant-id"
workflow_id = "test-workflow-id"

View File

@@ -1,14 +1,14 @@
import dataclasses
import secrets
from unittest import mock
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy import Engine
from sqlalchemy.orm import Session
from core.variables import StringSegment
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType
from core.workflow.nodes.enums import NodeType
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
from services.workflow_draft_variable_service import (
@@ -18,13 +18,25 @@ from services.workflow_draft_variable_service import (
)
@pytest.fixture
def mock_engine() -> Engine:
return Mock(spec=Engine)
@pytest.fixture
def mock_session(mock_engine) -> Session:
mock_session = Mock(spec=Session)
mock_session.get_bind.return_value = mock_engine
return mock_session
class TestDraftVariableSaver:
def _get_test_app_id(self):
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def test__should_variable_be_visible(self):
mock_session = mock.MagicMock(spec=Session)
mock_session = MagicMock(spec=Session)
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
@@ -70,7 +82,7 @@ class TestDraftVariableSaver:
),
]
mock_session = mock.MagicMock(spec=Session)
mock_session = MagicMock(spec=Session)
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
@@ -105,9 +117,8 @@ class TestWorkflowDraftVariableService:
conversation_variables=[],
)
def test_reset_conversation_variable(self):
def test_reset_conversation_variable(self, mock_session):
"""Test resetting a conversation variable"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -131,9 +142,8 @@ class TestWorkflowDraftVariableService:
mock_reset_conv.assert_called_once_with(workflow, variable)
assert result == expected_result
def test_reset_node_variable_with_no_execution_id(self):
def test_reset_node_variable_with_no_execution_id(self, mock_session):
"""Test resetting a node variable with no execution ID - should delete variable"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -158,11 +168,26 @@ class TestWorkflowDraftVariableService:
mock_session.flush.assert_called_once()
assert result is None
def test_reset_node_variable_with_missing_execution_record(self):
def test_reset_node_variable_with_missing_execution_record(
self,
mock_engine,
mock_session,
monkeypatch,
):
"""Test resetting a node variable when execution record doesn't exist"""
mock_session = Mock(spec=Session)
mock_repo_session = Mock(spec=Session)
mock_session_maker = MagicMock()
# Mock the context manager protocol for sessionmaker
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
mock_session_maker.return_value.__exit__.return_value = None
monkeypatch.setattr("services.workflow_draft_variable_service.sessionmaker", mock_session_maker)
service = WorkflowDraftVariableService(mock_session)
# Mock the repository to return None (no execution record found)
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = None
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
@@ -171,24 +196,41 @@ class TestWorkflowDraftVariableService:
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
# Mock session.scalars to return None (no execution record found)
mock_scalars = Mock()
mock_scalars.first.return_value = None
mock_session.scalars.return_value = mock_scalars
# Variable is editable by default from factory method
result = service._reset_node_var_or_sys_var(workflow, variable)
mock_session_maker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
# Should delete the variable and return None
mock_session.delete.assert_called_once_with(instance=variable)
mock_session.flush.assert_called_once()
assert result is None
def test_reset_node_variable_with_valid_execution_record(self):
def test_reset_node_variable_with_valid_execution_record(
self,
mock_session,
monkeypatch,
):
"""Test resetting a node variable with valid execution record - should restore from execution"""
mock_session = Mock(spec=Session)
mock_repo_session = Mock(spec=Session)
mock_session_maker = MagicMock()
# Mock the context manager protocol for sessionmaker
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
mock_session_maker.return_value.__exit__.return_value = None
mock_session_maker = monkeypatch.setattr(
"services.workflow_draft_variable_service.sessionmaker", mock_session_maker
)
service = WorkflowDraftVariableService(mock_session)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"test_var": "output_value"}
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
@@ -197,16 +239,7 @@ class TestWorkflowDraftVariableService:
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.process_data_dict = {"test_var": "process_value"}
mock_execution.outputs_dict = {"test_var": "output_value"}
# Mock session.scalars to return the execution record
mock_scalars = Mock()
mock_scalars.first.return_value = mock_execution
mock_session.scalars.return_value = mock_scalars
# Variable is editable by default from factory method
# Mock workflow methods
mock_node_config = {"type": "test_node"}
@@ -224,9 +257,8 @@ class TestWorkflowDraftVariableService:
# Should return the updated variable
assert result == variable
def test_reset_non_editable_system_variable_raises_error(self):
def test_reset_non_editable_system_variable_raises_error(self, mock_session):
"""Test that resetting a non-editable system variable raises an error"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -242,24 +274,13 @@ class TestWorkflowDraftVariableService:
editable=False, # Non-editable system variable
)
# Mock the service to properly check system variable editability
with patch.object(service, "reset_variable") as mock_reset:
with pytest.raises(VariableResetError) as exc_info:
service.reset_variable(workflow, variable)
assert "cannot reset system variable" in str(exc_info.value)
assert f"variable_id={variable.id}" in str(exc_info.value)
def side_effect(wf, var):
if var.get_variable_type() == DraftVariableType.SYS and not is_system_variable_editable(var.name):
raise VariableResetError(f"cannot reset system variable, variable_id={var.id}")
return var
mock_reset.side_effect = side_effect
with pytest.raises(VariableResetError) as exc_info:
service.reset_variable(workflow, variable)
assert "cannot reset system variable" in str(exc_info.value)
assert f"variable_id={variable.id}" in str(exc_info.value)
def test_reset_editable_system_variable_succeeds(self):
def test_reset_editable_system_variable_succeeds(self, mock_session):
"""Test that resetting an editable system variable succeeds"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -279,10 +300,9 @@ class TestWorkflowDraftVariableService:
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"sys.files": "[]"}
# Mock session.scalars to return the execution record
mock_scalars = Mock()
mock_scalars.first.return_value = mock_execution
mock_session.scalars.return_value = mock_scalars
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
result = service._reset_node_var_or_sys_var(workflow, variable)
@@ -291,9 +311,8 @@ class TestWorkflowDraftVariableService:
assert variable.last_edited_at is None
mock_session.flush.assert_called()
def test_reset_query_system_variable_succeeds(self):
def test_reset_query_system_variable_succeeds(self, mock_session):
"""Test that resetting query system variable (another editable one) succeeds"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -313,10 +332,9 @@ class TestWorkflowDraftVariableService:
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"sys.query": "reset query"}
# Mock session.scalars to return the execution record
mock_scalars = Mock()
mock_scalars.first.return_value = mock_execution
mock_session.scalars.return_value = mock_scalars
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
result = service._reset_node_var_or_sys_var(workflow, variable)

View File

@@ -0,0 +1,288 @@
from datetime import datetime
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from models.workflow import WorkflowNodeExecutionModel
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
@pytest.fixture
def repository(self):
mock_session_maker = MagicMock()
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
@pytest.fixture
def mock_execution(self):
execution = MagicMock(spec=WorkflowNodeExecutionModel)
execution.id = str(uuid4())
execution.tenant_id = "tenant-123"
execution.app_id = "app-456"
execution.workflow_id = "workflow-789"
execution.workflow_run_id = "run-101"
execution.node_id = "node-202"
execution.index = 1
execution.created_at = "2023-01-01T00:00:00Z"
return execution
def test_get_node_last_execution_found(self, repository, mock_execution):
"""Test getting the last execution for a node when it exists."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = mock_execution
# Act
result = repository.get_node_last_execution(
tenant_id="tenant-123",
app_id="app-456",
workflow_id="workflow-789",
node_id="node-202",
)
# Assert
assert result == mock_execution
mock_session.scalar.assert_called_once()
# Verify the query was constructed correctly
call_args = mock_session.scalar.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
def test_get_node_last_execution_not_found(self, repository):
"""Test getting the last execution for a node when it doesn't exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = None
# Act
result = repository.get_node_last_execution(
tenant_id="tenant-123",
app_id="app-456",
workflow_id="workflow-789",
node_id="node-202",
)
# Assert
assert result is None
mock_session.scalar.assert_called_once()
def test_get_executions_by_workflow_run(self, repository, mock_execution):
"""Test getting all executions for a workflow run."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
executions = [mock_execution]
mock_session.execute.return_value.scalars.return_value.all.return_value = executions
# Act
result = repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-456",
workflow_run_id="run-101",
)
# Assert
assert result == executions
mock_session.execute.assert_called_once()
# Verify the query was constructed correctly
call_args = mock_session.execute.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
def test_get_executions_by_workflow_run_empty(self, repository):
"""Test getting executions for a workflow run when none exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalars.return_value.all.return_value = []
# Act
result = repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-456",
workflow_run_id="run-101",
)
# Assert
assert result == []
mock_session.execute.assert_called_once()
def test_get_execution_by_id_found(self, repository, mock_execution):
"""Test getting execution by ID when it exists."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = mock_execution
# Act
result = repository.get_execution_by_id(mock_execution.id)
# Assert
assert result == mock_execution
mock_session.scalar.assert_called_once()
def test_get_execution_by_id_not_found(self, repository):
"""Test getting execution by ID when it doesn't exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = None
# Act
result = repository.get_execution_by_id("non-existent-id")
# Assert
assert result is None
mock_session.scalar.assert_called_once()
def test_repository_implements_protocol(self, repository):
"""Test that the repository implements the required protocol methods."""
# Verify all protocol methods are implemented
assert hasattr(repository, "get_node_last_execution")
assert hasattr(repository, "get_executions_by_workflow_run")
assert hasattr(repository, "get_execution_by_id")
# Verify methods are callable
assert callable(repository.get_node_last_execution)
assert callable(repository.get_executions_by_workflow_run)
assert callable(repository.get_execution_by_id)
assert callable(repository.delete_expired_executions)
assert callable(repository.delete_executions_by_app)
assert callable(repository.get_expired_executions_batch)
assert callable(repository.delete_executions_by_ids)
def test_delete_expired_executions(self, repository):
"""Test deleting expired executions."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
before_date = datetime(2023, 1, 1)
# Act
result = repository.delete_expired_executions(
tenant_id="tenant-123",
before_date=before_date,
batch_size=1000,
)
# Assert
assert result == 2
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_delete_executions_by_app(self, repository):
"""Test deleting executions by app."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"]
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
# Act
result = repository.delete_executions_by_app(
tenant_id="tenant-123",
app_id="app-456",
batch_size=1000,
)
# Assert
assert result == 2
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_get_expired_executions_batch(self, repository):
"""Test getting expired executions batch for backup."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Create mock execution objects
mock_execution1 = MagicMock()
mock_execution1.id = "exec-1"
mock_execution2 = MagicMock()
mock_execution2.id = "exec-2"
mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
before_date = datetime(2023, 1, 1)
# Act
result = repository.get_expired_executions_batch(
tenant_id="tenant-123",
before_date=before_date,
batch_size=1000,
)
# Assert
assert len(result) == 2
assert result[0].id == "exec-1"
assert result[1].id == "exec-2"
mock_session.execute.assert_called_once()
def test_delete_executions_by_ids(self, repository):
"""Test deleting executions by IDs."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the delete query result
mock_result = MagicMock()
mock_result.rowcount = 3
mock_session.execute.return_value = mock_result
execution_ids = ["id1", "id2", "id3"]
# Act
result = repository.delete_executions_by_ids(execution_ids)
# Assert
assert result == 3
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_delete_executions_by_ids_empty_list(self, repository):
"""Test deleting executions with empty ID list."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Act
result = repository.delete_executions_by_ids([])
# Assert
assert result == 0
mock_session.query.assert_not_called()
mock_session.commit.assert_not_called()

View File

@@ -10,7 +10,8 @@ from services.workflow_service import WorkflowService
class TestWorkflowService:
@pytest.fixture
def workflow_service(self):
return WorkflowService()
mock_session_maker = MagicMock()
return WorkflowService(mock_session_maker)
@pytest.fixture
def mock_app(self):