refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)
refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025) This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization. Key changes: - Introduce distinct `value_type` tags for Integer and Float segments/variables - Add `VariableUnion` and `SegmentUnion` types for proper type discrimination - Leverage Pydantic's discriminated union feature for seamless serialization/deserialization - Enable accurate serialization of data structures containing these types Closes #22024.
This commit is contained in:
@@ -1,14 +1,49 @@
|
||||
import dataclasses
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.helper import encrypter
|
||||
from core.variables import SecretVariable, StringVariable
|
||||
from core.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
SegmentUnion,
|
||||
StringSegment,
|
||||
get_segment_discriminator,
|
||||
)
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import (
|
||||
ArrayAnyVariable,
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
NoneVariable,
|
||||
ObjectVariable,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
VariableUnion,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def test_segment_group_to_text():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey("user_id"): "fake-user-id",
|
||||
},
|
||||
system_variables=SystemVariable(user_id="fake-user-id"),
|
||||
user_inputs={},
|
||||
environment_variables=[
|
||||
SecretVariable(name="secret_key", value="fake-secret-key"),
|
||||
@@ -30,7 +65,7 @@ def test_segment_group_to_text():
|
||||
|
||||
def test_convert_constant_to_segment_group():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
@@ -43,9 +78,7 @@ def test_convert_constant_to_segment_group():
|
||||
|
||||
def test_convert_variable_to_segment_group():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey("user_id"): "fake-user-id",
|
||||
},
|
||||
system_variables=SystemVariable(user_id="fake-user-id"),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
@@ -56,3 +89,297 @@ def test_convert_variable_to_segment_group():
|
||||
assert segments_group.log == "fake-user-id"
|
||||
assert isinstance(segments_group.value[0], StringVariable)
|
||||
assert segments_group.value[0].value == "fake-user-id"
|
||||
|
||||
|
||||
class _Segments(BaseModel):
|
||||
segments: list[SegmentUnion]
|
||||
|
||||
|
||||
class _Variables(BaseModel):
|
||||
variables: list[VariableUnion]
|
||||
|
||||
|
||||
def create_test_file(
|
||||
file_type: FileType = FileType.DOCUMENT,
|
||||
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
|
||||
filename: str = "test.txt",
|
||||
extension: str = ".txt",
|
||||
mime_type: str = "text/plain",
|
||||
size: int = 1024,
|
||||
) -> File:
|
||||
"""Factory function to create File objects for testing"""
|
||||
return File(
|
||||
tenant_id="test-tenant",
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
|
||||
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
storage_key="test-storage-key",
|
||||
)
|
||||
|
||||
|
||||
class TestSegmentDumpAndLoad:
|
||||
"""Test suite for segment and variable serialization/deserialization"""
|
||||
|
||||
def test_segments(self):
|
||||
"""Test basic segment serialization compatibility"""
|
||||
model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")])
|
||||
json = model.model_dump_json()
|
||||
print("Json: ", json)
|
||||
loaded = _Segments.model_validate_json(json)
|
||||
assert loaded == model
|
||||
|
||||
def test_segment_number(self):
|
||||
"""Test number segment serialization compatibility"""
|
||||
model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)])
|
||||
json = model.model_dump_json()
|
||||
print("Json: ", json)
|
||||
loaded = _Segments.model_validate_json(json)
|
||||
assert loaded == model
|
||||
|
||||
def test_variables(self):
|
||||
"""Test variable serialization compatibility"""
|
||||
model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")])
|
||||
json = model.model_dump_json()
|
||||
print("Json: ", json)
|
||||
restored = _Variables.model_validate_json(json)
|
||||
assert restored == model
|
||||
|
||||
def test_all_segments_serialization(self):
|
||||
"""Test serialization/deserialization of all segment types"""
|
||||
# Create one instance of each segment type
|
||||
test_file = create_test_file()
|
||||
|
||||
all_segments: list[SegmentUnion] = [
|
||||
NoneSegment(),
|
||||
StringSegment(value="test string"),
|
||||
IntegerSegment(value=42),
|
||||
FloatSegment(value=3.14),
|
||||
ObjectSegment(value={"key": "value", "number": 123}),
|
||||
FileSegment(value=test_file),
|
||||
ArrayAnySegment(value=[1, "string", 3.14, {"key": "value"}]),
|
||||
ArrayStringSegment(value=["hello", "world"]),
|
||||
ArrayNumberSegment(value=[1, 2.5, 3]),
|
||||
ArrayObjectSegment(value=[{"id": 1}, {"id": 2}]),
|
||||
ArrayFileSegment(value=[]), # Empty array to avoid file complexity
|
||||
]
|
||||
|
||||
# Test serialization and deserialization
|
||||
model = _Segments(segments=all_segments)
|
||||
json_str = model.model_dump_json()
|
||||
loaded = _Segments.model_validate_json(json_str)
|
||||
|
||||
# Verify all segments are preserved
|
||||
assert len(loaded.segments) == len(all_segments)
|
||||
|
||||
for original, loaded_segment in zip(all_segments, loaded.segments):
|
||||
assert type(loaded_segment) == type(original)
|
||||
assert loaded_segment.value_type == original.value_type
|
||||
|
||||
# For file segments, compare key properties instead of exact equality
|
||||
if isinstance(original, FileSegment) and isinstance(loaded_segment, FileSegment):
|
||||
orig_file = original.value
|
||||
loaded_file = loaded_segment.value
|
||||
assert isinstance(orig_file, File)
|
||||
assert isinstance(loaded_file, File)
|
||||
assert loaded_file.tenant_id == orig_file.tenant_id
|
||||
assert loaded_file.type == orig_file.type
|
||||
assert loaded_file.filename == orig_file.filename
|
||||
else:
|
||||
assert loaded_segment.value == original.value
|
||||
|
||||
def test_all_variables_serialization(self):
|
||||
"""Test serialization/deserialization of all variable types"""
|
||||
# Create one instance of each variable type
|
||||
test_file = create_test_file()
|
||||
|
||||
all_variables: list[VariableUnion] = [
|
||||
NoneVariable(name="none_var"),
|
||||
StringVariable(value="test string", name="string_var"),
|
||||
IntegerVariable(value=42, name="int_var"),
|
||||
FloatVariable(value=3.14, name="float_var"),
|
||||
ObjectVariable(value={"key": "value", "number": 123}, name="object_var"),
|
||||
FileVariable(value=test_file, name="file_var"),
|
||||
ArrayAnyVariable(value=[1, "string", 3.14, {"key": "value"}], name="array_any_var"),
|
||||
ArrayStringVariable(value=["hello", "world"], name="array_string_var"),
|
||||
ArrayNumberVariable(value=[1, 2.5, 3], name="array_number_var"),
|
||||
ArrayObjectVariable(value=[{"id": 1}, {"id": 2}], name="array_object_var"),
|
||||
ArrayFileVariable(value=[], name="array_file_var"), # Empty array to avoid file complexity
|
||||
]
|
||||
|
||||
# Test serialization and deserialization
|
||||
model = _Variables(variables=all_variables)
|
||||
json_str = model.model_dump_json()
|
||||
loaded = _Variables.model_validate_json(json_str)
|
||||
|
||||
# Verify all variables are preserved
|
||||
assert len(loaded.variables) == len(all_variables)
|
||||
|
||||
for original, loaded_variable in zip(all_variables, loaded.variables):
|
||||
assert type(loaded_variable) == type(original)
|
||||
assert loaded_variable.value_type == original.value_type
|
||||
assert loaded_variable.name == original.name
|
||||
|
||||
# For file variables, compare key properties instead of exact equality
|
||||
if isinstance(original, FileVariable) and isinstance(loaded_variable, FileVariable):
|
||||
orig_file = original.value
|
||||
loaded_file = loaded_variable.value
|
||||
assert isinstance(orig_file, File)
|
||||
assert isinstance(loaded_file, File)
|
||||
assert loaded_file.tenant_id == orig_file.tenant_id
|
||||
assert loaded_file.type == orig_file.type
|
||||
assert loaded_file.filename == orig_file.filename
|
||||
else:
|
||||
assert loaded_variable.value == original.value
|
||||
|
||||
def test_segment_discriminator_function_for_segment_types(self):
|
||||
"""Test the segment discriminator function"""
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestCase:
|
||||
segment: Segment
|
||||
expected_segment_type: SegmentType
|
||||
|
||||
file1 = create_test_file()
|
||||
file2 = create_test_file(filename="test2.txt")
|
||||
|
||||
cases = [
|
||||
TestCase(
|
||||
NoneSegment(),
|
||||
SegmentType.NONE,
|
||||
),
|
||||
TestCase(
|
||||
StringSegment(value=""),
|
||||
SegmentType.STRING,
|
||||
),
|
||||
TestCase(
|
||||
FloatSegment(value=0.0),
|
||||
SegmentType.FLOAT,
|
||||
),
|
||||
TestCase(
|
||||
IntegerSegment(value=0),
|
||||
SegmentType.INTEGER,
|
||||
),
|
||||
TestCase(
|
||||
ObjectSegment(value={}),
|
||||
SegmentType.OBJECT,
|
||||
),
|
||||
TestCase(
|
||||
FileSegment(value=file1),
|
||||
SegmentType.FILE,
|
||||
),
|
||||
TestCase(
|
||||
ArrayAnySegment(value=[0, 0.0, ""]),
|
||||
SegmentType.ARRAY_ANY,
|
||||
),
|
||||
TestCase(
|
||||
ArrayStringSegment(value=[""]),
|
||||
SegmentType.ARRAY_STRING,
|
||||
),
|
||||
TestCase(
|
||||
ArrayNumberSegment(value=[0, 0.0]),
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
),
|
||||
TestCase(
|
||||
ArrayObjectSegment(value=[{}]),
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
),
|
||||
TestCase(
|
||||
ArrayFileSegment(value=[file1, file2]),
|
||||
SegmentType.ARRAY_FILE,
|
||||
),
|
||||
]
|
||||
|
||||
for test_case in cases:
|
||||
segment = test_case.segment
|
||||
assert get_segment_discriminator(segment) == test_case.expected_segment_type, (
|
||||
f"get_segment_discriminator failed for type {type(segment)}"
|
||||
)
|
||||
model_dict = segment.model_dump(mode="json")
|
||||
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
|
||||
f"get_segment_discriminator failed for serialized form of type {type(segment)}"
|
||||
)
|
||||
|
||||
def test_variable_discriminator_function_for_variable_types(self):
|
||||
"""Test the variable discriminator function"""
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestCase:
|
||||
variable: Variable
|
||||
expected_segment_type: SegmentType
|
||||
|
||||
file1 = create_test_file()
|
||||
file2 = create_test_file(filename="test2.txt")
|
||||
|
||||
cases = [
|
||||
TestCase(
|
||||
NoneVariable(name="none_var"),
|
||||
SegmentType.NONE,
|
||||
),
|
||||
TestCase(
|
||||
StringVariable(value="test", name="string_var"),
|
||||
SegmentType.STRING,
|
||||
),
|
||||
TestCase(
|
||||
FloatVariable(value=0.0, name="float_var"),
|
||||
SegmentType.FLOAT,
|
||||
),
|
||||
TestCase(
|
||||
IntegerVariable(value=0, name="int_var"),
|
||||
SegmentType.INTEGER,
|
||||
),
|
||||
TestCase(
|
||||
ObjectVariable(value={}, name="object_var"),
|
||||
SegmentType.OBJECT,
|
||||
),
|
||||
TestCase(
|
||||
FileVariable(value=file1, name="file_var"),
|
||||
SegmentType.FILE,
|
||||
),
|
||||
TestCase(
|
||||
SecretVariable(value="secret", name="secret_var"),
|
||||
SegmentType.SECRET,
|
||||
),
|
||||
TestCase(
|
||||
ArrayAnyVariable(value=[0, 0.0, ""], name="array_any_var"),
|
||||
SegmentType.ARRAY_ANY,
|
||||
),
|
||||
TestCase(
|
||||
ArrayStringVariable(value=[""], name="array_string_var"),
|
||||
SegmentType.ARRAY_STRING,
|
||||
),
|
||||
TestCase(
|
||||
ArrayNumberVariable(value=[0, 0.0], name="array_number_var"),
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
),
|
||||
TestCase(
|
||||
ArrayObjectVariable(value=[{}], name="array_object_var"),
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
),
|
||||
TestCase(
|
||||
ArrayFileVariable(value=[file1, file2], name="array_file_var"),
|
||||
SegmentType.ARRAY_FILE,
|
||||
),
|
||||
]
|
||||
|
||||
for test_case in cases:
|
||||
variable = test_case.variable
|
||||
assert get_segment_discriminator(variable) == test_case.expected_segment_type, (
|
||||
f"get_segment_discriminator failed for type {type(variable)}"
|
||||
)
|
||||
model_dict = variable.model_dump(mode="json")
|
||||
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
|
||||
f"get_segment_discriminator failed for serialized form of type {type(variable)}"
|
||||
)
|
||||
|
||||
def test_invlaid_value_for_discriminator(self):
|
||||
# Test invalid cases
|
||||
assert get_segment_discriminator({"value_type": "invalid"}) is None
|
||||
assert get_segment_discriminator({}) is None
|
||||
assert get_segment_discriminator("not_a_dict") is None
|
||||
assert get_segment_discriminator(42) is None
|
||||
assert get_segment_discriminator(object) is None
|
||||
|
||||
60
api/tests/unit_tests/core/variables/test_segment_type.py
Normal file
60
api/tests/unit_tests/core/variables/test_segment_type.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
|
||||
class TestSegmentTypeIsArrayType:
|
||||
"""
|
||||
Test class for SegmentType.is_array_type method.
|
||||
|
||||
Provides comprehensive coverage of all SegmentType values to ensure
|
||||
correct identification of array and non-array types.
|
||||
"""
|
||||
|
||||
def test_is_array_type(self):
|
||||
"""
|
||||
Test that all SegmentType enum values are covered in our test cases.
|
||||
|
||||
Ensures comprehensive coverage by verifying that every SegmentType
|
||||
value is tested for the is_array_type method.
|
||||
"""
|
||||
# Arrange
|
||||
all_segment_types = set(SegmentType)
|
||||
expected_array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
]
|
||||
expected_non_array_types = [
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.STRING,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
SegmentType.GROUP,
|
||||
]
|
||||
|
||||
for seg_type in expected_array_types:
|
||||
assert seg_type.is_array_type()
|
||||
|
||||
for seg_type in expected_non_array_types:
|
||||
assert not seg_type.is_array_type()
|
||||
|
||||
# Act & Assert
|
||||
covered_types = set(expected_array_types) | set(expected_non_array_types)
|
||||
assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests"
|
||||
|
||||
def test_all_enum_values_are_supported(self):
|
||||
"""
|
||||
Test that all enum values are supported and return boolean values.
|
||||
|
||||
Validates that every SegmentType enum value can be processed by
|
||||
is_array_type method and returns a boolean value.
|
||||
"""
|
||||
enum_values: list[SegmentType] = list(SegmentType)
|
||||
for seg_type in enum_values:
|
||||
is_array = seg_type.is_array_type()
|
||||
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"
|
||||
@@ -11,6 +11,7 @@ from core.variables import (
|
||||
SegmentType,
|
||||
StringVariable,
|
||||
)
|
||||
from core.variables.variables import Variable
|
||||
|
||||
|
||||
def test_frozen_variables():
|
||||
@@ -75,7 +76,7 @@ def test_object_variable_to_object():
|
||||
|
||||
|
||||
def test_variable_to_object():
|
||||
var = StringVariable(name="text", value="text")
|
||||
var: Variable = StringVariable(name="text", value="text")
|
||||
assert var.to_object() == "text"
|
||||
var = IntegerVariable(name="integer", value=42)
|
||||
assert var.to_object() == 42
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
import time
|
||||
from decimal import Decimal
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def create_test_graph_runtime_state() -> GraphRuntimeState:
|
||||
"""Factory function to create a GraphRuntimeState with non-empty values for testing."""
|
||||
# Create a variable pool with system variables
|
||||
system_vars = SystemVariable(
|
||||
user_id="test_user_123",
|
||||
app_id="test_app_456",
|
||||
workflow_id="test_workflow_789",
|
||||
workflow_execution_id="test_execution_001",
|
||||
query="test query",
|
||||
conversation_id="test_conv_123",
|
||||
dialogue_count=5,
|
||||
)
|
||||
variable_pool = VariablePool(system_variables=system_vars)
|
||||
|
||||
# Add some variables to the variable pool
|
||||
variable_pool.add(["test_node", "test_var"], "test_value")
|
||||
variable_pool.add(["another_node", "another_var"], 42)
|
||||
|
||||
# Create LLM usage with realistic values
|
||||
llm_usage = LLMUsage(
|
||||
prompt_tokens=150,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.15"),
|
||||
completion_tokens=75,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.15"),
|
||||
total_tokens=225,
|
||||
total_price=Decimal("0.30"),
|
||||
currency="USD",
|
||||
latency=1.25,
|
||||
)
|
||||
|
||||
# Create runtime route state with some node states
|
||||
node_run_state = RuntimeRouteState()
|
||||
node_state = node_run_state.create_node_state("test_node_1")
|
||||
node_run_state.add_route(node_state.id, "target_node_id")
|
||||
|
||||
return GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
total_tokens=100,
|
||||
llm_usage=llm_usage,
|
||||
outputs={
|
||||
"string_output": "test result",
|
||||
"int_output": 42,
|
||||
"float_output": 3.14,
|
||||
"list_output": ["item1", "item2", "item3"],
|
||||
"dict_output": {"key1": "value1", "key2": 123},
|
||||
"nested_dict": {"level1": {"level2": ["nested", "list", 456]}},
|
||||
},
|
||||
node_run_steps=5,
|
||||
node_run_state=node_run_state,
|
||||
)
|
||||
|
||||
|
||||
def test_basic_round_trip_serialization():
|
||||
"""Test basic round-trip serialization ensures GraphRuntimeState values remain unchanged."""
|
||||
# Create a state with non-empty values
|
||||
original_state = create_test_graph_runtime_state()
|
||||
|
||||
# Serialize to JSON and deserialize back
|
||||
json_data = original_state.model_dump_json()
|
||||
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||
|
||||
# Core test: ensure the round-trip preserves all values
|
||||
assert deserialized_state == original_state
|
||||
|
||||
# Serialize to JSON and deserialize back
|
||||
dict_data = original_state.model_dump(mode="python")
|
||||
deserialized_state = GraphRuntimeState.model_validate(dict_data)
|
||||
assert deserialized_state == original_state
|
||||
|
||||
# Serialize to JSON and deserialize back
|
||||
dict_data = original_state.model_dump(mode="json")
|
||||
deserialized_state = GraphRuntimeState.model_validate(dict_data)
|
||||
assert deserialized_state == original_state
|
||||
|
||||
|
||||
def test_outputs_field_round_trip():
|
||||
"""Test the problematic outputs field maintains values through round-trip serialization."""
|
||||
original_state = create_test_graph_runtime_state()
|
||||
|
||||
# Serialize and deserialize
|
||||
json_data = original_state.model_dump_json()
|
||||
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||
|
||||
# Verify the outputs field specifically maintains its values
|
||||
assert deserialized_state.outputs == original_state.outputs
|
||||
assert deserialized_state == original_state
|
||||
|
||||
|
||||
def test_empty_outputs_round_trip():
|
||||
"""Test round-trip serialization with empty outputs field."""
|
||||
variable_pool = VariablePool.empty()
|
||||
original_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
outputs={}, # Empty outputs
|
||||
)
|
||||
|
||||
json_data = original_state.model_dump_json()
|
||||
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||
|
||||
assert deserialized_state == original_state
|
||||
|
||||
|
||||
def test_llm_usage_round_trip():
|
||||
# Create LLM usage with specific decimal values
|
||||
llm_usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.0015"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.15"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.003"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.15"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.30"),
|
||||
currency="USD",
|
||||
latency=2.5,
|
||||
)
|
||||
|
||||
json_data = llm_usage.model_dump_json()
|
||||
deserialized = LLMUsage.model_validate_json(json_data)
|
||||
assert deserialized == llm_usage
|
||||
|
||||
dict_data = llm_usage.model_dump(mode="python")
|
||||
deserialized = LLMUsage.model_validate(dict_data)
|
||||
assert deserialized == llm_usage
|
||||
|
||||
dict_data = llm_usage.model_dump(mode="json")
|
||||
deserialized = LLMUsage.model_validate(dict_data)
|
||||
assert deserialized == llm_usage
|
||||
@@ -0,0 +1,401 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
|
||||
|
||||
_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
|
||||
class TestRouteNodeStateSerialization:
|
||||
"""Test cases for RouteNodeState Pydantic serialization/deserialization."""
|
||||
|
||||
def _test_route_node_state(self):
|
||||
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
|
||||
|
||||
node_run_result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"input_key": "input_value"},
|
||||
outputs={"output_key": "output_value"},
|
||||
)
|
||||
|
||||
node_state = RouteNodeState(
|
||||
node_id="comprehensive_test_node",
|
||||
start_at=_TEST_DATETIME,
|
||||
finished_at=_TEST_DATETIME,
|
||||
status=RouteNodeState.Status.SUCCESS,
|
||||
node_run_result=node_run_result,
|
||||
index=5,
|
||||
paused_at=_TEST_DATETIME,
|
||||
paused_by="user_123",
|
||||
failed_reason="test_reason",
|
||||
)
|
||||
return node_state
|
||||
|
||||
def test_route_node_state_comprehensive_field_validation(self):
|
||||
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
|
||||
node_state = self._test_route_node_state()
|
||||
serialized = node_state.model_dump()
|
||||
|
||||
# Comprehensive validation of all RouteNodeState fields
|
||||
assert serialized["node_id"] == "comprehensive_test_node"
|
||||
assert serialized["status"] == RouteNodeState.Status.SUCCESS
|
||||
assert serialized["start_at"] == _TEST_DATETIME
|
||||
assert serialized["finished_at"] == _TEST_DATETIME
|
||||
assert serialized["paused_at"] == _TEST_DATETIME
|
||||
assert serialized["paused_by"] == "user_123"
|
||||
assert serialized["failed_reason"] == "test_reason"
|
||||
assert serialized["index"] == 5
|
||||
assert "id" in serialized
|
||||
assert isinstance(serialized["id"], str)
|
||||
uuid.UUID(serialized["id"]) # Validate UUID format
|
||||
|
||||
# Validate nested NodeRunResult structure
|
||||
assert serialized["node_run_result"] is not None
|
||||
assert serialized["node_run_result"]["status"] == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert serialized["node_run_result"]["inputs"] == {"input_key": "input_value"}
|
||||
assert serialized["node_run_result"]["outputs"] == {"output_key": "output_value"}
|
||||
|
||||
def test_route_node_state_minimal_required_fields(self):
|
||||
"""Test RouteNodeState with only required fields, focusing on defaults."""
|
||||
node_state = RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME)
|
||||
|
||||
serialized = node_state.model_dump()
|
||||
|
||||
# Focus on required fields and default values (not re-testing all fields)
|
||||
assert serialized["node_id"] == "minimal_node"
|
||||
assert serialized["start_at"] == _TEST_DATETIME
|
||||
assert serialized["status"] == RouteNodeState.Status.RUNNING # Default status
|
||||
assert serialized["index"] == 1 # Default index
|
||||
assert serialized["node_run_result"] is None # Default None
|
||||
json = node_state.model_dump_json()
|
||||
deserialized = RouteNodeState.model_validate_json(json)
|
||||
assert deserialized == node_state
|
||||
|
||||
def test_route_node_state_deserialization_from_dict(self):
|
||||
"""Test RouteNodeState deserialization from dictionary data."""
|
||||
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||
test_id = str(uuid.uuid4())
|
||||
|
||||
dict_data = {
|
||||
"id": test_id,
|
||||
"node_id": "deserialized_node",
|
||||
"start_at": test_datetime,
|
||||
"status": "success",
|
||||
"finished_at": test_datetime,
|
||||
"index": 3,
|
||||
}
|
||||
|
||||
node_state = RouteNodeState.model_validate(dict_data)
|
||||
|
||||
# Focus on deserialization accuracy
|
||||
assert node_state.id == test_id
|
||||
assert node_state.node_id == "deserialized_node"
|
||||
assert node_state.start_at == test_datetime
|
||||
assert node_state.status == RouteNodeState.Status.SUCCESS
|
||||
assert node_state.finished_at == test_datetime
|
||||
assert node_state.index == 3
|
||||
|
||||
def test_route_node_state_round_trip_consistency(self):
|
||||
node_states = (
|
||||
self._test_route_node_state(),
|
||||
RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME),
|
||||
)
|
||||
for node_state in node_states:
|
||||
json = node_state.model_dump_json()
|
||||
deserialized = RouteNodeState.model_validate_json(json)
|
||||
assert deserialized == node_state
|
||||
|
||||
dict_ = node_state.model_dump(mode="python")
|
||||
deserialized = RouteNodeState.model_validate(dict_)
|
||||
assert deserialized == node_state
|
||||
|
||||
dict_ = node_state.model_dump(mode="json")
|
||||
deserialized = RouteNodeState.model_validate(dict_)
|
||||
assert deserialized == node_state
|
||||
|
||||
|
||||
class TestRouteNodeStateEnumSerialization:
|
||||
"""Dedicated tests for RouteNodeState Status enum serialization behavior."""
|
||||
|
||||
def test_status_enum_model_dump_behavior(self):
|
||||
"""Test Status enum serialization in model_dump() returns enum objects."""
|
||||
|
||||
for status_enum in RouteNodeState.Status:
|
||||
node_state = RouteNodeState(node_id="enum_test", start_at=_TEST_DATETIME, status=status_enum)
|
||||
serialized = node_state.model_dump(mode="python")
|
||||
assert serialized["status"] == status_enum
|
||||
serialized = node_state.model_dump(mode="json")
|
||||
assert serialized["status"] == status_enum.value
|
||||
|
||||
def test_status_enum_json_serialization_behavior(self):
|
||||
"""Test Status enum serialization in JSON returns string values."""
|
||||
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
enum_to_string_mapping = {
|
||||
RouteNodeState.Status.RUNNING: "running",
|
||||
RouteNodeState.Status.SUCCESS: "success",
|
||||
RouteNodeState.Status.FAILED: "failed",
|
||||
RouteNodeState.Status.PAUSED: "paused",
|
||||
RouteNodeState.Status.EXCEPTION: "exception",
|
||||
}
|
||||
|
||||
for status_enum, expected_string in enum_to_string_mapping.items():
|
||||
node_state = RouteNodeState(node_id="json_enum_test", start_at=test_datetime, status=status_enum)
|
||||
|
||||
json_data = json.loads(node_state.model_dump_json())
|
||||
assert json_data["status"] == expected_string
|
||||
|
||||
def test_status_enum_deserialization_from_string(self):
|
||||
"""Test Status enum deserialization from string values."""
|
||||
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
string_to_enum_mapping = {
|
||||
"running": RouteNodeState.Status.RUNNING,
|
||||
"success": RouteNodeState.Status.SUCCESS,
|
||||
"failed": RouteNodeState.Status.FAILED,
|
||||
"paused": RouteNodeState.Status.PAUSED,
|
||||
"exception": RouteNodeState.Status.EXCEPTION,
|
||||
}
|
||||
|
||||
for status_string, expected_enum in string_to_enum_mapping.items():
|
||||
dict_data = {
|
||||
"node_id": "enum_deserialize_test",
|
||||
"start_at": test_datetime,
|
||||
"status": status_string,
|
||||
}
|
||||
|
||||
node_state = RouteNodeState.model_validate(dict_data)
|
||||
assert node_state.status == expected_enum
|
||||
|
||||
|
||||
class TestRuntimeRouteStateSerialization:
|
||||
"""Test cases for RuntimeRouteState Pydantic serialization/deserialization."""
|
||||
|
||||
_NODE1_ID = "node_1"
|
||||
_ROUTE_STATE1_ID = str(uuid.uuid4())
|
||||
_NODE2_ID = "node_2"
|
||||
_ROUTE_STATE2_ID = str(uuid.uuid4())
|
||||
_NODE3_ID = "node_3"
|
||||
_ROUTE_STATE3_ID = str(uuid.uuid4())
|
||||
|
||||
def _get_runtime_route_state(self):
|
||||
# Create node states with different configurations
|
||||
node_state_1 = RouteNodeState(
|
||||
id=self._ROUTE_STATE1_ID,
|
||||
node_id=self._NODE1_ID,
|
||||
start_at=_TEST_DATETIME,
|
||||
index=1,
|
||||
)
|
||||
node_state_2 = RouteNodeState(
|
||||
id=self._ROUTE_STATE2_ID,
|
||||
node_id=self._NODE2_ID,
|
||||
start_at=_TEST_DATETIME,
|
||||
status=RouteNodeState.Status.SUCCESS,
|
||||
finished_at=_TEST_DATETIME,
|
||||
index=2,
|
||||
)
|
||||
node_state_3 = RouteNodeState(
|
||||
id=self._ROUTE_STATE3_ID,
|
||||
node_id=self._NODE3_ID,
|
||||
start_at=_TEST_DATETIME,
|
||||
status=RouteNodeState.Status.FAILED,
|
||||
failed_reason="Test failure",
|
||||
index=3,
|
||||
)
|
||||
|
||||
runtime_state = RuntimeRouteState(
|
||||
routes={node_state_1.id: [node_state_2.id, node_state_3.id], node_state_2.id: [node_state_3.id]},
|
||||
node_state_mapping={
|
||||
node_state_1.id: node_state_1,
|
||||
node_state_2.id: node_state_2,
|
||||
node_state_3.id: node_state_3,
|
||||
},
|
||||
)
|
||||
|
||||
return runtime_state
|
||||
|
||||
def test_runtime_route_state_comprehensive_structure_validation(self):
|
||||
"""Test comprehensive RuntimeRouteState serialization with full structure validation."""
|
||||
|
||||
runtime_state = self._get_runtime_route_state()
|
||||
serialized = runtime_state.model_dump()
|
||||
|
||||
# Comprehensive validation of RuntimeRouteState structure
|
||||
assert "routes" in serialized
|
||||
assert "node_state_mapping" in serialized
|
||||
assert isinstance(serialized["routes"], dict)
|
||||
assert isinstance(serialized["node_state_mapping"], dict)
|
||||
|
||||
# Validate routes dictionary structure and content
|
||||
assert len(serialized["routes"]) == 2
|
||||
assert self._ROUTE_STATE1_ID in serialized["routes"]
|
||||
assert self._ROUTE_STATE2_ID in serialized["routes"]
|
||||
assert serialized["routes"][self._ROUTE_STATE1_ID] == [self._ROUTE_STATE2_ID, self._ROUTE_STATE3_ID]
|
||||
assert serialized["routes"][self._ROUTE_STATE2_ID] == [self._ROUTE_STATE3_ID]
|
||||
|
||||
# Validate node_state_mapping dictionary structure and content
|
||||
assert len(serialized["node_state_mapping"]) == 3
|
||||
for state_id in [
|
||||
self._ROUTE_STATE1_ID,
|
||||
self._ROUTE_STATE2_ID,
|
||||
self._ROUTE_STATE3_ID,
|
||||
]:
|
||||
assert state_id in serialized["node_state_mapping"]
|
||||
node_data = serialized["node_state_mapping"][state_id]
|
||||
node_state = runtime_state.node_state_mapping[state_id]
|
||||
assert node_data["node_id"] == node_state.node_id
|
||||
assert node_data["status"] == node_state.status
|
||||
assert node_data["index"] == node_state.index
|
||||
|
||||
def test_runtime_route_state_empty_collections(self):
|
||||
"""Test RuntimeRouteState with empty collections, focusing on default behavior."""
|
||||
runtime_state = RuntimeRouteState()
|
||||
serialized = runtime_state.model_dump()
|
||||
|
||||
# Focus on default empty collection behavior
|
||||
assert serialized["routes"] == {}
|
||||
assert serialized["node_state_mapping"] == {}
|
||||
assert isinstance(serialized["routes"], dict)
|
||||
assert isinstance(serialized["node_state_mapping"], dict)
|
||||
|
||||
def test_runtime_route_state_json_serialization_structure(self):
|
||||
"""Test RuntimeRouteState JSON serialization structure."""
|
||||
node_state = RouteNodeState(node_id="json_node", start_at=_TEST_DATETIME)
|
||||
|
||||
runtime_state = RuntimeRouteState(
|
||||
routes={"source": ["target1", "target2"]}, node_state_mapping={node_state.id: node_state}
|
||||
)
|
||||
|
||||
json_str = runtime_state.model_dump_json()
|
||||
json_data = json.loads(json_str)
|
||||
|
||||
# Focus on JSON structure validation
|
||||
assert isinstance(json_str, str)
|
||||
assert isinstance(json_data, dict)
|
||||
assert "routes" in json_data
|
||||
assert "node_state_mapping" in json_data
|
||||
assert json_data["routes"]["source"] == ["target1", "target2"]
|
||||
assert node_state.id in json_data["node_state_mapping"]
|
||||
|
||||
def test_runtime_route_state_deserialization_from_dict(self):
|
||||
"""Test RuntimeRouteState deserialization from dictionary data."""
|
||||
node_id = str(uuid.uuid4())
|
||||
|
||||
dict_data = {
|
||||
"routes": {"source_node": ["target_node_1", "target_node_2"]},
|
||||
"node_state_mapping": {
|
||||
node_id: {
|
||||
"id": node_id,
|
||||
"node_id": "test_node",
|
||||
"start_at": _TEST_DATETIME,
|
||||
"status": "running",
|
||||
"index": 1,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
runtime_state = RuntimeRouteState.model_validate(dict_data)
|
||||
|
||||
# Focus on deserialization accuracy
|
||||
assert runtime_state.routes == {"source_node": ["target_node_1", "target_node_2"]}
|
||||
assert len(runtime_state.node_state_mapping) == 1
|
||||
assert node_id in runtime_state.node_state_mapping
|
||||
|
||||
deserialized_node = runtime_state.node_state_mapping[node_id]
|
||||
assert deserialized_node.node_id == "test_node"
|
||||
assert deserialized_node.status == RouteNodeState.Status.RUNNING
|
||||
assert deserialized_node.index == 1
|
||||
|
||||
def test_runtime_route_state_round_trip_consistency(self):
|
||||
"""Test RuntimeRouteState round-trip serialization consistency."""
|
||||
original = self._get_runtime_route_state()
|
||||
|
||||
# Dictionary round trip
|
||||
dict_data = original.model_dump(mode="python")
|
||||
reconstructed = RuntimeRouteState.model_validate(dict_data)
|
||||
assert reconstructed == original
|
||||
|
||||
dict_data = original.model_dump(mode="json")
|
||||
reconstructed = RuntimeRouteState.model_validate(dict_data)
|
||||
assert reconstructed == original
|
||||
|
||||
# JSON round trip
|
||||
json_str = original.model_dump_json()
|
||||
json_reconstructed = RuntimeRouteState.model_validate_json(json_str)
|
||||
assert json_reconstructed == original
|
||||
|
||||
|
||||
class TestSerializationEdgeCases:
|
||||
"""Test edge cases and error conditions for serialization/deserialization."""
|
||||
|
||||
def test_invalid_status_deserialization(self):
|
||||
"""Test deserialization with invalid status values."""
|
||||
test_datetime = _TEST_DATETIME
|
||||
invalid_data = {
|
||||
"node_id": "invalid_test",
|
||||
"start_at": test_datetime,
|
||||
"status": "invalid_status",
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RouteNodeState.model_validate(invalid_data)
|
||||
assert "status" in str(exc_info.value)
|
||||
|
||||
def test_missing_required_fields_deserialization(self):
|
||||
"""Test deserialization with missing required fields."""
|
||||
incomplete_data = {"id": str(uuid.uuid4())}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RouteNodeState.model_validate(incomplete_data)
|
||||
error_str = str(exc_info.value)
|
||||
assert "node_id" in error_str or "start_at" in error_str
|
||||
|
||||
def test_invalid_datetime_deserialization(self):
|
||||
"""Test deserialization with invalid datetime values."""
|
||||
invalid_data = {
|
||||
"node_id": "datetime_test",
|
||||
"start_at": "invalid_datetime",
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RouteNodeState.model_validate(invalid_data)
|
||||
assert "start_at" in str(exc_info.value)
|
||||
|
||||
def test_invalid_routes_structure_deserialization(self):
|
||||
"""Test RuntimeRouteState deserialization with invalid routes structure."""
|
||||
invalid_data = {
|
||||
"routes": "invalid_routes_structure", # Should be dict
|
||||
"node_state_mapping": {},
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RuntimeRouteState.model_validate(invalid_data)
|
||||
assert "routes" in str(exc_info.value)
|
||||
|
||||
def test_timezone_handling_in_datetime_fields(self):
|
||||
"""Test timezone handling in datetime field serialization."""
|
||||
utc_datetime = datetime.now(UTC)
|
||||
naive_datetime = utc_datetime.replace(tzinfo=None)
|
||||
|
||||
node_state = RouteNodeState(node_id="timezone_test", start_at=naive_datetime)
|
||||
dict_ = node_state.model_dump()
|
||||
|
||||
assert dict_["start_at"] == naive_datetime
|
||||
|
||||
# Test round trip
|
||||
reconstructed = RouteNodeState.model_validate(dict_)
|
||||
assert reconstructed.start_at == naive_datetime
|
||||
assert reconstructed.start_at.tzinfo is None
|
||||
|
||||
json = node_state.model_dump_json()
|
||||
|
||||
reconstructed = RouteNodeState.model_validate_json(json)
|
||||
assert reconstructed.start_at == naive_datetime
|
||||
assert reconstructed.start_at.tzinfo is None
|
||||
@@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseNodeEvent,
|
||||
GraphRunFailedEvent,
|
||||
@@ -27,6 +26,7 @@ from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
@@ -171,7 +171,8 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
|
||||
system_variables=SystemVariable(user_id="aaa", app_id="1", workflow_id="1", files=[]),
|
||||
user_inputs={"query": "hi"},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
@@ -293,12 +294,12 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "what's the weather in SF",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
system_variables=SystemVariable(
|
||||
user_id="aaa",
|
||||
files=[],
|
||||
query="what's the weather in SF",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
@@ -474,12 +475,12 @@ def test_run_branch(mock_close, mock_remove):
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "hi",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
system_variables=SystemVariable(
|
||||
user_id="aaa",
|
||||
files=[],
|
||||
query="hi",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={"uid": "takato"},
|
||||
)
|
||||
|
||||
@@ -804,18 +805,22 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "dify",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "1",
|
||||
},
|
||||
system_variables=SystemVariable(
|
||||
user_id="1",
|
||||
files=[],
|
||||
query="dify",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
|
||||
system_variables=SystemVariable(
|
||||
user_id="aaa",
|
||||
files=[],
|
||||
),
|
||||
user_inputs={"query": "hi"},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
@@ -5,11 +5,11 @@ from unittest.mock import MagicMock
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
@@ -51,7 +51,7 @@ def test_execute_answer():
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
||||
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
@@ -15,6 +14,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
|
||||
@@ -180,12 +180,12 @@ def test_process():
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "what's the weather in SF",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
system_variables=SystemVariable(
|
||||
user_id="aaa",
|
||||
files=[],
|
||||
query="what's the weather in SF",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
|
||||
@@ -7,12 +7,13 @@ from core.workflow.nodes.http_request import (
|
||||
)
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_number_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "number"], 42)
|
||||
@@ -65,7 +66,7 @@ def test_executor_with_json_body_and_number_variable():
|
||||
def test_executor_with_json_body_and_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
@@ -120,7 +121,7 @@ def test_executor_with_json_body_and_object_variable():
|
||||
def test_executor_with_json_body_and_nested_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
@@ -174,7 +175,7 @@ def test_executor_with_json_body_and_nested_object_variable():
|
||||
|
||||
|
||||
def test_extract_selectors_from_template_with_newline():
|
||||
variable_pool = VariablePool()
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Nested Object Variable",
|
||||
@@ -201,7 +202,7 @@ def test_extract_selectors_from_template_with_newline():
|
||||
def test_executor_with_form_data():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "text_field"], "Hello, World!")
|
||||
@@ -280,7 +281,11 @@ def test_init_headers():
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
|
||||
return Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.empty()),
|
||||
)
|
||||
|
||||
executor = create_executor("aa\n cc:")
|
||||
executor._init_headers()
|
||||
@@ -310,7 +315,11 @@ def test_init_params():
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
)
|
||||
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||
return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
|
||||
return Executor(
|
||||
node_data=node_data,
|
||||
timeout=timeout,
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.empty()),
|
||||
)
|
||||
|
||||
# Test basic key-value pairs
|
||||
executor = create_executor("key1:value1\nkey2:value2")
|
||||
|
||||
@@ -15,6 +15,7 @@ from core.workflow.nodes.http_request import (
|
||||
HttpRequestNodeBody,
|
||||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
@@ -40,7 +41,7 @@ def test_http_request_node_binary_file(monkeypatch):
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(
|
||||
@@ -128,7 +129,7 @@ def test_http_request_node_form_with_file(monkeypatch):
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(
|
||||
@@ -223,7 +224,7 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from core.variables.segments import ArrayAnySegment, ArrayStringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
@@ -15,6 +14,7 @@ from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
@@ -151,12 +151,12 @@ def test_run():
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "dify",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "1",
|
||||
},
|
||||
system_variables=SystemVariable(
|
||||
user_id="1",
|
||||
files=[],
|
||||
query="dify",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
@@ -368,12 +368,12 @@ def test_run_parallel():
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "dify",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "1",
|
||||
},
|
||||
system_variables=SystemVariable(
|
||||
user_id="1",
|
||||
files=[],
|
||||
query="dify",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
@@ -584,12 +584,12 @@ def test_iteration_run_in_parallel_mode():
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "dify",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "1",
|
||||
},
|
||||
system_variables=SystemVariable(
|
||||
user_id="1",
|
||||
files=[],
|
||||
query="dify",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
@@ -808,12 +808,12 @@ def test_iteration_run_error_handle():
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "dify",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "1",
|
||||
},
|
||||
system_variables=SystemVariable(
|
||||
user_id="1",
|
||||
files=[],
|
||||
query="dify",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
@@ -36,6 +36,7 @@ from core.workflow.nodes.llm.entities import (
|
||||
)
|
||||
from core.workflow.nodes.llm.file_saver import LLMFileSaver
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.provider import ProviderType
|
||||
from models.workflow import WorkflowType
|
||||
@@ -104,7 +105,7 @@ def graph() -> Graph:
|
||||
@pytest.fixture
|
||||
def graph_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
return GraphRuntimeState(
|
||||
@@ -181,7 +182,7 @@ def test_fetch_files_with_file_segment():
|
||||
related_id="1",
|
||||
storage_key="",
|
||||
)
|
||||
variable_pool = VariablePool()
|
||||
variable_pool = VariablePool.empty()
|
||||
variable_pool.add(["sys", "files"], file)
|
||||
|
||||
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||
@@ -209,7 +210,7 @@ def test_fetch_files_with_array_file_segment():
|
||||
storage_key="",
|
||||
),
|
||||
]
|
||||
variable_pool = VariablePool()
|
||||
variable_pool = VariablePool.empty()
|
||||
variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||
|
||||
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||
@@ -217,7 +218,7 @@ def test_fetch_files_with_array_file_segment():
|
||||
|
||||
|
||||
def test_fetch_files_with_none_segment():
|
||||
variable_pool = VariablePool()
|
||||
variable_pool = VariablePool.empty()
|
||||
variable_pool.add(["sys", "files"], NoneSegment())
|
||||
|
||||
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||
@@ -225,7 +226,7 @@ def test_fetch_files_with_none_segment():
|
||||
|
||||
|
||||
def test_fetch_files_with_array_any_segment():
|
||||
variable_pool = VariablePool()
|
||||
variable_pool = VariablePool.empty()
|
||||
variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
||||
|
||||
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||
@@ -233,7 +234,7 @@ def test_fetch_files_with_array_any_segment():
|
||||
|
||||
|
||||
def test_fetch_files_with_non_existent_variable():
|
||||
variable_pool = VariablePool()
|
||||
variable_pool = VariablePool.empty()
|
||||
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
|
||||
assert result == []
|
||||
|
||||
|
||||
@@ -5,11 +5,11 @@ from unittest.mock import MagicMock
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
@@ -53,7 +53,7 @@ def test_execute_answer():
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
||||
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
|
||||
@@ -5,7 +5,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphRunPartialSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
@@ -17,6 +16,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
@@ -167,12 +167,12 @@ class ContinueOnErrorTestHelper:
|
||||
"""Helper method to create a graph engine instance for testing"""
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "clear",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
system_variables=SystemVariable(
|
||||
user_id="aaa",
|
||||
files=[],
|
||||
query="clear",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs=user_inputs or {"uid": "takato"},
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
@@ -7,12 +7,12 @@ from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
@@ -37,9 +37,7 @@ def test_execute_if_else_result_true():
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={}
|
||||
)
|
||||
pool = VariablePool(system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={})
|
||||
pool.add(["start", "array_contains"], ["ab", "def"])
|
||||
pool.add(["start", "array_not_contains"], ["ac", "def"])
|
||||
pool.add(["start", "contains"], "cabcde")
|
||||
@@ -157,7 +155,7 @@ def test_execute_if_else_result_false():
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
||||
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from core.workflow.nodes.enums import ErrorStrategy
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models import UserFrom, WorkflowType
|
||||
|
||||
|
||||
@@ -34,7 +35,7 @@ def _create_tool_node():
|
||||
version="1",
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
node = ToolNode(
|
||||
|
||||
@@ -7,12 +7,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import ArrayStringVariable, StringVariable
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
@@ -68,7 +68,7 @@ def test_overwrite_string_variable():
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
|
||||
system_variables=SystemVariable(conversation_id=conversation_id),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
@@ -165,7 +165,7 @@ def test_append_variable_to_array():
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
|
||||
system_variables=SystemVariable(conversation_id=conversation_id),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
@@ -256,7 +256,7 @@ def test_clear_array():
|
||||
|
||||
conversation_id = str(uuid.uuid4())
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
|
||||
system_variables=SystemVariable(conversation_id=conversation_id),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
|
||||
@@ -5,12 +5,12 @@ from uuid import uuid4
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import ArrayStringVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
@@ -109,7 +109,7 @@ def test_remove_first_from_array():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||
system_variables=SystemVariable(conversation_id="conversation_id"),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
@@ -196,7 +196,7 @@ def test_remove_last_from_array():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||
system_variables=SystemVariable(conversation_id="conversation_id"),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
@@ -275,7 +275,7 @@ def test_remove_first_from_empty_array():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||
system_variables=SystemVariable(conversation_id="conversation_id"),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
@@ -354,7 +354,7 @@ def test_remove_last_from_empty_array():
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||
system_variables=SystemVariable(conversation_id="conversation_id"),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
|
||||
251
api/tests/unit_tests/core/workflow/test_system_variable.py
Normal file
251
api/tests/unit_tests/core/workflow/test_system_variable.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
# Test data constants for SystemVariable serialization tests
|
||||
VALID_BASE_DATA: dict[str, Any] = {
|
||||
"user_id": "a20f06b1-8703-45ab-937c-860a60072113",
|
||||
"app_id": "661bed75-458d-49c9-b487-fda0762677b9",
|
||||
"workflow_id": "d31f2136-b292-4ae0-96d4-1e77894a4f43",
|
||||
}
|
||||
|
||||
COMPLETE_VALID_DATA: dict[str, Any] = {
|
||||
**VALID_BASE_DATA,
|
||||
"query": "test query",
|
||||
"files": [],
|
||||
"conversation_id": "91f1eb7d-69f4-4d7b-b82f-4003d51744b9",
|
||||
"dialogue_count": 5,
|
||||
"workflow_run_id": "eb4704b5-2274-47f2-bfcd-0452daa82cb5",
|
||||
}
|
||||
|
||||
|
||||
def create_test_file() -> File:
|
||||
"""Create a test File object for serialization tests."""
|
||||
return File(
|
||||
tenant_id="test-tenant-id",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="test-file-id",
|
||||
filename="test.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=1024,
|
||||
storage_key="test-storage-key",
|
||||
)
|
||||
|
||||
|
||||
class TestSystemVariableSerialization:
|
||||
"""Focused tests for SystemVariable serialization/deserialization logic."""
|
||||
|
||||
def test_basic_deserialization(self):
|
||||
"""Test successful deserialization from JSON structure with all fields correctly mapped."""
|
||||
# Test with complete data
|
||||
system_var = SystemVariable(**COMPLETE_VALID_DATA)
|
||||
|
||||
# Verify all fields are correctly mapped
|
||||
assert system_var.user_id == COMPLETE_VALID_DATA["user_id"]
|
||||
assert system_var.app_id == COMPLETE_VALID_DATA["app_id"]
|
||||
assert system_var.workflow_id == COMPLETE_VALID_DATA["workflow_id"]
|
||||
assert system_var.query == COMPLETE_VALID_DATA["query"]
|
||||
assert system_var.conversation_id == COMPLETE_VALID_DATA["conversation_id"]
|
||||
assert system_var.dialogue_count == COMPLETE_VALID_DATA["dialogue_count"]
|
||||
assert system_var.workflow_execution_id == COMPLETE_VALID_DATA["workflow_run_id"]
|
||||
assert system_var.files == []
|
||||
|
||||
# Test with minimal data (only required fields)
|
||||
minimal_var = SystemVariable(**VALID_BASE_DATA)
|
||||
assert minimal_var.user_id == VALID_BASE_DATA["user_id"]
|
||||
assert minimal_var.app_id == VALID_BASE_DATA["app_id"]
|
||||
assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"]
|
||||
assert minimal_var.query is None
|
||||
assert minimal_var.conversation_id is None
|
||||
assert minimal_var.dialogue_count is None
|
||||
assert minimal_var.workflow_execution_id is None
|
||||
assert minimal_var.files == []
|
||||
|
||||
def test_alias_handling(self):
|
||||
"""Test workflow_execution_id vs workflow_run_id alias resolution - core deserialization logic."""
|
||||
workflow_id = "eb4704b5-2274-47f2-bfcd-0452daa82cb5"
|
||||
|
||||
# Test workflow_run_id only (preferred alias)
|
||||
data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
|
||||
system_var1 = SystemVariable(**data_run_id)
|
||||
assert system_var1.workflow_execution_id == workflow_id
|
||||
|
||||
# Test workflow_execution_id only (direct field name)
|
||||
data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
|
||||
system_var2 = SystemVariable(**data_execution_id)
|
||||
assert system_var2.workflow_execution_id == workflow_id
|
||||
|
||||
# Test both present - workflow_run_id should take precedence
|
||||
data_both = {
|
||||
**VALID_BASE_DATA,
|
||||
"workflow_execution_id": "should-be-ignored",
|
||||
"workflow_run_id": workflow_id,
|
||||
}
|
||||
system_var3 = SystemVariable(**data_both)
|
||||
assert system_var3.workflow_execution_id == workflow_id
|
||||
|
||||
# Test neither present - should be None
|
||||
system_var4 = SystemVariable(**VALID_BASE_DATA)
|
||||
assert system_var4.workflow_execution_id is None
|
||||
|
||||
def test_serialization_round_trip(self):
|
||||
"""Test that serialize → deserialize produces the same result with alias handling."""
|
||||
# Create original SystemVariable
|
||||
original = SystemVariable(**COMPLETE_VALID_DATA)
|
||||
|
||||
# Serialize to dict
|
||||
serialized = original.model_dump(mode="json")
|
||||
|
||||
# Verify alias is used in serialization (workflow_run_id, not workflow_execution_id)
|
||||
assert "workflow_run_id" in serialized
|
||||
assert "workflow_execution_id" not in serialized
|
||||
assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
|
||||
|
||||
# Deserialize back
|
||||
deserialized = SystemVariable(**serialized)
|
||||
|
||||
# Verify all fields match after round-trip
|
||||
assert deserialized.user_id == original.user_id
|
||||
assert deserialized.app_id == original.app_id
|
||||
assert deserialized.workflow_id == original.workflow_id
|
||||
assert deserialized.query == original.query
|
||||
assert deserialized.conversation_id == original.conversation_id
|
||||
assert deserialized.dialogue_count == original.dialogue_count
|
||||
assert deserialized.workflow_execution_id == original.workflow_execution_id
|
||||
assert list(deserialized.files) == list(original.files)
|
||||
|
||||
def test_json_round_trip(self):
|
||||
"""Test JSON serialization/deserialization consistency with proper structure."""
|
||||
# Create original SystemVariable
|
||||
original = SystemVariable(**COMPLETE_VALID_DATA)
|
||||
|
||||
# Serialize to JSON string
|
||||
json_str = original.model_dump_json()
|
||||
|
||||
# Parse JSON and verify structure
|
||||
json_data = json.loads(json_str)
|
||||
assert "workflow_run_id" in json_data
|
||||
assert "workflow_execution_id" not in json_data
|
||||
assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
|
||||
|
||||
# Deserialize from JSON data
|
||||
deserialized = SystemVariable(**json_data)
|
||||
|
||||
# Verify key fields match after JSON round-trip
|
||||
assert deserialized.workflow_execution_id == original.workflow_execution_id
|
||||
assert deserialized.user_id == original.user_id
|
||||
assert deserialized.app_id == original.app_id
|
||||
assert deserialized.workflow_id == original.workflow_id
|
||||
|
||||
def test_files_field_deserialization(self):
|
||||
"""Test deserialization with File objects in the files field - SystemVariable specific logic."""
|
||||
# Test with empty files list
|
||||
data_empty = {**VALID_BASE_DATA, "files": []}
|
||||
system_var_empty = SystemVariable(**data_empty)
|
||||
assert system_var_empty.files == []
|
||||
|
||||
# Test with single File object
|
||||
test_file = create_test_file()
|
||||
data_single = {**VALID_BASE_DATA, "files": [test_file]}
|
||||
system_var_single = SystemVariable(**data_single)
|
||||
assert len(system_var_single.files) == 1
|
||||
assert system_var_single.files[0].filename == "test.txt"
|
||||
assert system_var_single.files[0].tenant_id == "test-tenant-id"
|
||||
|
||||
# Test with multiple File objects
|
||||
file1 = File(
|
||||
tenant_id="tenant1",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="file1",
|
||||
filename="doc1.txt",
|
||||
storage_key="key1",
|
||||
)
|
||||
file2 = File(
|
||||
tenant_id="tenant2",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
filename="image.jpg",
|
||||
storage_key="key2",
|
||||
)
|
||||
|
||||
data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]}
|
||||
system_var_multiple = SystemVariable(**data_multiple)
|
||||
assert len(system_var_multiple.files) == 2
|
||||
assert system_var_multiple.files[0].filename == "doc1.txt"
|
||||
assert system_var_multiple.files[1].filename == "image.jpg"
|
||||
|
||||
# Verify files field serialization/deserialization
|
||||
serialized = system_var_multiple.model_dump(mode="json")
|
||||
deserialized = SystemVariable(**serialized)
|
||||
assert len(deserialized.files) == 2
|
||||
assert deserialized.files[0].filename == "doc1.txt"
|
||||
assert deserialized.files[1].filename == "image.jpg"
|
||||
|
||||
def test_alias_serialization_consistency(self):
|
||||
"""Test that alias handling works consistently in both serialization directions."""
|
||||
workflow_id = "test-workflow-id"
|
||||
|
||||
# Create with workflow_run_id (alias)
|
||||
data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
|
||||
system_var = SystemVariable(**data_with_alias)
|
||||
|
||||
# Serialize and verify alias is used
|
||||
serialized = system_var.model_dump()
|
||||
assert serialized["workflow_run_id"] == workflow_id
|
||||
assert "workflow_execution_id" not in serialized
|
||||
|
||||
# Deserialize and verify field mapping
|
||||
deserialized = SystemVariable(**serialized)
|
||||
assert deserialized.workflow_execution_id == workflow_id
|
||||
|
||||
# Test JSON serialization path
|
||||
json_serialized = json.loads(system_var.model_dump_json())
|
||||
assert json_serialized["workflow_run_id"] == workflow_id
|
||||
assert "workflow_execution_id" not in json_serialized
|
||||
|
||||
json_deserialized = SystemVariable(**json_serialized)
|
||||
assert json_deserialized.workflow_execution_id == workflow_id
|
||||
|
||||
def test_model_validator_serialization_logic(self):
|
||||
"""Test the custom model validator behavior for serialization scenarios."""
|
||||
workflow_id = "test-workflow-execution-id"
|
||||
|
||||
# Test direct instantiation with workflow_execution_id (should work)
|
||||
data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
|
||||
system_var1 = SystemVariable(**data1)
|
||||
assert system_var1.workflow_execution_id == workflow_id
|
||||
|
||||
# Test serialization of the above (should use alias)
|
||||
serialized1 = system_var1.model_dump()
|
||||
assert "workflow_run_id" in serialized1
|
||||
assert serialized1["workflow_run_id"] == workflow_id
|
||||
|
||||
# Test both present - workflow_run_id takes precedence (validator logic)
|
||||
data2 = {
|
||||
**VALID_BASE_DATA,
|
||||
"workflow_execution_id": "should-be-removed",
|
||||
"workflow_run_id": workflow_id,
|
||||
}
|
||||
system_var2 = SystemVariable(**data2)
|
||||
assert system_var2.workflow_execution_id == workflow_id
|
||||
|
||||
# Verify serialization consistency
|
||||
serialized2 = system_var2.model_dump()
|
||||
assert serialized2["workflow_run_id"] == workflow_id
|
||||
|
||||
|
||||
def test_constructor_with_extra_key():
|
||||
# Test that SystemVariable should forbid extra keys
|
||||
with pytest.raises(ValidationError):
|
||||
# This should fail because there is an unexpected key.
|
||||
SystemVariable(invalid_key=1) # type: ignore
|
||||
@@ -1,17 +1,43 @@
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import FileSegment, StringSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
)
|
||||
from core.variables.variables import (
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
StringVariable,
|
||||
VariableUnion,
|
||||
)
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from factories.variable_factory import build_segment, segment_to_variable
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pool():
|
||||
return VariablePool(system_variables={}, user_inputs={})
|
||||
return VariablePool(
|
||||
system_variables=SystemVariable(user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -52,18 +78,28 @@ def test_use_long_selector(pool):
|
||||
|
||||
class TestVariablePool:
|
||||
def test_constructor(self):
|
||||
pool = VariablePool()
|
||||
# Test with minimal required SystemVariable
|
||||
minimal_system_vars = SystemVariable(
|
||||
user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"
|
||||
)
|
||||
pool = VariablePool(system_variables=minimal_system_vars)
|
||||
|
||||
# Test with all parameters
|
||||
pool = VariablePool(
|
||||
variable_dictionary={},
|
||||
user_inputs={},
|
||||
system_variables={},
|
||||
system_variables=minimal_system_vars,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Test with more complex SystemVariable
|
||||
complex_system_vars = SystemVariable(
|
||||
user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"
|
||||
)
|
||||
pool = VariablePool(
|
||||
user_inputs={"key": "value"},
|
||||
system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"},
|
||||
system_variables=complex_system_vars,
|
||||
environment_variables=[
|
||||
segment_to_variable(
|
||||
segment=build_segment(1),
|
||||
@@ -80,6 +116,323 @@ class TestVariablePool:
|
||||
],
|
||||
)
|
||||
|
||||
def test_constructor_with_invalid_system_variable_key(self):
|
||||
with pytest.raises(ValidationError):
|
||||
VariablePool(system_variables={"invalid_key": "value"}) # type: ignore
|
||||
def test_get_system_variables(self):
|
||||
sys_var = SystemVariable(
|
||||
user_id="test_user_id",
|
||||
app_id="test_app_id",
|
||||
workflow_id="test_workflow_id",
|
||||
workflow_execution_id="test_execution_123",
|
||||
query="test query",
|
||||
conversation_id="test_conv_id",
|
||||
dialogue_count=5,
|
||||
)
|
||||
pool = VariablePool(system_variables=sys_var)
|
||||
|
||||
kv = [
|
||||
("user_id", sys_var.user_id),
|
||||
("app_id", sys_var.app_id),
|
||||
("workflow_id", sys_var.workflow_id),
|
||||
("workflow_run_id", sys_var.workflow_execution_id),
|
||||
("query", sys_var.query),
|
||||
("conversation_id", sys_var.conversation_id),
|
||||
("dialogue_count", sys_var.dialogue_count),
|
||||
]
|
||||
for key, expected_value in kv:
|
||||
segment = pool.get([SYSTEM_VARIABLE_NODE_ID, key])
|
||||
assert segment is not None
|
||||
assert segment.value == expected_value
|
||||
|
||||
|
||||
class TestVariablePoolSerialization:
|
||||
"""Test cases for VariablePool serialization and deserialization using Pydantic's built-in methods.
|
||||
|
||||
These tests focus exclusively on serialization/deserialization logic to ensure that
|
||||
VariablePool data can be properly serialized to dictionaries/JSON and reconstructed
|
||||
while preserving all data integrity.
|
||||
"""
|
||||
|
||||
_NODE1_ID = "node_1"
|
||||
_NODE2_ID = "node_2"
|
||||
_NODE3_ID = "node_3"
|
||||
|
||||
def _create_pool_without_file(self):
|
||||
# Create comprehensive system variables
|
||||
system_vars = SystemVariable(
|
||||
user_id="test_user_id",
|
||||
app_id="test_app_id",
|
||||
workflow_id="test_workflow_id",
|
||||
workflow_execution_id="test_execution_123",
|
||||
query="test query",
|
||||
conversation_id="test_conv_id",
|
||||
dialogue_count=5,
|
||||
)
|
||||
|
||||
# Create environment variables with all types including ArrayFileVariable
|
||||
env_vars: list[VariableUnion] = [
|
||||
StringVariable(
|
||||
id="env_string_id",
|
||||
name="env_string",
|
||||
value="env_string_value",
|
||||
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_string"],
|
||||
),
|
||||
IntegerVariable(
|
||||
id="env_integer_id",
|
||||
name="env_integer",
|
||||
value=1,
|
||||
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_integer"],
|
||||
),
|
||||
FloatVariable(
|
||||
id="env_float_id",
|
||||
name="env_float",
|
||||
value=1.0,
|
||||
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_float"],
|
||||
),
|
||||
]
|
||||
|
||||
# Create conversation variables with complex data
|
||||
conv_vars: list[VariableUnion] = [
|
||||
StringVariable(
|
||||
id="conv_string_id",
|
||||
name="conv_string",
|
||||
value="conv_string_value",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_string"],
|
||||
),
|
||||
IntegerVariable(
|
||||
id="conv_integer_id",
|
||||
name="conv_integer",
|
||||
value=1,
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_integer"],
|
||||
),
|
||||
FloatVariable(
|
||||
id="conv_float_id",
|
||||
name="conv_float",
|
||||
value=1.0,
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_float"],
|
||||
),
|
||||
ObjectVariable(
|
||||
id="conv_object_id",
|
||||
name="conv_object",
|
||||
value={"key": "value", "nested": {"data": 123}},
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_object"],
|
||||
),
|
||||
ArrayStringVariable(
|
||||
id="conv_array_string_id",
|
||||
name="conv_array_string",
|
||||
value=["conv_array_string_value"],
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_string"],
|
||||
),
|
||||
ArrayNumberVariable(
|
||||
id="conv_array_number_id",
|
||||
name="conv_array_number",
|
||||
value=[1, 1.0],
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_number"],
|
||||
),
|
||||
ArrayObjectVariable(
|
||||
id="conv_array_object_id",
|
||||
name="conv_array_object",
|
||||
value=[{"a": 1}, {"b": "2"}],
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_object"],
|
||||
),
|
||||
]
|
||||
|
||||
# Create comprehensive user inputs
|
||||
user_inputs = {
|
||||
"string_input": "test_value",
|
||||
"number_input": 42,
|
||||
"object_input": {"nested": {"key": "value"}},
|
||||
"array_input": ["item1", "item2", "item3"],
|
||||
}
|
||||
|
||||
# Create VariablePool
|
||||
pool = VariablePool(
|
||||
system_variables=system_vars,
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=env_vars,
|
||||
conversation_variables=conv_vars,
|
||||
)
|
||||
return pool
|
||||
|
||||
def _add_node_data_to_pool(self, pool: VariablePool, with_file=False):
|
||||
test_file = File(
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="test_related_id",
|
||||
remote_url="test_url",
|
||||
filename="test_file.txt",
|
||||
storage_key="test_storage_key",
|
||||
)
|
||||
|
||||
# Add various segment types to variable dictionary
|
||||
pool.add((self._NODE1_ID, "string_var"), StringSegment(value="test_string"))
|
||||
pool.add((self._NODE1_ID, "int_var"), IntegerSegment(value=123))
|
||||
pool.add((self._NODE1_ID, "float_var"), FloatSegment(value=45.67))
|
||||
pool.add((self._NODE1_ID, "object_var"), ObjectSegment(value={"test": "data"}))
|
||||
if with_file:
|
||||
pool.add((self._NODE1_ID, "file_var"), FileSegment(value=test_file))
|
||||
pool.add((self._NODE1_ID, "none_var"), NoneSegment())
|
||||
|
||||
# Add array segments including ArrayFileVariable
|
||||
pool.add((self._NODE2_ID, "array_string"), ArrayStringSegment(value=["a", "b", "c"]))
|
||||
pool.add((self._NODE2_ID, "array_number"), ArrayNumberSegment(value=[1, 2, 3]))
|
||||
pool.add((self._NODE2_ID, "array_object"), ArrayObjectSegment(value=[{"a": 1}, {"b": 2}]))
|
||||
if with_file:
|
||||
pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
|
||||
pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
|
||||
|
||||
# Add nested variables
|
||||
pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value"))
|
||||
|
||||
def test_system_variables(self):
|
||||
sys_vars = SystemVariable(
|
||||
user_id="test_user_id",
|
||||
app_id="test_app_id",
|
||||
workflow_id="test_workflow_id",
|
||||
workflow_execution_id="test_execution_123",
|
||||
query="test query",
|
||||
conversation_id="test_conv_id",
|
||||
dialogue_count=5,
|
||||
)
|
||||
pool = VariablePool(system_variables=sys_vars)
|
||||
json = pool.model_dump_json()
|
||||
pool2 = VariablePool.model_validate_json(json)
|
||||
assert pool2.system_variables == sys_vars
|
||||
|
||||
for mode in ["json", "python"]:
|
||||
dict_ = pool.model_dump(mode=mode)
|
||||
pool2 = VariablePool.model_validate(dict_)
|
||||
assert pool2.system_variables == sys_vars
|
||||
|
||||
def test_pool_without_file_vars(self):
|
||||
pool = self._create_pool_without_file()
|
||||
json = pool.model_dump_json()
|
||||
pool2 = pool.model_validate_json(json)
|
||||
assert pool2.system_variables == pool.system_variables
|
||||
assert pool2.conversation_variables == pool.conversation_variables
|
||||
assert pool2.environment_variables == pool.environment_variables
|
||||
assert pool2.user_inputs == pool.user_inputs
|
||||
assert pool2.variable_dictionary == pool.variable_dictionary
|
||||
assert pool2 == pool
|
||||
|
||||
def test_basic_dictionary_round_trip(self):
|
||||
"""Test basic round-trip serialization: model_dump() → model_validate()"""
|
||||
# Create a comprehensive VariablePool with all data types
|
||||
original_pool = self._create_pool_without_file()
|
||||
self._add_node_data_to_pool(original_pool)
|
||||
|
||||
# Serialize to dictionary using Pydantic's model_dump()
|
||||
serialized_data = original_pool.model_dump()
|
||||
|
||||
# Verify serialized data structure
|
||||
assert isinstance(serialized_data, dict)
|
||||
assert "system_variables" in serialized_data
|
||||
assert "user_inputs" in serialized_data
|
||||
assert "environment_variables" in serialized_data
|
||||
assert "conversation_variables" in serialized_data
|
||||
assert "variable_dictionary" in serialized_data
|
||||
|
||||
# Deserialize back using Pydantic's model_validate()
|
||||
reconstructed_pool = VariablePool.model_validate(serialized_data)
|
||||
|
||||
# Verify data integrity is preserved
|
||||
self._assert_pools_equal(original_pool, reconstructed_pool)
|
||||
|
||||
def test_json_round_trip(self):
|
||||
"""Test JSON round-trip serialization: model_dump_json() → model_validate_json()"""
|
||||
# Create a comprehensive VariablePool with all data types
|
||||
original_pool = self._create_pool_without_file()
|
||||
self._add_node_data_to_pool(original_pool)
|
||||
|
||||
# Serialize to JSON string using Pydantic's model_dump_json()
|
||||
json_data = original_pool.model_dump_json()
|
||||
|
||||
# Verify JSON is valid string
|
||||
assert isinstance(json_data, str)
|
||||
assert len(json_data) > 0
|
||||
|
||||
# Deserialize back using Pydantic's model_validate_json()
|
||||
reconstructed_pool = VariablePool.model_validate_json(json_data)
|
||||
|
||||
# Verify data integrity is preserved
|
||||
self._assert_pools_equal(original_pool, reconstructed_pool)
|
||||
|
||||
def test_complex_data_serialization(self):
|
||||
"""Test serialization of complex data structures including ArrayFileVariable"""
|
||||
original_pool = self._create_pool_without_file()
|
||||
self._add_node_data_to_pool(original_pool, with_file=True)
|
||||
|
||||
# Test dictionary round-trip
|
||||
dict_data = original_pool.model_dump()
|
||||
reconstructed_dict = VariablePool.model_validate(dict_data)
|
||||
|
||||
# Test JSON round-trip
|
||||
json_data = original_pool.model_dump_json()
|
||||
reconstructed_json = VariablePool.model_validate_json(json_data)
|
||||
|
||||
# Verify both reconstructed pools are equivalent
|
||||
self._assert_pools_equal(reconstructed_dict, reconstructed_json)
|
||||
# TODO: assert the data for file object...
|
||||
|
||||
def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool) -> None:
|
||||
"""Assert that two VariablePools contain equivalent data"""
|
||||
|
||||
# Compare system variables
|
||||
assert pool1.system_variables == pool2.system_variables
|
||||
|
||||
# Compare user inputs
|
||||
assert dict(pool1.user_inputs) == dict(pool2.user_inputs)
|
||||
|
||||
# Compare environment variables count
|
||||
assert pool1.environment_variables == pool2.environment_variables
|
||||
|
||||
# Compare conversation variables count
|
||||
assert pool1.conversation_variables == pool2.conversation_variables
|
||||
|
||||
# Test key variable retrievals to ensure functionality is preserved
|
||||
test_selectors = [
|
||||
(SYSTEM_VARIABLE_NODE_ID, "user_id"),
|
||||
(SYSTEM_VARIABLE_NODE_ID, "app_id"),
|
||||
(ENVIRONMENT_VARIABLE_NODE_ID, "env_string"),
|
||||
(ENVIRONMENT_VARIABLE_NODE_ID, "env_number"),
|
||||
(CONVERSATION_VARIABLE_NODE_ID, "conv_string"),
|
||||
(self._NODE1_ID, "string_var"),
|
||||
(self._NODE1_ID, "int_var"),
|
||||
(self._NODE1_ID, "float_var"),
|
||||
(self._NODE2_ID, "array_string"),
|
||||
(self._NODE2_ID, "array_number"),
|
||||
(self._NODE3_ID, "nested", "deep", "var"),
|
||||
]
|
||||
|
||||
for selector in test_selectors:
|
||||
val1 = pool1.get(selector)
|
||||
val2 = pool2.get(selector)
|
||||
|
||||
# Both should exist or both should be None
|
||||
assert (val1 is None) == (val2 is None)
|
||||
|
||||
if val1 is not None and val2 is not None:
|
||||
# Values should be equal
|
||||
assert val1.value == val2.value
|
||||
# Value types should be the same (more important than exact class type)
|
||||
assert val1.value_type == val2.value_type
|
||||
|
||||
def test_variable_pool_deserialization_default_dict(self):
|
||||
variable_pool = VariablePool(
|
||||
user_inputs={"a": 1, "b": "2"},
|
||||
system_variables=SystemVariable(workflow_id=str(uuid.uuid4())),
|
||||
environment_variables=[
|
||||
StringVariable(name="str_var", value="a"),
|
||||
],
|
||||
conversation_variables=[IntegerVariable(name="int_var", value=1)],
|
||||
)
|
||||
assert isinstance(variable_pool.variable_dictionary, defaultdict)
|
||||
json = variable_pool.model_dump_json()
|
||||
loaded = VariablePool.model_validate_json(json)
|
||||
assert isinstance(loaded.variable_dictionary, defaultdict)
|
||||
|
||||
loaded.add(["non_exist_node", "a"], 1)
|
||||
|
||||
pool_dict = variable_pool.model_dump()
|
||||
loaded = VariablePool.model_validate(pool_dict)
|
||||
assert isinstance(loaded.variable_dictionary, defaultdict)
|
||||
loaded.add(["non_exist_node", "a"], 1)
|
||||
|
||||
@@ -18,10 +18,10 @@ from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
@@ -67,14 +67,14 @@ def real_app_generate_entity():
|
||||
|
||||
@pytest.fixture
|
||||
def real_workflow_system_variables():
|
||||
return {
|
||||
SystemVariableKey.QUERY: "test query",
|
||||
SystemVariableKey.CONVERSATION_ID: "test-conversation-id",
|
||||
SystemVariableKey.USER_ID: "test-user-id",
|
||||
SystemVariableKey.APP_ID: "test-app-id",
|
||||
SystemVariableKey.WORKFLOW_ID: "test-workflow-id",
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: "test-workflow-run-id",
|
||||
}
|
||||
return SystemVariable(
|
||||
query="test query",
|
||||
conversation_id="test-conversation-id",
|
||||
user_id="test-user-id",
|
||||
app_id="test-app-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-run-id",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -10,7 +10,7 @@ class TestAppendVariablesRecursively:
|
||||
|
||||
def test_append_simple_dict_value(self):
|
||||
"""Test appending a simple dictionary value"""
|
||||
pool = VariablePool()
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["output"]
|
||||
variable_value = {"name": "John", "age": 30}
|
||||
@@ -33,7 +33,7 @@ class TestAppendVariablesRecursively:
|
||||
|
||||
def test_append_object_segment_value(self):
|
||||
"""Test appending an ObjectSegment value"""
|
||||
pool = VariablePool()
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["result"]
|
||||
|
||||
@@ -60,7 +60,7 @@ class TestAppendVariablesRecursively:
|
||||
|
||||
def test_append_nested_dict_value(self):
|
||||
"""Test appending a nested dictionary value"""
|
||||
pool = VariablePool()
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["data"]
|
||||
|
||||
@@ -97,7 +97,7 @@ class TestAppendVariablesRecursively:
|
||||
|
||||
def test_append_non_dict_value(self):
|
||||
"""Test appending a non-dictionary value (should not recurse)"""
|
||||
pool = VariablePool()
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["simple"]
|
||||
variable_value = "simple_string"
|
||||
@@ -114,7 +114,7 @@ class TestAppendVariablesRecursively:
|
||||
|
||||
def test_append_segment_non_object_value(self):
|
||||
"""Test appending a Segment that is not ObjectSegment (should not recurse)"""
|
||||
pool = VariablePool()
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["text"]
|
||||
variable_value = StringSegment(value="Hello World")
|
||||
@@ -132,7 +132,7 @@ class TestAppendVariablesRecursively:
|
||||
|
||||
def test_append_empty_dict_value(self):
|
||||
"""Test appending an empty dictionary value"""
|
||||
pool = VariablePool()
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["empty"]
|
||||
variable_value: dict[str, Any] = {}
|
||||
|
||||
Reference in New Issue
Block a user