refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)
refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025) This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization. Key changes: - Introduce distinct `value_type` tags for Integer and Float segments/variables - Add `VariableUnion` and `SegmentUnion` types for proper type discrimination - Leverage Pydantic's discriminated union feature for seamless serialization/deserialization - Enable accurate serialization of data structures containing these types Closes #22024.
This commit is contained in:
@@ -1,14 +1,49 @@
|
||||
import dataclasses
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.helper import encrypter
|
||||
from core.variables import SecretVariable, StringVariable
|
||||
from core.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
SegmentUnion,
|
||||
StringSegment,
|
||||
get_segment_discriminator,
|
||||
)
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import (
|
||||
ArrayAnyVariable,
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
NoneVariable,
|
||||
ObjectVariable,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
VariableUnion,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def test_segment_group_to_text():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey("user_id"): "fake-user-id",
|
||||
},
|
||||
system_variables=SystemVariable(user_id="fake-user-id"),
|
||||
user_inputs={},
|
||||
environment_variables=[
|
||||
SecretVariable(name="secret_key", value="fake-secret-key"),
|
||||
@@ -30,7 +65,7 @@ def test_segment_group_to_text():
|
||||
|
||||
def test_convert_constant_to_segment_group():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
@@ -43,9 +78,7 @@ def test_convert_constant_to_segment_group():
|
||||
|
||||
def test_convert_variable_to_segment_group():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey("user_id"): "fake-user-id",
|
||||
},
|
||||
system_variables=SystemVariable(user_id="fake-user-id"),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
@@ -56,3 +89,297 @@ def test_convert_variable_to_segment_group():
|
||||
assert segments_group.log == "fake-user-id"
|
||||
assert isinstance(segments_group.value[0], StringVariable)
|
||||
assert segments_group.value[0].value == "fake-user-id"
|
||||
|
||||
|
||||
class _Segments(BaseModel):
|
||||
segments: list[SegmentUnion]
|
||||
|
||||
|
||||
class _Variables(BaseModel):
|
||||
variables: list[VariableUnion]
|
||||
|
||||
|
||||
def create_test_file(
|
||||
file_type: FileType = FileType.DOCUMENT,
|
||||
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
|
||||
filename: str = "test.txt",
|
||||
extension: str = ".txt",
|
||||
mime_type: str = "text/plain",
|
||||
size: int = 1024,
|
||||
) -> File:
|
||||
"""Factory function to create File objects for testing"""
|
||||
return File(
|
||||
tenant_id="test-tenant",
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
|
||||
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
storage_key="test-storage-key",
|
||||
)
|
||||
|
||||
|
||||
class TestSegmentDumpAndLoad:
|
||||
"""Test suite for segment and variable serialization/deserialization"""
|
||||
|
||||
def test_segments(self):
|
||||
"""Test basic segment serialization compatibility"""
|
||||
model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")])
|
||||
json = model.model_dump_json()
|
||||
print("Json: ", json)
|
||||
loaded = _Segments.model_validate_json(json)
|
||||
assert loaded == model
|
||||
|
||||
def test_segment_number(self):
|
||||
"""Test number segment serialization compatibility"""
|
||||
model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)])
|
||||
json = model.model_dump_json()
|
||||
print("Json: ", json)
|
||||
loaded = _Segments.model_validate_json(json)
|
||||
assert loaded == model
|
||||
|
||||
def test_variables(self):
|
||||
"""Test variable serialization compatibility"""
|
||||
model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")])
|
||||
json = model.model_dump_json()
|
||||
print("Json: ", json)
|
||||
restored = _Variables.model_validate_json(json)
|
||||
assert restored == model
|
||||
|
||||
def test_all_segments_serialization(self):
|
||||
"""Test serialization/deserialization of all segment types"""
|
||||
# Create one instance of each segment type
|
||||
test_file = create_test_file()
|
||||
|
||||
all_segments: list[SegmentUnion] = [
|
||||
NoneSegment(),
|
||||
StringSegment(value="test string"),
|
||||
IntegerSegment(value=42),
|
||||
FloatSegment(value=3.14),
|
||||
ObjectSegment(value={"key": "value", "number": 123}),
|
||||
FileSegment(value=test_file),
|
||||
ArrayAnySegment(value=[1, "string", 3.14, {"key": "value"}]),
|
||||
ArrayStringSegment(value=["hello", "world"]),
|
||||
ArrayNumberSegment(value=[1, 2.5, 3]),
|
||||
ArrayObjectSegment(value=[{"id": 1}, {"id": 2}]),
|
||||
ArrayFileSegment(value=[]), # Empty array to avoid file complexity
|
||||
]
|
||||
|
||||
# Test serialization and deserialization
|
||||
model = _Segments(segments=all_segments)
|
||||
json_str = model.model_dump_json()
|
||||
loaded = _Segments.model_validate_json(json_str)
|
||||
|
||||
# Verify all segments are preserved
|
||||
assert len(loaded.segments) == len(all_segments)
|
||||
|
||||
for original, loaded_segment in zip(all_segments, loaded.segments):
|
||||
assert type(loaded_segment) == type(original)
|
||||
assert loaded_segment.value_type == original.value_type
|
||||
|
||||
# For file segments, compare key properties instead of exact equality
|
||||
if isinstance(original, FileSegment) and isinstance(loaded_segment, FileSegment):
|
||||
orig_file = original.value
|
||||
loaded_file = loaded_segment.value
|
||||
assert isinstance(orig_file, File)
|
||||
assert isinstance(loaded_file, File)
|
||||
assert loaded_file.tenant_id == orig_file.tenant_id
|
||||
assert loaded_file.type == orig_file.type
|
||||
assert loaded_file.filename == orig_file.filename
|
||||
else:
|
||||
assert loaded_segment.value == original.value
|
||||
|
||||
def test_all_variables_serialization(self):
|
||||
"""Test serialization/deserialization of all variable types"""
|
||||
# Create one instance of each variable type
|
||||
test_file = create_test_file()
|
||||
|
||||
all_variables: list[VariableUnion] = [
|
||||
NoneVariable(name="none_var"),
|
||||
StringVariable(value="test string", name="string_var"),
|
||||
IntegerVariable(value=42, name="int_var"),
|
||||
FloatVariable(value=3.14, name="float_var"),
|
||||
ObjectVariable(value={"key": "value", "number": 123}, name="object_var"),
|
||||
FileVariable(value=test_file, name="file_var"),
|
||||
ArrayAnyVariable(value=[1, "string", 3.14, {"key": "value"}], name="array_any_var"),
|
||||
ArrayStringVariable(value=["hello", "world"], name="array_string_var"),
|
||||
ArrayNumberVariable(value=[1, 2.5, 3], name="array_number_var"),
|
||||
ArrayObjectVariable(value=[{"id": 1}, {"id": 2}], name="array_object_var"),
|
||||
ArrayFileVariable(value=[], name="array_file_var"), # Empty array to avoid file complexity
|
||||
]
|
||||
|
||||
# Test serialization and deserialization
|
||||
model = _Variables(variables=all_variables)
|
||||
json_str = model.model_dump_json()
|
||||
loaded = _Variables.model_validate_json(json_str)
|
||||
|
||||
# Verify all variables are preserved
|
||||
assert len(loaded.variables) == len(all_variables)
|
||||
|
||||
for original, loaded_variable in zip(all_variables, loaded.variables):
|
||||
assert type(loaded_variable) == type(original)
|
||||
assert loaded_variable.value_type == original.value_type
|
||||
assert loaded_variable.name == original.name
|
||||
|
||||
# For file variables, compare key properties instead of exact equality
|
||||
if isinstance(original, FileVariable) and isinstance(loaded_variable, FileVariable):
|
||||
orig_file = original.value
|
||||
loaded_file = loaded_variable.value
|
||||
assert isinstance(orig_file, File)
|
||||
assert isinstance(loaded_file, File)
|
||||
assert loaded_file.tenant_id == orig_file.tenant_id
|
||||
assert loaded_file.type == orig_file.type
|
||||
assert loaded_file.filename == orig_file.filename
|
||||
else:
|
||||
assert loaded_variable.value == original.value
|
||||
|
||||
def test_segment_discriminator_function_for_segment_types(self):
|
||||
"""Test the segment discriminator function"""
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestCase:
|
||||
segment: Segment
|
||||
expected_segment_type: SegmentType
|
||||
|
||||
file1 = create_test_file()
|
||||
file2 = create_test_file(filename="test2.txt")
|
||||
|
||||
cases = [
|
||||
TestCase(
|
||||
NoneSegment(),
|
||||
SegmentType.NONE,
|
||||
),
|
||||
TestCase(
|
||||
StringSegment(value=""),
|
||||
SegmentType.STRING,
|
||||
),
|
||||
TestCase(
|
||||
FloatSegment(value=0.0),
|
||||
SegmentType.FLOAT,
|
||||
),
|
||||
TestCase(
|
||||
IntegerSegment(value=0),
|
||||
SegmentType.INTEGER,
|
||||
),
|
||||
TestCase(
|
||||
ObjectSegment(value={}),
|
||||
SegmentType.OBJECT,
|
||||
),
|
||||
TestCase(
|
||||
FileSegment(value=file1),
|
||||
SegmentType.FILE,
|
||||
),
|
||||
TestCase(
|
||||
ArrayAnySegment(value=[0, 0.0, ""]),
|
||||
SegmentType.ARRAY_ANY,
|
||||
),
|
||||
TestCase(
|
||||
ArrayStringSegment(value=[""]),
|
||||
SegmentType.ARRAY_STRING,
|
||||
),
|
||||
TestCase(
|
||||
ArrayNumberSegment(value=[0, 0.0]),
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
),
|
||||
TestCase(
|
||||
ArrayObjectSegment(value=[{}]),
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
),
|
||||
TestCase(
|
||||
ArrayFileSegment(value=[file1, file2]),
|
||||
SegmentType.ARRAY_FILE,
|
||||
),
|
||||
]
|
||||
|
||||
for test_case in cases:
|
||||
segment = test_case.segment
|
||||
assert get_segment_discriminator(segment) == test_case.expected_segment_type, (
|
||||
f"get_segment_discriminator failed for type {type(segment)}"
|
||||
)
|
||||
model_dict = segment.model_dump(mode="json")
|
||||
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
|
||||
f"get_segment_discriminator failed for serialized form of type {type(segment)}"
|
||||
)
|
||||
|
||||
def test_variable_discriminator_function_for_variable_types(self):
|
||||
"""Test the variable discriminator function"""
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestCase:
|
||||
variable: Variable
|
||||
expected_segment_type: SegmentType
|
||||
|
||||
file1 = create_test_file()
|
||||
file2 = create_test_file(filename="test2.txt")
|
||||
|
||||
cases = [
|
||||
TestCase(
|
||||
NoneVariable(name="none_var"),
|
||||
SegmentType.NONE,
|
||||
),
|
||||
TestCase(
|
||||
StringVariable(value="test", name="string_var"),
|
||||
SegmentType.STRING,
|
||||
),
|
||||
TestCase(
|
||||
FloatVariable(value=0.0, name="float_var"),
|
||||
SegmentType.FLOAT,
|
||||
),
|
||||
TestCase(
|
||||
IntegerVariable(value=0, name="int_var"),
|
||||
SegmentType.INTEGER,
|
||||
),
|
||||
TestCase(
|
||||
ObjectVariable(value={}, name="object_var"),
|
||||
SegmentType.OBJECT,
|
||||
),
|
||||
TestCase(
|
||||
FileVariable(value=file1, name="file_var"),
|
||||
SegmentType.FILE,
|
||||
),
|
||||
TestCase(
|
||||
SecretVariable(value="secret", name="secret_var"),
|
||||
SegmentType.SECRET,
|
||||
),
|
||||
TestCase(
|
||||
ArrayAnyVariable(value=[0, 0.0, ""], name="array_any_var"),
|
||||
SegmentType.ARRAY_ANY,
|
||||
),
|
||||
TestCase(
|
||||
ArrayStringVariable(value=[""], name="array_string_var"),
|
||||
SegmentType.ARRAY_STRING,
|
||||
),
|
||||
TestCase(
|
||||
ArrayNumberVariable(value=[0, 0.0], name="array_number_var"),
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
),
|
||||
TestCase(
|
||||
ArrayObjectVariable(value=[{}], name="array_object_var"),
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
),
|
||||
TestCase(
|
||||
ArrayFileVariable(value=[file1, file2], name="array_file_var"),
|
||||
SegmentType.ARRAY_FILE,
|
||||
),
|
||||
]
|
||||
|
||||
for test_case in cases:
|
||||
variable = test_case.variable
|
||||
assert get_segment_discriminator(variable) == test_case.expected_segment_type, (
|
||||
f"get_segment_discriminator failed for type {type(variable)}"
|
||||
)
|
||||
model_dict = variable.model_dump(mode="json")
|
||||
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
|
||||
f"get_segment_discriminator failed for serialized form of type {type(variable)}"
|
||||
)
|
||||
|
||||
def test_invlaid_value_for_discriminator(self):
|
||||
# Test invalid cases
|
||||
assert get_segment_discriminator({"value_type": "invalid"}) is None
|
||||
assert get_segment_discriminator({}) is None
|
||||
assert get_segment_discriminator("not_a_dict") is None
|
||||
assert get_segment_discriminator(42) is None
|
||||
assert get_segment_discriminator(object) is None
|
||||
|
||||
60
api/tests/unit_tests/core/variables/test_segment_type.py
Normal file
60
api/tests/unit_tests/core/variables/test_segment_type.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
|
||||
class TestSegmentTypeIsArrayType:
|
||||
"""
|
||||
Test class for SegmentType.is_array_type method.
|
||||
|
||||
Provides comprehensive coverage of all SegmentType values to ensure
|
||||
correct identification of array and non-array types.
|
||||
"""
|
||||
|
||||
def test_is_array_type(self):
|
||||
"""
|
||||
Test that all SegmentType enum values are covered in our test cases.
|
||||
|
||||
Ensures comprehensive coverage by verifying that every SegmentType
|
||||
value is tested for the is_array_type method.
|
||||
"""
|
||||
# Arrange
|
||||
all_segment_types = set(SegmentType)
|
||||
expected_array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
]
|
||||
expected_non_array_types = [
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.STRING,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
SegmentType.GROUP,
|
||||
]
|
||||
|
||||
for seg_type in expected_array_types:
|
||||
assert seg_type.is_array_type()
|
||||
|
||||
for seg_type in expected_non_array_types:
|
||||
assert not seg_type.is_array_type()
|
||||
|
||||
# Act & Assert
|
||||
covered_types = set(expected_array_types) | set(expected_non_array_types)
|
||||
assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests"
|
||||
|
||||
def test_all_enum_values_are_supported(self):
|
||||
"""
|
||||
Test that all enum values are supported and return boolean values.
|
||||
|
||||
Validates that every SegmentType enum value can be processed by
|
||||
is_array_type method and returns a boolean value.
|
||||
"""
|
||||
enum_values: list[SegmentType] = list(SegmentType)
|
||||
for seg_type in enum_values:
|
||||
is_array = seg_type.is_array_type()
|
||||
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"
|
||||
@@ -11,6 +11,7 @@ from core.variables import (
|
||||
SegmentType,
|
||||
StringVariable,
|
||||
)
|
||||
from core.variables.variables import Variable
|
||||
|
||||
|
||||
def test_frozen_variables():
|
||||
@@ -75,7 +76,7 @@ def test_object_variable_to_object():
|
||||
|
||||
|
||||
def test_variable_to_object():
|
||||
var = StringVariable(name="text", value="text")
|
||||
var: Variable = StringVariable(name="text", value="text")
|
||||
assert var.to_object() == "text"
|
||||
var = IntegerVariable(name="integer", value=42)
|
||||
assert var.to_object() == 42
|
||||
|
||||
Reference in New Issue
Block a user