feat(graph_engine): Support pausing workflow graph executions (#26585)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -5,12 +5,13 @@ import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
@@ -5,10 +5,11 @@ from urllib.parse import urlencode
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
|
||||
@@ -174,13 +175,13 @@ def test_custom_authorization_header(setup_http_mock):
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
|
||||
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeData,
|
||||
HttpRequestNodeTimeout,
|
||||
)
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
# Create variable pool
|
||||
|
||||
@@ -6,12 +6,13 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
|
||||
@@ -5,11 +5,12 @@ from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.model_runtime.entities import AssistantPromptMessage
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
|
||||
@@ -4,11 +4,12 @@ import uuid
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
@@ -4,12 +4,13 @@ from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@@ -99,6 +99,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
@@ -237,6 +239,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
@@ -390,6 +394,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def _make_state(workflow_run_id: str | None) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=workflow_run_id))
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
|
||||
|
||||
class _StubPipeline(GraphRuntimeStateSupport):
|
||||
def __init__(self, *, cached_state: GraphRuntimeState | None, queue_state: GraphRuntimeState | None):
|
||||
self._graph_runtime_state = cached_state
|
||||
self._base_task_pipeline = SimpleNamespace(queue_manager=SimpleNamespace(graph_runtime_state=queue_state))
|
||||
|
||||
|
||||
def test_ensure_graph_runtime_initialized_caches_explicit_state():
|
||||
explicit_state = _make_state("run-explicit")
|
||||
pipeline = _StubPipeline(cached_state=None, queue_state=None)
|
||||
|
||||
resolved = pipeline._ensure_graph_runtime_initialized(explicit_state)
|
||||
|
||||
assert resolved is explicit_state
|
||||
assert pipeline._graph_runtime_state is explicit_state
|
||||
|
||||
|
||||
def test_resolve_graph_runtime_state_reads_from_queue_when_cache_empty():
|
||||
queued_state = _make_state("run-queue")
|
||||
pipeline = _StubPipeline(cached_state=None, queue_state=queued_state)
|
||||
|
||||
resolved = pipeline._resolve_graph_runtime_state()
|
||||
|
||||
assert resolved is queued_state
|
||||
assert pipeline._graph_runtime_state is queued_state
|
||||
|
||||
|
||||
def test_resolve_graph_runtime_state_raises_when_no_state_available():
|
||||
pipeline = _StubPipeline(cached_state=None, queue_state=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
pipeline._resolve_graph_runtime_state()
|
||||
|
||||
|
||||
def test_extract_workflow_run_id_returns_value():
|
||||
state = _make_state("run-identifier")
|
||||
pipeline = _StubPipeline(cached_state=state, queue_state=None)
|
||||
|
||||
run_id = pipeline._extract_workflow_run_id(state)
|
||||
|
||||
assert run_id == "run-identifier"
|
||||
|
||||
|
||||
def test_extract_workflow_run_id_raises_when_missing():
|
||||
state = _make_state(None)
|
||||
pipeline = _StubPipeline(cached_state=state, queue_state=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
pipeline._extract_workflow_run_id(state)
|
||||
@@ -3,8 +3,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
@@ -12,24 +11,17 @@ import pytest
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueNodeRetryEvent, QueueNodeSucceededEvent
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessDataResponseScenario:
|
||||
"""Test scenario for process_data in responses."""
|
||||
|
||||
name: str
|
||||
original_process_data: dict[str, Any] | None
|
||||
truncated_process_data: dict[str, Any] | None
|
||||
expected_response_data: dict[str, Any] | None
|
||||
expected_truncated_flag: bool
|
||||
|
||||
|
||||
class TestWorkflowResponseConverterCenarios:
|
||||
"""Test process_data truncation in WorkflowResponseConverter."""
|
||||
|
||||
@@ -39,6 +31,7 @@ class TestWorkflowResponseConverterCenarios:
|
||||
mock_app_config = Mock()
|
||||
mock_app_config.tenant_id = "test-tenant-id"
|
||||
mock_entity.app_config = mock_app_config
|
||||
mock_entity.inputs = {}
|
||||
return mock_entity
|
||||
|
||||
def create_workflow_response_converter(self) -> WorkflowResponseConverter:
|
||||
@@ -50,54 +43,59 @@ class TestWorkflowResponseConverterCenarios:
|
||||
mock_user.name = "Test User"
|
||||
mock_user.email = "test@example.com"
|
||||
|
||||
return WorkflowResponseConverter(application_generate_entity=mock_entity, user=mock_user)
|
||||
|
||||
def create_workflow_node_execution(
|
||||
self,
|
||||
process_data: dict[str, Any] | None = None,
|
||||
truncated_process_data: dict[str, Any] | None = None,
|
||||
execution_id: str = "test-execution-id",
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a WorkflowNodeExecution for testing."""
|
||||
execution = WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=process_data,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now(),
|
||||
system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id")
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=mock_entity,
|
||||
user=mock_user,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
if truncated_process_data is not None:
|
||||
execution.set_truncated_process_data(truncated_process_data)
|
||||
def create_node_started_event(self, *, node_execution_id: str | None = None) -> QueueNodeStartedEvent:
|
||||
"""Create a QueueNodeStartedEvent for testing."""
|
||||
return QueueNodeStartedEvent(
|
||||
node_execution_id=node_execution_id or str(uuid.uuid4()),
|
||||
node_id="test-node-id",
|
||||
node_title="Test Node",
|
||||
node_type=NodeType.CODE,
|
||||
start_at=naive_utc_now(),
|
||||
predecessor_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
provider_type="built-in",
|
||||
provider_id="code",
|
||||
)
|
||||
|
||||
return execution
|
||||
|
||||
def create_node_succeeded_event(self) -> QueueNodeSucceededEvent:
|
||||
def create_node_succeeded_event(
|
||||
self,
|
||||
*,
|
||||
node_execution_id: str,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
) -> QueueNodeSucceededEvent:
|
||||
"""Create a QueueNodeSucceededEvent for testing."""
|
||||
return QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
node_execution_id=node_execution_id,
|
||||
start_at=naive_utc_now(),
|
||||
parallel_id=None,
|
||||
parallel_start_node_id=None,
|
||||
parent_parallel_id=None,
|
||||
parent_parallel_start_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
inputs={},
|
||||
process_data=process_data or {},
|
||||
outputs={},
|
||||
execution_metadata={},
|
||||
)
|
||||
|
||||
def create_node_retry_event(self) -> QueueNodeRetryEvent:
|
||||
def create_node_retry_event(
|
||||
self,
|
||||
*,
|
||||
node_execution_id: str,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
) -> QueueNodeRetryEvent:
|
||||
"""Create a QueueNodeRetryEvent for testing."""
|
||||
return QueueNodeRetryEvent(
|
||||
inputs={"data": "inputs"},
|
||||
outputs={"data": "outputs"},
|
||||
process_data=process_data or {},
|
||||
error="oops",
|
||||
retry_index=1,
|
||||
node_id="test-node-id",
|
||||
@@ -105,12 +103,8 @@ class TestWorkflowResponseConverterCenarios:
|
||||
node_title="test code",
|
||||
provider_type="built-in",
|
||||
provider_id="code",
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
node_execution_id=node_execution_id,
|
||||
start_at=naive_utc_now(),
|
||||
parallel_id=None,
|
||||
parallel_start_node_id=None,
|
||||
parent_parallel_id=None,
|
||||
parent_parallel_start_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
@@ -122,15 +116,28 @@ class TestWorkflowResponseConverterCenarios:
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
execution = self.create_workflow_node_execution(
|
||||
process_data=original_data, truncated_process_data=truncated_data
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
if mapping == dict(original_data):
|
||||
return truncated_data, True
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use truncated data, not original
|
||||
@@ -145,13 +152,26 @@ class TestWorkflowResponseConverterCenarios:
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=original_data)
|
||||
event = self.create_node_succeeded_event()
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use original data
|
||||
@@ -163,18 +183,31 @@ class TestWorkflowResponseConverterCenarios:
|
||||
"""Test node finish response when process_data is None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=None)
|
||||
event = self.create_node_succeeded_event()
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=None,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should have None process_data
|
||||
# Response should normalize missing process_data to an empty mapping
|
||||
assert response is not None
|
||||
assert response.data.process_data is None
|
||||
assert response.data.process_data == {}
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_retry_response_uses_truncated_process_data(self):
|
||||
@@ -184,15 +217,28 @@ class TestWorkflowResponseConverterCenarios:
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
execution = self.create_workflow_node_execution(
|
||||
process_data=original_data, truncated_process_data=truncated_data
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
event = self.create_node_retry_event()
|
||||
|
||||
event = self.create_node_retry_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
if mapping == dict(original_data):
|
||||
return truncated_data, True
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use truncated data, not original
|
||||
@@ -207,224 +253,72 @@ class TestWorkflowResponseConverterCenarios:
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=original_data)
|
||||
event = self.create_node_retry_event()
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_retry_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use original data
|
||||
assert response is not None
|
||||
assert response.data.process_data == original_data
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_iteration_and_loop_nodes_return_none(self):
|
||||
"""Test that iteration and loop nodes return None (no change from existing behavior)."""
|
||||
"""Test that iteration and loop nodes return None (no streaming events)."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
# Test iteration node
|
||||
iteration_execution = self.create_workflow_node_execution(process_data={"test": "data"})
|
||||
iteration_execution.node_type = NodeType.ITERATION
|
||||
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=iteration_execution,
|
||||
)
|
||||
|
||||
# Should return None for iteration nodes
|
||||
assert response is None
|
||||
|
||||
# Test loop node
|
||||
loop_execution = self.create_workflow_node_execution(process_data={"test": "data"})
|
||||
loop_execution.node_type = NodeType.LOOP
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=loop_execution,
|
||||
)
|
||||
|
||||
# Should return None for loop nodes
|
||||
assert response is None
|
||||
|
||||
def test_execution_without_workflow_execution_id_returns_none(self):
|
||||
"""Test that executions without workflow_execution_id return None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data={"test": "data"})
|
||||
execution.workflow_execution_id = None # Single-step debugging
|
||||
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Should return None for single-step debugging
|
||||
assert response is None
|
||||
|
||||
@staticmethod
|
||||
def get_process_data_response_scenarios() -> list[ProcessDataResponseScenario]:
|
||||
"""Create test scenarios for process_data responses."""
|
||||
return [
|
||||
ProcessDataResponseScenario(
|
||||
name="none_process_data",
|
||||
original_process_data=None,
|
||||
truncated_process_data=None,
|
||||
expected_response_data=None,
|
||||
expected_truncated_flag=False,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="small_process_data_no_truncation",
|
||||
original_process_data={"small": "data"},
|
||||
truncated_process_data=None,
|
||||
expected_response_data={"small": "data"},
|
||||
expected_truncated_flag=False,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="large_process_data_with_truncation",
|
||||
original_process_data={"large": "x" * 10000, "metadata": "info"},
|
||||
truncated_process_data={"large": "[TRUNCATED]", "metadata": "info"},
|
||||
expected_response_data={"large": "[TRUNCATED]", "metadata": "info"},
|
||||
expected_truncated_flag=True,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="empty_process_data",
|
||||
original_process_data={},
|
||||
truncated_process_data=None,
|
||||
expected_response_data={},
|
||||
expected_truncated_flag=False,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="complex_data_with_truncation",
|
||||
original_process_data={
|
||||
"logs": ["entry"] * 1000, # Large array
|
||||
"config": {"setting": "value"},
|
||||
"status": "processing",
|
||||
},
|
||||
truncated_process_data={
|
||||
"logs": "[TRUNCATED: 1000 items]",
|
||||
"config": {"setting": "value"},
|
||||
"status": "processing",
|
||||
},
|
||||
expected_response_data={
|
||||
"logs": "[TRUNCATED: 1000 items]",
|
||||
"config": {"setting": "value"},
|
||||
"status": "processing",
|
||||
},
|
||||
expected_truncated_flag=True,
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
get_process_data_response_scenarios(),
|
||||
ids=[scenario.name for scenario in get_process_data_response_scenarios()],
|
||||
)
|
||||
def test_node_finish_response_scenarios(self, scenario: ProcessDataResponseScenario):
|
||||
"""Test various scenarios for node finish responses."""
|
||||
|
||||
mock_user = Mock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.name = "Test User"
|
||||
mock_user.email = "test@example.com"
|
||||
|
||||
converter = WorkflowResponseConverter(
|
||||
application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant")),
|
||||
user=mock_user,
|
||||
)
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=scenario.original_process_data,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now(),
|
||||
)
|
||||
|
||||
if scenario.truncated_process_data is not None:
|
||||
execution.set_truncated_process_data(scenario.truncated_process_data)
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
iteration_event = QueueNodeSucceededEvent(
|
||||
node_id="iteration-node",
|
||||
node_type=NodeType.ITERATION,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
start_at=naive_utc_now(),
|
||||
parallel_id=None,
|
||||
parallel_start_node_id=None,
|
||||
parent_parallel_id=None,
|
||||
parent_parallel_start_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
execution_metadata={},
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
event=iteration_event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
assert response is None
|
||||
|
||||
assert response is not None
|
||||
assert response.data.process_data == scenario.expected_response_data
|
||||
assert response.data.process_data_truncated == scenario.expected_truncated_flag
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
get_process_data_response_scenarios(),
|
||||
ids=[scenario.name for scenario in get_process_data_response_scenarios()],
|
||||
)
|
||||
def test_node_retry_response_scenarios(self, scenario: ProcessDataResponseScenario):
|
||||
"""Test various scenarios for node retry responses."""
|
||||
|
||||
mock_user = Mock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.name = "Test User"
|
||||
mock_user.email = "test@example.com"
|
||||
|
||||
converter = WorkflowResponseConverter(
|
||||
application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant")),
|
||||
user=mock_user,
|
||||
)
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=scenario.original_process_data,
|
||||
status=WorkflowNodeExecutionStatus.FAILED, # Retry scenario
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now(),
|
||||
)
|
||||
|
||||
if scenario.truncated_process_data is not None:
|
||||
execution.set_truncated_process_data(scenario.truncated_process_data)
|
||||
|
||||
event = self.create_node_retry_event()
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP})
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=loop_event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
assert response is None
|
||||
|
||||
def test_finish_without_start_raises(self):
|
||||
"""Ensure finish responses require a prior workflow start."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
process_data={},
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.process_data == scenario.expected_response_data
|
||||
assert response.data.process_data_truncated == scenario.expected_truncated_flag
|
||||
with pytest.raises(ValueError):
|
||||
converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
@@ -37,7 +37,7 @@ from core.variables.variables import (
|
||||
Variable,
|
||||
VariableUnion,
|
||||
)
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import json
|
||||
from time import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
|
||||
|
||||
|
||||
class TestGraphRuntimeState:
|
||||
@@ -95,3 +97,141 @@ class TestGraphRuntimeState:
|
||||
# Test add_tokens validation
|
||||
with pytest.raises(ValueError):
|
||||
state.add_tokens(-1)
|
||||
|
||||
def test_ready_queue_default_instantiation(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
|
||||
queue = state.ready_queue
|
||||
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
|
||||
assert isinstance(queue, InMemoryReadyQueue)
|
||||
assert state.ready_queue is queue
|
||||
|
||||
def test_graph_execution_lazy_instantiation(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
|
||||
execution = state.graph_execution
|
||||
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
|
||||
assert isinstance(execution, GraphExecution)
|
||||
assert execution.workflow_id == ""
|
||||
assert state.graph_execution is execution
|
||||
|
||||
def test_response_coordinator_configuration(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = state.response_coordinator
|
||||
|
||||
mock_graph = MagicMock()
|
||||
with patch("core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator") as coordinator_cls:
|
||||
coordinator_instance = MagicMock()
|
||||
coordinator_cls.return_value = coordinator_instance
|
||||
|
||||
state.configure(graph=mock_graph)
|
||||
|
||||
assert state.response_coordinator is coordinator_instance
|
||||
coordinator_cls.assert_called_once_with(variable_pool=variable_pool, graph=mock_graph)
|
||||
|
||||
# Configure again with same graph should be idempotent
|
||||
state.configure(graph=mock_graph)
|
||||
|
||||
other_graph = MagicMock()
|
||||
with pytest.raises(ValueError):
|
||||
state.attach_graph(other_graph)
|
||||
|
||||
def test_read_only_wrapper_exposes_additional_state(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
state.configure()
|
||||
|
||||
wrapper = ReadOnlyGraphRuntimeStateWrapper(state)
|
||||
|
||||
assert wrapper.ready_queue_size == 0
|
||||
assert wrapper.exceptions_count == 0
|
||||
|
||||
def test_read_only_wrapper_serializes_runtime_state(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
state.total_tokens = 5
|
||||
state.set_output("result", {"success": True})
|
||||
state.ready_queue.put("node-1")
|
||||
|
||||
wrapper = ReadOnlyGraphRuntimeStateWrapper(state)
|
||||
|
||||
wrapper_snapshot = json.loads(wrapper.dumps())
|
||||
state_snapshot = json.loads(state.dumps())
|
||||
|
||||
assert wrapper_snapshot == state_snapshot
|
||||
|
||||
def test_dumps_and_loads_roundtrip_with_response_coordinator(self):
|
||||
variable_pool = VariablePool()
|
||||
variable_pool.add(("node1", "value"), "payload")
|
||||
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
state.total_tokens = 10
|
||||
state.node_run_steps = 3
|
||||
state.set_output("final", {"result": True})
|
||||
usage = LLMUsage.from_metadata(
|
||||
{
|
||||
"prompt_tokens": 2,
|
||||
"completion_tokens": 3,
|
||||
"total_tokens": 5,
|
||||
"total_price": "1.23",
|
||||
"currency": "USD",
|
||||
"latency": 0.5,
|
||||
}
|
||||
)
|
||||
state.llm_usage = usage
|
||||
state.ready_queue.put("node-A")
|
||||
|
||||
graph_execution = state.graph_execution
|
||||
graph_execution.workflow_id = "wf-123"
|
||||
graph_execution.exceptions_count = 4
|
||||
graph_execution.started = True
|
||||
|
||||
class StubCoordinator:
|
||||
def __init__(self) -> None:
|
||||
self.state = "initial"
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps({"state": self.state})
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
payload = json.loads(data)
|
||||
self.state = payload["state"]
|
||||
|
||||
mock_graph = MagicMock()
|
||||
stub = StubCoordinator()
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub):
|
||||
state.attach_graph(mock_graph)
|
||||
|
||||
stub.state = "configured"
|
||||
|
||||
snapshot = state.dumps()
|
||||
|
||||
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
restored.loads(snapshot)
|
||||
|
||||
assert restored.total_tokens == 10
|
||||
assert restored.node_run_steps == 3
|
||||
assert restored.get_output("final") == {"result": True}
|
||||
assert restored.llm_usage.total_tokens == usage.total_tokens
|
||||
assert restored.ready_queue.qsize() == 1
|
||||
assert restored.ready_queue.get(timeout=0.01) == "node-A"
|
||||
|
||||
restored_segment = restored.variable_pool.get(("node1", "value"))
|
||||
assert restored_segment is not None
|
||||
assert restored_segment.value == "payload"
|
||||
|
||||
restored_execution = restored.graph_execution
|
||||
assert restored_execution.workflow_id == "wf-123"
|
||||
assert restored_execution.exceptions_count == 4
|
||||
assert restored_execution.started is True
|
||||
|
||||
new_stub = StubCoordinator()
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
|
||||
restored.attach_graph(mock_graph)
|
||||
|
||||
assert new_stub.state == "configured"
|
||||
|
||||
@@ -4,7 +4,7 @@ from core.variables.segments import (
|
||||
NoneSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
class TestVariablePoolGetAndNestedAttribute:
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
|
||||
def _make_node(node_id: str, node_type: NodeType = NodeType.START) -> Node:
|
||||
node = MagicMock(spec=Node)
|
||||
node.id = node_id
|
||||
node.node_type = node_type
|
||||
node.execution_type = None # attribute not used in builder path
|
||||
return node
|
||||
|
||||
|
||||
def test_graph_builder_creates_linear_graph():
|
||||
builder = Graph.new()
|
||||
root = _make_node("root", NodeType.START)
|
||||
mid = _make_node("mid", NodeType.LLM)
|
||||
end = _make_node("end", NodeType.END)
|
||||
|
||||
graph = builder.add_root(root).add_node(mid).add_node(end).build()
|
||||
|
||||
assert graph.root_node is root
|
||||
assert graph.nodes == {"root": root, "mid": mid, "end": end}
|
||||
assert len(graph.edges) == 2
|
||||
first_edge = next(iter(graph.edges.values()))
|
||||
assert first_edge.tail == "root"
|
||||
assert first_edge.head == "mid"
|
||||
assert graph.out_edges["mid"] == [edge_id for edge_id, edge in graph.edges.items() if edge.tail == "mid"]
|
||||
|
||||
|
||||
def test_graph_builder_supports_custom_predecessor():
|
||||
builder = Graph.new()
|
||||
root = _make_node("root")
|
||||
branch = _make_node("branch")
|
||||
other = _make_node("other")
|
||||
|
||||
graph = builder.add_root(root).add_node(branch).add_node(other, from_node_id="root").build()
|
||||
|
||||
outgoing_root = graph.out_edges["root"]
|
||||
assert len(outgoing_root) == 2
|
||||
edge_targets = {graph.edges[eid].head for eid in outgoing_root}
|
||||
assert edge_targets == {"branch", "other"}
|
||||
|
||||
|
||||
def test_graph_builder_validates_usage():
|
||||
builder = Graph.new()
|
||||
node = _make_node("node")
|
||||
|
||||
with pytest.raises(ValueError, match="Root node"):
|
||||
builder.add_node(node)
|
||||
|
||||
builder.add_root(node)
|
||||
duplicate = _make_node("node")
|
||||
with pytest.raises(ValueError, match="Duplicate"):
|
||||
builder.add_node(duplicate)
|
||||
@@ -20,9 +20,6 @@ The TableTestRunner (`test_table_runner.py`) provides a robust table-driven test
|
||||
- **Mock configuration** - Seamless integration with the auto-mock system
|
||||
- **Performance metrics** - Track execution times and bottlenecks
|
||||
- **Detailed error reporting** - Comprehensive failure diagnostics
|
||||
- **Test tagging** - Organize and filter tests by tags
|
||||
- **Retry mechanism** - Handle flaky tests gracefully
|
||||
- **Custom validators** - Define custom validation logic
|
||||
|
||||
### Basic Usage
|
||||
|
||||
@@ -68,49 +65,6 @@ suite_result = runner.run_table_tests(
|
||||
print(f"Success rate: {suite_result.success_rate:.1f}%")
|
||||
```
|
||||
|
||||
#### Test Tagging and Filtering
|
||||
|
||||
```python
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="workflow",
|
||||
inputs={},
|
||||
expected_outputs={},
|
||||
tags=["smoke", "critical"],
|
||||
)
|
||||
|
||||
# Run only tests with specific tags
|
||||
suite_result = runner.run_table_tests(
|
||||
test_cases,
|
||||
tags_filter=["smoke"]
|
||||
)
|
||||
```
|
||||
|
||||
#### Retry Mechanism
|
||||
|
||||
```python
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="flaky_workflow",
|
||||
inputs={},
|
||||
expected_outputs={},
|
||||
retry_count=2, # Retry up to 2 times on failure
|
||||
)
|
||||
```
|
||||
|
||||
#### Custom Validators
|
||||
|
||||
```python
|
||||
def custom_validator(outputs: dict) -> bool:
|
||||
# Custom validation logic
|
||||
return "error" not in outputs.get("status", "")
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="workflow",
|
||||
inputs={},
|
||||
expected_outputs={"status": "success"},
|
||||
custom_validator=custom_validator,
|
||||
)
|
||||
```
|
||||
|
||||
#### Event Sequence Validation
|
||||
|
||||
```python
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
@@ -16,6 +15,7 @@ from core.workflow.graph_engine.response_coordinator.coordinator import Response
|
||||
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import RetryConfig
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
|
||||
class _StubEdgeProcessor:
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand
|
||||
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunStartedEvent
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
|
||||
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
|
||||
def test_abort_command():
|
||||
@@ -100,8 +100,57 @@ def test_redis_channel_serialization():
|
||||
assert command_data["command_type"] == "abort"
|
||||
assert command_data["reason"] == "Test abort"
|
||||
|
||||
# Test pause command serialization
|
||||
pause_command = PauseCommand(reason="User requested pause")
|
||||
channel.send_command(pause_command)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_abort_command()
|
||||
test_redis_channel_serialization()
|
||||
print("All tests passed!")
|
||||
assert len(mock_pipeline.rpush.call_args_list) == 2
|
||||
second_call_args = mock_pipeline.rpush.call_args_list[1]
|
||||
pause_command_json = second_call_args[0][1]
|
||||
pause_command_data = json.loads(pause_command_json)
|
||||
assert pause_command_data["command_type"] == CommandType.PAUSE.value
|
||||
assert pause_command_data["reason"] == "User requested pause"
|
||||
|
||||
|
||||
def test_pause_command():
|
||||
"""Test that GraphEngine properly handles pause commands."""
|
||||
|
||||
shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start"
|
||||
|
||||
mock_start_node = MagicMock()
|
||||
mock_start_node.state = None
|
||||
mock_start_node.id = "start"
|
||||
mock_start_node.graph_runtime_state = shared_runtime_state
|
||||
mock_graph.nodes["start"] = mock_start_node
|
||||
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
command_channel = InMemoryChannel()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=shared_runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
pause_command = PauseCommand(reason="User requested pause")
|
||||
command_channel.send_command(pause_command)
|
||||
|
||||
events = list(engine.run())
|
||||
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
|
||||
assert len(pause_events) == 1
|
||||
assert pause_events[0].reason == "User requested pause"
|
||||
|
||||
graph_execution = engine.graph_runtime_state.graph_execution
|
||||
assert graph_execution.is_paused
|
||||
assert graph_execution.pause_reason == "User requested pause"
|
||||
|
||||
@@ -21,6 +21,7 @@ class _StubExecutionCoordinator:
|
||||
self._execution_complete = False
|
||||
self.mark_complete_called = False
|
||||
self.failed = False
|
||||
self._paused = False
|
||||
|
||||
def check_commands(self) -> None:
|
||||
self.command_checks += 1
|
||||
@@ -28,6 +29,10 @@ class _StubExecutionCoordinator:
|
||||
def check_scaling(self) -> None:
|
||||
self.scaling_checks += 1
|
||||
|
||||
@property
|
||||
def is_paused(self) -> bool:
|
||||
return self._paused
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
return self._execution_complete
|
||||
|
||||
@@ -96,7 +101,7 @@ def _make_succeeded_event() -> NodeRunSucceededEvent:
|
||||
|
||||
|
||||
def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None:
|
||||
"""Dispatcher polls commands when idle and re-checks after completion events."""
|
||||
"""Dispatcher polls commands when idle and after completion events."""
|
||||
started_checks = _run_dispatcher_for_event(_make_started_event())
|
||||
succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event())
|
||||
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Unit tests for the execution coordinator orchestration logic."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
|
||||
from core.workflow.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator
|
||||
from core.workflow.graph_engine.worker_management.worker_pool import WorkerPool
|
||||
|
||||
|
||||
def _build_coordinator(graph_execution: GraphExecution) -> tuple[ExecutionCoordinator, MagicMock, MagicMock]:
|
||||
command_processor = MagicMock(spec=CommandProcessor)
|
||||
state_manager = MagicMock(spec=GraphStateManager)
|
||||
worker_pool = MagicMock(spec=WorkerPool)
|
||||
|
||||
coordinator = ExecutionCoordinator(
|
||||
graph_execution=graph_execution,
|
||||
state_manager=state_manager,
|
||||
command_processor=command_processor,
|
||||
worker_pool=worker_pool,
|
||||
)
|
||||
return coordinator, state_manager, worker_pool
|
||||
|
||||
|
||||
def test_handle_pause_stops_workers_and_clears_state() -> None:
|
||||
"""Paused execution should stop workers and clear executing state."""
|
||||
graph_execution = GraphExecution(workflow_id="workflow")
|
||||
graph_execution.start()
|
||||
graph_execution.pause("Awaiting human input")
|
||||
|
||||
coordinator, state_manager, worker_pool = _build_coordinator(graph_execution)
|
||||
|
||||
coordinator.handle_pause_if_needed()
|
||||
|
||||
worker_pool.stop.assert_called_once_with()
|
||||
state_manager.clear_executing.assert_called_once_with()
|
||||
|
||||
|
||||
def test_handle_pause_noop_when_execution_running() -> None:
|
||||
"""Running execution should not trigger pause handling."""
|
||||
graph_execution = GraphExecution(workflow_id="workflow")
|
||||
graph_execution.start()
|
||||
|
||||
coordinator, state_manager, worker_pool = _build_coordinator(graph_execution)
|
||||
|
||||
coordinator.handle_pause_if_needed()
|
||||
|
||||
worker_pool.stop.assert_not_called()
|
||||
state_manager.clear_executing.assert_not_called()
|
||||
|
||||
|
||||
def test_is_execution_complete_when_paused() -> None:
|
||||
"""Paused execution should be treated as complete."""
|
||||
graph_execution = GraphExecution(workflow_id="workflow")
|
||||
graph_execution.start()
|
||||
graph_execution.pause("Awaiting input")
|
||||
|
||||
coordinator, state_manager, _worker_pool = _build_coordinator(graph_execution)
|
||||
state_manager.is_execution_complete.return_value = False
|
||||
|
||||
assert coordinator.is_execution_complete()
|
||||
@@ -0,0 +1,341 @@
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_config["data"])
|
||||
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
title=title,
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=prompt_text,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
)
|
||||
llm_config = {"id": node_id, "data": llm_data.model_dump()}
|
||||
llm_node = MockLLMNode(
|
||||
id=node_id,
|
||||
config=llm_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_node.init_node_data(llm_config["data"])
|
||||
return llm_node
|
||||
|
||||
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
required_variables=["human.input_ready"],
|
||||
pause_reason="Awaiting human input",
|
||||
)
|
||||
human_config = {"id": "human", "data": human_data.model_dump()}
|
||||
human_node = HumanInputNode(
|
||||
id=human_config["id"],
|
||||
config=human_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
human_node.init_node_data(human_config["data"])
|
||||
|
||||
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
||||
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
|
||||
|
||||
end_primary_data = EndNodeData(
|
||||
title="End Primary",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="primary_text", value_selector=["llm_primary", "text"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()}
|
||||
end_primary = EndNode(
|
||||
id=end_primary_config["id"],
|
||||
config=end_primary_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_primary.init_node_data(end_primary_config["data"])
|
||||
|
||||
end_secondary_data = EndNodeData(
|
||||
title="End Secondary",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="secondary_text", value_selector=["llm_secondary", "text"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()}
|
||||
end_secondary = EndNode(
|
||||
id=end_secondary_config["id"],
|
||||
config=end_secondary_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_secondary.init_node_data(end_secondary_config["data"])
|
||||
|
||||
graph = (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(llm_initial)
|
||||
.add_node(human_node)
|
||||
.add_node(llm_primary, from_node_id="human", source_handle="primary")
|
||||
.add_node(end_primary, from_node_id="llm_primary")
|
||||
.add_node(llm_secondary, from_node_id="human", source_handle="secondary")
|
||||
.add_node(end_secondary, from_node_id="llm_secondary")
|
||||
.build()
|
||||
)
|
||||
return graph, graph_runtime_state
|
||||
|
||||
|
||||
def _expected_mock_llm_chunks(text: str) -> list[str]:
|
||||
chunks: list[str] = []
|
||||
for index, word in enumerate(text.split(" ")):
|
||||
chunk = word if index == 0 else f" {word}"
|
||||
chunks.append(chunk)
|
||||
chunks.append("")
|
||||
return chunks
|
||||
|
||||
|
||||
def _assert_stream_chunk_sequence(
|
||||
chunk_events: Iterable[NodeRunStreamChunkEvent],
|
||||
expected_nodes: list[str],
|
||||
expected_chunks: list[str],
|
||||
) -> None:
|
||||
actual_nodes = [event.node_id for event in chunk_events]
|
||||
actual_chunks = [event.chunk for event in chunk_events]
|
||||
assert actual_nodes == expected_nodes
|
||||
assert actual_chunks == expected_chunks
|
||||
|
||||
|
||||
def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
mock_config = MockConfig()
|
||||
mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"})
|
||||
mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"})
|
||||
mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"})
|
||||
|
||||
branch_scenarios = [
|
||||
{
|
||||
"handle": "primary",
|
||||
"resume_llm": "llm_primary",
|
||||
"end_node": "end_primary",
|
||||
"expected_pre_chunks": [
|
||||
("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes
|
||||
("end_primary", ["\n"]), # literal segment emitted when end_primary session activates
|
||||
],
|
||||
"expected_post_chunks": [
|
||||
("llm_primary", _expected_mock_llm_chunks("Primary stream output")), # live stream from chosen branch
|
||||
],
|
||||
},
|
||||
{
|
||||
"handle": "secondary",
|
||||
"resume_llm": "llm_secondary",
|
||||
"end_node": "end_secondary",
|
||||
"expected_pre_chunks": [
|
||||
("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes
|
||||
("end_secondary", ["\n"]), # literal segment emitted when end_secondary session activates
|
||||
],
|
||||
"expected_post_chunks": [
|
||||
("llm_secondary", _expected_mock_llm_chunks("Secondary")), # live stream from chosen branch
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
for scenario in branch_scenarios:
|
||||
runner = TableTestRunner()
|
||||
|
||||
def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_branching_graph(mock_config)
|
||||
|
||||
initial_case = WorkflowTestCase(
|
||||
description="HumanInput pause before branching decision",
|
||||
graph_factory=initial_graph_factory,
|
||||
expected_event_sequence=[
|
||||
GraphRunStartedEvent, # initial run: graph execution starts
|
||||
NodeRunStartedEvent, # start node begins execution
|
||||
NodeRunSucceededEvent, # start node completes
|
||||
NodeRunStartedEvent, # llm_initial starts streaming
|
||||
NodeRunSucceededEvent, # llm_initial completes streaming
|
||||
NodeRunStartedEvent, # human node begins and issues pause
|
||||
NodeRunPauseRequestedEvent, # human node requests pause awaiting input
|
||||
GraphRunPausedEvent, # graph run pauses awaiting resume
|
||||
],
|
||||
)
|
||||
|
||||
initial_result = runner.run_test_case(initial_case)
|
||||
|
||||
assert initial_result.success, initial_result.event_mismatch_details
|
||||
assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events)
|
||||
|
||||
graph_runtime_state = initial_result.graph_runtime_state
|
||||
graph = initial_result.graph
|
||||
assert graph_runtime_state is not None
|
||||
assert graph is not None
|
||||
|
||||
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
|
||||
graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"])
|
||||
graph_runtime_state.graph_execution.pause_reason = None
|
||||
|
||||
pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"])
|
||||
post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"])
|
||||
|
||||
expected_resume_sequence: list[type] = (
|
||||
[
|
||||
GraphRunStartedEvent,
|
||||
NodeRunStartedEvent,
|
||||
]
|
||||
+ [NodeRunStreamChunkEvent] * pre_chunk_count
|
||||
+ [
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
]
|
||||
+ [NodeRunStreamChunkEvent] * post_chunk_count
|
||||
+ [
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
]
|
||||
)
|
||||
|
||||
def resume_graph_factory(
|
||||
graph_snapshot: Graph = graph,
|
||||
state_snapshot: GraphRuntimeState = graph_runtime_state,
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
return graph_snapshot, state_snapshot
|
||||
|
||||
resume_case = WorkflowTestCase(
|
||||
description=f"HumanInput resumes via {scenario['handle']} branch",
|
||||
graph_factory=resume_graph_factory,
|
||||
expected_event_sequence=expected_resume_sequence,
|
||||
)
|
||||
|
||||
resume_result = runner.run_test_case(resume_case)
|
||||
|
||||
assert resume_result.success, resume_result.event_mismatch_details
|
||||
|
||||
resume_events = resume_result.events
|
||||
|
||||
chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)]
|
||||
assert len(chunk_events) == pre_chunk_count + post_chunk_count
|
||||
|
||||
pre_chunk_events = chunk_events[:pre_chunk_count]
|
||||
post_chunk_events = chunk_events[pre_chunk_count:]
|
||||
|
||||
expected_pre_nodes: list[str] = []
|
||||
expected_pre_chunks: list[str] = []
|
||||
for node_id, chunks in scenario["expected_pre_chunks"]:
|
||||
expected_pre_nodes.extend([node_id] * len(chunks))
|
||||
expected_pre_chunks.extend(chunks)
|
||||
_assert_stream_chunk_sequence(pre_chunk_events, expected_pre_nodes, expected_pre_chunks)
|
||||
|
||||
expected_post_nodes: list[str] = []
|
||||
expected_post_chunks: list[str] = []
|
||||
for node_id, chunks in scenario["expected_post_chunks"]:
|
||||
expected_post_nodes.extend([node_id] * len(chunks))
|
||||
expected_post_chunks.extend(chunks)
|
||||
_assert_stream_chunk_sequence(post_chunk_events, expected_post_nodes, expected_post_chunks)
|
||||
|
||||
human_success_index = next(
|
||||
index
|
||||
for index, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human"
|
||||
)
|
||||
pre_indices = [
|
||||
index
|
||||
for index, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index
|
||||
]
|
||||
assert pre_indices == list(range(2, 2 + pre_chunk_count))
|
||||
|
||||
resume_chunk_indices = [
|
||||
index
|
||||
for index, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"]
|
||||
]
|
||||
assert resume_chunk_indices, "Expected streaming output from the selected branch"
|
||||
resume_start_index = next(
|
||||
index
|
||||
for index, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"]
|
||||
)
|
||||
resume_success_index = next(
|
||||
index
|
||||
for index, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"]
|
||||
)
|
||||
assert resume_start_index < min(resume_chunk_indices)
|
||||
assert max(resume_chunk_indices) < resume_success_index
|
||||
|
||||
started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)]
|
||||
assert started_nodes == ["human", scenario["resume_llm"], scenario["end_node"]]
|
||||
@@ -0,0 +1,297 @@
|
||||
import time
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_config["data"])
|
||||
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
title=title,
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=prompt_text,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
)
|
||||
llm_config = {"id": node_id, "data": llm_data.model_dump()}
|
||||
llm_node = MockLLMNode(
|
||||
id=node_id,
|
||||
config=llm_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_node.init_node_data(llm_config["data"])
|
||||
return llm_node
|
||||
|
||||
llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt")
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
required_variables=["human.input_ready"],
|
||||
pause_reason="Awaiting human input",
|
||||
)
|
||||
human_config = {"id": "human", "data": human_data.model_dump()}
|
||||
human_node = HumanInputNode(
|
||||
id=human_config["id"],
|
||||
config=human_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
human_node.init_node_data(human_config["data"])
|
||||
|
||||
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
|
||||
|
||||
end_data = EndNodeData(
|
||||
title="End",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="resume_text", value_selector=["llm_resume", "text"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_config = {"id": "end", "data": end_data.model_dump()}
|
||||
end_node = EndNode(
|
||||
id=end_config["id"],
|
||||
config=end_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_node.init_node_data(end_config["data"])
|
||||
|
||||
graph = (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(llm_first)
|
||||
.add_node(human_node)
|
||||
.add_node(llm_second)
|
||||
.add_node(end_node)
|
||||
.build()
|
||||
)
|
||||
return graph, graph_runtime_state
|
||||
|
||||
|
||||
def _expected_mock_llm_chunks(text: str) -> list[str]:
|
||||
chunks: list[str] = []
|
||||
for index, word in enumerate(text.split(" ")):
|
||||
chunk = word if index == 0 else f" {word}"
|
||||
chunks.append(chunk)
|
||||
chunks.append("")
|
||||
return chunks
|
||||
|
||||
|
||||
def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
runner = TableTestRunner()
|
||||
|
||||
initial_text = "Hello, pause"
|
||||
resume_text = "Welcome back!"
|
||||
|
||||
mock_config = MockConfig()
|
||||
mock_config.set_node_outputs("llm_initial", {"text": initial_text})
|
||||
mock_config.set_node_outputs("llm_resume", {"text": resume_text})
|
||||
|
||||
expected_initial_sequence: list[type] = [
|
||||
GraphRunStartedEvent, # graph run begins
|
||||
NodeRunStartedEvent, # start node begins
|
||||
NodeRunSucceededEvent, # start node completes
|
||||
NodeRunStartedEvent, # llm_initial begins streaming
|
||||
NodeRunSucceededEvent, # llm_initial completes streaming
|
||||
NodeRunStartedEvent, # human node begins and requests pause
|
||||
NodeRunPauseRequestedEvent, # human node pause requested
|
||||
GraphRunPausedEvent, # graph run pauses awaiting resume
|
||||
]
|
||||
|
||||
def graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_llm_human_llm_graph(mock_config)
|
||||
|
||||
initial_case = WorkflowTestCase(
|
||||
description="HumanInput pause preserves LLM streaming order",
|
||||
graph_factory=graph_factory,
|
||||
expected_event_sequence=expected_initial_sequence,
|
||||
)
|
||||
|
||||
initial_result = runner.run_test_case(initial_case)
|
||||
|
||||
assert initial_result.success, initial_result.event_mismatch_details
|
||||
|
||||
initial_events = initial_result.events
|
||||
initial_chunks = _expected_mock_llm_chunks(initial_text)
|
||||
|
||||
initial_stream_chunk_events = [event for event in initial_events if isinstance(event, NodeRunStreamChunkEvent)]
|
||||
assert initial_stream_chunk_events == []
|
||||
|
||||
pause_index = next(i for i, event in enumerate(initial_events) if isinstance(event, GraphRunPausedEvent))
|
||||
llm_succeeded_index = next(
|
||||
i
|
||||
for i, event in enumerate(initial_events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_initial"
|
||||
)
|
||||
assert llm_succeeded_index < pause_index
|
||||
|
||||
graph_runtime_state = initial_result.graph_runtime_state
|
||||
graph = initial_result.graph
|
||||
assert graph_runtime_state is not None
|
||||
assert graph is not None
|
||||
|
||||
coordinator = graph_runtime_state.response_coordinator
|
||||
stream_buffers = coordinator._stream_buffers # Tests may access internals for assertions
|
||||
assert ("llm_initial", "text") in stream_buffers
|
||||
initial_stream_chunks = [event.chunk for event in stream_buffers[("llm_initial", "text")]]
|
||||
assert initial_stream_chunks == initial_chunks
|
||||
assert ("llm_resume", "text") not in stream_buffers
|
||||
|
||||
resume_chunks = _expected_mock_llm_chunks(resume_text)
|
||||
expected_resume_sequence: list[type] = [
|
||||
GraphRunStartedEvent, # resumed graph run begins
|
||||
NodeRunStartedEvent, # human node restarts
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 1
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 2
|
||||
NodeRunStreamChunkEvent, # cached llm_initial final chunk
|
||||
NodeRunStreamChunkEvent, # end node emits combined template separator
|
||||
NodeRunSucceededEvent, # human node finishes instantly after input
|
||||
NodeRunStartedEvent, # llm_resume begins streaming
|
||||
NodeRunStreamChunkEvent, # llm_resume chunk 1
|
||||
NodeRunStreamChunkEvent, # llm_resume chunk 2
|
||||
NodeRunStreamChunkEvent, # llm_resume final chunk
|
||||
NodeRunSucceededEvent, # llm_resume completes streaming
|
||||
NodeRunStartedEvent, # end node starts
|
||||
NodeRunSucceededEvent, # end node finishes
|
||||
GraphRunSucceededEvent, # graph run succeeds after resume
|
||||
]
|
||||
|
||||
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
assert graph_runtime_state is not None
|
||||
assert graph is not None
|
||||
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
|
||||
graph_runtime_state.graph_execution.pause_reason = None
|
||||
return graph, graph_runtime_state
|
||||
|
||||
resume_case = WorkflowTestCase(
|
||||
description="HumanInput resume continues LLM streaming order",
|
||||
graph_factory=resume_graph_factory,
|
||||
expected_event_sequence=expected_resume_sequence,
|
||||
)
|
||||
|
||||
resume_result = runner.run_test_case(resume_case)
|
||||
|
||||
assert resume_result.success, resume_result.event_mismatch_details
|
||||
|
||||
resume_events = resume_result.events
|
||||
|
||||
success_index = next(i for i, event in enumerate(resume_events) if isinstance(event, GraphRunSucceededEvent))
|
||||
llm_resume_succeeded_index = next(
|
||||
i
|
||||
for i, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume"
|
||||
)
|
||||
assert llm_resume_succeeded_index < success_index
|
||||
|
||||
resume_chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)]
|
||||
assert [event.node_id for event in resume_chunk_events[:3]] == ["llm_initial"] * 3
|
||||
assert [event.chunk for event in resume_chunk_events[:3]] == initial_chunks
|
||||
assert resume_chunk_events[3].node_id == "end"
|
||||
assert resume_chunk_events[3].chunk == "\n"
|
||||
assert [event.node_id for event in resume_chunk_events[4:]] == ["llm_resume"] * 3
|
||||
assert [event.chunk for event in resume_chunk_events[4:]] == resume_chunks
|
||||
|
||||
human_success_index = next(
|
||||
i
|
||||
for i, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human"
|
||||
)
|
||||
cached_chunk_indices = [
|
||||
i
|
||||
for i, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id in {"llm_initial", "end"}
|
||||
]
|
||||
assert all(index < human_success_index for index in cached_chunk_indices)
|
||||
|
||||
llm_resume_start_index = next(
|
||||
i
|
||||
for i, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStartedEvent) and event.node_id == "llm_resume"
|
||||
)
|
||||
llm_resume_success_index = next(
|
||||
i
|
||||
for i, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume"
|
||||
)
|
||||
llm_resume_chunk_indices = [
|
||||
i
|
||||
for i, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == "llm_resume"
|
||||
]
|
||||
assert llm_resume_chunk_indices
|
||||
first_resume_chunk_index = min(llm_resume_chunk_indices)
|
||||
last_resume_chunk_index = max(llm_resume_chunk_indices)
|
||||
assert llm_resume_start_index < first_resume_chunk_index
|
||||
assert last_resume_chunk_index < llm_resume_success_index
|
||||
|
||||
started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)]
|
||||
assert started_nodes == ["human", "llm_resume", "end"]
|
||||
@@ -0,0 +1,321 @@
|
||||
import time
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(("branch", "value"), branch_value)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_config["data"])
|
||||
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
title=title,
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=prompt_text,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
)
|
||||
llm_config = {"id": node_id, "data": llm_data.model_dump()}
|
||||
llm_node = MockLLMNode(
|
||||
id=node_id,
|
||||
config=llm_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_node.init_node_data(llm_config["data"])
|
||||
return llm_node
|
||||
|
||||
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
|
||||
|
||||
if_else_data = IfElseNodeData(
|
||||
title="IfElse",
|
||||
cases=[
|
||||
IfElseNodeData.Case(
|
||||
case_id="primary",
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
Condition(variable_selector=["branch", "value"], comparison_operator="is", value="primary")
|
||||
],
|
||||
),
|
||||
IfElseNodeData.Case(
|
||||
case_id="secondary",
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
Condition(variable_selector=["branch", "value"], comparison_operator="is", value="secondary")
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
if_else_config = {"id": "if_else", "data": if_else_data.model_dump()}
|
||||
if_else_node = IfElseNode(
|
||||
id=if_else_config["id"],
|
||||
config=if_else_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
if_else_node.init_node_data(if_else_config["data"])
|
||||
|
||||
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
||||
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
|
||||
|
||||
end_primary_data = EndNodeData(
|
||||
title="End Primary",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="primary_text", value_selector=["llm_primary", "text"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()}
|
||||
end_primary = EndNode(
|
||||
id=end_primary_config["id"],
|
||||
config=end_primary_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_primary.init_node_data(end_primary_config["data"])
|
||||
|
||||
end_secondary_data = EndNodeData(
|
||||
title="End Secondary",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="secondary_text", value_selector=["llm_secondary", "text"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()}
|
||||
end_secondary = EndNode(
|
||||
id=end_secondary_config["id"],
|
||||
config=end_secondary_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_secondary.init_node_data(end_secondary_config["data"])
|
||||
|
||||
graph = (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(llm_initial)
|
||||
.add_node(if_else_node)
|
||||
.add_node(llm_primary, from_node_id="if_else", source_handle="primary")
|
||||
.add_node(end_primary, from_node_id="llm_primary")
|
||||
.add_node(llm_secondary, from_node_id="if_else", source_handle="secondary")
|
||||
.add_node(end_secondary, from_node_id="llm_secondary")
|
||||
.build()
|
||||
)
|
||||
return graph, graph_runtime_state
|
||||
|
||||
|
||||
def _expected_mock_llm_chunks(text: str) -> list[str]:
|
||||
chunks: list[str] = []
|
||||
for index, word in enumerate(text.split(" ")):
|
||||
chunk = word if index == 0 else f" {word}"
|
||||
chunks.append(chunk)
|
||||
chunks.append("")
|
||||
return chunks
|
||||
|
||||
|
||||
def test_if_else_llm_streaming_order() -> None:
|
||||
mock_config = MockConfig()
|
||||
mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"})
|
||||
mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"})
|
||||
mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"})
|
||||
|
||||
scenarios = [
|
||||
{
|
||||
"branch": "primary",
|
||||
"resume_llm": "llm_primary",
|
||||
"end_node": "end_primary",
|
||||
"expected_sequence": [
|
||||
GraphRunStartedEvent, # graph run begins
|
||||
NodeRunStartedEvent, # start node begins execution
|
||||
NodeRunSucceededEvent, # start node completes
|
||||
NodeRunStartedEvent, # llm_initial starts and streams
|
||||
NodeRunSucceededEvent, # llm_initial completes streaming
|
||||
NodeRunStartedEvent, # if_else evaluates conditions
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed
|
||||
NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed
|
||||
NodeRunStreamChunkEvent, # template literal newline emitted
|
||||
NodeRunSucceededEvent, # if_else completes branch selection
|
||||
NodeRunStartedEvent, # llm_primary begins streaming
|
||||
NodeRunStreamChunkEvent, # llm_primary chunk 1
|
||||
NodeRunStreamChunkEvent, # llm_primary chunk 2
|
||||
NodeRunStreamChunkEvent, # llm_primary chunk 3
|
||||
NodeRunStreamChunkEvent, # llm_primary final chunk
|
||||
NodeRunSucceededEvent, # llm_primary completes streaming
|
||||
NodeRunStartedEvent, # end_primary node starts
|
||||
NodeRunSucceededEvent, # end_primary finishes aggregation
|
||||
GraphRunSucceededEvent, # graph run succeeds
|
||||
],
|
||||
"expected_chunks": [
|
||||
("llm_initial", _expected_mock_llm_chunks("Initial stream")),
|
||||
("end_primary", ["\n"]),
|
||||
("llm_primary", _expected_mock_llm_chunks("Primary stream output")),
|
||||
],
|
||||
},
|
||||
{
|
||||
"branch": "secondary",
|
||||
"resume_llm": "llm_secondary",
|
||||
"end_node": "end_secondary",
|
||||
"expected_sequence": [
|
||||
GraphRunStartedEvent, # graph run begins
|
||||
NodeRunStartedEvent, # start node begins execution
|
||||
NodeRunSucceededEvent, # start node completes
|
||||
NodeRunStartedEvent, # llm_initial starts and streams
|
||||
NodeRunSucceededEvent, # llm_initial completes streaming
|
||||
NodeRunStartedEvent, # if_else evaluates conditions
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed
|
||||
NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed
|
||||
NodeRunStreamChunkEvent, # template literal newline emitted
|
||||
NodeRunSucceededEvent, # if_else completes branch selection
|
||||
NodeRunStartedEvent, # llm_secondary begins streaming
|
||||
NodeRunStreamChunkEvent, # llm_secondary chunk 1
|
||||
NodeRunStreamChunkEvent, # llm_secondary final chunk
|
||||
NodeRunSucceededEvent, # llm_secondary completes
|
||||
NodeRunStartedEvent, # end_secondary node starts
|
||||
NodeRunSucceededEvent, # end_secondary finishes aggregation
|
||||
GraphRunSucceededEvent, # graph run succeeds
|
||||
],
|
||||
"expected_chunks": [
|
||||
("llm_initial", _expected_mock_llm_chunks("Initial stream")),
|
||||
("end_secondary", ["\n"]),
|
||||
("llm_secondary", _expected_mock_llm_chunks("Secondary")),
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
for scenario in scenarios:
|
||||
runner = TableTestRunner()
|
||||
|
||||
def graph_factory(
|
||||
branch_value: str = scenario["branch"],
|
||||
cfg: MockConfig = mock_config,
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_if_else_graph(branch_value, cfg)
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
description=f"IfElse streaming via {scenario['branch']} branch",
|
||||
graph_factory=graph_factory,
|
||||
expected_event_sequence=scenario["expected_sequence"],
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
assert result.success, result.event_mismatch_details
|
||||
|
||||
chunk_events = [event for event in result.events if isinstance(event, NodeRunStreamChunkEvent)]
|
||||
expected_nodes: list[str] = []
|
||||
expected_chunks: list[str] = []
|
||||
for node_id, chunks in scenario["expected_chunks"]:
|
||||
expected_nodes.extend([node_id] * len(chunks))
|
||||
expected_chunks.extend(chunks)
|
||||
assert [event.node_id for event in chunk_events] == expected_nodes
|
||||
assert [event.chunk for event in chunk_events] == expected_chunks
|
||||
|
||||
branch_node_index = next(
|
||||
index
|
||||
for index, event in enumerate(result.events)
|
||||
if isinstance(event, NodeRunStartedEvent) and event.node_id == "if_else"
|
||||
)
|
||||
branch_success_index = next(
|
||||
index
|
||||
for index, event in enumerate(result.events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "if_else"
|
||||
)
|
||||
pre_branch_chunk_indices = [
|
||||
index
|
||||
for index, event in enumerate(result.events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and index < branch_success_index
|
||||
]
|
||||
assert len(pre_branch_chunk_indices) == len(_expected_mock_llm_chunks("Initial stream")) + 1
|
||||
assert min(pre_branch_chunk_indices) == branch_node_index + 1
|
||||
assert max(pre_branch_chunk_indices) < branch_success_index
|
||||
|
||||
resume_chunk_indices = [
|
||||
index
|
||||
for index, event in enumerate(result.events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"]
|
||||
]
|
||||
assert resume_chunk_indices
|
||||
resume_start_index = next(
|
||||
index
|
||||
for index, event in enumerate(result.events)
|
||||
if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"]
|
||||
)
|
||||
resume_success_index = next(
|
||||
index
|
||||
for index, event in enumerate(result.events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"]
|
||||
)
|
||||
assert resume_start_index < min(resume_chunk_indices)
|
||||
assert max(resume_chunk_indices) < resume_success_index
|
||||
|
||||
started_nodes = [event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)]
|
||||
assert started_nodes == ["start", "llm_initial", "if_else", scenario["resume_llm"], scenario["end_node"]]
|
||||
@@ -27,7 +27,8 @@ from .test_mock_nodes import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
|
||||
|
||||
@@ -42,7 +42,8 @@ def test_mock_iteration_node_preserves_config():
|
||||
"""Test that MockIterationNode preserves mock configuration."""
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode
|
||||
|
||||
@@ -103,7 +104,8 @@ def test_mock_loop_node_preserves_config():
|
||||
"""Test that MockLoopNode preserves mock configuration."""
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode
|
||||
|
||||
|
||||
@@ -24,7 +24,8 @@ from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
|
||||
@@ -561,10 +562,11 @@ class MockIterationNode(MockNodeMixin, IterationNode):
|
||||
def _create_graph_engine(self, index: int, item: Any):
|
||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
# Import our MockNodeFactory instead of DifyNodeFactory
|
||||
from .test_mock_factory import MockNodeFactory
|
||||
@@ -635,10 +637,11 @@ class MockLoopNode(MockNodeMixin, LoopNode):
|
||||
def _create_graph_engine(self, start_at, root_node_id: str):
|
||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
# Import our MockNodeFactory instead of DifyNodeFactory
|
||||
from .test_mock_factory import MockNodeFactory
|
||||
|
||||
@@ -16,8 +16,8 @@ class TestMockTemplateTransformNode:
|
||||
|
||||
def test_mock_template_transform_node_default_output(self):
|
||||
"""Test that MockTemplateTransformNode processes templates with Jinja2."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -76,8 +76,8 @@ class TestMockTemplateTransformNode:
|
||||
|
||||
def test_mock_template_transform_node_custom_output(self):
|
||||
"""Test that MockTemplateTransformNode returns custom configured output."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -137,8 +137,8 @@ class TestMockTemplateTransformNode:
|
||||
|
||||
def test_mock_template_transform_node_error_simulation(self):
|
||||
"""Test that MockTemplateTransformNode can simulate errors."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -196,8 +196,8 @@ class TestMockTemplateTransformNode:
|
||||
def test_mock_template_transform_node_with_variables(self):
|
||||
"""Test that MockTemplateTransformNode processes templates with variables."""
|
||||
from core.variables import StringVariable
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -262,8 +262,8 @@ class TestMockCodeNode:
|
||||
|
||||
def test_mock_code_node_default_output(self):
|
||||
"""Test that MockCodeNode returns default output."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -323,8 +323,8 @@ class TestMockCodeNode:
|
||||
|
||||
def test_mock_code_node_with_output_schema(self):
|
||||
"""Test that MockCodeNode generates outputs based on schema."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -392,8 +392,8 @@ class TestMockCodeNode:
|
||||
|
||||
def test_mock_code_node_custom_output(self):
|
||||
"""Test that MockCodeNode returns custom configured output."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -463,8 +463,8 @@ class TestMockNodeFactory:
|
||||
|
||||
def test_code_and_template_nodes_mocked_by_default(self):
|
||||
"""Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy)."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -504,8 +504,8 @@ class TestMockNodeFactory:
|
||||
|
||||
def test_factory_creates_mock_template_transform_node(self):
|
||||
"""Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -555,8 +555,8 @@ class TestMockNodeFactory:
|
||||
|
||||
def test_factory_creates_mock_code_node(self):
|
||||
"""Test that MockNodeFactory creates MockCodeNode for code type."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
|
||||
@@ -13,7 +13,7 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
@@ -27,6 +27,7 @@ from core.workflow.graph_events import (
|
||||
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import redis
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
|
||||
|
||||
@@ -52,6 +52,29 @@ class TestRedisStopIntegration:
|
||||
assert command_data["command_type"] == CommandType.ABORT
|
||||
assert command_data["reason"] == "Test stop"
|
||||
|
||||
def test_graph_engine_manager_sends_pause_command(self):
|
||||
"""Test that GraphEngineManager correctly sends pause command through Redis."""
|
||||
task_id = "test-task-pause-123"
|
||||
expected_channel_key = f"workflow:{task_id}:commands"
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_pipeline = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
|
||||
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
with patch("core.workflow.graph_engine.manager.redis_client", mock_redis):
|
||||
GraphEngineManager.send_pause_command(task_id, reason="Awaiting resources")
|
||||
|
||||
mock_redis.pipeline.assert_called_once()
|
||||
calls = mock_pipeline.rpush.call_args_list
|
||||
assert len(calls) == 1
|
||||
assert calls[0][0][0] == expected_channel_key
|
||||
|
||||
command_json = calls[0][0][1]
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == CommandType.PAUSE.value
|
||||
assert command_data["reason"] == "Awaiting resources"
|
||||
|
||||
def test_graph_engine_manager_handles_redis_failure_gracefully(self):
|
||||
"""Test that GraphEngineManager handles Redis failures without raising exceptions."""
|
||||
task_id = "test-task-456"
|
||||
@@ -105,28 +128,37 @@ class TestRedisStopIntegration:
|
||||
channel_key = "workflow:test:commands"
|
||||
channel = RedisChannel(mock_redis, channel_key)
|
||||
|
||||
# Create abort command
|
||||
# Create commands
|
||||
abort_command = AbortCommand(reason="User requested stop")
|
||||
pause_command = PauseCommand(reason="User requested pause")
|
||||
|
||||
# Execute
|
||||
channel.send_command(abort_command)
|
||||
channel.send_command(pause_command)
|
||||
|
||||
# Verify
|
||||
mock_redis.pipeline.assert_called_once()
|
||||
mock_redis.pipeline.assert_called()
|
||||
|
||||
# Check rpush was called
|
||||
calls = mock_pipeline.rpush.call_args_list
|
||||
assert len(calls) == 1
|
||||
assert len(calls) == 2
|
||||
assert calls[0][0][0] == channel_key
|
||||
assert calls[1][0][0] == channel_key
|
||||
|
||||
# Verify serialized command
|
||||
command_json = calls[0][0][1]
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == CommandType.ABORT
|
||||
assert command_data["reason"] == "User requested stop"
|
||||
# Verify serialized commands
|
||||
abort_command_json = calls[0][0][1]
|
||||
abort_command_data = json.loads(abort_command_json)
|
||||
assert abort_command_data["command_type"] == CommandType.ABORT.value
|
||||
assert abort_command_data["reason"] == "User requested stop"
|
||||
|
||||
# Check expire was set
|
||||
mock_pipeline.expire.assert_called_once_with(channel_key, 3600)
|
||||
pause_command_json = calls[1][0][1]
|
||||
pause_command_data = json.loads(pause_command_json)
|
||||
assert pause_command_data["command_type"] == CommandType.PAUSE.value
|
||||
assert pause_command_data["reason"] == "User requested pause"
|
||||
|
||||
# Check expire was set for each
|
||||
assert mock_pipeline.expire.call_count == 2
|
||||
mock_pipeline.expire.assert_any_call(channel_key, 3600)
|
||||
|
||||
def test_redis_channel_fetch_commands(self):
|
||||
"""Test RedisChannel correctly fetches and deserializes commands."""
|
||||
@@ -143,12 +175,17 @@ class TestRedisStopIntegration:
|
||||
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
|
||||
|
||||
# Mock command data
|
||||
abort_command_json = json.dumps({"command_type": CommandType.ABORT, "reason": "Test abort", "payload": None})
|
||||
abort_command_json = json.dumps(
|
||||
{"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None}
|
||||
)
|
||||
pause_command_json = json.dumps(
|
||||
{"command_type": CommandType.PAUSE.value, "reason": "Pause requested", "payload": None}
|
||||
)
|
||||
|
||||
# Mock pipeline execute to return commands
|
||||
pending_pipe.execute.return_value = [b"1", 1]
|
||||
fetch_pipe.execute.return_value = [
|
||||
[abort_command_json.encode()], # lrange result
|
||||
[abort_command_json.encode(), pause_command_json.encode()], # lrange result
|
||||
True, # delete result
|
||||
]
|
||||
|
||||
@@ -159,10 +196,13 @@ class TestRedisStopIntegration:
|
||||
commands = channel.fetch_commands()
|
||||
|
||||
# Verify
|
||||
assert len(commands) == 1
|
||||
assert len(commands) == 2
|
||||
assert isinstance(commands[0], AbortCommand)
|
||||
assert commands[0].command_type == CommandType.ABORT
|
||||
assert commands[0].reason == "Test abort"
|
||||
assert isinstance(commands[1], PauseCommand)
|
||||
assert commands[1].command_type == CommandType.PAUSE
|
||||
assert commands[1].reason == "Pause requested"
|
||||
|
||||
# Verify Redis operations
|
||||
pending_pipe.get.assert_called_once_with(f"{channel_key}:pending")
|
||||
|
||||
@@ -29,7 +29,6 @@ from core.variables import (
|
||||
ObjectVariable,
|
||||
StringVariable,
|
||||
)
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
@@ -40,6 +39,7 @@ from core.workflow.graph_events import (
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
@@ -52,8 +52,8 @@ logger = logging.getLogger(__name__)
|
||||
class WorkflowTestCase:
|
||||
"""Represents a single test case for table-driven testing."""
|
||||
|
||||
fixture_path: str
|
||||
expected_outputs: dict[str, Any]
|
||||
fixture_path: str = ""
|
||||
expected_outputs: dict[str, Any] = field(default_factory=dict)
|
||||
inputs: dict[str, Any] = field(default_factory=dict)
|
||||
query: str = ""
|
||||
description: str = ""
|
||||
@@ -61,11 +61,7 @@ class WorkflowTestCase:
|
||||
mock_config: MockConfig | None = None
|
||||
use_auto_mock: bool = False
|
||||
expected_event_sequence: Sequence[type[GraphEngineEvent]] | None = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
skip: bool = False
|
||||
skip_reason: str = ""
|
||||
retry_count: int = 0
|
||||
custom_validator: Callable[[dict[str, Any]], bool] | None = None
|
||||
graph_factory: Callable[[], tuple[Graph, GraphRuntimeState]] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -80,7 +76,8 @@ class WorkflowTestResult:
|
||||
event_sequence_match: bool | None = None
|
||||
event_mismatch_details: str | None = None
|
||||
events: list[GraphEngineEvent] = field(default_factory=list)
|
||||
retry_attempts: int = 0
|
||||
graph: Graph | None = None
|
||||
graph_runtime_state: GraphRuntimeState | None = None
|
||||
validation_details: str | None = None
|
||||
|
||||
|
||||
@@ -91,7 +88,6 @@ class TestSuiteResult:
|
||||
total_tests: int
|
||||
passed_tests: int
|
||||
failed_tests: int
|
||||
skipped_tests: int
|
||||
total_execution_time: float
|
||||
results: list[WorkflowTestResult]
|
||||
|
||||
@@ -106,10 +102,6 @@ class TestSuiteResult:
|
||||
"""Get all failed test results."""
|
||||
return [r for r in self.results if not r.success]
|
||||
|
||||
def get_results_by_tag(self, tag: str) -> list[WorkflowTestResult]:
|
||||
"""Get test results filtered by tag."""
|
||||
return [r for r in self.results if tag in r.test_case.tags]
|
||||
|
||||
|
||||
class WorkflowRunner:
|
||||
"""Core workflow execution engine for tests."""
|
||||
@@ -286,90 +278,30 @@ class TableTestRunner:
|
||||
Returns:
|
||||
WorkflowTestResult with execution details
|
||||
"""
|
||||
if test_case.skip:
|
||||
self.logger.info("Skipping test: %s - %s", test_case.description, test_case.skip_reason)
|
||||
return WorkflowTestResult(
|
||||
test_case=test_case,
|
||||
success=True,
|
||||
execution_time=0.0,
|
||||
validation_details=f"Skipped: {test_case.skip_reason}",
|
||||
)
|
||||
|
||||
retry_attempts = 0
|
||||
last_result = None
|
||||
last_error = None
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for attempt in range(test_case.retry_count + 1):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
result = self._execute_test_case(test_case)
|
||||
last_result = result # Save the last result
|
||||
|
||||
if result.success:
|
||||
result.retry_attempts = retry_attempts
|
||||
self.logger.info("Test passed: %s", test_case.description)
|
||||
return result
|
||||
|
||||
last_error = result.error
|
||||
retry_attempts += 1
|
||||
|
||||
if attempt < test_case.retry_count:
|
||||
self.logger.warning(
|
||||
"Test failed (attempt %d/%d): %s",
|
||||
attempt + 1,
|
||||
test_case.retry_count + 1,
|
||||
test_case.description,
|
||||
)
|
||||
time.sleep(0.5 * (attempt + 1)) # Exponential backoff
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
retry_attempts += 1
|
||||
|
||||
if attempt < test_case.retry_count:
|
||||
self.logger.warning(
|
||||
"Test error (attempt %d/%d): %s - %s",
|
||||
attempt + 1,
|
||||
test_case.retry_count + 1,
|
||||
test_case.description,
|
||||
str(e),
|
||||
)
|
||||
time.sleep(0.5 * (attempt + 1))
|
||||
|
||||
# All retries failed - return the last result if available
|
||||
if last_result:
|
||||
last_result.retry_attempts = retry_attempts
|
||||
self.logger.error("Test failed after %d attempts: %s", retry_attempts, test_case.description)
|
||||
return last_result
|
||||
|
||||
# If no result available (all attempts threw exceptions), create a failure result
|
||||
self.logger.error("Test failed after %d attempts: %s", retry_attempts, test_case.description)
|
||||
return WorkflowTestResult(
|
||||
test_case=test_case,
|
||||
success=False,
|
||||
error=last_error,
|
||||
execution_time=time.perf_counter() - start_time,
|
||||
retry_attempts=retry_attempts,
|
||||
)
|
||||
try:
|
||||
result = self._execute_test_case(test_case)
|
||||
if result.success:
|
||||
self.logger.info("Test passed: %s", test_case.description)
|
||||
else:
|
||||
self.logger.error("Test failed: %s", test_case.description)
|
||||
return result
|
||||
except Exception as exc:
|
||||
self.logger.exception("Error executing test case: %s", test_case.description)
|
||||
return WorkflowTestResult(
|
||||
test_case=test_case,
|
||||
success=False,
|
||||
error=exc,
|
||||
execution_time=time.perf_counter() - start_time,
|
||||
)
|
||||
|
||||
def _execute_test_case(self, test_case: WorkflowTestCase) -> WorkflowTestResult:
|
||||
"""Internal method to execute a single test case."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
# Load fixture data
|
||||
fixture_data = self.workflow_runner.load_fixture(test_case.fixture_path)
|
||||
|
||||
# Create graph from fixture
|
||||
graph, graph_runtime_state = self.workflow_runner.create_graph_from_fixture(
|
||||
fixture_data=fixture_data,
|
||||
inputs=test_case.inputs,
|
||||
query=test_case.query,
|
||||
use_mock_factory=test_case.use_auto_mock,
|
||||
mock_config=test_case.mock_config,
|
||||
)
|
||||
graph, graph_runtime_state = self._create_graph_runtime_state(test_case)
|
||||
|
||||
# Create and run the engine with configured worker settings
|
||||
engine = GraphEngine(
|
||||
@@ -384,7 +316,7 @@ class TableTestRunner:
|
||||
)
|
||||
|
||||
# Execute and collect events
|
||||
events = []
|
||||
events: list[GraphEngineEvent] = []
|
||||
for event in engine.run():
|
||||
events.append(event)
|
||||
|
||||
@@ -416,6 +348,8 @@ class TableTestRunner:
|
||||
events=events,
|
||||
event_sequence_match=event_sequence_match,
|
||||
event_mismatch_details=event_mismatch_details,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Get actual outputs
|
||||
@@ -423,9 +357,7 @@ class TableTestRunner:
|
||||
actual_outputs = success_event.outputs or {}
|
||||
|
||||
# Validate outputs
|
||||
output_success, validation_details = self._validate_outputs(
|
||||
test_case.expected_outputs, actual_outputs, test_case.custom_validator
|
||||
)
|
||||
output_success, validation_details = self._validate_outputs(test_case.expected_outputs, actual_outputs)
|
||||
|
||||
# Overall success requires both output and event sequence validation
|
||||
success = output_success and (event_sequence_match if event_sequence_match is not None else True)
|
||||
@@ -440,6 +372,8 @@ class TableTestRunner:
|
||||
events=events,
|
||||
validation_details=validation_details,
|
||||
error=None if success else Exception(validation_details or event_mismatch_details or "Test failed"),
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -449,13 +383,33 @@ class TableTestRunner:
|
||||
success=False,
|
||||
error=e,
|
||||
execution_time=time.perf_counter() - start_time,
|
||||
graph=graph if "graph" in locals() else None,
|
||||
graph_runtime_state=graph_runtime_state if "graph_runtime_state" in locals() else None,
|
||||
)
|
||||
|
||||
def _create_graph_runtime_state(self, test_case: WorkflowTestCase) -> tuple[Graph, GraphRuntimeState]:
|
||||
"""Create or retrieve graph/runtime state according to test configuration."""
|
||||
|
||||
if test_case.graph_factory is not None:
|
||||
return test_case.graph_factory()
|
||||
|
||||
if not test_case.fixture_path:
|
||||
raise ValueError("fixture_path must be provided when graph_factory is not specified")
|
||||
|
||||
fixture_data = self.workflow_runner.load_fixture(test_case.fixture_path)
|
||||
|
||||
return self.workflow_runner.create_graph_from_fixture(
|
||||
fixture_data=fixture_data,
|
||||
inputs=test_case.inputs,
|
||||
query=test_case.query,
|
||||
use_mock_factory=test_case.use_auto_mock,
|
||||
mock_config=test_case.mock_config,
|
||||
)
|
||||
|
||||
def _validate_outputs(
|
||||
self,
|
||||
expected_outputs: dict[str, Any],
|
||||
actual_outputs: dict[str, Any],
|
||||
custom_validator: Callable[[dict[str, Any]], bool] | None = None,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Validate actual outputs against expected outputs.
|
||||
@@ -490,14 +444,6 @@ class TableTestRunner:
|
||||
f"Value mismatch for key '{key}':\n Expected: {expected_value}\n Actual: {actual_value}"
|
||||
)
|
||||
|
||||
# Apply custom validator if provided
|
||||
if custom_validator:
|
||||
try:
|
||||
if not custom_validator(actual_outputs):
|
||||
validation_errors.append("Custom validator failed")
|
||||
except Exception as e:
|
||||
validation_errors.append(f"Custom validator error: {str(e)}")
|
||||
|
||||
if validation_errors:
|
||||
return False, "\n".join(validation_errors)
|
||||
|
||||
@@ -537,7 +483,6 @@ class TableTestRunner:
|
||||
self,
|
||||
test_cases: list[WorkflowTestCase],
|
||||
parallel: bool = False,
|
||||
tags_filter: list[str] | None = None,
|
||||
fail_fast: bool = False,
|
||||
) -> TestSuiteResult:
|
||||
"""
|
||||
@@ -546,22 +491,16 @@ class TableTestRunner:
|
||||
Args:
|
||||
test_cases: List of test cases to execute
|
||||
parallel: Run tests in parallel
|
||||
tags_filter: Only run tests with specified tags
|
||||
fail_fast: Stop execution on first failure
|
||||
fail_fast: Stop execution on first failure
|
||||
|
||||
Returns:
|
||||
TestSuiteResult with aggregated results
|
||||
"""
|
||||
# Filter by tags if specified
|
||||
if tags_filter:
|
||||
test_cases = [tc for tc in test_cases if any(tag in tc.tags for tag in tags_filter)]
|
||||
|
||||
if not test_cases:
|
||||
return TestSuiteResult(
|
||||
total_tests=0,
|
||||
passed_tests=0,
|
||||
failed_tests=0,
|
||||
skipped_tests=0,
|
||||
total_execution_time=0.0,
|
||||
results=[],
|
||||
)
|
||||
@@ -576,16 +515,14 @@ class TableTestRunner:
|
||||
|
||||
# Calculate statistics
|
||||
total_tests = len(results)
|
||||
passed_tests = sum(1 for r in results if r.success and not r.test_case.skip)
|
||||
failed_tests = sum(1 for r in results if not r.success and not r.test_case.skip)
|
||||
skipped_tests = sum(1 for r in results if r.test_case.skip)
|
||||
passed_tests = sum(1 for r in results if r.success)
|
||||
failed_tests = total_tests - passed_tests
|
||||
total_execution_time = time.perf_counter() - start_time
|
||||
|
||||
return TestSuiteResult(
|
||||
total_tests=total_tests,
|
||||
passed_tests=passed_tests,
|
||||
failed_tests=failed_tests,
|
||||
skipped_tests=skipped_tests,
|
||||
total_execution_time=total_execution_time,
|
||||
results=results,
|
||||
)
|
||||
@@ -598,7 +535,7 @@ class TableTestRunner:
|
||||
result = self.run_test_case(test_case)
|
||||
results.append(result)
|
||||
|
||||
if fail_fast and not result.success and not result.test_case.skip:
|
||||
if fail_fast and not result.success:
|
||||
self.logger.info("Fail-fast enabled: stopping execution")
|
||||
break
|
||||
|
||||
@@ -618,11 +555,11 @@ class TableTestRunner:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
|
||||
if fail_fast and not result.success and not result.test_case.skip:
|
||||
if fail_fast and not result.success:
|
||||
self.logger.info("Fail-fast enabled: cancelling remaining tests")
|
||||
# Cancel remaining futures
|
||||
for f in future_to_test:
|
||||
f.cancel()
|
||||
for remaining_future in future_to_test:
|
||||
if not remaining_future.done():
|
||||
remaining_future.cancel()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
@@ -636,8 +573,9 @@ class TableTestRunner:
|
||||
)
|
||||
|
||||
if fail_fast:
|
||||
for f in future_to_test:
|
||||
f.cancel()
|
||||
for remaining_future in future_to_test:
|
||||
if not remaining_future.done():
|
||||
remaining_future.cancel()
|
||||
break
|
||||
|
||||
return results
|
||||
@@ -663,7 +601,6 @@ class TableTestRunner:
|
||||
report.append(f" Total Tests: {suite_result.total_tests}")
|
||||
report.append(f" Passed: {suite_result.passed_tests}")
|
||||
report.append(f" Failed: {suite_result.failed_tests}")
|
||||
report.append(f" Skipped: {suite_result.skipped_tests}")
|
||||
report.append(f" Success Rate: {suite_result.success_rate:.1f}%")
|
||||
report.append(f" Total Time: {suite_result.total_execution_time:.2f}s")
|
||||
report.append("")
|
||||
|
||||
@@ -3,11 +3,12 @@ import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.nodes.http_request import (
|
||||
BodyData,
|
||||
HttpRequestNodeAuthorization,
|
||||
@@ -7,6 +6,7 @@ from core.workflow.nodes.http_request import (
|
||||
)
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
@@ -32,6 +32,7 @@ from core.workflow.nodes.llm.entities import (
|
||||
)
|
||||
from core.workflow.nodes.llm.file_saver import LLMFileSaver
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.provider import ProviderType
|
||||
|
||||
@@ -7,12 +7,13 @@ import pytest
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -6,11 +6,12 @@ from uuid import uuid4
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import ArrayStringVariable, StringVariable
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@@ -4,11 +4,12 @@ from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import ArrayStringVariable
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from core.variables.variables import (
|
||||
VariableUnion,
|
||||
)
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from factories.variable_factory import build_segment, segment_to_variable
|
||||
|
||||
|
||||
@@ -1,476 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities import (
|
||||
WorkflowExecution,
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from core.workflow.enums import (
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_app_generate_entity():
|
||||
additional_features = AppAdditionalFeatures(
|
||||
file_upload=None,
|
||||
opening_statement=None,
|
||||
suggested_questions=[],
|
||||
suggested_questions_after_answer=False,
|
||||
show_retrieve_source=False,
|
||||
more_like_this=False,
|
||||
speech_to_text=False,
|
||||
text_to_speech=None,
|
||||
trace_config=None,
|
||||
)
|
||||
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="test-tenant-id",
|
||||
app_id="test-app-id",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
additional_features=additional_features,
|
||||
workflow_id="test-workflow-id",
|
||||
)
|
||||
|
||||
entity = AdvancedChatAppGenerateEntity(
|
||||
task_id="test-task-id",
|
||||
app_config=app_config,
|
||||
inputs={"query": "test query"},
|
||||
files=[],
|
||||
user_id="test-user-id",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
query="test query",
|
||||
conversation_id="test-conversation-id",
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_workflow_system_variables():
|
||||
return SystemVariable(
|
||||
query="test query",
|
||||
conversation_id="test-conversation-id",
|
||||
user_id="test-user-id",
|
||||
app_id="test-app-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-run-id",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_node_execution_repository():
|
||||
repo = MagicMock(spec=WorkflowNodeExecutionRepository)
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_execution_repository():
|
||||
repo = MagicMock(spec=WorkflowExecutionRepository)
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_workflow_entity():
|
||||
return CycleManagerWorkflowInfo(
|
||||
workflow_id="test-workflow-id", # Matches ID used in other fixtures
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
version="1.0.0",
|
||||
graph_data={
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"type": "chat", # NodeType is a string enum
|
||||
"name": "Chat Node",
|
||||
"data": {"model": "gpt-3.5-turbo", "prompt": "test prompt"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_cycle_manager(
|
||||
real_app_generate_entity,
|
||||
real_workflow_system_variables,
|
||||
mock_workflow_execution_repository,
|
||||
mock_node_execution_repository,
|
||||
real_workflow_entity,
|
||||
):
|
||||
return WorkflowCycleManager(
|
||||
application_generate_entity=real_app_generate_entity,
|
||||
workflow_system_variables=real_workflow_system_variables,
|
||||
workflow_info=real_workflow_entity,
|
||||
workflow_execution_repository=mock_workflow_execution_repository,
|
||||
workflow_node_execution_repository=mock_node_execution_repository,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
session = MagicMock(spec=Session)
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_workflow():
|
||||
workflow = Workflow()
|
||||
workflow.id = "test-workflow-id"
|
||||
workflow.tenant_id = "test-tenant-id"
|
||||
workflow.app_id = "test-app-id"
|
||||
workflow.type = "chat"
|
||||
workflow.version = "1.0"
|
||||
|
||||
graph_data = {"nodes": [], "edges": []}
|
||||
workflow.graph = json.dumps(graph_data)
|
||||
workflow.features = json.dumps({"file_upload": {"enabled": False}})
|
||||
workflow.created_by = "test-user-id"
|
||||
workflow.created_at = naive_utc_now()
|
||||
workflow.updated_at = naive_utc_now()
|
||||
workflow._environment_variables = "{}"
|
||||
workflow._conversation_variables = "{}"
|
||||
|
||||
return workflow
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_workflow_run():
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.id = "test-workflow-run-id"
|
||||
workflow_run.tenant_id = "test-tenant-id"
|
||||
workflow_run.app_id = "test-app-id"
|
||||
workflow_run.workflow_id = "test-workflow-id"
|
||||
workflow_run.type = "chat"
|
||||
workflow_run.triggered_from = "app-run"
|
||||
workflow_run.version = "1.0"
|
||||
workflow_run.graph = json.dumps({"nodes": [], "edges": []})
|
||||
workflow_run.inputs = json.dumps({"query": "test query"})
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
workflow_run.outputs = json.dumps({"answer": "test answer"})
|
||||
workflow_run.created_by_role = CreatorUserRole.ACCOUNT
|
||||
workflow_run.created_by = "test-user-id"
|
||||
workflow_run.created_at = naive_utc_now()
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
||||
def test_init(
|
||||
workflow_cycle_manager,
|
||||
real_app_generate_entity,
|
||||
real_workflow_system_variables,
|
||||
mock_workflow_execution_repository,
|
||||
mock_node_execution_repository,
|
||||
):
|
||||
"""Test initialization of WorkflowCycleManager"""
|
||||
assert workflow_cycle_manager._application_generate_entity == real_app_generate_entity
|
||||
assert workflow_cycle_manager._workflow_system_variables == real_workflow_system_variables
|
||||
assert workflow_cycle_manager._workflow_execution_repository == mock_workflow_execution_repository
|
||||
assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository
|
||||
|
||||
|
||||
def test_handle_workflow_run_start(workflow_cycle_manager):
|
||||
"""Test handle_workflow_run_start method"""
|
||||
# Call the method
|
||||
workflow_execution = workflow_cycle_manager.handle_workflow_run_start()
|
||||
|
||||
# Verify the result
|
||||
assert workflow_execution.workflow_id == "test-workflow-id"
|
||||
|
||||
# Verify the workflow_execution_repository.save was called
|
||||
workflow_cycle_manager._workflow_execution_repository.save.assert_called_once_with(workflow_execution)
|
||||
|
||||
|
||||
def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execution_repository):
|
||||
"""Test handle_workflow_run_success method"""
|
||||
# Create a real WorkflowExecution
|
||||
|
||||
workflow_execution = WorkflowExecution(
|
||||
id_="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_run_success(
|
||||
workflow_run_id="test-workflow-run-id",
|
||||
total_tokens=100,
|
||||
total_steps=5,
|
||||
outputs={"answer": "test answer"},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == workflow_execution
|
||||
assert result.status == WorkflowExecutionStatus.SUCCEEDED
|
||||
assert result.outputs == {"answer": "test answer"}
|
||||
assert result.total_tokens == 100
|
||||
assert result.total_steps == 5
|
||||
assert result.finished_at is not None
|
||||
|
||||
|
||||
def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execution_repository):
|
||||
"""Test handle_workflow_run_failed method"""
|
||||
# Create a real WorkflowExecution
|
||||
|
||||
workflow_execution = WorkflowExecution(
|
||||
id_="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||
|
||||
# No running node executions in cache (empty cache)
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_run_failed(
|
||||
workflow_run_id="test-workflow-run-id",
|
||||
total_tokens=50,
|
||||
total_steps=3,
|
||||
status=WorkflowExecutionStatus.FAILED,
|
||||
error_message="Test error message",
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == workflow_execution
|
||||
assert result.status == WorkflowExecutionStatus.FAILED
|
||||
assert result.error_message == "Test error message"
|
||||
assert result.total_tokens == 50
|
||||
assert result.total_steps == 3
|
||||
assert result.finished_at is not None
|
||||
|
||||
|
||||
def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execution_repository):
|
||||
"""Test handle_node_execution_start method"""
|
||||
# Create a real WorkflowExecution
|
||||
|
||||
workflow_execution = WorkflowExecution(
|
||||
id_="test-workflow-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||
|
||||
# Create a mock event
|
||||
event = MagicMock(spec=QueueNodeStartedEvent)
|
||||
event.node_execution_id = "test-node-execution-id"
|
||||
event.node_id = "test-node-id"
|
||||
event.node_type = NodeType.LLM
|
||||
event.node_title = "Test Node"
|
||||
event.predecessor_node_id = "test-predecessor-node-id"
|
||||
event.node_run_index = 1
|
||||
event.parallel_mode_run_id = "test-parallel-mode-run-id"
|
||||
event.in_iteration_id = "test-iteration-id"
|
||||
event.in_loop_id = "test-loop-id"
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_node_execution_start(
|
||||
workflow_execution_id=workflow_execution.id_,
|
||||
event=event,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result.workflow_id == workflow_execution.workflow_id
|
||||
assert result.workflow_execution_id == workflow_execution.id_
|
||||
assert result.node_execution_id == event.node_execution_id
|
||||
assert result.node_id == event.node_id
|
||||
assert result.node_type == event.node_type
|
||||
assert result.title == event.node_title
|
||||
assert result.status == WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
# Verify save was called
|
||||
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result)
|
||||
|
||||
|
||||
def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_workflow_execution_repository):
|
||||
"""Test _get_workflow_execution_or_raise_error method"""
|
||||
# Create a real WorkflowExecution
|
||||
|
||||
workflow_execution = WorkflowExecution(
|
||||
id_="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache["test-workflow-run-id"] = workflow_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id")
|
||||
|
||||
# Verify the result
|
||||
assert result == workflow_execution
|
||||
|
||||
# Test error case - clear cache
|
||||
workflow_cycle_manager._workflow_execution_cache.clear()
|
||||
|
||||
# Expect an error when execution is not found
|
||||
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
|
||||
|
||||
with pytest.raises(WorkflowRunNotFoundError):
|
||||
workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id")
|
||||
|
||||
|
||||
def test_handle_workflow_node_execution_success(workflow_cycle_manager):
|
||||
"""Test handle_workflow_node_execution_success method"""
|
||||
# Create a mock event
|
||||
event = MagicMock(spec=QueueNodeSucceededEvent)
|
||||
event.node_execution_id = "test-node-execution-id"
|
||||
event.inputs = {"input": "test input"}
|
||||
event.process_data = {"process": "test process"}
|
||||
event.outputs = {"output": "test output"}
|
||||
event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100}
|
||||
event.start_at = naive_utc_now()
|
||||
|
||||
# Create a real node execution
|
||||
|
||||
node_execution = WorkflowNodeExecution(
|
||||
id="test-node-execution-record-id",
|
||||
node_execution_id="test-node-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Pre-populate the cache with the node execution
|
||||
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_node_execution_success(
|
||||
event=event,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == node_execution
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
# Verify save was called
|
||||
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution)
|
||||
|
||||
|
||||
def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workflow_execution_repository):
|
||||
"""Test handle_workflow_run_partial_success method"""
|
||||
# Create a real WorkflowExecution
|
||||
|
||||
workflow_execution = WorkflowExecution(
|
||||
id_="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_run_partial_success(
|
||||
workflow_run_id="test-workflow-run-id",
|
||||
total_tokens=75,
|
||||
total_steps=4,
|
||||
outputs={"partial_answer": "test partial answer"},
|
||||
exceptions_count=2,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == workflow_execution
|
||||
assert result.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED
|
||||
assert result.outputs == {"partial_answer": "test partial answer"}
|
||||
assert result.total_tokens == 75
|
||||
assert result.total_steps == 4
|
||||
assert result.exceptions_count == 2
|
||||
assert result.finished_at is not None
|
||||
|
||||
|
||||
def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
|
||||
"""Test handle_workflow_node_execution_failed method"""
|
||||
# Create a mock event
|
||||
event = MagicMock(spec=QueueNodeFailedEvent)
|
||||
event.node_execution_id = "test-node-execution-id"
|
||||
event.inputs = {"input": "test input"}
|
||||
event.process_data = {"process": "test process"}
|
||||
event.outputs = {"output": "test output"}
|
||||
event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100}
|
||||
event.start_at = naive_utc_now()
|
||||
event.error = "Test error message"
|
||||
|
||||
# Create a real node execution
|
||||
|
||||
node_execution = WorkflowNodeExecution(
|
||||
id="test-node-execution-record-id",
|
||||
node_execution_id="test-node-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Pre-populate the cache with the node execution
|
||||
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_node_execution_failed(
|
||||
event=event,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == node_execution
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == "Test error message"
|
||||
|
||||
# Verify save was called
|
||||
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution)
|
||||
@@ -7,7 +7,7 @@ from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
Reference in New Issue
Block a user