refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)

refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)

This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization.

Key changes:
- Introduce distinct `value_type` tags for Integer and Float segments/variables
- Add `VariableUnion` and `SegmentUnion` types for proper type discrimination
- Leverage Pydantic's discriminated union feature for seamless serialization/deserialization
- Enable accurate serialization of data structures containing these types

Closes #22024.
This commit is contained in:
QuantumGhost
2025-07-16 12:31:37 +08:00
committed by GitHub
parent 229b4d621e
commit 2c1ab4879f
58 changed files with 2325 additions and 328 deletions

View File

@@ -1,14 +1,49 @@
import dataclasses
from pydantic import BaseModel
from core.file import File, FileTransferMethod, FileType
from core.helper import encrypter
from core.variables import SecretVariable, StringVariable
from core.variables.segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
SegmentUnion,
StringSegment,
get_segment_discriminator,
)
from core.variables.types import SegmentType
from core.variables.variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
VariableUnion,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.system_variable import SystemVariable
def test_segment_group_to_text():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
system_variables=SystemVariable(user_id="fake-user-id"),
user_inputs={},
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
@@ -30,7 +65,7 @@ def test_segment_group_to_text():
def test_convert_constant_to_segment_group():
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -43,9 +78,7 @@ def test_convert_constant_to_segment_group():
def test_convert_variable_to_segment_group():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
system_variables=SystemVariable(user_id="fake-user-id"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -56,3 +89,297 @@ def test_convert_variable_to_segment_group():
assert segments_group.log == "fake-user-id"
assert isinstance(segments_group.value[0], StringVariable)
assert segments_group.value[0].value == "fake-user-id"
class _Segments(BaseModel):
segments: list[SegmentUnion]
class _Variables(BaseModel):
variables: list[VariableUnion]
def create_test_file(
file_type: FileType = FileType.DOCUMENT,
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
filename: str = "test.txt",
extension: str = ".txt",
mime_type: str = "text/plain",
size: int = 1024,
) -> File:
"""Factory function to create File objects for testing"""
return File(
tenant_id="test-tenant",
type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
storage_key="test-storage-key",
)
class TestSegmentDumpAndLoad:
"""Test suite for segment and variable serialization/deserialization"""
def test_segments(self):
"""Test basic segment serialization compatibility"""
model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")])
json = model.model_dump_json()
print("Json: ", json)
loaded = _Segments.model_validate_json(json)
assert loaded == model
def test_segment_number(self):
"""Test number segment serialization compatibility"""
model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)])
json = model.model_dump_json()
print("Json: ", json)
loaded = _Segments.model_validate_json(json)
assert loaded == model
def test_variables(self):
"""Test variable serialization compatibility"""
model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")])
json = model.model_dump_json()
print("Json: ", json)
restored = _Variables.model_validate_json(json)
assert restored == model
def test_all_segments_serialization(self):
"""Test serialization/deserialization of all segment types"""
# Create one instance of each segment type
test_file = create_test_file()
all_segments: list[SegmentUnion] = [
NoneSegment(),
StringSegment(value="test string"),
IntegerSegment(value=42),
FloatSegment(value=3.14),
ObjectSegment(value={"key": "value", "number": 123}),
FileSegment(value=test_file),
ArrayAnySegment(value=[1, "string", 3.14, {"key": "value"}]),
ArrayStringSegment(value=["hello", "world"]),
ArrayNumberSegment(value=[1, 2.5, 3]),
ArrayObjectSegment(value=[{"id": 1}, {"id": 2}]),
ArrayFileSegment(value=[]), # Empty array to avoid file complexity
]
# Test serialization and deserialization
model = _Segments(segments=all_segments)
json_str = model.model_dump_json()
loaded = _Segments.model_validate_json(json_str)
# Verify all segments are preserved
assert len(loaded.segments) == len(all_segments)
for original, loaded_segment in zip(all_segments, loaded.segments):
assert type(loaded_segment) == type(original)
assert loaded_segment.value_type == original.value_type
# For file segments, compare key properties instead of exact equality
if isinstance(original, FileSegment) and isinstance(loaded_segment, FileSegment):
orig_file = original.value
loaded_file = loaded_segment.value
assert isinstance(orig_file, File)
assert isinstance(loaded_file, File)
assert loaded_file.tenant_id == orig_file.tenant_id
assert loaded_file.type == orig_file.type
assert loaded_file.filename == orig_file.filename
else:
assert loaded_segment.value == original.value
def test_all_variables_serialization(self):
"""Test serialization/deserialization of all variable types"""
# Create one instance of each variable type
test_file = create_test_file()
all_variables: list[VariableUnion] = [
NoneVariable(name="none_var"),
StringVariable(value="test string", name="string_var"),
IntegerVariable(value=42, name="int_var"),
FloatVariable(value=3.14, name="float_var"),
ObjectVariable(value={"key": "value", "number": 123}, name="object_var"),
FileVariable(value=test_file, name="file_var"),
ArrayAnyVariable(value=[1, "string", 3.14, {"key": "value"}], name="array_any_var"),
ArrayStringVariable(value=["hello", "world"], name="array_string_var"),
ArrayNumberVariable(value=[1, 2.5, 3], name="array_number_var"),
ArrayObjectVariable(value=[{"id": 1}, {"id": 2}], name="array_object_var"),
ArrayFileVariable(value=[], name="array_file_var"), # Empty array to avoid file complexity
]
# Test serialization and deserialization
model = _Variables(variables=all_variables)
json_str = model.model_dump_json()
loaded = _Variables.model_validate_json(json_str)
# Verify all variables are preserved
assert len(loaded.variables) == len(all_variables)
for original, loaded_variable in zip(all_variables, loaded.variables):
assert type(loaded_variable) == type(original)
assert loaded_variable.value_type == original.value_type
assert loaded_variable.name == original.name
# For file variables, compare key properties instead of exact equality
if isinstance(original, FileVariable) and isinstance(loaded_variable, FileVariable):
orig_file = original.value
loaded_file = loaded_variable.value
assert isinstance(orig_file, File)
assert isinstance(loaded_file, File)
assert loaded_file.tenant_id == orig_file.tenant_id
assert loaded_file.type == orig_file.type
assert loaded_file.filename == orig_file.filename
else:
assert loaded_variable.value == original.value
def test_segment_discriminator_function_for_segment_types(self):
"""Test the segment discriminator function"""
@dataclasses.dataclass
class TestCase:
segment: Segment
expected_segment_type: SegmentType
file1 = create_test_file()
file2 = create_test_file(filename="test2.txt")
cases = [
TestCase(
NoneSegment(),
SegmentType.NONE,
),
TestCase(
StringSegment(value=""),
SegmentType.STRING,
),
TestCase(
FloatSegment(value=0.0),
SegmentType.FLOAT,
),
TestCase(
IntegerSegment(value=0),
SegmentType.INTEGER,
),
TestCase(
ObjectSegment(value={}),
SegmentType.OBJECT,
),
TestCase(
FileSegment(value=file1),
SegmentType.FILE,
),
TestCase(
ArrayAnySegment(value=[0, 0.0, ""]),
SegmentType.ARRAY_ANY,
),
TestCase(
ArrayStringSegment(value=[""]),
SegmentType.ARRAY_STRING,
),
TestCase(
ArrayNumberSegment(value=[0, 0.0]),
SegmentType.ARRAY_NUMBER,
),
TestCase(
ArrayObjectSegment(value=[{}]),
SegmentType.ARRAY_OBJECT,
),
TestCase(
ArrayFileSegment(value=[file1, file2]),
SegmentType.ARRAY_FILE,
),
]
for test_case in cases:
segment = test_case.segment
assert get_segment_discriminator(segment) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for type {type(segment)}"
)
model_dict = segment.model_dump(mode="json")
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for serialized form of type {type(segment)}"
)
def test_variable_discriminator_function_for_variable_types(self):
"""Test the variable discriminator function"""
@dataclasses.dataclass
class TestCase:
variable: Variable
expected_segment_type: SegmentType
file1 = create_test_file()
file2 = create_test_file(filename="test2.txt")
cases = [
TestCase(
NoneVariable(name="none_var"),
SegmentType.NONE,
),
TestCase(
StringVariable(value="test", name="string_var"),
SegmentType.STRING,
),
TestCase(
FloatVariable(value=0.0, name="float_var"),
SegmentType.FLOAT,
),
TestCase(
IntegerVariable(value=0, name="int_var"),
SegmentType.INTEGER,
),
TestCase(
ObjectVariable(value={}, name="object_var"),
SegmentType.OBJECT,
),
TestCase(
FileVariable(value=file1, name="file_var"),
SegmentType.FILE,
),
TestCase(
SecretVariable(value="secret", name="secret_var"),
SegmentType.SECRET,
),
TestCase(
ArrayAnyVariable(value=[0, 0.0, ""], name="array_any_var"),
SegmentType.ARRAY_ANY,
),
TestCase(
ArrayStringVariable(value=[""], name="array_string_var"),
SegmentType.ARRAY_STRING,
),
TestCase(
ArrayNumberVariable(value=[0, 0.0], name="array_number_var"),
SegmentType.ARRAY_NUMBER,
),
TestCase(
ArrayObjectVariable(value=[{}], name="array_object_var"),
SegmentType.ARRAY_OBJECT,
),
TestCase(
ArrayFileVariable(value=[file1, file2], name="array_file_var"),
SegmentType.ARRAY_FILE,
),
]
for test_case in cases:
variable = test_case.variable
assert get_segment_discriminator(variable) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for type {type(variable)}"
)
model_dict = variable.model_dump(mode="json")
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for serialized form of type {type(variable)}"
)
def test_invlaid_value_for_discriminator(self):
# Test invalid cases
assert get_segment_discriminator({"value_type": "invalid"}) is None
assert get_segment_discriminator({}) is None
assert get_segment_discriminator("not_a_dict") is None
assert get_segment_discriminator(42) is None
assert get_segment_discriminator(object) is None

View File

@@ -0,0 +1,60 @@
from core.variables.types import SegmentType
class TestSegmentTypeIsArrayType:
"""
Test class for SegmentType.is_array_type method.
Provides comprehensive coverage of all SegmentType values to ensure
correct identification of array and non-array types.
"""
def test_is_array_type(self):
"""
Test that all SegmentType enum values are covered in our test cases.
Ensures comprehensive coverage by verifying that every SegmentType
value is tested for the is_array_type method.
"""
# Arrange
all_segment_types = set(SegmentType)
expected_array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
]
expected_non_array_types = [
SegmentType.INTEGER,
SegmentType.FLOAT,
SegmentType.NUMBER,
SegmentType.STRING,
SegmentType.OBJECT,
SegmentType.SECRET,
SegmentType.FILE,
SegmentType.NONE,
SegmentType.GROUP,
]
for seg_type in expected_array_types:
assert seg_type.is_array_type()
for seg_type in expected_non_array_types:
assert not seg_type.is_array_type()
# Act & Assert
covered_types = set(expected_array_types) | set(expected_non_array_types)
assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests"
def test_all_enum_values_are_supported(self):
"""
Test that all enum values are supported and return boolean values.
Validates that every SegmentType enum value can be processed by
is_array_type method and returns a boolean value.
"""
enum_values: list[SegmentType] = list(SegmentType)
for seg_type in enum_values:
is_array = seg_type.is_array_type()
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"

View File

@@ -11,6 +11,7 @@ from core.variables import (
SegmentType,
StringVariable,
)
from core.variables.variables import Variable
def test_frozen_variables():
@@ -75,7 +76,7 @@ def test_object_variable_to_object():
def test_variable_to_object():
var = StringVariable(name="text", value="text")
var: Variable = StringVariable(name="text", value="text")
assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42