feat(graph_engine): Support pausing workflow graph executions (#26585)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-10-19 21:33:41 +08:00
committed by GitHub
parent 9a5f214623
commit 578247ffbc
112 changed files with 3766 additions and 2415 deletions

View File

@@ -5,12 +5,13 @@ import pytest
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock

View File

@@ -5,10 +5,11 @@ from urllib.parse import urlencode
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
@@ -174,13 +175,13 @@ def test_custom_authorization_header(setup_http_mock):
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
HttpRequestNodeTimeout,
)
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
# Create variable pool

View File

@@ -6,12 +6,13 @@ from unittest.mock import MagicMock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.output_parser.structured_output import _parse_structured_output
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import StreamCompletedEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom

View File

@@ -5,11 +5,12 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities import AssistantPromptMessage
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom

View File

@@ -4,11 +4,12 @@ import uuid
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock

View File

@@ -4,12 +4,13 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import StreamCompletedEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom

View File

@@ -99,6 +99,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
workflow=mock_workflow,
system_user_id=str(uuid4()),
app=MagicMock(),
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
# Mock database session
@@ -237,6 +239,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
workflow=mock_workflow,
system_user_id=str(uuid4()),
app=MagicMock(),
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
# Mock database session
@@ -390,6 +394,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
workflow=mock_workflow,
system_user_id=str(uuid4()),
app=MagicMock(),
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
# Mock database session

View File

@@ -0,0 +1,63 @@
from types import SimpleNamespace
import pytest
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.workflow.runtime import GraphRuntimeState
from core.workflow.runtime.variable_pool import VariablePool
from core.workflow.system_variable import SystemVariable
def _make_state(workflow_run_id: str | None) -> GraphRuntimeState:
variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=workflow_run_id))
return GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
class _StubPipeline(GraphRuntimeStateSupport):
def __init__(self, *, cached_state: GraphRuntimeState | None, queue_state: GraphRuntimeState | None):
self._graph_runtime_state = cached_state
self._base_task_pipeline = SimpleNamespace(queue_manager=SimpleNamespace(graph_runtime_state=queue_state))
def test_ensure_graph_runtime_initialized_caches_explicit_state():
explicit_state = _make_state("run-explicit")
pipeline = _StubPipeline(cached_state=None, queue_state=None)
resolved = pipeline._ensure_graph_runtime_initialized(explicit_state)
assert resolved is explicit_state
assert pipeline._graph_runtime_state is explicit_state
def test_resolve_graph_runtime_state_reads_from_queue_when_cache_empty():
queued_state = _make_state("run-queue")
pipeline = _StubPipeline(cached_state=None, queue_state=queued_state)
resolved = pipeline._resolve_graph_runtime_state()
assert resolved is queued_state
assert pipeline._graph_runtime_state is queued_state
def test_resolve_graph_runtime_state_raises_when_no_state_available():
pipeline = _StubPipeline(cached_state=None, queue_state=None)
with pytest.raises(ValueError):
pipeline._resolve_graph_runtime_state()
def test_extract_workflow_run_id_returns_value():
state = _make_state("run-identifier")
pipeline = _StubPipeline(cached_state=state, queue_state=None)
run_id = pipeline._extract_workflow_run_id(state)
assert run_id == "run-identifier"
def test_extract_workflow_run_id_raises_when_missing():
state = _make_state(None)
pipeline = _StubPipeline(cached_state=state, queue_state=None)
with pytest.raises(ValueError):
pipeline._extract_workflow_run_id(state)

View File

@@ -3,8 +3,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun
"""
import uuid
from dataclasses import dataclass
from datetime import datetime
from collections.abc import Mapping
from typing import Any
from unittest.mock import Mock
@@ -12,24 +11,17 @@ import pytest
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
from core.app.entities.queue_entities import QueueNodeRetryEvent, QueueNodeSucceededEvent
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.app.entities.queue_entities import (
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.workflow.enums import NodeType
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from models import Account
@dataclass
class ProcessDataResponseScenario:
"""Test scenario for process_data in responses."""
name: str
original_process_data: dict[str, Any] | None
truncated_process_data: dict[str, Any] | None
expected_response_data: dict[str, Any] | None
expected_truncated_flag: bool
class TestWorkflowResponseConverterCenarios:
"""Test process_data truncation in WorkflowResponseConverter."""
@@ -39,6 +31,7 @@ class TestWorkflowResponseConverterCenarios:
mock_app_config = Mock()
mock_app_config.tenant_id = "test-tenant-id"
mock_entity.app_config = mock_app_config
mock_entity.inputs = {}
return mock_entity
def create_workflow_response_converter(self) -> WorkflowResponseConverter:
@@ -50,54 +43,59 @@ class TestWorkflowResponseConverterCenarios:
mock_user.name = "Test User"
mock_user.email = "test@example.com"
return WorkflowResponseConverter(application_generate_entity=mock_entity, user=mock_user)
def create_workflow_node_execution(
self,
process_data: dict[str, Any] | None = None,
truncated_process_data: dict[str, Any] | None = None,
execution_id: str = "test-execution-id",
) -> WorkflowNodeExecution:
"""Create a WorkflowNodeExecution for testing."""
execution = WorkflowNodeExecution(
id=execution_id,
workflow_id="test-workflow-id",
workflow_execution_id="test-run-id",
index=1,
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
process_data=process_data,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_at=datetime.now(),
finished_at=datetime.now(),
system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id")
return WorkflowResponseConverter(
application_generate_entity=mock_entity,
user=mock_user,
system_variables=system_variables,
)
if truncated_process_data is not None:
execution.set_truncated_process_data(truncated_process_data)
def create_node_started_event(self, *, node_execution_id: str | None = None) -> QueueNodeStartedEvent:
"""Create a QueueNodeStartedEvent for testing."""
return QueueNodeStartedEvent(
node_execution_id=node_execution_id or str(uuid.uuid4()),
node_id="test-node-id",
node_title="Test Node",
node_type=NodeType.CODE,
start_at=naive_utc_now(),
predecessor_node_id=None,
in_iteration_id=None,
in_loop_id=None,
provider_type="built-in",
provider_id="code",
)
return execution
def create_node_succeeded_event(self) -> QueueNodeSucceededEvent:
def create_node_succeeded_event(
self,
*,
node_execution_id: str,
process_data: Mapping[str, Any] | None = None,
) -> QueueNodeSucceededEvent:
"""Create a QueueNodeSucceededEvent for testing."""
return QueueNodeSucceededEvent(
node_id="test-node-id",
node_type=NodeType.CODE,
node_execution_id=str(uuid.uuid4()),
node_execution_id=node_execution_id,
start_at=naive_utc_now(),
parallel_id=None,
parallel_start_node_id=None,
parent_parallel_id=None,
parent_parallel_start_node_id=None,
in_iteration_id=None,
in_loop_id=None,
inputs={},
process_data=process_data or {},
outputs={},
execution_metadata={},
)
def create_node_retry_event(self) -> QueueNodeRetryEvent:
def create_node_retry_event(
self,
*,
node_execution_id: str,
process_data: Mapping[str, Any] | None = None,
) -> QueueNodeRetryEvent:
"""Create a QueueNodeRetryEvent for testing."""
return QueueNodeRetryEvent(
inputs={"data": "inputs"},
outputs={"data": "outputs"},
process_data=process_data or {},
error="oops",
retry_index=1,
node_id="test-node-id",
@@ -105,12 +103,8 @@ class TestWorkflowResponseConverterCenarios:
node_title="test code",
provider_type="built-in",
provider_id="code",
node_execution_id=str(uuid.uuid4()),
node_execution_id=node_execution_id,
start_at=naive_utc_now(),
parallel_id=None,
parallel_start_node_id=None,
parent_parallel_id=None,
parent_parallel_start_node_id=None,
in_iteration_id=None,
in_loop_id=None,
)
@@ -122,15 +116,28 @@ class TestWorkflowResponseConverterCenarios:
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
execution = self.create_workflow_node_execution(
process_data=original_data, truncated_process_data=truncated_data
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event()
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
if mapping == dict(original_data):
return truncated_data, True
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=execution,
)
# Response should use truncated data, not original
@@ -145,13 +152,26 @@ class TestWorkflowResponseConverterCenarios:
original_data = {"small": "data"}
execution = self.create_workflow_node_execution(process_data=original_data)
event = self.create_node_succeeded_event()
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=execution,
)
# Response should use original data
@@ -163,18 +183,31 @@ class TestWorkflowResponseConverterCenarios:
"""Test node finish response when process_data is None."""
converter = self.create_workflow_response_converter()
execution = self.create_workflow_node_execution(process_data=None)
event = self.create_node_succeeded_event()
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=None,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=execution,
)
# Response should have None process_data
# Response should normalize missing process_data to an empty mapping
assert response is not None
assert response.data.process_data is None
assert response.data.process_data == {}
assert response.data.process_data_truncated is False
def test_workflow_node_retry_response_uses_truncated_process_data(self):
@@ -184,15 +217,28 @@ class TestWorkflowResponseConverterCenarios:
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
execution = self.create_workflow_node_execution(
process_data=original_data, truncated_process_data=truncated_data
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_retry_event()
event = self.create_node_retry_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
if mapping == dict(original_data):
return truncated_data, True
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_retry_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=execution,
)
# Response should use truncated data, not original
@@ -207,224 +253,72 @@ class TestWorkflowResponseConverterCenarios:
original_data = {"small": "data"}
execution = self.create_workflow_node_execution(process_data=original_data)
event = self.create_node_retry_event()
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_retry_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_retry_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=execution,
)
# Response should use original data
assert response is not None
assert response.data.process_data == original_data
assert response.data.process_data_truncated is False
def test_iteration_and_loop_nodes_return_none(self):
"""Test that iteration and loop nodes return None (no change from existing behavior)."""
"""Test that iteration and loop nodes return None (no streaming events)."""
converter = self.create_workflow_response_converter()
# Test iteration node
iteration_execution = self.create_workflow_node_execution(process_data={"test": "data"})
iteration_execution.node_type = NodeType.ITERATION
event = self.create_node_succeeded_event()
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=iteration_execution,
)
# Should return None for iteration nodes
assert response is None
# Test loop node
loop_execution = self.create_workflow_node_execution(process_data={"test": "data"})
loop_execution.node_type = NodeType.LOOP
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=loop_execution,
)
# Should return None for loop nodes
assert response is None
def test_execution_without_workflow_execution_id_returns_none(self):
"""Test that executions without workflow_execution_id return None."""
converter = self.create_workflow_response_converter()
execution = self.create_workflow_node_execution(process_data={"test": "data"})
execution.workflow_execution_id = None # Single-step debugging
event = self.create_node_succeeded_event()
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=execution,
)
# Should return None for single-step debugging
assert response is None
@staticmethod
def get_process_data_response_scenarios() -> list[ProcessDataResponseScenario]:
"""Create test scenarios for process_data responses."""
return [
ProcessDataResponseScenario(
name="none_process_data",
original_process_data=None,
truncated_process_data=None,
expected_response_data=None,
expected_truncated_flag=False,
),
ProcessDataResponseScenario(
name="small_process_data_no_truncation",
original_process_data={"small": "data"},
truncated_process_data=None,
expected_response_data={"small": "data"},
expected_truncated_flag=False,
),
ProcessDataResponseScenario(
name="large_process_data_with_truncation",
original_process_data={"large": "x" * 10000, "metadata": "info"},
truncated_process_data={"large": "[TRUNCATED]", "metadata": "info"},
expected_response_data={"large": "[TRUNCATED]", "metadata": "info"},
expected_truncated_flag=True,
),
ProcessDataResponseScenario(
name="empty_process_data",
original_process_data={},
truncated_process_data=None,
expected_response_data={},
expected_truncated_flag=False,
),
ProcessDataResponseScenario(
name="complex_data_with_truncation",
original_process_data={
"logs": ["entry"] * 1000, # Large array
"config": {"setting": "value"},
"status": "processing",
},
truncated_process_data={
"logs": "[TRUNCATED: 1000 items]",
"config": {"setting": "value"},
"status": "processing",
},
expected_response_data={
"logs": "[TRUNCATED: 1000 items]",
"config": {"setting": "value"},
"status": "processing",
},
expected_truncated_flag=True,
),
]
@pytest.mark.parametrize(
"scenario",
get_process_data_response_scenarios(),
ids=[scenario.name for scenario in get_process_data_response_scenarios()],
)
def test_node_finish_response_scenarios(self, scenario: ProcessDataResponseScenario):
"""Test various scenarios for node finish responses."""
mock_user = Mock(spec=Account)
mock_user.id = "test-user-id"
mock_user.name = "Test User"
mock_user.email = "test@example.com"
converter = WorkflowResponseConverter(
application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant")),
user=mock_user,
)
execution = WorkflowNodeExecution(
id="test-execution-id",
workflow_id="test-workflow-id",
workflow_execution_id="test-run-id",
index=1,
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
process_data=scenario.original_process_data,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_at=datetime.now(),
finished_at=datetime.now(),
)
if scenario.truncated_process_data is not None:
execution.set_truncated_process_data(scenario.truncated_process_data)
event = QueueNodeSucceededEvent(
node_id="test-node-id",
node_type=NodeType.CODE,
iteration_event = QueueNodeSucceededEvent(
node_id="iteration-node",
node_type=NodeType.ITERATION,
node_execution_id=str(uuid.uuid4()),
start_at=naive_utc_now(),
parallel_id=None,
parallel_start_node_id=None,
parent_parallel_id=None,
parent_parallel_start_node_id=None,
in_iteration_id=None,
in_loop_id=None,
inputs={},
process_data={},
outputs={},
execution_metadata={},
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
event=iteration_event,
task_id="test-task-id",
workflow_node_execution=execution,
)
assert response is None
assert response is not None
assert response.data.process_data == scenario.expected_response_data
assert response.data.process_data_truncated == scenario.expected_truncated_flag
@pytest.mark.parametrize(
"scenario",
get_process_data_response_scenarios(),
ids=[scenario.name for scenario in get_process_data_response_scenarios()],
)
def test_node_retry_response_scenarios(self, scenario: ProcessDataResponseScenario):
"""Test various scenarios for node retry responses."""
mock_user = Mock(spec=Account)
mock_user.id = "test-user-id"
mock_user.name = "Test User"
mock_user.email = "test@example.com"
converter = WorkflowResponseConverter(
application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant")),
user=mock_user,
)
execution = WorkflowNodeExecution(
id="test-execution-id",
workflow_id="test-workflow-id",
workflow_execution_id="test-run-id",
index=1,
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
process_data=scenario.original_process_data,
status=WorkflowNodeExecutionStatus.FAILED, # Retry scenario
created_at=datetime.now(),
finished_at=datetime.now(),
)
if scenario.truncated_process_data is not None:
execution.set_truncated_process_data(scenario.truncated_process_data)
event = self.create_node_retry_event()
response = converter.workflow_node_retry_to_stream_response(
event=event,
loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP})
response = converter.workflow_node_finish_to_stream_response(
event=loop_event,
task_id="test-task-id",
workflow_node_execution=execution,
)
assert response is None
def test_finish_without_start_raises(self):
"""Ensure finish responses require a prior workflow start."""
converter = self.create_workflow_response_converter()
event = self.create_node_succeeded_event(
node_execution_id=str(uuid.uuid4()),
process_data={},
)
assert response is not None
assert response.data.process_data == scenario.expected_response_data
assert response.data.process_data_truncated == scenario.expected_truncated_flag
with pytest.raises(ValueError):
converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)

View File

@@ -37,7 +37,7 @@ from core.variables.variables import (
Variable,
VariableUnion,
)
from core.workflow.entities import VariablePool
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable

View File

@@ -1,9 +1,11 @@
import json
from time import time
from unittest.mock import MagicMock, patch
import pytest
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
class TestGraphRuntimeState:
@@ -95,3 +97,141 @@ class TestGraphRuntimeState:
# Test add_tokens validation
with pytest.raises(ValueError):
state.add_tokens(-1)
def test_ready_queue_default_instantiation(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
queue = state.ready_queue
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
assert isinstance(queue, InMemoryReadyQueue)
assert state.ready_queue is queue
def test_graph_execution_lazy_instantiation(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
execution = state.graph_execution
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
assert isinstance(execution, GraphExecution)
assert execution.workflow_id == ""
assert state.graph_execution is execution
def test_response_coordinator_configuration(self):
variable_pool = VariablePool()
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
with pytest.raises(ValueError):
_ = state.response_coordinator
mock_graph = MagicMock()
with patch("core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator") as coordinator_cls:
coordinator_instance = MagicMock()
coordinator_cls.return_value = coordinator_instance
state.configure(graph=mock_graph)
assert state.response_coordinator is coordinator_instance
coordinator_cls.assert_called_once_with(variable_pool=variable_pool, graph=mock_graph)
# Configure again with same graph should be idempotent
state.configure(graph=mock_graph)
other_graph = MagicMock()
with pytest.raises(ValueError):
state.attach_graph(other_graph)
def test_read_only_wrapper_exposes_additional_state(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
state.configure()
wrapper = ReadOnlyGraphRuntimeStateWrapper(state)
assert wrapper.ready_queue_size == 0
assert wrapper.exceptions_count == 0
def test_read_only_wrapper_serializes_runtime_state(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
state.total_tokens = 5
state.set_output("result", {"success": True})
state.ready_queue.put("node-1")
wrapper = ReadOnlyGraphRuntimeStateWrapper(state)
wrapper_snapshot = json.loads(wrapper.dumps())
state_snapshot = json.loads(state.dumps())
assert wrapper_snapshot == state_snapshot
def test_dumps_and_loads_roundtrip_with_response_coordinator(self):
variable_pool = VariablePool()
variable_pool.add(("node1", "value"), "payload")
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
state.total_tokens = 10
state.node_run_steps = 3
state.set_output("final", {"result": True})
usage = LLMUsage.from_metadata(
{
"prompt_tokens": 2,
"completion_tokens": 3,
"total_tokens": 5,
"total_price": "1.23",
"currency": "USD",
"latency": 0.5,
}
)
state.llm_usage = usage
state.ready_queue.put("node-A")
graph_execution = state.graph_execution
graph_execution.workflow_id = "wf-123"
graph_execution.exceptions_count = 4
graph_execution.started = True
class StubCoordinator:
def __init__(self) -> None:
self.state = "initial"
def dumps(self) -> str:
return json.dumps({"state": self.state})
def loads(self, data: str) -> None:
payload = json.loads(data)
self.state = payload["state"]
mock_graph = MagicMock()
stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub):
state.attach_graph(mock_graph)
stub.state = "configured"
snapshot = state.dumps()
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
restored.loads(snapshot)
assert restored.total_tokens == 10
assert restored.node_run_steps == 3
assert restored.get_output("final") == {"result": True}
assert restored.llm_usage.total_tokens == usage.total_tokens
assert restored.ready_queue.qsize() == 1
assert restored.ready_queue.get(timeout=0.01) == "node-A"
restored_segment = restored.variable_pool.get(("node1", "value"))
assert restored_segment is not None
assert restored_segment.value == "payload"
restored_execution = restored.graph_execution
assert restored_execution.workflow_id == "wf-123"
assert restored_execution.exceptions_count == 4
assert restored_execution.started is True
new_stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
restored.attach_graph(mock_graph)
assert new_stub.state == "configured"

View File

@@ -4,7 +4,7 @@ from core.variables.segments import (
NoneSegment,
StringSegment,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.runtime import VariablePool
class TestVariablePoolGetAndNestedAttribute:

View File

@@ -0,0 +1,59 @@
from unittest.mock import MagicMock
import pytest
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
from core.workflow.nodes.base.node import Node
def _make_node(node_id: str, node_type: NodeType = NodeType.START) -> Node:
node = MagicMock(spec=Node)
node.id = node_id
node.node_type = node_type
node.execution_type = None # attribute not used in builder path
return node
def test_graph_builder_creates_linear_graph():
builder = Graph.new()
root = _make_node("root", NodeType.START)
mid = _make_node("mid", NodeType.LLM)
end = _make_node("end", NodeType.END)
graph = builder.add_root(root).add_node(mid).add_node(end).build()
assert graph.root_node is root
assert graph.nodes == {"root": root, "mid": mid, "end": end}
assert len(graph.edges) == 2
first_edge = next(iter(graph.edges.values()))
assert first_edge.tail == "root"
assert first_edge.head == "mid"
assert graph.out_edges["mid"] == [edge_id for edge_id, edge in graph.edges.items() if edge.tail == "mid"]
def test_graph_builder_supports_custom_predecessor():
builder = Graph.new()
root = _make_node("root")
branch = _make_node("branch")
other = _make_node("other")
graph = builder.add_root(root).add_node(branch).add_node(other, from_node_id="root").build()
outgoing_root = graph.out_edges["root"]
assert len(outgoing_root) == 2
edge_targets = {graph.edges[eid].head for eid in outgoing_root}
assert edge_targets == {"branch", "other"}
def test_graph_builder_validates_usage():
builder = Graph.new()
node = _make_node("node")
with pytest.raises(ValueError, match="Root node"):
builder.add_node(node)
builder.add_root(node)
duplicate = _make_node("node")
with pytest.raises(ValueError, match="Duplicate"):
builder.add_node(duplicate)

View File

@@ -20,9 +20,6 @@ The TableTestRunner (`test_table_runner.py`) provides a robust table-driven test
- **Mock configuration** - Seamless integration with the auto-mock system
- **Performance metrics** - Track execution times and bottlenecks
- **Detailed error reporting** - Comprehensive failure diagnostics
- **Test tagging** - Organize and filter tests by tags
- **Retry mechanism** - Handle flaky tests gracefully
- **Custom validators** - Define custom validation logic
### Basic Usage
@@ -68,49 +65,6 @@ suite_result = runner.run_table_tests(
print(f"Success rate: {suite_result.success_rate:.1f}%")
```
#### Test Tagging and Filtering
```python
test_case = WorkflowTestCase(
fixture_path="workflow",
inputs={},
expected_outputs={},
tags=["smoke", "critical"],
)
# Run only tests with specific tags
suite_result = runner.run_table_tests(
test_cases,
tags_filter=["smoke"]
)
```
#### Retry Mechanism
```python
test_case = WorkflowTestCase(
fixture_path="flaky_workflow",
inputs={},
expected_outputs={},
retry_count=2, # Retry up to 2 times on failure
)
```
#### Custom Validators
```python
def custom_validator(outputs: dict) -> bool:
# Custom validation logic
return "error" not in outputs.get("status", "")
test_case = WorkflowTestCase(
fixture_path="workflow",
inputs={},
expected_outputs={"status": "success"},
custom_validator=custom_validator,
)
```
#### Event Sequence Validation
```python

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
from datetime import datetime
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
@@ -16,6 +15,7 @@ from core.workflow.graph_engine.response_coordinator.coordinator import Response
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import RetryConfig
from core.workflow.runtime import GraphRuntimeState, VariablePool
class _StubEdgeProcessor:

View File

@@ -3,12 +3,12 @@
import time
from unittest.mock import MagicMock
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.entities.commands import AbortCommand
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunStartedEvent
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent
from core.workflow.runtime import GraphRuntimeState, VariablePool
def test_abort_command():
@@ -100,8 +100,57 @@ def test_redis_channel_serialization():
assert command_data["command_type"] == "abort"
assert command_data["reason"] == "Test abort"
# Test pause command serialization
pause_command = PauseCommand(reason="User requested pause")
channel.send_command(pause_command)
if __name__ == "__main__":
test_abort_command()
test_redis_channel_serialization()
print("All tests passed!")
assert len(mock_pipeline.rpush.call_args_list) == 2
second_call_args = mock_pipeline.rpush.call_args_list[1]
pause_command_json = second_call_args[0][1]
pause_command_data = json.loads(pause_command_json)
assert pause_command_data["command_type"] == CommandType.PAUSE.value
assert pause_command_data["reason"] == "User requested pause"
def test_pause_command():
"""Test that GraphEngine properly handles pause commands."""
shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start"
mock_start_node = MagicMock()
mock_start_node.state = None
mock_start_node.id = "start"
mock_start_node.graph_runtime_state = shared_runtime_state
mock_graph.nodes["start"] = mock_start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
command_channel = InMemoryChannel()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=shared_runtime_state,
command_channel=command_channel,
)
pause_command = PauseCommand(reason="User requested pause")
command_channel.send_command(pause_command)
events = list(engine.run())
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
assert len(pause_events) == 1
assert pause_events[0].reason == "User requested pause"
graph_execution = engine.graph_runtime_state.graph_execution
assert graph_execution.is_paused
assert graph_execution.pause_reason == "User requested pause"

View File

@@ -21,6 +21,7 @@ class _StubExecutionCoordinator:
self._execution_complete = False
self.mark_complete_called = False
self.failed = False
self._paused = False
def check_commands(self) -> None:
self.command_checks += 1
@@ -28,6 +29,10 @@ class _StubExecutionCoordinator:
def check_scaling(self) -> None:
self.scaling_checks += 1
@property
def is_paused(self) -> bool:
return self._paused
def is_execution_complete(self) -> bool:
return self._execution_complete
@@ -96,7 +101,7 @@ def _make_succeeded_event() -> NodeRunSucceededEvent:
def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None:
"""Dispatcher polls commands when idle and re-checks after completion events."""
"""Dispatcher polls commands when idle and after completion events."""
started_checks = _run_dispatcher_for_event(_make_started_event())
succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event())

View File

@@ -0,0 +1,62 @@
"""Unit tests for the execution coordinator orchestration logic."""
from unittest.mock import MagicMock
from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
from core.workflow.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator
from core.workflow.graph_engine.worker_management.worker_pool import WorkerPool
def _build_coordinator(graph_execution: GraphExecution) -> tuple[ExecutionCoordinator, MagicMock, MagicMock]:
command_processor = MagicMock(spec=CommandProcessor)
state_manager = MagicMock(spec=GraphStateManager)
worker_pool = MagicMock(spec=WorkerPool)
coordinator = ExecutionCoordinator(
graph_execution=graph_execution,
state_manager=state_manager,
command_processor=command_processor,
worker_pool=worker_pool,
)
return coordinator, state_manager, worker_pool
def test_handle_pause_stops_workers_and_clears_state() -> None:
"""Paused execution should stop workers and clear executing state."""
graph_execution = GraphExecution(workflow_id="workflow")
graph_execution.start()
graph_execution.pause("Awaiting human input")
coordinator, state_manager, worker_pool = _build_coordinator(graph_execution)
coordinator.handle_pause_if_needed()
worker_pool.stop.assert_called_once_with()
state_manager.clear_executing.assert_called_once_with()
def test_handle_pause_noop_when_execution_running() -> None:
"""Running execution should not trigger pause handling."""
graph_execution = GraphExecution(workflow_id="workflow")
graph_execution.start()
coordinator, state_manager, worker_pool = _build_coordinator(graph_execution)
coordinator.handle_pause_if_needed()
worker_pool.stop.assert_not_called()
state_manager.clear_executing.assert_not_called()
def test_is_execution_complete_when_paused() -> None:
"""Paused execution should be treated as complete."""
graph_execution = GraphExecution(workflow_id="workflow")
graph_execution.start()
graph_execution.pause("Awaiting input")
coordinator, state_manager, _worker_pool = _build_coordinator(graph_execution)
state_manager.is_execution_complete.return_value = False
assert coordinator.is_execution_complete()

View File

@@ -0,0 +1,341 @@
import time
from collections.abc import Iterable
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunPauseRequestedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.human_input.entities import HumanInputNodeData
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from .test_mock_config import MockConfig
from .test_mock_nodes import MockLLMNode
from .test_table_runner import TableTestRunner, WorkflowTestCase
def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
start_node.init_node_data(start_config["data"])
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
title=title,
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text=prompt_text,
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
)
llm_config = {"id": node_id, "data": llm_data.model_dump()}
llm_node = MockLLMNode(
id=node_id,
config=llm_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
llm_node.init_node_data(llm_config["data"])
return llm_node
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
human_data = HumanInputNodeData(
title="Human Input",
required_variables=["human.input_ready"],
pause_reason="Awaiting human input",
)
human_config = {"id": "human", "data": human_data.model_dump()}
human_node = HumanInputNode(
id=human_config["id"],
config=human_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
human_node.init_node_data(human_config["data"])
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
end_primary_data = EndNodeData(
title="End Primary",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="primary_text", value_selector=["llm_primary", "text"]),
],
desc=None,
)
end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()}
end_primary = EndNode(
id=end_primary_config["id"],
config=end_primary_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_primary.init_node_data(end_primary_config["data"])
end_secondary_data = EndNodeData(
title="End Secondary",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="secondary_text", value_selector=["llm_secondary", "text"]),
],
desc=None,
)
end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()}
end_secondary = EndNode(
id=end_secondary_config["id"],
config=end_secondary_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_secondary.init_node_data(end_secondary_config["data"])
graph = (
Graph.new()
.add_root(start_node)
.add_node(llm_initial)
.add_node(human_node)
.add_node(llm_primary, from_node_id="human", source_handle="primary")
.add_node(end_primary, from_node_id="llm_primary")
.add_node(llm_secondary, from_node_id="human", source_handle="secondary")
.add_node(end_secondary, from_node_id="llm_secondary")
.build()
)
return graph, graph_runtime_state
def _expected_mock_llm_chunks(text: str) -> list[str]:
chunks: list[str] = []
for index, word in enumerate(text.split(" ")):
chunk = word if index == 0 else f" {word}"
chunks.append(chunk)
chunks.append("")
return chunks
def _assert_stream_chunk_sequence(
chunk_events: Iterable[NodeRunStreamChunkEvent],
expected_nodes: list[str],
expected_chunks: list[str],
) -> None:
actual_nodes = [event.node_id for event in chunk_events]
actual_chunks = [event.chunk for event in chunk_events]
assert actual_nodes == expected_nodes
assert actual_chunks == expected_chunks
def test_human_input_llm_streaming_across_multiple_branches() -> None:
mock_config = MockConfig()
mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"})
mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"})
mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"})
branch_scenarios = [
{
"handle": "primary",
"resume_llm": "llm_primary",
"end_node": "end_primary",
"expected_pre_chunks": [
("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes
("end_primary", ["\n"]), # literal segment emitted when end_primary session activates
],
"expected_post_chunks": [
("llm_primary", _expected_mock_llm_chunks("Primary stream output")), # live stream from chosen branch
],
},
{
"handle": "secondary",
"resume_llm": "llm_secondary",
"end_node": "end_secondary",
"expected_pre_chunks": [
("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes
("end_secondary", ["\n"]), # literal segment emitted when end_secondary session activates
],
"expected_post_chunks": [
("llm_secondary", _expected_mock_llm_chunks("Secondary")), # live stream from chosen branch
],
},
]
for scenario in branch_scenarios:
runner = TableTestRunner()
def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]:
return _build_branching_graph(mock_config)
initial_case = WorkflowTestCase(
description="HumanInput pause before branching decision",
graph_factory=initial_graph_factory,
expected_event_sequence=[
GraphRunStartedEvent, # initial run: graph execution starts
NodeRunStartedEvent, # start node begins execution
NodeRunSucceededEvent, # start node completes
NodeRunStartedEvent, # llm_initial starts streaming
NodeRunSucceededEvent, # llm_initial completes streaming
NodeRunStartedEvent, # human node begins and issues pause
NodeRunPauseRequestedEvent, # human node requests pause awaiting input
GraphRunPausedEvent, # graph run pauses awaiting resume
],
)
initial_result = runner.run_test_case(initial_case)
assert initial_result.success, initial_result.event_mismatch_details
assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events)
graph_runtime_state = initial_result.graph_runtime_state
graph = initial_result.graph
assert graph_runtime_state is not None
assert graph is not None
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"])
graph_runtime_state.graph_execution.pause_reason = None
pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"])
post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"])
expected_resume_sequence: list[type] = (
[
GraphRunStartedEvent,
NodeRunStartedEvent,
]
+ [NodeRunStreamChunkEvent] * pre_chunk_count
+ [
NodeRunSucceededEvent,
NodeRunStartedEvent,
]
+ [NodeRunStreamChunkEvent] * post_chunk_count
+ [
NodeRunSucceededEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
GraphRunSucceededEvent,
]
)
def resume_graph_factory(
graph_snapshot: Graph = graph,
state_snapshot: GraphRuntimeState = graph_runtime_state,
) -> tuple[Graph, GraphRuntimeState]:
return graph_snapshot, state_snapshot
resume_case = WorkflowTestCase(
description=f"HumanInput resumes via {scenario['handle']} branch",
graph_factory=resume_graph_factory,
expected_event_sequence=expected_resume_sequence,
)
resume_result = runner.run_test_case(resume_case)
assert resume_result.success, resume_result.event_mismatch_details
resume_events = resume_result.events
chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)]
assert len(chunk_events) == pre_chunk_count + post_chunk_count
pre_chunk_events = chunk_events[:pre_chunk_count]
post_chunk_events = chunk_events[pre_chunk_count:]
expected_pre_nodes: list[str] = []
expected_pre_chunks: list[str] = []
for node_id, chunks in scenario["expected_pre_chunks"]:
expected_pre_nodes.extend([node_id] * len(chunks))
expected_pre_chunks.extend(chunks)
_assert_stream_chunk_sequence(pre_chunk_events, expected_pre_nodes, expected_pre_chunks)
expected_post_nodes: list[str] = []
expected_post_chunks: list[str] = []
for node_id, chunks in scenario["expected_post_chunks"]:
expected_post_nodes.extend([node_id] * len(chunks))
expected_post_chunks.extend(chunks)
_assert_stream_chunk_sequence(post_chunk_events, expected_post_nodes, expected_post_chunks)
human_success_index = next(
index
for index, event in enumerate(resume_events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human"
)
pre_indices = [
index
for index, event in enumerate(resume_events)
if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index
]
assert pre_indices == list(range(2, 2 + pre_chunk_count))
resume_chunk_indices = [
index
for index, event in enumerate(resume_events)
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"]
]
assert resume_chunk_indices, "Expected streaming output from the selected branch"
resume_start_index = next(
index
for index, event in enumerate(resume_events)
if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"]
)
resume_success_index = next(
index
for index, event in enumerate(resume_events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"]
)
assert resume_start_index < min(resume_chunk_indices)
assert max(resume_chunk_indices) < resume_success_index
started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)]
assert started_nodes == ["human", scenario["resume_llm"], scenario["end_node"]]

View File

@@ -0,0 +1,297 @@
import time
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunPauseRequestedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.human_input.entities import HumanInputNodeData
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from .test_mock_config import MockConfig
from .test_mock_nodes import MockLLMNode
from .test_table_runner import TableTestRunner, WorkflowTestCase
def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
start_node.init_node_data(start_config["data"])
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
title=title,
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text=prompt_text,
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
)
llm_config = {"id": node_id, "data": llm_data.model_dump()}
llm_node = MockLLMNode(
id=node_id,
config=llm_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
llm_node.init_node_data(llm_config["data"])
return llm_node
llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt")
human_data = HumanInputNodeData(
title="Human Input",
required_variables=["human.input_ready"],
pause_reason="Awaiting human input",
)
human_config = {"id": "human", "data": human_data.model_dump()}
human_node = HumanInputNode(
id=human_config["id"],
config=human_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
human_node.init_node_data(human_config["data"])
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
end_data = EndNodeData(
title="End",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="resume_text", value_selector=["llm_resume", "text"]),
],
desc=None,
)
end_config = {"id": "end", "data": end_data.model_dump()}
end_node = EndNode(
id=end_config["id"],
config=end_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_node.init_node_data(end_config["data"])
graph = (
Graph.new()
.add_root(start_node)
.add_node(llm_first)
.add_node(human_node)
.add_node(llm_second)
.add_node(end_node)
.build()
)
return graph, graph_runtime_state
def _expected_mock_llm_chunks(text: str) -> list[str]:
chunks: list[str] = []
for index, word in enumerate(text.split(" ")):
chunk = word if index == 0 else f" {word}"
chunks.append(chunk)
chunks.append("")
return chunks
def test_human_input_llm_streaming_order_across_pause() -> None:
runner = TableTestRunner()
initial_text = "Hello, pause"
resume_text = "Welcome back!"
mock_config = MockConfig()
mock_config.set_node_outputs("llm_initial", {"text": initial_text})
mock_config.set_node_outputs("llm_resume", {"text": resume_text})
expected_initial_sequence: list[type] = [
GraphRunStartedEvent, # graph run begins
NodeRunStartedEvent, # start node begins
NodeRunSucceededEvent, # start node completes
NodeRunStartedEvent, # llm_initial begins streaming
NodeRunSucceededEvent, # llm_initial completes streaming
NodeRunStartedEvent, # human node begins and requests pause
NodeRunPauseRequestedEvent, # human node pause requested
GraphRunPausedEvent, # graph run pauses awaiting resume
]
def graph_factory() -> tuple[Graph, GraphRuntimeState]:
return _build_llm_human_llm_graph(mock_config)
initial_case = WorkflowTestCase(
description="HumanInput pause preserves LLM streaming order",
graph_factory=graph_factory,
expected_event_sequence=expected_initial_sequence,
)
initial_result = runner.run_test_case(initial_case)
assert initial_result.success, initial_result.event_mismatch_details
initial_events = initial_result.events
initial_chunks = _expected_mock_llm_chunks(initial_text)
initial_stream_chunk_events = [event for event in initial_events if isinstance(event, NodeRunStreamChunkEvent)]
assert initial_stream_chunk_events == []
pause_index = next(i for i, event in enumerate(initial_events) if isinstance(event, GraphRunPausedEvent))
llm_succeeded_index = next(
i
for i, event in enumerate(initial_events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_initial"
)
assert llm_succeeded_index < pause_index
graph_runtime_state = initial_result.graph_runtime_state
graph = initial_result.graph
assert graph_runtime_state is not None
assert graph is not None
coordinator = graph_runtime_state.response_coordinator
stream_buffers = coordinator._stream_buffers # Tests may access internals for assertions
assert ("llm_initial", "text") in stream_buffers
initial_stream_chunks = [event.chunk for event in stream_buffers[("llm_initial", "text")]]
assert initial_stream_chunks == initial_chunks
assert ("llm_resume", "text") not in stream_buffers
resume_chunks = _expected_mock_llm_chunks(resume_text)
expected_resume_sequence: list[type] = [
GraphRunStartedEvent, # resumed graph run begins
NodeRunStartedEvent, # human node restarts
NodeRunStreamChunkEvent, # cached llm_initial chunk 1
NodeRunStreamChunkEvent, # cached llm_initial chunk 2
NodeRunStreamChunkEvent, # cached llm_initial final chunk
NodeRunStreamChunkEvent, # end node emits combined template separator
NodeRunSucceededEvent, # human node finishes instantly after input
NodeRunStartedEvent, # llm_resume begins streaming
NodeRunStreamChunkEvent, # llm_resume chunk 1
NodeRunStreamChunkEvent, # llm_resume chunk 2
NodeRunStreamChunkEvent, # llm_resume final chunk
NodeRunSucceededEvent, # llm_resume completes streaming
NodeRunStartedEvent, # end node starts
NodeRunSucceededEvent, # end node finishes
GraphRunSucceededEvent, # graph run succeeds after resume
]
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
assert graph_runtime_state is not None
assert graph is not None
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
graph_runtime_state.graph_execution.pause_reason = None
return graph, graph_runtime_state
resume_case = WorkflowTestCase(
description="HumanInput resume continues LLM streaming order",
graph_factory=resume_graph_factory,
expected_event_sequence=expected_resume_sequence,
)
resume_result = runner.run_test_case(resume_case)
assert resume_result.success, resume_result.event_mismatch_details
resume_events = resume_result.events
success_index = next(i for i, event in enumerate(resume_events) if isinstance(event, GraphRunSucceededEvent))
llm_resume_succeeded_index = next(
i
for i, event in enumerate(resume_events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume"
)
assert llm_resume_succeeded_index < success_index
resume_chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)]
assert [event.node_id for event in resume_chunk_events[:3]] == ["llm_initial"] * 3
assert [event.chunk for event in resume_chunk_events[:3]] == initial_chunks
assert resume_chunk_events[3].node_id == "end"
assert resume_chunk_events[3].chunk == "\n"
assert [event.node_id for event in resume_chunk_events[4:]] == ["llm_resume"] * 3
assert [event.chunk for event in resume_chunk_events[4:]] == resume_chunks
human_success_index = next(
i
for i, event in enumerate(resume_events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human"
)
cached_chunk_indices = [
i
for i, event in enumerate(resume_events)
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id in {"llm_initial", "end"}
]
assert all(index < human_success_index for index in cached_chunk_indices)
llm_resume_start_index = next(
i
for i, event in enumerate(resume_events)
if isinstance(event, NodeRunStartedEvent) and event.node_id == "llm_resume"
)
llm_resume_success_index = next(
i
for i, event in enumerate(resume_events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume"
)
llm_resume_chunk_indices = [
i
for i, event in enumerate(resume_events)
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == "llm_resume"
]
assert llm_resume_chunk_indices
first_resume_chunk_index = min(llm_resume_chunk_indices)
last_resume_chunk_index = max(llm_resume_chunk_indices)
assert llm_resume_start_index < first_resume_chunk_index
assert last_resume_chunk_index < llm_resume_success_index
started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)]
assert started_nodes == ["human", "llm_resume", "end"]

View File

@@ -0,0 +1,321 @@
import time
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.utils.condition.entities import Condition
from .test_mock_config import MockConfig
from .test_mock_nodes import MockLLMNode
from .test_table_runner import TableTestRunner, WorkflowTestCase
def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
variable_pool.add(("branch", "value"), branch_value)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
start_node.init_node_data(start_config["data"])
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
title=title,
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text=prompt_text,
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
)
llm_config = {"id": node_id, "data": llm_data.model_dump()}
llm_node = MockLLMNode(
id=node_id,
config=llm_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
llm_node.init_node_data(llm_config["data"])
return llm_node
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
if_else_data = IfElseNodeData(
title="IfElse",
cases=[
IfElseNodeData.Case(
case_id="primary",
logical_operator="and",
conditions=[
Condition(variable_selector=["branch", "value"], comparison_operator="is", value="primary")
],
),
IfElseNodeData.Case(
case_id="secondary",
logical_operator="and",
conditions=[
Condition(variable_selector=["branch", "value"], comparison_operator="is", value="secondary")
],
),
],
)
if_else_config = {"id": "if_else", "data": if_else_data.model_dump()}
if_else_node = IfElseNode(
id=if_else_config["id"],
config=if_else_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
if_else_node.init_node_data(if_else_config["data"])
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
end_primary_data = EndNodeData(
title="End Primary",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="primary_text", value_selector=["llm_primary", "text"]),
],
desc=None,
)
end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()}
end_primary = EndNode(
id=end_primary_config["id"],
config=end_primary_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_primary.init_node_data(end_primary_config["data"])
end_secondary_data = EndNodeData(
title="End Secondary",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="secondary_text", value_selector=["llm_secondary", "text"]),
],
desc=None,
)
end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()}
end_secondary = EndNode(
id=end_secondary_config["id"],
config=end_secondary_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_secondary.init_node_data(end_secondary_config["data"])
graph = (
Graph.new()
.add_root(start_node)
.add_node(llm_initial)
.add_node(if_else_node)
.add_node(llm_primary, from_node_id="if_else", source_handle="primary")
.add_node(end_primary, from_node_id="llm_primary")
.add_node(llm_secondary, from_node_id="if_else", source_handle="secondary")
.add_node(end_secondary, from_node_id="llm_secondary")
.build()
)
return graph, graph_runtime_state
def _expected_mock_llm_chunks(text: str) -> list[str]:
chunks: list[str] = []
for index, word in enumerate(text.split(" ")):
chunk = word if index == 0 else f" {word}"
chunks.append(chunk)
chunks.append("")
return chunks
def test_if_else_llm_streaming_order() -> None:
mock_config = MockConfig()
mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"})
mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"})
mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"})
scenarios = [
{
"branch": "primary",
"resume_llm": "llm_primary",
"end_node": "end_primary",
"expected_sequence": [
GraphRunStartedEvent, # graph run begins
NodeRunStartedEvent, # start node begins execution
NodeRunSucceededEvent, # start node completes
NodeRunStartedEvent, # llm_initial starts and streams
NodeRunSucceededEvent, # llm_initial completes streaming
NodeRunStartedEvent, # if_else evaluates conditions
NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed
NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed
NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed
NodeRunStreamChunkEvent, # template literal newline emitted
NodeRunSucceededEvent, # if_else completes branch selection
NodeRunStartedEvent, # llm_primary begins streaming
NodeRunStreamChunkEvent, # llm_primary chunk 1
NodeRunStreamChunkEvent, # llm_primary chunk 2
NodeRunStreamChunkEvent, # llm_primary chunk 3
NodeRunStreamChunkEvent, # llm_primary final chunk
NodeRunSucceededEvent, # llm_primary completes streaming
NodeRunStartedEvent, # end_primary node starts
NodeRunSucceededEvent, # end_primary finishes aggregation
GraphRunSucceededEvent, # graph run succeeds
],
"expected_chunks": [
("llm_initial", _expected_mock_llm_chunks("Initial stream")),
("end_primary", ["\n"]),
("llm_primary", _expected_mock_llm_chunks("Primary stream output")),
],
},
{
"branch": "secondary",
"resume_llm": "llm_secondary",
"end_node": "end_secondary",
"expected_sequence": [
GraphRunStartedEvent, # graph run begins
NodeRunStartedEvent, # start node begins execution
NodeRunSucceededEvent, # start node completes
NodeRunStartedEvent, # llm_initial starts and streams
NodeRunSucceededEvent, # llm_initial completes streaming
NodeRunStartedEvent, # if_else evaluates conditions
NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed
NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed
NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed
NodeRunStreamChunkEvent, # template literal newline emitted
NodeRunSucceededEvent, # if_else completes branch selection
NodeRunStartedEvent, # llm_secondary begins streaming
NodeRunStreamChunkEvent, # llm_secondary chunk 1
NodeRunStreamChunkEvent, # llm_secondary final chunk
NodeRunSucceededEvent, # llm_secondary completes
NodeRunStartedEvent, # end_secondary node starts
NodeRunSucceededEvent, # end_secondary finishes aggregation
GraphRunSucceededEvent, # graph run succeeds
],
"expected_chunks": [
("llm_initial", _expected_mock_llm_chunks("Initial stream")),
("end_secondary", ["\n"]),
("llm_secondary", _expected_mock_llm_chunks("Secondary")),
],
},
]
for scenario in scenarios:
runner = TableTestRunner()
def graph_factory(
branch_value: str = scenario["branch"],
cfg: MockConfig = mock_config,
) -> tuple[Graph, GraphRuntimeState]:
return _build_if_else_graph(branch_value, cfg)
test_case = WorkflowTestCase(
description=f"IfElse streaming via {scenario['branch']} branch",
graph_factory=graph_factory,
expected_event_sequence=scenario["expected_sequence"],
)
result = runner.run_test_case(test_case)
assert result.success, result.event_mismatch_details
chunk_events = [event for event in result.events if isinstance(event, NodeRunStreamChunkEvent)]
expected_nodes: list[str] = []
expected_chunks: list[str] = []
for node_id, chunks in scenario["expected_chunks"]:
expected_nodes.extend([node_id] * len(chunks))
expected_chunks.extend(chunks)
assert [event.node_id for event in chunk_events] == expected_nodes
assert [event.chunk for event in chunk_events] == expected_chunks
branch_node_index = next(
index
for index, event in enumerate(result.events)
if isinstance(event, NodeRunStartedEvent) and event.node_id == "if_else"
)
branch_success_index = next(
index
for index, event in enumerate(result.events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "if_else"
)
pre_branch_chunk_indices = [
index
for index, event in enumerate(result.events)
if isinstance(event, NodeRunStreamChunkEvent) and index < branch_success_index
]
assert len(pre_branch_chunk_indices) == len(_expected_mock_llm_chunks("Initial stream")) + 1
assert min(pre_branch_chunk_indices) == branch_node_index + 1
assert max(pre_branch_chunk_indices) < branch_success_index
resume_chunk_indices = [
index
for index, event in enumerate(result.events)
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"]
]
assert resume_chunk_indices
resume_start_index = next(
index
for index, event in enumerate(result.events)
if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"]
)
resume_success_index = next(
index
for index, event in enumerate(result.events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"]
)
assert resume_start_index < min(resume_chunk_indices)
assert max(resume_chunk_indices) < resume_success_index
started_nodes = [event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)]
assert started_nodes == ["start", "llm_initial", "if_else", scenario["resume_llm"], scenario["end_node"]]

View File

@@ -27,7 +27,8 @@ from .test_mock_nodes import (
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
from .test_mock_config import MockConfig

View File

@@ -42,7 +42,8 @@ def test_mock_iteration_node_preserves_config():
"""Test that MockIterationNode preserves mock configuration."""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode
@@ -103,7 +104,8 @@ def test_mock_loop_node_preserves_config():
"""Test that MockLoopNode preserves mock configuration."""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode

View File

@@ -24,7 +24,8 @@ from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.nodes.tool import ToolNode
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
from .test_mock_config import MockConfig
@@ -561,10 +562,11 @@ class MockIterationNode(MockNodeMixin, IterationNode):
def _create_graph_engine(self, index: int, item: Any):
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.runtime import GraphRuntimeState
# Import our MockNodeFactory instead of DifyNodeFactory
from .test_mock_factory import MockNodeFactory
@@ -635,10 +637,11 @@ class MockLoopNode(MockNodeMixin, LoopNode):
def _create_graph_engine(self, start_at, root_node_id: str):
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.runtime import GraphRuntimeState
# Import our MockNodeFactory instead of DifyNodeFactory
from .test_mock_factory import MockNodeFactory

View File

@@ -16,8 +16,8 @@ class TestMockTemplateTransformNode:
def test_mock_template_transform_node_default_output(self):
"""Test that MockTemplateTransformNode processes templates with Jinja2."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -76,8 +76,8 @@ class TestMockTemplateTransformNode:
def test_mock_template_transform_node_custom_output(self):
"""Test that MockTemplateTransformNode returns custom configured output."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -137,8 +137,8 @@ class TestMockTemplateTransformNode:
def test_mock_template_transform_node_error_simulation(self):
"""Test that MockTemplateTransformNode can simulate errors."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -196,8 +196,8 @@ class TestMockTemplateTransformNode:
def test_mock_template_transform_node_with_variables(self):
"""Test that MockTemplateTransformNode processes templates with variables."""
from core.variables import StringVariable
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -262,8 +262,8 @@ class TestMockCodeNode:
def test_mock_code_node_default_output(self):
"""Test that MockCodeNode returns default output."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -323,8 +323,8 @@ class TestMockCodeNode:
def test_mock_code_node_with_output_schema(self):
"""Test that MockCodeNode generates outputs based on schema."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -392,8 +392,8 @@ class TestMockCodeNode:
def test_mock_code_node_custom_output(self):
"""Test that MockCodeNode returns custom configured output."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -463,8 +463,8 @@ class TestMockNodeFactory:
def test_code_and_template_nodes_mocked_by_default(self):
"""Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy)."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -504,8 +504,8 @@ class TestMockNodeFactory:
def test_factory_creates_mock_template_transform_node(self):
"""Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -555,8 +555,8 @@ class TestMockNodeFactory:
def test_factory_creates_mock_code_node(self):
"""Test that MockNodeFactory creates MockCodeNode for code type."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(

View File

@@ -13,7 +13,7 @@ from unittest.mock import patch
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
@@ -27,6 +27,7 @@ from core.workflow.graph_events import (
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom

View File

@@ -13,7 +13,7 @@ import redis
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
from core.workflow.graph_engine.manager import GraphEngineManager
@@ -52,6 +52,29 @@ class TestRedisStopIntegration:
assert command_data["command_type"] == CommandType.ABORT
assert command_data["reason"] == "Test stop"
def test_graph_engine_manager_sends_pause_command(self):
"""Test that GraphEngineManager correctly sends pause command through Redis."""
task_id = "test-task-pause-123"
expected_channel_key = f"workflow:{task_id}:commands"
mock_redis = MagicMock()
mock_pipeline = MagicMock()
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
with patch("core.workflow.graph_engine.manager.redis_client", mock_redis):
GraphEngineManager.send_pause_command(task_id, reason="Awaiting resources")
mock_redis.pipeline.assert_called_once()
calls = mock_pipeline.rpush.call_args_list
assert len(calls) == 1
assert calls[0][0][0] == expected_channel_key
command_json = calls[0][0][1]
command_data = json.loads(command_json)
assert command_data["command_type"] == CommandType.PAUSE.value
assert command_data["reason"] == "Awaiting resources"
def test_graph_engine_manager_handles_redis_failure_gracefully(self):
"""Test that GraphEngineManager handles Redis failures without raising exceptions."""
task_id = "test-task-456"
@@ -105,28 +128,37 @@ class TestRedisStopIntegration:
channel_key = "workflow:test:commands"
channel = RedisChannel(mock_redis, channel_key)
# Create abort command
# Create commands
abort_command = AbortCommand(reason="User requested stop")
pause_command = PauseCommand(reason="User requested pause")
# Execute
channel.send_command(abort_command)
channel.send_command(pause_command)
# Verify
mock_redis.pipeline.assert_called_once()
mock_redis.pipeline.assert_called()
# Check rpush was called
calls = mock_pipeline.rpush.call_args_list
assert len(calls) == 1
assert len(calls) == 2
assert calls[0][0][0] == channel_key
assert calls[1][0][0] == channel_key
# Verify serialized command
command_json = calls[0][0][1]
command_data = json.loads(command_json)
assert command_data["command_type"] == CommandType.ABORT
assert command_data["reason"] == "User requested stop"
# Verify serialized commands
abort_command_json = calls[0][0][1]
abort_command_data = json.loads(abort_command_json)
assert abort_command_data["command_type"] == CommandType.ABORT.value
assert abort_command_data["reason"] == "User requested stop"
# Check expire was set
mock_pipeline.expire.assert_called_once_with(channel_key, 3600)
pause_command_json = calls[1][0][1]
pause_command_data = json.loads(pause_command_json)
assert pause_command_data["command_type"] == CommandType.PAUSE.value
assert pause_command_data["reason"] == "User requested pause"
# Check expire was set for each
assert mock_pipeline.expire.call_count == 2
mock_pipeline.expire.assert_any_call(channel_key, 3600)
def test_redis_channel_fetch_commands(self):
"""Test RedisChannel correctly fetches and deserializes commands."""
@@ -143,12 +175,17 @@ class TestRedisStopIntegration:
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
# Mock command data
abort_command_json = json.dumps({"command_type": CommandType.ABORT, "reason": "Test abort", "payload": None})
abort_command_json = json.dumps(
{"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None}
)
pause_command_json = json.dumps(
{"command_type": CommandType.PAUSE.value, "reason": "Pause requested", "payload": None}
)
# Mock pipeline execute to return commands
pending_pipe.execute.return_value = [b"1", 1]
fetch_pipe.execute.return_value = [
[abort_command_json.encode()], # lrange result
[abort_command_json.encode(), pause_command_json.encode()], # lrange result
True, # delete result
]
@@ -159,10 +196,13 @@ class TestRedisStopIntegration:
commands = channel.fetch_commands()
# Verify
assert len(commands) == 1
assert len(commands) == 2
assert isinstance(commands[0], AbortCommand)
assert commands[0].command_type == CommandType.ABORT
assert commands[0].reason == "Test abort"
assert isinstance(commands[1], PauseCommand)
assert commands[1].command_type == CommandType.PAUSE
assert commands[1].reason == "Pause requested"
# Verify Redis operations
pending_pipe.get.assert_called_once_with(f"{channel_key}:pending")

View File

@@ -29,7 +29,6 @@ from core.variables import (
ObjectVariable,
StringVariable,
)
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
@@ -40,6 +39,7 @@ from core.workflow.graph_events import (
GraphRunSucceededEvent,
)
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from .test_mock_config import MockConfig
@@ -52,8 +52,8 @@ logger = logging.getLogger(__name__)
class WorkflowTestCase:
"""Represents a single test case for table-driven testing."""
fixture_path: str
expected_outputs: dict[str, Any]
fixture_path: str = ""
expected_outputs: dict[str, Any] = field(default_factory=dict)
inputs: dict[str, Any] = field(default_factory=dict)
query: str = ""
description: str = ""
@@ -61,11 +61,7 @@ class WorkflowTestCase:
mock_config: MockConfig | None = None
use_auto_mock: bool = False
expected_event_sequence: Sequence[type[GraphEngineEvent]] | None = None
tags: list[str] = field(default_factory=list)
skip: bool = False
skip_reason: str = ""
retry_count: int = 0
custom_validator: Callable[[dict[str, Any]], bool] | None = None
graph_factory: Callable[[], tuple[Graph, GraphRuntimeState]] | None = None
@dataclass
@@ -80,7 +76,8 @@ class WorkflowTestResult:
event_sequence_match: bool | None = None
event_mismatch_details: str | None = None
events: list[GraphEngineEvent] = field(default_factory=list)
retry_attempts: int = 0
graph: Graph | None = None
graph_runtime_state: GraphRuntimeState | None = None
validation_details: str | None = None
@@ -91,7 +88,6 @@ class TestSuiteResult:
total_tests: int
passed_tests: int
failed_tests: int
skipped_tests: int
total_execution_time: float
results: list[WorkflowTestResult]
@@ -106,10 +102,6 @@ class TestSuiteResult:
"""Get all failed test results."""
return [r for r in self.results if not r.success]
def get_results_by_tag(self, tag: str) -> list[WorkflowTestResult]:
"""Get test results filtered by tag."""
return [r for r in self.results if tag in r.test_case.tags]
class WorkflowRunner:
"""Core workflow execution engine for tests."""
@@ -286,90 +278,30 @@ class TableTestRunner:
Returns:
WorkflowTestResult with execution details
"""
if test_case.skip:
self.logger.info("Skipping test: %s - %s", test_case.description, test_case.skip_reason)
return WorkflowTestResult(
test_case=test_case,
success=True,
execution_time=0.0,
validation_details=f"Skipped: {test_case.skip_reason}",
)
retry_attempts = 0
last_result = None
last_error = None
start_time = time.perf_counter()
for attempt in range(test_case.retry_count + 1):
start_time = time.perf_counter()
try:
result = self._execute_test_case(test_case)
last_result = result # Save the last result
if result.success:
result.retry_attempts = retry_attempts
self.logger.info("Test passed: %s", test_case.description)
return result
last_error = result.error
retry_attempts += 1
if attempt < test_case.retry_count:
self.logger.warning(
"Test failed (attempt %d/%d): %s",
attempt + 1,
test_case.retry_count + 1,
test_case.description,
)
time.sleep(0.5 * (attempt + 1)) # Exponential backoff
except Exception as e:
last_error = e
retry_attempts += 1
if attempt < test_case.retry_count:
self.logger.warning(
"Test error (attempt %d/%d): %s - %s",
attempt + 1,
test_case.retry_count + 1,
test_case.description,
str(e),
)
time.sleep(0.5 * (attempt + 1))
# All retries failed - return the last result if available
if last_result:
last_result.retry_attempts = retry_attempts
self.logger.error("Test failed after %d attempts: %s", retry_attempts, test_case.description)
return last_result
# If no result available (all attempts threw exceptions), create a failure result
self.logger.error("Test failed after %d attempts: %s", retry_attempts, test_case.description)
return WorkflowTestResult(
test_case=test_case,
success=False,
error=last_error,
execution_time=time.perf_counter() - start_time,
retry_attempts=retry_attempts,
)
try:
result = self._execute_test_case(test_case)
if result.success:
self.logger.info("Test passed: %s", test_case.description)
else:
self.logger.error("Test failed: %s", test_case.description)
return result
except Exception as exc:
self.logger.exception("Error executing test case: %s", test_case.description)
return WorkflowTestResult(
test_case=test_case,
success=False,
error=exc,
execution_time=time.perf_counter() - start_time,
)
def _execute_test_case(self, test_case: WorkflowTestCase) -> WorkflowTestResult:
"""Internal method to execute a single test case."""
start_time = time.perf_counter()
try:
# Load fixture data
fixture_data = self.workflow_runner.load_fixture(test_case.fixture_path)
# Create graph from fixture
graph, graph_runtime_state = self.workflow_runner.create_graph_from_fixture(
fixture_data=fixture_data,
inputs=test_case.inputs,
query=test_case.query,
use_mock_factory=test_case.use_auto_mock,
mock_config=test_case.mock_config,
)
graph, graph_runtime_state = self._create_graph_runtime_state(test_case)
# Create and run the engine with configured worker settings
engine = GraphEngine(
@@ -384,7 +316,7 @@ class TableTestRunner:
)
# Execute and collect events
events = []
events: list[GraphEngineEvent] = []
for event in engine.run():
events.append(event)
@@ -416,6 +348,8 @@ class TableTestRunner:
events=events,
event_sequence_match=event_sequence_match,
event_mismatch_details=event_mismatch_details,
graph=graph,
graph_runtime_state=graph_runtime_state,
)
# Get actual outputs
@@ -423,9 +357,7 @@ class TableTestRunner:
actual_outputs = success_event.outputs or {}
# Validate outputs
output_success, validation_details = self._validate_outputs(
test_case.expected_outputs, actual_outputs, test_case.custom_validator
)
output_success, validation_details = self._validate_outputs(test_case.expected_outputs, actual_outputs)
# Overall success requires both output and event sequence validation
success = output_success and (event_sequence_match if event_sequence_match is not None else True)
@@ -440,6 +372,8 @@ class TableTestRunner:
events=events,
validation_details=validation_details,
error=None if success else Exception(validation_details or event_mismatch_details or "Test failed"),
graph=graph,
graph_runtime_state=graph_runtime_state,
)
except Exception as e:
@@ -449,13 +383,33 @@ class TableTestRunner:
success=False,
error=e,
execution_time=time.perf_counter() - start_time,
graph=graph if "graph" in locals() else None,
graph_runtime_state=graph_runtime_state if "graph_runtime_state" in locals() else None,
)
def _create_graph_runtime_state(self, test_case: WorkflowTestCase) -> tuple[Graph, GraphRuntimeState]:
"""Create or retrieve graph/runtime state according to test configuration."""
if test_case.graph_factory is not None:
return test_case.graph_factory()
if not test_case.fixture_path:
raise ValueError("fixture_path must be provided when graph_factory is not specified")
fixture_data = self.workflow_runner.load_fixture(test_case.fixture_path)
return self.workflow_runner.create_graph_from_fixture(
fixture_data=fixture_data,
inputs=test_case.inputs,
query=test_case.query,
use_mock_factory=test_case.use_auto_mock,
mock_config=test_case.mock_config,
)
def _validate_outputs(
self,
expected_outputs: dict[str, Any],
actual_outputs: dict[str, Any],
custom_validator: Callable[[dict[str, Any]], bool] | None = None,
) -> tuple[bool, str | None]:
"""
Validate actual outputs against expected outputs.
@@ -490,14 +444,6 @@ class TableTestRunner:
f"Value mismatch for key '{key}':\n Expected: {expected_value}\n Actual: {actual_value}"
)
# Apply custom validator if provided
if custom_validator:
try:
if not custom_validator(actual_outputs):
validation_errors.append("Custom validator failed")
except Exception as e:
validation_errors.append(f"Custom validator error: {str(e)}")
if validation_errors:
return False, "\n".join(validation_errors)
@@ -537,7 +483,6 @@ class TableTestRunner:
self,
test_cases: list[WorkflowTestCase],
parallel: bool = False,
tags_filter: list[str] | None = None,
fail_fast: bool = False,
) -> TestSuiteResult:
"""
@@ -546,22 +491,16 @@ class TableTestRunner:
Args:
test_cases: List of test cases to execute
parallel: Run tests in parallel
tags_filter: Only run tests with specified tags
fail_fast: Stop execution on first failure
fail_fast: Stop execution on first failure
Returns:
TestSuiteResult with aggregated results
"""
# Filter by tags if specified
if tags_filter:
test_cases = [tc for tc in test_cases if any(tag in tc.tags for tag in tags_filter)]
if not test_cases:
return TestSuiteResult(
total_tests=0,
passed_tests=0,
failed_tests=0,
skipped_tests=0,
total_execution_time=0.0,
results=[],
)
@@ -576,16 +515,14 @@ class TableTestRunner:
# Calculate statistics
total_tests = len(results)
passed_tests = sum(1 for r in results if r.success and not r.test_case.skip)
failed_tests = sum(1 for r in results if not r.success and not r.test_case.skip)
skipped_tests = sum(1 for r in results if r.test_case.skip)
passed_tests = sum(1 for r in results if r.success)
failed_tests = total_tests - passed_tests
total_execution_time = time.perf_counter() - start_time
return TestSuiteResult(
total_tests=total_tests,
passed_tests=passed_tests,
failed_tests=failed_tests,
skipped_tests=skipped_tests,
total_execution_time=total_execution_time,
results=results,
)
@@ -598,7 +535,7 @@ class TableTestRunner:
result = self.run_test_case(test_case)
results.append(result)
if fail_fast and not result.success and not result.test_case.skip:
if fail_fast and not result.success:
self.logger.info("Fail-fast enabled: stopping execution")
break
@@ -618,11 +555,11 @@ class TableTestRunner:
result = future.result()
results.append(result)
if fail_fast and not result.success and not result.test_case.skip:
if fail_fast and not result.success:
self.logger.info("Fail-fast enabled: cancelling remaining tests")
# Cancel remaining futures
for f in future_to_test:
f.cancel()
for remaining_future in future_to_test:
if not remaining_future.done():
remaining_future.cancel()
break
except Exception as e:
@@ -636,8 +573,9 @@ class TableTestRunner:
)
if fail_fast:
for f in future_to_test:
f.cancel()
for remaining_future in future_to_test:
if not remaining_future.done():
remaining_future.cancel()
break
return results
@@ -663,7 +601,6 @@ class TableTestRunner:
report.append(f" Total Tests: {suite_result.total_tests}")
report.append(f" Passed: {suite_result.passed_tests}")
report.append(f" Failed: {suite_result.failed_tests}")
report.append(f" Skipped: {suite_result.skipped_tests}")
report.append(f" Success Rate: {suite_result.success_rate:.1f}%")
report.append(f" Total Time: {suite_result.total_execution_time:.2f}s")
report.append("")

View File

@@ -3,11 +3,12 @@ import uuid
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom

View File

@@ -1,4 +1,3 @@
from core.workflow.entities import VariablePool
from core.workflow.nodes.http_request import (
BodyData,
HttpRequestNodeAuthorization,
@@ -7,6 +6,7 @@ from core.workflow.nodes.http_request import (
)
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable

View File

@@ -20,7 +20,7 @@ from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import (
ContextConfig,
@@ -32,6 +32,7 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.llm.file_saver import LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.provider import ProviderType

View File

@@ -7,12 +7,13 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileSegment
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
from extensions.ext_database import db

View File

@@ -6,11 +6,12 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable, StringVariable
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom

View File

@@ -4,11 +4,12 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom

View File

@@ -27,7 +27,7 @@ from core.variables.variables import (
VariableUnion,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import VariablePool
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
from factories.variable_factory import build_segment, segment_to_variable

View File

@@ -1,476 +0,0 @@
import json
from unittest.mock import MagicMock
import pytest
from sqlalchemy.orm import Session
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.workflow.entities import (
WorkflowExecution,
WorkflowNodeExecution,
)
from core.workflow.enums import (
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
WorkflowType,
)
from core.workflow.nodes import NodeType
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import Workflow, WorkflowRun
@pytest.fixture
def real_app_generate_entity():
additional_features = AppAdditionalFeatures(
file_upload=None,
opening_statement=None,
suggested_questions=[],
suggested_questions_after_answer=False,
show_retrieve_source=False,
more_like_this=False,
speech_to_text=False,
text_to_speech=None,
trace_config=None,
)
app_config = WorkflowUIBasedAppConfig(
tenant_id="test-tenant-id",
app_id="test-app-id",
app_mode=AppMode.WORKFLOW,
additional_features=additional_features,
workflow_id="test-workflow-id",
)
entity = AdvancedChatAppGenerateEntity(
task_id="test-task-id",
app_config=app_config,
inputs={"query": "test query"},
files=[],
user_id="test-user-id",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
query="test query",
conversation_id="test-conversation-id",
)
return entity
@pytest.fixture
def real_workflow_system_variables():
return SystemVariable(
query="test query",
conversation_id="test-conversation-id",
user_id="test-user-id",
app_id="test-app-id",
workflow_id="test-workflow-id",
workflow_execution_id="test-workflow-run-id",
)
@pytest.fixture
def mock_node_execution_repository():
repo = MagicMock(spec=WorkflowNodeExecutionRepository)
return repo
@pytest.fixture
def mock_workflow_execution_repository():
repo = MagicMock(spec=WorkflowExecutionRepository)
return repo
@pytest.fixture
def real_workflow_entity():
return CycleManagerWorkflowInfo(
workflow_id="test-workflow-id", # Matches ID used in other fixtures
workflow_type=WorkflowType.WORKFLOW,
version="1.0.0",
graph_data={
"nodes": [
{
"id": "node1",
"type": "chat", # NodeType is a string enum
"name": "Chat Node",
"data": {"model": "gpt-3.5-turbo", "prompt": "test prompt"},
}
],
"edges": [],
},
)
@pytest.fixture
def workflow_cycle_manager(
real_app_generate_entity,
real_workflow_system_variables,
mock_workflow_execution_repository,
mock_node_execution_repository,
real_workflow_entity,
):
return WorkflowCycleManager(
application_generate_entity=real_app_generate_entity,
workflow_system_variables=real_workflow_system_variables,
workflow_info=real_workflow_entity,
workflow_execution_repository=mock_workflow_execution_repository,
workflow_node_execution_repository=mock_node_execution_repository,
)
@pytest.fixture
def mock_session():
session = MagicMock(spec=Session)
return session
@pytest.fixture
def real_workflow():
workflow = Workflow()
workflow.id = "test-workflow-id"
workflow.tenant_id = "test-tenant-id"
workflow.app_id = "test-app-id"
workflow.type = "chat"
workflow.version = "1.0"
graph_data = {"nodes": [], "edges": []}
workflow.graph = json.dumps(graph_data)
workflow.features = json.dumps({"file_upload": {"enabled": False}})
workflow.created_by = "test-user-id"
workflow.created_at = naive_utc_now()
workflow.updated_at = naive_utc_now()
workflow._environment_variables = "{}"
workflow._conversation_variables = "{}"
return workflow
@pytest.fixture
def real_workflow_run():
workflow_run = WorkflowRun()
workflow_run.id = "test-workflow-run-id"
workflow_run.tenant_id = "test-tenant-id"
workflow_run.app_id = "test-app-id"
workflow_run.workflow_id = "test-workflow-id"
workflow_run.type = "chat"
workflow_run.triggered_from = "app-run"
workflow_run.version = "1.0"
workflow_run.graph = json.dumps({"nodes": [], "edges": []})
workflow_run.inputs = json.dumps({"query": "test query"})
workflow_run.status = WorkflowExecutionStatus.RUNNING
workflow_run.outputs = json.dumps({"answer": "test answer"})
workflow_run.created_by_role = CreatorUserRole.ACCOUNT
workflow_run.created_by = "test-user-id"
workflow_run.created_at = naive_utc_now()
return workflow_run
def test_init(
workflow_cycle_manager,
real_app_generate_entity,
real_workflow_system_variables,
mock_workflow_execution_repository,
mock_node_execution_repository,
):
"""Test initialization of WorkflowCycleManager"""
assert workflow_cycle_manager._application_generate_entity == real_app_generate_entity
assert workflow_cycle_manager._workflow_system_variables == real_workflow_system_variables
assert workflow_cycle_manager._workflow_execution_repository == mock_workflow_execution_repository
assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository
def test_handle_workflow_run_start(workflow_cycle_manager):
"""Test handle_workflow_run_start method"""
# Call the method
workflow_execution = workflow_cycle_manager.handle_workflow_run_start()
# Verify the result
assert workflow_execution.workflow_id == "test-workflow-id"
# Verify the workflow_execution_repository.save was called
workflow_cycle_manager._workflow_execution_repository.save.assert_called_once_with(workflow_execution)
def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test handle_workflow_run_success method"""
# Create a real WorkflowExecution
workflow_execution = WorkflowExecution(
id_="test-workflow-run-id",
workflow_id="test-workflow-id",
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id="test-workflow-run-id",
total_tokens=100,
total_steps=5,
outputs={"answer": "test answer"},
)
# Verify the result
assert result == workflow_execution
assert result.status == WorkflowExecutionStatus.SUCCEEDED
assert result.outputs == {"answer": "test answer"}
assert result.total_tokens == 100
assert result.total_steps == 5
assert result.finished_at is not None
def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test handle_workflow_run_failed method"""
# Create a real WorkflowExecution
workflow_execution = WorkflowExecution(
id_="test-workflow-run-id",
workflow_id="test-workflow-id",
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# No running node executions in cache (empty cache)
# Call the method
result = workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id="test-workflow-run-id",
total_tokens=50,
total_steps=3,
status=WorkflowExecutionStatus.FAILED,
error_message="Test error message",
)
# Verify the result
assert result == workflow_execution
assert result.status == WorkflowExecutionStatus.FAILED
assert result.error_message == "Test error message"
assert result.total_tokens == 50
assert result.total_steps == 3
assert result.finished_at is not None
def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test handle_node_execution_start method"""
# Create a real WorkflowExecution
workflow_execution = WorkflowExecution(
id_="test-workflow-execution-id",
workflow_id="test-workflow-id",
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Create a mock event
event = MagicMock(spec=QueueNodeStartedEvent)
event.node_execution_id = "test-node-execution-id"
event.node_id = "test-node-id"
event.node_type = NodeType.LLM
event.node_title = "Test Node"
event.predecessor_node_id = "test-predecessor-node-id"
event.node_run_index = 1
event.parallel_mode_run_id = "test-parallel-mode-run-id"
event.in_iteration_id = "test-iteration-id"
event.in_loop_id = "test-loop-id"
# Call the method
result = workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=workflow_execution.id_,
event=event,
)
# Verify the result
assert result.workflow_id == workflow_execution.workflow_id
assert result.workflow_execution_id == workflow_execution.id_
assert result.node_execution_id == event.node_execution_id
assert result.node_id == event.node_id
assert result.node_type == event.node_type
assert result.title == event.node_title
assert result.status == WorkflowNodeExecutionStatus.RUNNING
# Verify save was called
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result)
def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test _get_workflow_execution_or_raise_error method"""
# Create a real WorkflowExecution
workflow_execution = WorkflowExecution(
id_="test-workflow-run-id",
workflow_id="test-workflow-id",
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache["test-workflow-run-id"] = workflow_execution
# Call the method
result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id")
# Verify the result
assert result == workflow_execution
# Test error case - clear cache
workflow_cycle_manager._workflow_execution_cache.clear()
# Expect an error when execution is not found
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
with pytest.raises(WorkflowRunNotFoundError):
workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id")
def test_handle_workflow_node_execution_success(workflow_cycle_manager):
"""Test handle_workflow_node_execution_success method"""
# Create a mock event
event = MagicMock(spec=QueueNodeSucceededEvent)
event.node_execution_id = "test-node-execution-id"
event.inputs = {"input": "test input"}
event.process_data = {"process": "test process"}
event.outputs = {"output": "test output"}
event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100}
event.start_at = naive_utc_now()
# Create a real node execution
node_execution = WorkflowNodeExecution(
id="test-node-execution-record-id",
node_execution_id="test-node-execution-id",
workflow_id="test-workflow-id",
workflow_execution_id="test-workflow-run-id",
index=1,
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
created_at=naive_utc_now(),
)
# Pre-populate the cache with the node execution
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_node_execution_success(
event=event,
)
# Verify the result
assert result == node_execution
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify save was called
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution)
def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test handle_workflow_run_partial_success method"""
# Create a real WorkflowExecution
workflow_execution = WorkflowExecution(
id_="test-workflow-run-id",
workflow_id="test-workflow-id",
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=naive_utc_now(),
)
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id="test-workflow-run-id",
total_tokens=75,
total_steps=4,
outputs={"partial_answer": "test partial answer"},
exceptions_count=2,
)
# Verify the result
assert result == workflow_execution
assert result.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED
assert result.outputs == {"partial_answer": "test partial answer"}
assert result.total_tokens == 75
assert result.total_steps == 4
assert result.exceptions_count == 2
assert result.finished_at is not None
def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
"""Test handle_workflow_node_execution_failed method"""
# Create a mock event
event = MagicMock(spec=QueueNodeFailedEvent)
event.node_execution_id = "test-node-execution-id"
event.inputs = {"input": "test input"}
event.process_data = {"process": "test process"}
event.outputs = {"output": "test output"}
event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100}
event.start_at = naive_utc_now()
event.error = "Test error message"
# Create a real node execution
node_execution = WorkflowNodeExecution(
id="test-node-execution-record-id",
node_execution_id="test-node-execution-id",
workflow_id="test-workflow-id",
workflow_execution_id="test-workflow-run-id",
index=1,
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
created_at=naive_utc_now(),
)
# Pre-populate the cache with the node execution
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_node_execution_failed(
event=event,
)
# Verify the result
assert result == node_execution
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error == "Test error message"
# Verify save was called
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution)

View File

@@ -7,7 +7,7 @@ from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry

View File

@@ -3,8 +3,8 @@
from unittest.mock import MagicMock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.workflow_entry import WorkflowEntry
from models.enums import UserFrom