refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)

refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)

This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization.

Key changes:
- Introduce distinct `value_type` tags for Integer and Float segments/variables
- Add `VariableUnion` and `SegmentUnion` types for proper type discrimination
- Leverage Pydantic's discriminated union feature for seamless serialization/deserialization
- Enable accurate serialization of data structures containing these types

Closes #22024.
This commit is contained in:
QuantumGhost
2025-07-16 12:31:37 +08:00
committed by GitHub
parent 229b4d621e
commit 2c1ab4879f
58 changed files with 2325 additions and 328 deletions

View File

@@ -1,14 +1,49 @@
import dataclasses
from pydantic import BaseModel
from core.file import File, FileTransferMethod, FileType
from core.helper import encrypter
from core.variables import SecretVariable, StringVariable
from core.variables.segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
SegmentUnion,
StringSegment,
get_segment_discriminator,
)
from core.variables.types import SegmentType
from core.variables.variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
VariableUnion,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.system_variable import SystemVariable
def test_segment_group_to_text():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
system_variables=SystemVariable(user_id="fake-user-id"),
user_inputs={},
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
@@ -30,7 +65,7 @@ def test_segment_group_to_text():
def test_convert_constant_to_segment_group():
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -43,9 +78,7 @@ def test_convert_constant_to_segment_group():
def test_convert_variable_to_segment_group():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
system_variables=SystemVariable(user_id="fake-user-id"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -56,3 +89,297 @@ def test_convert_variable_to_segment_group():
assert segments_group.log == "fake-user-id"
assert isinstance(segments_group.value[0], StringVariable)
assert segments_group.value[0].value == "fake-user-id"
class _Segments(BaseModel):
segments: list[SegmentUnion]
class _Variables(BaseModel):
variables: list[VariableUnion]
def create_test_file(
file_type: FileType = FileType.DOCUMENT,
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
filename: str = "test.txt",
extension: str = ".txt",
mime_type: str = "text/plain",
size: int = 1024,
) -> File:
"""Factory function to create File objects for testing"""
return File(
tenant_id="test-tenant",
type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
storage_key="test-storage-key",
)
class TestSegmentDumpAndLoad:
"""Test suite for segment and variable serialization/deserialization"""
def test_segments(self):
"""Test basic segment serialization compatibility"""
model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")])
json = model.model_dump_json()
print("Json: ", json)
loaded = _Segments.model_validate_json(json)
assert loaded == model
def test_segment_number(self):
"""Test number segment serialization compatibility"""
model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)])
json = model.model_dump_json()
print("Json: ", json)
loaded = _Segments.model_validate_json(json)
assert loaded == model
def test_variables(self):
"""Test variable serialization compatibility"""
model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")])
json = model.model_dump_json()
print("Json: ", json)
restored = _Variables.model_validate_json(json)
assert restored == model
def test_all_segments_serialization(self):
"""Test serialization/deserialization of all segment types"""
# Create one instance of each segment type
test_file = create_test_file()
all_segments: list[SegmentUnion] = [
NoneSegment(),
StringSegment(value="test string"),
IntegerSegment(value=42),
FloatSegment(value=3.14),
ObjectSegment(value={"key": "value", "number": 123}),
FileSegment(value=test_file),
ArrayAnySegment(value=[1, "string", 3.14, {"key": "value"}]),
ArrayStringSegment(value=["hello", "world"]),
ArrayNumberSegment(value=[1, 2.5, 3]),
ArrayObjectSegment(value=[{"id": 1}, {"id": 2}]),
ArrayFileSegment(value=[]), # Empty array to avoid file complexity
]
# Test serialization and deserialization
model = _Segments(segments=all_segments)
json_str = model.model_dump_json()
loaded = _Segments.model_validate_json(json_str)
# Verify all segments are preserved
assert len(loaded.segments) == len(all_segments)
for original, loaded_segment in zip(all_segments, loaded.segments):
assert type(loaded_segment) == type(original)
assert loaded_segment.value_type == original.value_type
# For file segments, compare key properties instead of exact equality
if isinstance(original, FileSegment) and isinstance(loaded_segment, FileSegment):
orig_file = original.value
loaded_file = loaded_segment.value
assert isinstance(orig_file, File)
assert isinstance(loaded_file, File)
assert loaded_file.tenant_id == orig_file.tenant_id
assert loaded_file.type == orig_file.type
assert loaded_file.filename == orig_file.filename
else:
assert loaded_segment.value == original.value
def test_all_variables_serialization(self):
"""Test serialization/deserialization of all variable types"""
# Create one instance of each variable type
test_file = create_test_file()
all_variables: list[VariableUnion] = [
NoneVariable(name="none_var"),
StringVariable(value="test string", name="string_var"),
IntegerVariable(value=42, name="int_var"),
FloatVariable(value=3.14, name="float_var"),
ObjectVariable(value={"key": "value", "number": 123}, name="object_var"),
FileVariable(value=test_file, name="file_var"),
ArrayAnyVariable(value=[1, "string", 3.14, {"key": "value"}], name="array_any_var"),
ArrayStringVariable(value=["hello", "world"], name="array_string_var"),
ArrayNumberVariable(value=[1, 2.5, 3], name="array_number_var"),
ArrayObjectVariable(value=[{"id": 1}, {"id": 2}], name="array_object_var"),
ArrayFileVariable(value=[], name="array_file_var"), # Empty array to avoid file complexity
]
# Test serialization and deserialization
model = _Variables(variables=all_variables)
json_str = model.model_dump_json()
loaded = _Variables.model_validate_json(json_str)
# Verify all variables are preserved
assert len(loaded.variables) == len(all_variables)
for original, loaded_variable in zip(all_variables, loaded.variables):
assert type(loaded_variable) == type(original)
assert loaded_variable.value_type == original.value_type
assert loaded_variable.name == original.name
# For file variables, compare key properties instead of exact equality
if isinstance(original, FileVariable) and isinstance(loaded_variable, FileVariable):
orig_file = original.value
loaded_file = loaded_variable.value
assert isinstance(orig_file, File)
assert isinstance(loaded_file, File)
assert loaded_file.tenant_id == orig_file.tenant_id
assert loaded_file.type == orig_file.type
assert loaded_file.filename == orig_file.filename
else:
assert loaded_variable.value == original.value
def test_segment_discriminator_function_for_segment_types(self):
"""Test the segment discriminator function"""
@dataclasses.dataclass
class TestCase:
segment: Segment
expected_segment_type: SegmentType
file1 = create_test_file()
file2 = create_test_file(filename="test2.txt")
cases = [
TestCase(
NoneSegment(),
SegmentType.NONE,
),
TestCase(
StringSegment(value=""),
SegmentType.STRING,
),
TestCase(
FloatSegment(value=0.0),
SegmentType.FLOAT,
),
TestCase(
IntegerSegment(value=0),
SegmentType.INTEGER,
),
TestCase(
ObjectSegment(value={}),
SegmentType.OBJECT,
),
TestCase(
FileSegment(value=file1),
SegmentType.FILE,
),
TestCase(
ArrayAnySegment(value=[0, 0.0, ""]),
SegmentType.ARRAY_ANY,
),
TestCase(
ArrayStringSegment(value=[""]),
SegmentType.ARRAY_STRING,
),
TestCase(
ArrayNumberSegment(value=[0, 0.0]),
SegmentType.ARRAY_NUMBER,
),
TestCase(
ArrayObjectSegment(value=[{}]),
SegmentType.ARRAY_OBJECT,
),
TestCase(
ArrayFileSegment(value=[file1, file2]),
SegmentType.ARRAY_FILE,
),
]
for test_case in cases:
segment = test_case.segment
assert get_segment_discriminator(segment) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for type {type(segment)}"
)
model_dict = segment.model_dump(mode="json")
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for serialized form of type {type(segment)}"
)
def test_variable_discriminator_function_for_variable_types(self):
"""Test the variable discriminator function"""
@dataclasses.dataclass
class TestCase:
variable: Variable
expected_segment_type: SegmentType
file1 = create_test_file()
file2 = create_test_file(filename="test2.txt")
cases = [
TestCase(
NoneVariable(name="none_var"),
SegmentType.NONE,
),
TestCase(
StringVariable(value="test", name="string_var"),
SegmentType.STRING,
),
TestCase(
FloatVariable(value=0.0, name="float_var"),
SegmentType.FLOAT,
),
TestCase(
IntegerVariable(value=0, name="int_var"),
SegmentType.INTEGER,
),
TestCase(
ObjectVariable(value={}, name="object_var"),
SegmentType.OBJECT,
),
TestCase(
FileVariable(value=file1, name="file_var"),
SegmentType.FILE,
),
TestCase(
SecretVariable(value="secret", name="secret_var"),
SegmentType.SECRET,
),
TestCase(
ArrayAnyVariable(value=[0, 0.0, ""], name="array_any_var"),
SegmentType.ARRAY_ANY,
),
TestCase(
ArrayStringVariable(value=[""], name="array_string_var"),
SegmentType.ARRAY_STRING,
),
TestCase(
ArrayNumberVariable(value=[0, 0.0], name="array_number_var"),
SegmentType.ARRAY_NUMBER,
),
TestCase(
ArrayObjectVariable(value=[{}], name="array_object_var"),
SegmentType.ARRAY_OBJECT,
),
TestCase(
ArrayFileVariable(value=[file1, file2], name="array_file_var"),
SegmentType.ARRAY_FILE,
),
]
for test_case in cases:
variable = test_case.variable
assert get_segment_discriminator(variable) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for type {type(variable)}"
)
model_dict = variable.model_dump(mode="json")
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for serialized form of type {type(variable)}"
)
def test_invlaid_value_for_discriminator(self):
# Test invalid cases
assert get_segment_discriminator({"value_type": "invalid"}) is None
assert get_segment_discriminator({}) is None
assert get_segment_discriminator("not_a_dict") is None
assert get_segment_discriminator(42) is None
assert get_segment_discriminator(object) is None

View File

@@ -0,0 +1,60 @@
from core.variables.types import SegmentType
class TestSegmentTypeIsArrayType:
"""
Test class for SegmentType.is_array_type method.
Provides comprehensive coverage of all SegmentType values to ensure
correct identification of array and non-array types.
"""
def test_is_array_type(self):
"""
Test that all SegmentType enum values are covered in our test cases.
Ensures comprehensive coverage by verifying that every SegmentType
value is tested for the is_array_type method.
"""
# Arrange
all_segment_types = set(SegmentType)
expected_array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
]
expected_non_array_types = [
SegmentType.INTEGER,
SegmentType.FLOAT,
SegmentType.NUMBER,
SegmentType.STRING,
SegmentType.OBJECT,
SegmentType.SECRET,
SegmentType.FILE,
SegmentType.NONE,
SegmentType.GROUP,
]
for seg_type in expected_array_types:
assert seg_type.is_array_type()
for seg_type in expected_non_array_types:
assert not seg_type.is_array_type()
# Act & Assert
covered_types = set(expected_array_types) | set(expected_non_array_types)
assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests"
def test_all_enum_values_are_supported(self):
"""
Test that all enum values are supported and return boolean values.
Validates that every SegmentType enum value can be processed by
is_array_type method and returns a boolean value.
"""
enum_values: list[SegmentType] = list(SegmentType)
for seg_type in enum_values:
is_array = seg_type.is_array_type()
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"

View File

@@ -11,6 +11,7 @@ from core.variables import (
SegmentType,
StringVariable,
)
from core.variables.variables import Variable
def test_frozen_variables():
@@ -75,7 +76,7 @@ def test_object_variable_to_object():
def test_variable_to_object():
var = StringVariable(name="text", value="text")
var: Variable = StringVariable(name="text", value="text")
assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42

View File

@@ -0,0 +1,146 @@
import time
from decimal import Decimal
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
from core.workflow.system_variable import SystemVariable
def create_test_graph_runtime_state() -> GraphRuntimeState:
"""Factory function to create a GraphRuntimeState with non-empty values for testing."""
# Create a variable pool with system variables
system_vars = SystemVariable(
user_id="test_user_123",
app_id="test_app_456",
workflow_id="test_workflow_789",
workflow_execution_id="test_execution_001",
query="test query",
conversation_id="test_conv_123",
dialogue_count=5,
)
variable_pool = VariablePool(system_variables=system_vars)
# Add some variables to the variable pool
variable_pool.add(["test_node", "test_var"], "test_value")
variable_pool.add(["another_node", "another_var"], 42)
# Create LLM usage with realistic values
llm_usage = LLMUsage(
prompt_tokens=150,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.15"),
completion_tokens=75,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.15"),
total_tokens=225,
total_price=Decimal("0.30"),
currency="USD",
latency=1.25,
)
# Create runtime route state with some node states
node_run_state = RuntimeRouteState()
node_state = node_run_state.create_node_state("test_node_1")
node_run_state.add_route(node_state.id, "target_node_id")
return GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
total_tokens=100,
llm_usage=llm_usage,
outputs={
"string_output": "test result",
"int_output": 42,
"float_output": 3.14,
"list_output": ["item1", "item2", "item3"],
"dict_output": {"key1": "value1", "key2": 123},
"nested_dict": {"level1": {"level2": ["nested", "list", 456]}},
},
node_run_steps=5,
node_run_state=node_run_state,
)
def test_basic_round_trip_serialization():
"""Test basic round-trip serialization ensures GraphRuntimeState values remain unchanged."""
# Create a state with non-empty values
original_state = create_test_graph_runtime_state()
# Serialize to JSON and deserialize back
json_data = original_state.model_dump_json()
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
# Core test: ensure the round-trip preserves all values
assert deserialized_state == original_state
# Serialize to JSON and deserialize back
dict_data = original_state.model_dump(mode="python")
deserialized_state = GraphRuntimeState.model_validate(dict_data)
assert deserialized_state == original_state
# Serialize to JSON and deserialize back
dict_data = original_state.model_dump(mode="json")
deserialized_state = GraphRuntimeState.model_validate(dict_data)
assert deserialized_state == original_state
def test_outputs_field_round_trip():
"""Test the problematic outputs field maintains values through round-trip serialization."""
original_state = create_test_graph_runtime_state()
# Serialize and deserialize
json_data = original_state.model_dump_json()
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
# Verify the outputs field specifically maintains its values
assert deserialized_state.outputs == original_state.outputs
assert deserialized_state == original_state
def test_empty_outputs_round_trip():
"""Test round-trip serialization with empty outputs field."""
variable_pool = VariablePool.empty()
original_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
outputs={}, # Empty outputs
)
json_data = original_state.model_dump_json()
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
assert deserialized_state == original_state
def test_llm_usage_round_trip():
# Create LLM usage with specific decimal values
llm_usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.0015"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.15"),
completion_tokens=50,
completion_unit_price=Decimal("0.003"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.15"),
total_tokens=150,
total_price=Decimal("0.30"),
currency="USD",
latency=2.5,
)
json_data = llm_usage.model_dump_json()
deserialized = LLMUsage.model_validate_json(json_data)
assert deserialized == llm_usage
dict_data = llm_usage.model_dump(mode="python")
deserialized = LLMUsage.model_validate(dict_data)
assert deserialized == llm_usage
dict_data = llm_usage.model_dump(mode="json")
deserialized = LLMUsage.model_validate(dict_data)
assert deserialized == llm_usage

View File

@@ -0,0 +1,401 @@
import json
import uuid
from datetime import UTC, datetime
import pytest
from pydantic import ValidationError
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45)
class TestRouteNodeStateSerialization:
"""Test cases for RouteNodeState Pydantic serialization/deserialization."""
def _test_route_node_state(self):
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
node_run_result = NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"input_key": "input_value"},
outputs={"output_key": "output_value"},
)
node_state = RouteNodeState(
node_id="comprehensive_test_node",
start_at=_TEST_DATETIME,
finished_at=_TEST_DATETIME,
status=RouteNodeState.Status.SUCCESS,
node_run_result=node_run_result,
index=5,
paused_at=_TEST_DATETIME,
paused_by="user_123",
failed_reason="test_reason",
)
return node_state
def test_route_node_state_comprehensive_field_validation(self):
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
node_state = self._test_route_node_state()
serialized = node_state.model_dump()
# Comprehensive validation of all RouteNodeState fields
assert serialized["node_id"] == "comprehensive_test_node"
assert serialized["status"] == RouteNodeState.Status.SUCCESS
assert serialized["start_at"] == _TEST_DATETIME
assert serialized["finished_at"] == _TEST_DATETIME
assert serialized["paused_at"] == _TEST_DATETIME
assert serialized["paused_by"] == "user_123"
assert serialized["failed_reason"] == "test_reason"
assert serialized["index"] == 5
assert "id" in serialized
assert isinstance(serialized["id"], str)
uuid.UUID(serialized["id"]) # Validate UUID format
# Validate nested NodeRunResult structure
assert serialized["node_run_result"] is not None
assert serialized["node_run_result"]["status"] == WorkflowNodeExecutionStatus.SUCCEEDED
assert serialized["node_run_result"]["inputs"] == {"input_key": "input_value"}
assert serialized["node_run_result"]["outputs"] == {"output_key": "output_value"}
def test_route_node_state_minimal_required_fields(self):
"""Test RouteNodeState with only required fields, focusing on defaults."""
node_state = RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME)
serialized = node_state.model_dump()
# Focus on required fields and default values (not re-testing all fields)
assert serialized["node_id"] == "minimal_node"
assert serialized["start_at"] == _TEST_DATETIME
assert serialized["status"] == RouteNodeState.Status.RUNNING # Default status
assert serialized["index"] == 1 # Default index
assert serialized["node_run_result"] is None # Default None
json = node_state.model_dump_json()
deserialized = RouteNodeState.model_validate_json(json)
assert deserialized == node_state
def test_route_node_state_deserialization_from_dict(self):
"""Test RouteNodeState deserialization from dictionary data."""
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
test_id = str(uuid.uuid4())
dict_data = {
"id": test_id,
"node_id": "deserialized_node",
"start_at": test_datetime,
"status": "success",
"finished_at": test_datetime,
"index": 3,
}
node_state = RouteNodeState.model_validate(dict_data)
# Focus on deserialization accuracy
assert node_state.id == test_id
assert node_state.node_id == "deserialized_node"
assert node_state.start_at == test_datetime
assert node_state.status == RouteNodeState.Status.SUCCESS
assert node_state.finished_at == test_datetime
assert node_state.index == 3
def test_route_node_state_round_trip_consistency(self):
node_states = (
self._test_route_node_state(),
RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME),
)
for node_state in node_states:
json = node_state.model_dump_json()
deserialized = RouteNodeState.model_validate_json(json)
assert deserialized == node_state
dict_ = node_state.model_dump(mode="python")
deserialized = RouteNodeState.model_validate(dict_)
assert deserialized == node_state
dict_ = node_state.model_dump(mode="json")
deserialized = RouteNodeState.model_validate(dict_)
assert deserialized == node_state
class TestRouteNodeStateEnumSerialization:
"""Dedicated tests for RouteNodeState Status enum serialization behavior."""
def test_status_enum_model_dump_behavior(self):
"""Test Status enum serialization in model_dump() returns enum objects."""
for status_enum in RouteNodeState.Status:
node_state = RouteNodeState(node_id="enum_test", start_at=_TEST_DATETIME, status=status_enum)
serialized = node_state.model_dump(mode="python")
assert serialized["status"] == status_enum
serialized = node_state.model_dump(mode="json")
assert serialized["status"] == status_enum.value
def test_status_enum_json_serialization_behavior(self):
"""Test Status enum serialization in JSON returns string values."""
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
enum_to_string_mapping = {
RouteNodeState.Status.RUNNING: "running",
RouteNodeState.Status.SUCCESS: "success",
RouteNodeState.Status.FAILED: "failed",
RouteNodeState.Status.PAUSED: "paused",
RouteNodeState.Status.EXCEPTION: "exception",
}
for status_enum, expected_string in enum_to_string_mapping.items():
node_state = RouteNodeState(node_id="json_enum_test", start_at=test_datetime, status=status_enum)
json_data = json.loads(node_state.model_dump_json())
assert json_data["status"] == expected_string
def test_status_enum_deserialization_from_string(self):
"""Test Status enum deserialization from string values."""
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
string_to_enum_mapping = {
"running": RouteNodeState.Status.RUNNING,
"success": RouteNodeState.Status.SUCCESS,
"failed": RouteNodeState.Status.FAILED,
"paused": RouteNodeState.Status.PAUSED,
"exception": RouteNodeState.Status.EXCEPTION,
}
for status_string, expected_enum in string_to_enum_mapping.items():
dict_data = {
"node_id": "enum_deserialize_test",
"start_at": test_datetime,
"status": status_string,
}
node_state = RouteNodeState.model_validate(dict_data)
assert node_state.status == expected_enum
class TestRuntimeRouteStateSerialization:
"""Test cases for RuntimeRouteState Pydantic serialization/deserialization."""
_NODE1_ID = "node_1"
_ROUTE_STATE1_ID = str(uuid.uuid4())
_NODE2_ID = "node_2"
_ROUTE_STATE2_ID = str(uuid.uuid4())
_NODE3_ID = "node_3"
_ROUTE_STATE3_ID = str(uuid.uuid4())
def _get_runtime_route_state(self):
# Create node states with different configurations
node_state_1 = RouteNodeState(
id=self._ROUTE_STATE1_ID,
node_id=self._NODE1_ID,
start_at=_TEST_DATETIME,
index=1,
)
node_state_2 = RouteNodeState(
id=self._ROUTE_STATE2_ID,
node_id=self._NODE2_ID,
start_at=_TEST_DATETIME,
status=RouteNodeState.Status.SUCCESS,
finished_at=_TEST_DATETIME,
index=2,
)
node_state_3 = RouteNodeState(
id=self._ROUTE_STATE3_ID,
node_id=self._NODE3_ID,
start_at=_TEST_DATETIME,
status=RouteNodeState.Status.FAILED,
failed_reason="Test failure",
index=3,
)
runtime_state = RuntimeRouteState(
routes={node_state_1.id: [node_state_2.id, node_state_3.id], node_state_2.id: [node_state_3.id]},
node_state_mapping={
node_state_1.id: node_state_1,
node_state_2.id: node_state_2,
node_state_3.id: node_state_3,
},
)
return runtime_state
def test_runtime_route_state_comprehensive_structure_validation(self):
"""Test comprehensive RuntimeRouteState serialization with full structure validation."""
runtime_state = self._get_runtime_route_state()
serialized = runtime_state.model_dump()
# Comprehensive validation of RuntimeRouteState structure
assert "routes" in serialized
assert "node_state_mapping" in serialized
assert isinstance(serialized["routes"], dict)
assert isinstance(serialized["node_state_mapping"], dict)
# Validate routes dictionary structure and content
assert len(serialized["routes"]) == 2
assert self._ROUTE_STATE1_ID in serialized["routes"]
assert self._ROUTE_STATE2_ID in serialized["routes"]
assert serialized["routes"][self._ROUTE_STATE1_ID] == [self._ROUTE_STATE2_ID, self._ROUTE_STATE3_ID]
assert serialized["routes"][self._ROUTE_STATE2_ID] == [self._ROUTE_STATE3_ID]
# Validate node_state_mapping dictionary structure and content
assert len(serialized["node_state_mapping"]) == 3
for state_id in [
self._ROUTE_STATE1_ID,
self._ROUTE_STATE2_ID,
self._ROUTE_STATE3_ID,
]:
assert state_id in serialized["node_state_mapping"]
node_data = serialized["node_state_mapping"][state_id]
node_state = runtime_state.node_state_mapping[state_id]
assert node_data["node_id"] == node_state.node_id
assert node_data["status"] == node_state.status
assert node_data["index"] == node_state.index
def test_runtime_route_state_empty_collections(self):
"""Test RuntimeRouteState with empty collections, focusing on default behavior."""
runtime_state = RuntimeRouteState()
serialized = runtime_state.model_dump()
# Focus on default empty collection behavior
assert serialized["routes"] == {}
assert serialized["node_state_mapping"] == {}
assert isinstance(serialized["routes"], dict)
assert isinstance(serialized["node_state_mapping"], dict)
def test_runtime_route_state_json_serialization_structure(self):
"""Test RuntimeRouteState JSON serialization structure."""
node_state = RouteNodeState(node_id="json_node", start_at=_TEST_DATETIME)
runtime_state = RuntimeRouteState(
routes={"source": ["target1", "target2"]}, node_state_mapping={node_state.id: node_state}
)
json_str = runtime_state.model_dump_json()
json_data = json.loads(json_str)
# Focus on JSON structure validation
assert isinstance(json_str, str)
assert isinstance(json_data, dict)
assert "routes" in json_data
assert "node_state_mapping" in json_data
assert json_data["routes"]["source"] == ["target1", "target2"]
assert node_state.id in json_data["node_state_mapping"]
def test_runtime_route_state_deserialization_from_dict(self):
"""Test RuntimeRouteState deserialization from dictionary data."""
node_id = str(uuid.uuid4())
dict_data = {
"routes": {"source_node": ["target_node_1", "target_node_2"]},
"node_state_mapping": {
node_id: {
"id": node_id,
"node_id": "test_node",
"start_at": _TEST_DATETIME,
"status": "running",
"index": 1,
}
},
}
runtime_state = RuntimeRouteState.model_validate(dict_data)
# Focus on deserialization accuracy
assert runtime_state.routes == {"source_node": ["target_node_1", "target_node_2"]}
assert len(runtime_state.node_state_mapping) == 1
assert node_id in runtime_state.node_state_mapping
deserialized_node = runtime_state.node_state_mapping[node_id]
assert deserialized_node.node_id == "test_node"
assert deserialized_node.status == RouteNodeState.Status.RUNNING
assert deserialized_node.index == 1
def test_runtime_route_state_round_trip_consistency(self):
"""Test RuntimeRouteState round-trip serialization consistency."""
original = self._get_runtime_route_state()
# Dictionary round trip
dict_data = original.model_dump(mode="python")
reconstructed = RuntimeRouteState.model_validate(dict_data)
assert reconstructed == original
dict_data = original.model_dump(mode="json")
reconstructed = RuntimeRouteState.model_validate(dict_data)
assert reconstructed == original
# JSON round trip
json_str = original.model_dump_json()
json_reconstructed = RuntimeRouteState.model_validate_json(json_str)
assert json_reconstructed == original
class TestSerializationEdgeCases:
"""Test edge cases and error conditions for serialization/deserialization."""
def test_invalid_status_deserialization(self):
"""Test deserialization with invalid status values."""
test_datetime = _TEST_DATETIME
invalid_data = {
"node_id": "invalid_test",
"start_at": test_datetime,
"status": "invalid_status",
}
with pytest.raises(ValidationError) as exc_info:
RouteNodeState.model_validate(invalid_data)
assert "status" in str(exc_info.value)
def test_missing_required_fields_deserialization(self):
"""Test deserialization with missing required fields."""
incomplete_data = {"id": str(uuid.uuid4())}
with pytest.raises(ValidationError) as exc_info:
RouteNodeState.model_validate(incomplete_data)
error_str = str(exc_info.value)
assert "node_id" in error_str or "start_at" in error_str
def test_invalid_datetime_deserialization(self):
"""Test deserialization with invalid datetime values."""
invalid_data = {
"node_id": "datetime_test",
"start_at": "invalid_datetime",
"status": "running",
}
with pytest.raises(ValidationError) as exc_info:
RouteNodeState.model_validate(invalid_data)
assert "start_at" in str(exc_info.value)
def test_invalid_routes_structure_deserialization(self):
"""Test RuntimeRouteState deserialization with invalid routes structure."""
invalid_data = {
"routes": "invalid_routes_structure", # Should be dict
"node_state_mapping": {},
}
with pytest.raises(ValidationError) as exc_info:
RuntimeRouteState.model_validate(invalid_data)
assert "routes" in str(exc_info.value)
def test_timezone_handling_in_datetime_fields(self):
"""Test timezone handling in datetime field serialization."""
utc_datetime = datetime.now(UTC)
naive_datetime = utc_datetime.replace(tzinfo=None)
node_state = RouteNodeState(node_id="timezone_test", start_at=naive_datetime)
dict_ = node_state.model_dump()
assert dict_["start_at"] == naive_datetime
# Test round trip
reconstructed = RouteNodeState.model_validate(dict_)
assert reconstructed.start_at == naive_datetime
assert reconstructed.start_at.tzinfo is None
json = node_state.model_dump_json()
reconstructed = RouteNodeState.model_validate_json(json)
assert reconstructed.start_at == naive_datetime
assert reconstructed.start_at.tzinfo is None

View File

@@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
BaseNodeEvent,
GraphRunFailedEvent,
@@ -27,6 +26,7 @@ from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -171,7 +171,8 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
system_variables=SystemVariable(user_id="aaa", app_id="1", workflow_id="1", files=[]),
user_inputs={"query": "hi"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
@@ -293,12 +294,12 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather in SF",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="what's the weather in SF",
conversation_id="abababa",
),
user_inputs={},
)
@@ -474,12 +475,12 @@ def test_run_branch(mock_close, mock_remove):
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "hi",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="hi",
conversation_id="abababa",
),
user_inputs={"uid": "takato"},
)
@@ -804,18 +805,22 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "dify",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "1",
},
system_variables=SystemVariable(
user_id="1",
files=[],
query="dify",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
system_variables=SystemVariable(
user_id="aaa",
files=[],
),
user_inputs={"query": "hi"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())

View File

@@ -5,11 +5,11 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -51,7 +51,7 @@ def test_execute_answer():
# construct variable pool
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
)

View File

@@ -3,7 +3,6 @@ from collections.abc import Generator
from datetime import UTC, datetime
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
@@ -15,6 +14,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.system_variable import SystemVariable
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
@@ -180,12 +180,12 @@ def test_process():
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather in SF",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="what's the weather in SF",
conversation_id="abababa",
),
user_inputs={},
)

View File

@@ -7,12 +7,13 @@ from core.workflow.nodes.http_request import (
)
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.system_variable import SystemVariable
def test_executor_with_json_body_and_number_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "number"], 42)
@@ -65,7 +66,7 @@ def test_executor_with_json_body_and_number_variable():
def test_executor_with_json_body_and_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
@@ -120,7 +121,7 @@ def test_executor_with_json_body_and_object_variable():
def test_executor_with_json_body_and_nested_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
@@ -174,7 +175,7 @@ def test_executor_with_json_body_and_nested_object_variable():
def test_extract_selectors_from_template_with_newline():
variable_pool = VariablePool()
variable_pool = VariablePool(system_variables=SystemVariable.empty())
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
@@ -201,7 +202,7 @@ def test_extract_selectors_from_template_with_newline():
def test_executor_with_form_data():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "text_field"], "Hello, World!")
@@ -280,7 +281,11 @@ def test_init_headers():
authorization=HttpRequestNodeAuthorization(type="no-auth"),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
return Executor(
node_data=node_data,
timeout=timeout,
variable_pool=VariablePool(system_variables=SystemVariable.empty()),
)
executor = create_executor("aa\n cc:")
executor._init_headers()
@@ -310,7 +315,11 @@ def test_init_params():
authorization=HttpRequestNodeAuthorization(type="no-auth"),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
return Executor(
node_data=node_data,
timeout=timeout,
variable_pool=VariablePool(system_variables=SystemVariable.empty()),
)
# Test basic key-value pairs
executor = create_executor("key1:value1\nkey2:value2")

View File

@@ -15,6 +15,7 @@ from core.workflow.nodes.http_request import (
HttpRequestNodeBody,
HttpRequestNodeData,
)
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -40,7 +41,7 @@ def test_http_request_node_binary_file(monkeypatch):
),
)
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(
@@ -128,7 +129,7 @@ def test_http_request_node_form_with_file(monkeypatch):
),
)
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(
@@ -223,7 +224,7 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
)
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)

View File

@@ -7,7 +7,6 @@ from core.variables.segments import ArrayAnySegment, ArrayStringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
@@ -15,6 +14,7 @@ from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -151,12 +151,12 @@ def test_run():
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "dify",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "1",
},
system_variables=SystemVariable(
user_id="1",
files=[],
query="dify",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
)
@@ -368,12 +368,12 @@ def test_run_parallel():
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "dify",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "1",
},
system_variables=SystemVariable(
user_id="1",
files=[],
query="dify",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
)
@@ -584,12 +584,12 @@ def test_iteration_run_in_parallel_mode():
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "dify",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "1",
},
system_variables=SystemVariable(
user_id="1",
files=[],
query="dify",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
)
@@ -808,12 +808,12 @@ def test_iteration_run_error_handle():
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "dify",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "1",
},
system_variables=SystemVariable(
user_id="1",
files=[],
query="dify",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
)

View File

@@ -36,6 +36,7 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.llm.file_saver import LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.provider import ProviderType
from models.workflow import WorkflowType
@@ -104,7 +105,7 @@ def graph() -> Graph:
@pytest.fixture
def graph_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
return GraphRuntimeState(
@@ -181,7 +182,7 @@ def test_fetch_files_with_file_segment():
related_id="1",
storage_key="",
)
variable_pool = VariablePool()
variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], file)
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@@ -209,7 +210,7 @@ def test_fetch_files_with_array_file_segment():
storage_key="",
),
]
variable_pool = VariablePool()
variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@@ -217,7 +218,7 @@ def test_fetch_files_with_array_file_segment():
def test_fetch_files_with_none_segment():
variable_pool = VariablePool()
variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], NoneSegment())
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@@ -225,7 +226,7 @@ def test_fetch_files_with_none_segment():
def test_fetch_files_with_array_any_segment():
variable_pool = VariablePool()
variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@@ -233,7 +234,7 @@ def test_fetch_files_with_array_any_segment():
def test_fetch_files_with_non_existent_variable():
variable_pool = VariablePool()
variable_pool = VariablePool.empty()
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == []

View File

@@ -5,11 +5,11 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -53,7 +53,7 @@ def test_execute_answer():
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],

View File

@@ -5,7 +5,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
NodeRunExceptionEvent,
@@ -17,6 +16,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -167,12 +167,12 @@ class ContinueOnErrorTestHelper:
"""Helper method to create a graph engine instance for testing"""
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "clear",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="clear",
conversation_id="abababa",
),
user_inputs=user_inputs or {"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())

View File

@@ -7,12 +7,12 @@ from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.system_variable import SystemVariable
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
from extensions.ext_database import db
from models.enums import UserFrom
@@ -37,9 +37,7 @@ def test_execute_if_else_result_true():
)
# construct variable pool
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={}
)
pool = VariablePool(system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={})
pool.add(["start", "array_contains"], ["ab", "def"])
pool.add(["start", "array_not_contains"], ["ac", "def"])
pool.add(["start", "contains"], "cabcde")
@@ -157,7 +155,7 @@ def test_execute_if_else_result_false():
# construct variable pool
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
)

View File

@@ -15,6 +15,7 @@ from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.system_variable import SystemVariable
from models import UserFrom, WorkflowType
@@ -34,7 +35,7 @@ def _create_tool_node():
version="1",
)
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
)
node = ToolNode(

View File

@@ -7,12 +7,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable, StringVariable
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -68,7 +68,7 @@ def test_overwrite_string_variable():
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
system_variables=SystemVariable(conversation_id=conversation_id),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -165,7 +165,7 @@ def test_append_variable_to_array():
conversation_id = str(uuid.uuid4())
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
system_variables=SystemVariable(conversation_id=conversation_id),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -256,7 +256,7 @@ def test_clear_array():
conversation_id = str(uuid.uuid4())
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
system_variables=SystemVariable(conversation_id=conversation_id),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],

View File

@@ -5,12 +5,12 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -109,7 +109,7 @@ def test_remove_first_from_array():
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -196,7 +196,7 @@ def test_remove_last_from_array():
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -275,7 +275,7 @@ def test_remove_first_from_empty_array():
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -354,7 +354,7 @@ def test_remove_last_from_empty_array():
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],

View File

@@ -0,0 +1,251 @@
import json
from typing import Any
import pytest
from pydantic import ValidationError
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
from core.workflow.system_variable import SystemVariable
# Test data constants for SystemVariable serialization tests
VALID_BASE_DATA: dict[str, Any] = {
"user_id": "a20f06b1-8703-45ab-937c-860a60072113",
"app_id": "661bed75-458d-49c9-b487-fda0762677b9",
"workflow_id": "d31f2136-b292-4ae0-96d4-1e77894a4f43",
}
COMPLETE_VALID_DATA: dict[str, Any] = {
**VALID_BASE_DATA,
"query": "test query",
"files": [],
"conversation_id": "91f1eb7d-69f4-4d7b-b82f-4003d51744b9",
"dialogue_count": 5,
"workflow_run_id": "eb4704b5-2274-47f2-bfcd-0452daa82cb5",
}
def create_test_file() -> File:
"""Create a test File object for serialization tests."""
return File(
tenant_id="test-tenant-id",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test-file-id",
filename="test.txt",
extension=".txt",
mime_type="text/plain",
size=1024,
storage_key="test-storage-key",
)
class TestSystemVariableSerialization:
"""Focused tests for SystemVariable serialization/deserialization logic."""
def test_basic_deserialization(self):
"""Test successful deserialization from JSON structure with all fields correctly mapped."""
# Test with complete data
system_var = SystemVariable(**COMPLETE_VALID_DATA)
# Verify all fields are correctly mapped
assert system_var.user_id == COMPLETE_VALID_DATA["user_id"]
assert system_var.app_id == COMPLETE_VALID_DATA["app_id"]
assert system_var.workflow_id == COMPLETE_VALID_DATA["workflow_id"]
assert system_var.query == COMPLETE_VALID_DATA["query"]
assert system_var.conversation_id == COMPLETE_VALID_DATA["conversation_id"]
assert system_var.dialogue_count == COMPLETE_VALID_DATA["dialogue_count"]
assert system_var.workflow_execution_id == COMPLETE_VALID_DATA["workflow_run_id"]
assert system_var.files == []
# Test with minimal data (only required fields)
minimal_var = SystemVariable(**VALID_BASE_DATA)
assert minimal_var.user_id == VALID_BASE_DATA["user_id"]
assert minimal_var.app_id == VALID_BASE_DATA["app_id"]
assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"]
assert minimal_var.query is None
assert minimal_var.conversation_id is None
assert minimal_var.dialogue_count is None
assert minimal_var.workflow_execution_id is None
assert minimal_var.files == []
def test_alias_handling(self):
"""Test workflow_execution_id vs workflow_run_id alias resolution - core deserialization logic."""
workflow_id = "eb4704b5-2274-47f2-bfcd-0452daa82cb5"
# Test workflow_run_id only (preferred alias)
data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
system_var1 = SystemVariable(**data_run_id)
assert system_var1.workflow_execution_id == workflow_id
# Test workflow_execution_id only (direct field name)
data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
system_var2 = SystemVariable(**data_execution_id)
assert system_var2.workflow_execution_id == workflow_id
# Test both present - workflow_run_id should take precedence
data_both = {
**VALID_BASE_DATA,
"workflow_execution_id": "should-be-ignored",
"workflow_run_id": workflow_id,
}
system_var3 = SystemVariable(**data_both)
assert system_var3.workflow_execution_id == workflow_id
# Test neither present - should be None
system_var4 = SystemVariable(**VALID_BASE_DATA)
assert system_var4.workflow_execution_id is None
def test_serialization_round_trip(self):
"""Test that serialize → deserialize produces the same result with alias handling."""
# Create original SystemVariable
original = SystemVariable(**COMPLETE_VALID_DATA)
# Serialize to dict
serialized = original.model_dump(mode="json")
# Verify alias is used in serialization (workflow_run_id, not workflow_execution_id)
assert "workflow_run_id" in serialized
assert "workflow_execution_id" not in serialized
assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
# Deserialize back
deserialized = SystemVariable(**serialized)
# Verify all fields match after round-trip
assert deserialized.user_id == original.user_id
assert deserialized.app_id == original.app_id
assert deserialized.workflow_id == original.workflow_id
assert deserialized.query == original.query
assert deserialized.conversation_id == original.conversation_id
assert deserialized.dialogue_count == original.dialogue_count
assert deserialized.workflow_execution_id == original.workflow_execution_id
assert list(deserialized.files) == list(original.files)
def test_json_round_trip(self):
"""Test JSON serialization/deserialization consistency with proper structure."""
# Create original SystemVariable
original = SystemVariable(**COMPLETE_VALID_DATA)
# Serialize to JSON string
json_str = original.model_dump_json()
# Parse JSON and verify structure
json_data = json.loads(json_str)
assert "workflow_run_id" in json_data
assert "workflow_execution_id" not in json_data
assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
# Deserialize from JSON data
deserialized = SystemVariable(**json_data)
# Verify key fields match after JSON round-trip
assert deserialized.workflow_execution_id == original.workflow_execution_id
assert deserialized.user_id == original.user_id
assert deserialized.app_id == original.app_id
assert deserialized.workflow_id == original.workflow_id
def test_files_field_deserialization(self):
"""Test deserialization with File objects in the files field - SystemVariable specific logic."""
# Test with empty files list
data_empty = {**VALID_BASE_DATA, "files": []}
system_var_empty = SystemVariable(**data_empty)
assert system_var_empty.files == []
# Test with single File object
test_file = create_test_file()
data_single = {**VALID_BASE_DATA, "files": [test_file]}
system_var_single = SystemVariable(**data_single)
assert len(system_var_single.files) == 1
assert system_var_single.files[0].filename == "test.txt"
assert system_var_single.files[0].tenant_id == "test-tenant-id"
# Test with multiple File objects
file1 = File(
tenant_id="tenant1",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file1",
filename="doc1.txt",
storage_key="key1",
)
file2 = File(
tenant_id="tenant2",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
filename="image.jpg",
storage_key="key2",
)
data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]}
system_var_multiple = SystemVariable(**data_multiple)
assert len(system_var_multiple.files) == 2
assert system_var_multiple.files[0].filename == "doc1.txt"
assert system_var_multiple.files[1].filename == "image.jpg"
# Verify files field serialization/deserialization
serialized = system_var_multiple.model_dump(mode="json")
deserialized = SystemVariable(**serialized)
assert len(deserialized.files) == 2
assert deserialized.files[0].filename == "doc1.txt"
assert deserialized.files[1].filename == "image.jpg"
def test_alias_serialization_consistency(self):
"""Test that alias handling works consistently in both serialization directions."""
workflow_id = "test-workflow-id"
# Create with workflow_run_id (alias)
data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
system_var = SystemVariable(**data_with_alias)
# Serialize and verify alias is used
serialized = system_var.model_dump()
assert serialized["workflow_run_id"] == workflow_id
assert "workflow_execution_id" not in serialized
# Deserialize and verify field mapping
deserialized = SystemVariable(**serialized)
assert deserialized.workflow_execution_id == workflow_id
# Test JSON serialization path
json_serialized = json.loads(system_var.model_dump_json())
assert json_serialized["workflow_run_id"] == workflow_id
assert "workflow_execution_id" not in json_serialized
json_deserialized = SystemVariable(**json_serialized)
assert json_deserialized.workflow_execution_id == workflow_id
def test_model_validator_serialization_logic(self):
"""Test the custom model validator behavior for serialization scenarios."""
workflow_id = "test-workflow-execution-id"
# Test direct instantiation with workflow_execution_id (should work)
data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
system_var1 = SystemVariable(**data1)
assert system_var1.workflow_execution_id == workflow_id
# Test serialization of the above (should use alias)
serialized1 = system_var1.model_dump()
assert "workflow_run_id" in serialized1
assert serialized1["workflow_run_id"] == workflow_id
# Test both present - workflow_run_id takes precedence (validator logic)
data2 = {
**VALID_BASE_DATA,
"workflow_execution_id": "should-be-removed",
"workflow_run_id": workflow_id,
}
system_var2 = SystemVariable(**data2)
assert system_var2.workflow_execution_id == workflow_id
# Verify serialization consistency
serialized2 = system_var2.model_dump()
assert serialized2["workflow_run_id"] == workflow_id
def test_constructor_with_extra_key():
# Test that SystemVariable should forbid extra keys
with pytest.raises(ValidationError):
# This should fail because there is an unexpected key.
SystemVariable(invalid_key=1) # type: ignore

View File

@@ -1,17 +1,43 @@
import uuid
from collections import defaultdict
import pytest
from pydantic import ValidationError
from core.file import File, FileTransferMethod, FileType
from core.variables import FileSegment, StringSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
from core.variables.segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
)
from core.variables.variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
StringVariable,
VariableUnion,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.system_variable import SystemVariable
from factories.variable_factory import build_segment, segment_to_variable
@pytest.fixture
def pool():
return VariablePool(system_variables={}, user_inputs={})
return VariablePool(
system_variables=SystemVariable(user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"),
user_inputs={},
)
@pytest.fixture
@@ -52,18 +78,28 @@ def test_use_long_selector(pool):
class TestVariablePool:
def test_constructor(self):
pool = VariablePool()
# Test with minimal required SystemVariable
minimal_system_vars = SystemVariable(
user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"
)
pool = VariablePool(system_variables=minimal_system_vars)
# Test with all parameters
pool = VariablePool(
variable_dictionary={},
user_inputs={},
system_variables={},
system_variables=minimal_system_vars,
environment_variables=[],
conversation_variables=[],
)
# Test with more complex SystemVariable
complex_system_vars = SystemVariable(
user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"
)
pool = VariablePool(
user_inputs={"key": "value"},
system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"},
system_variables=complex_system_vars,
environment_variables=[
segment_to_variable(
segment=build_segment(1),
@@ -80,6 +116,323 @@ class TestVariablePool:
],
)
def test_constructor_with_invalid_system_variable_key(self):
with pytest.raises(ValidationError):
VariablePool(system_variables={"invalid_key": "value"}) # type: ignore
def test_get_system_variables(self):
sys_var = SystemVariable(
user_id="test_user_id",
app_id="test_app_id",
workflow_id="test_workflow_id",
workflow_execution_id="test_execution_123",
query="test query",
conversation_id="test_conv_id",
dialogue_count=5,
)
pool = VariablePool(system_variables=sys_var)
kv = [
("user_id", sys_var.user_id),
("app_id", sys_var.app_id),
("workflow_id", sys_var.workflow_id),
("workflow_run_id", sys_var.workflow_execution_id),
("query", sys_var.query),
("conversation_id", sys_var.conversation_id),
("dialogue_count", sys_var.dialogue_count),
]
for key, expected_value in kv:
segment = pool.get([SYSTEM_VARIABLE_NODE_ID, key])
assert segment is not None
assert segment.value == expected_value
class TestVariablePoolSerialization:
"""Test cases for VariablePool serialization and deserialization using Pydantic's built-in methods.
These tests focus exclusively on serialization/deserialization logic to ensure that
VariablePool data can be properly serialized to dictionaries/JSON and reconstructed
while preserving all data integrity.
"""
_NODE1_ID = "node_1"
_NODE2_ID = "node_2"
_NODE3_ID = "node_3"
def _create_pool_without_file(self):
# Create comprehensive system variables
system_vars = SystemVariable(
user_id="test_user_id",
app_id="test_app_id",
workflow_id="test_workflow_id",
workflow_execution_id="test_execution_123",
query="test query",
conversation_id="test_conv_id",
dialogue_count=5,
)
# Create environment variables with all types including ArrayFileVariable
env_vars: list[VariableUnion] = [
StringVariable(
id="env_string_id",
name="env_string",
value="env_string_value",
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_string"],
),
IntegerVariable(
id="env_integer_id",
name="env_integer",
value=1,
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_integer"],
),
FloatVariable(
id="env_float_id",
name="env_float",
value=1.0,
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_float"],
),
]
# Create conversation variables with complex data
conv_vars: list[VariableUnion] = [
StringVariable(
id="conv_string_id",
name="conv_string",
value="conv_string_value",
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_string"],
),
IntegerVariable(
id="conv_integer_id",
name="conv_integer",
value=1,
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_integer"],
),
FloatVariable(
id="conv_float_id",
name="conv_float",
value=1.0,
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_float"],
),
ObjectVariable(
id="conv_object_id",
name="conv_object",
value={"key": "value", "nested": {"data": 123}},
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_object"],
),
ArrayStringVariable(
id="conv_array_string_id",
name="conv_array_string",
value=["conv_array_string_value"],
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_string"],
),
ArrayNumberVariable(
id="conv_array_number_id",
name="conv_array_number",
value=[1, 1.0],
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_number"],
),
ArrayObjectVariable(
id="conv_array_object_id",
name="conv_array_object",
value=[{"a": 1}, {"b": "2"}],
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_object"],
),
]
# Create comprehensive user inputs
user_inputs = {
"string_input": "test_value",
"number_input": 42,
"object_input": {"nested": {"key": "value"}},
"array_input": ["item1", "item2", "item3"],
}
# Create VariablePool
pool = VariablePool(
system_variables=system_vars,
user_inputs=user_inputs,
environment_variables=env_vars,
conversation_variables=conv_vars,
)
return pool
def _add_node_data_to_pool(self, pool: VariablePool, with_file=False):
test_file = File(
tenant_id="test_tenant_id",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_related_id",
remote_url="test_url",
filename="test_file.txt",
storage_key="test_storage_key",
)
# Add various segment types to variable dictionary
pool.add((self._NODE1_ID, "string_var"), StringSegment(value="test_string"))
pool.add((self._NODE1_ID, "int_var"), IntegerSegment(value=123))
pool.add((self._NODE1_ID, "float_var"), FloatSegment(value=45.67))
pool.add((self._NODE1_ID, "object_var"), ObjectSegment(value={"test": "data"}))
if with_file:
pool.add((self._NODE1_ID, "file_var"), FileSegment(value=test_file))
pool.add((self._NODE1_ID, "none_var"), NoneSegment())
# Add array segments including ArrayFileVariable
pool.add((self._NODE2_ID, "array_string"), ArrayStringSegment(value=["a", "b", "c"]))
pool.add((self._NODE2_ID, "array_number"), ArrayNumberSegment(value=[1, 2, 3]))
pool.add((self._NODE2_ID, "array_object"), ArrayObjectSegment(value=[{"a": 1}, {"b": 2}]))
if with_file:
pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
# Add nested variables
pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value"))
def test_system_variables(self):
sys_vars = SystemVariable(
user_id="test_user_id",
app_id="test_app_id",
workflow_id="test_workflow_id",
workflow_execution_id="test_execution_123",
query="test query",
conversation_id="test_conv_id",
dialogue_count=5,
)
pool = VariablePool(system_variables=sys_vars)
json = pool.model_dump_json()
pool2 = VariablePool.model_validate_json(json)
assert pool2.system_variables == sys_vars
for mode in ["json", "python"]:
dict_ = pool.model_dump(mode=mode)
pool2 = VariablePool.model_validate(dict_)
assert pool2.system_variables == sys_vars
def test_pool_without_file_vars(self):
pool = self._create_pool_without_file()
json = pool.model_dump_json()
pool2 = pool.model_validate_json(json)
assert pool2.system_variables == pool.system_variables
assert pool2.conversation_variables == pool.conversation_variables
assert pool2.environment_variables == pool.environment_variables
assert pool2.user_inputs == pool.user_inputs
assert pool2.variable_dictionary == pool.variable_dictionary
assert pool2 == pool
def test_basic_dictionary_round_trip(self):
"""Test basic round-trip serialization: model_dump() → model_validate()"""
# Create a comprehensive VariablePool with all data types
original_pool = self._create_pool_without_file()
self._add_node_data_to_pool(original_pool)
# Serialize to dictionary using Pydantic's model_dump()
serialized_data = original_pool.model_dump()
# Verify serialized data structure
assert isinstance(serialized_data, dict)
assert "system_variables" in serialized_data
assert "user_inputs" in serialized_data
assert "environment_variables" in serialized_data
assert "conversation_variables" in serialized_data
assert "variable_dictionary" in serialized_data
# Deserialize back using Pydantic's model_validate()
reconstructed_pool = VariablePool.model_validate(serialized_data)
# Verify data integrity is preserved
self._assert_pools_equal(original_pool, reconstructed_pool)
def test_json_round_trip(self):
"""Test JSON round-trip serialization: model_dump_json() → model_validate_json()"""
# Create a comprehensive VariablePool with all data types
original_pool = self._create_pool_without_file()
self._add_node_data_to_pool(original_pool)
# Serialize to JSON string using Pydantic's model_dump_json()
json_data = original_pool.model_dump_json()
# Verify JSON is valid string
assert isinstance(json_data, str)
assert len(json_data) > 0
# Deserialize back using Pydantic's model_validate_json()
reconstructed_pool = VariablePool.model_validate_json(json_data)
# Verify data integrity is preserved
self._assert_pools_equal(original_pool, reconstructed_pool)
def test_complex_data_serialization(self):
"""Test serialization of complex data structures including ArrayFileVariable"""
original_pool = self._create_pool_without_file()
self._add_node_data_to_pool(original_pool, with_file=True)
# Test dictionary round-trip
dict_data = original_pool.model_dump()
reconstructed_dict = VariablePool.model_validate(dict_data)
# Test JSON round-trip
json_data = original_pool.model_dump_json()
reconstructed_json = VariablePool.model_validate_json(json_data)
# Verify both reconstructed pools are equivalent
self._assert_pools_equal(reconstructed_dict, reconstructed_json)
# TODO: assert the data for file object...
def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool) -> None:
"""Assert that two VariablePools contain equivalent data"""
# Compare system variables
assert pool1.system_variables == pool2.system_variables
# Compare user inputs
assert dict(pool1.user_inputs) == dict(pool2.user_inputs)
# Compare environment variables count
assert pool1.environment_variables == pool2.environment_variables
# Compare conversation variables count
assert pool1.conversation_variables == pool2.conversation_variables
# Test key variable retrievals to ensure functionality is preserved
test_selectors = [
(SYSTEM_VARIABLE_NODE_ID, "user_id"),
(SYSTEM_VARIABLE_NODE_ID, "app_id"),
(ENVIRONMENT_VARIABLE_NODE_ID, "env_string"),
(ENVIRONMENT_VARIABLE_NODE_ID, "env_number"),
(CONVERSATION_VARIABLE_NODE_ID, "conv_string"),
(self._NODE1_ID, "string_var"),
(self._NODE1_ID, "int_var"),
(self._NODE1_ID, "float_var"),
(self._NODE2_ID, "array_string"),
(self._NODE2_ID, "array_number"),
(self._NODE3_ID, "nested", "deep", "var"),
]
for selector in test_selectors:
val1 = pool1.get(selector)
val2 = pool2.get(selector)
# Both should exist or both should be None
assert (val1 is None) == (val2 is None)
if val1 is not None and val2 is not None:
# Values should be equal
assert val1.value == val2.value
# Value types should be the same (more important than exact class type)
assert val1.value_type == val2.value_type
def test_variable_pool_deserialization_default_dict(self):
variable_pool = VariablePool(
user_inputs={"a": 1, "b": "2"},
system_variables=SystemVariable(workflow_id=str(uuid.uuid4())),
environment_variables=[
StringVariable(name="str_var", value="a"),
],
conversation_variables=[IntegerVariable(name="int_var", value=1)],
)
assert isinstance(variable_pool.variable_dictionary, defaultdict)
json = variable_pool.model_dump_json()
loaded = VariablePool.model_validate_json(json)
assert isinstance(loaded.variable_dictionary, defaultdict)
loaded.add(["non_exist_node", "a"], 1)
pool_dict = variable_pool.model_dump()
loaded = VariablePool.model_validate(pool_dict)
assert isinstance(loaded.variable_dictionary, defaultdict)
loaded.add(["non_exist_node", "a"], 1)

View File

@@ -18,10 +18,10 @@ from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from models.enums import CreatorUserRole
from models.model import AppMode
@@ -67,14 +67,14 @@ def real_app_generate_entity():
@pytest.fixture
def real_workflow_system_variables():
return {
SystemVariableKey.QUERY: "test query",
SystemVariableKey.CONVERSATION_ID: "test-conversation-id",
SystemVariableKey.USER_ID: "test-user-id",
SystemVariableKey.APP_ID: "test-app-id",
SystemVariableKey.WORKFLOW_ID: "test-workflow-id",
SystemVariableKey.WORKFLOW_EXECUTION_ID: "test-workflow-run-id",
}
return SystemVariable(
query="test query",
conversation_id="test-conversation-id",
user_id="test-user-id",
app_id="test-app-id",
workflow_id="test-workflow-id",
workflow_execution_id="test-workflow-run-id",
)
@pytest.fixture

View File

@@ -10,7 +10,7 @@ class TestAppendVariablesRecursively:
def test_append_simple_dict_value(self):
"""Test appending a simple dictionary value"""
pool = VariablePool()
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["output"]
variable_value = {"name": "John", "age": 30}
@@ -33,7 +33,7 @@ class TestAppendVariablesRecursively:
def test_append_object_segment_value(self):
"""Test appending an ObjectSegment value"""
pool = VariablePool()
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["result"]
@@ -60,7 +60,7 @@ class TestAppendVariablesRecursively:
def test_append_nested_dict_value(self):
"""Test appending a nested dictionary value"""
pool = VariablePool()
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["data"]
@@ -97,7 +97,7 @@ class TestAppendVariablesRecursively:
def test_append_non_dict_value(self):
"""Test appending a non-dictionary value (should not recurse)"""
pool = VariablePool()
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["simple"]
variable_value = "simple_string"
@@ -114,7 +114,7 @@ class TestAppendVariablesRecursively:
def test_append_segment_non_object_value(self):
"""Test appending a Segment that is not ObjectSegment (should not recurse)"""
pool = VariablePool()
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["text"]
variable_value = StringSegment(value="Hello World")
@@ -132,7 +132,7 @@ class TestAppendVariablesRecursively:
def test_append_empty_dict_value(self):
"""Test appending an empty dictionary value"""
pool = VariablePool()
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["empty"]
variable_value: dict[str, Any] = {}