refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)

refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)

This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization.

Key changes:
- Introduce distinct `value_type` tags for Integer and Float segments/variables
- Add `VariableUnion` and `SegmentUnion` types for proper type discrimination
- Leverage Pydantic's discriminated union feature for seamless serialization/deserialization
- Enable accurate serialization of data structures containing these types

Closes #22024.
This commit is contained in:
QuantumGhost
2025-07-16 12:31:37 +08:00
committed by GitHub
parent 229b4d621e
commit 2c1ab4879f
58 changed files with 2325 additions and 328 deletions

View File

@@ -1,9 +1,9 @@
import json
import sys
from collections.abc import Mapping, Sequence
from typing import Any
from typing import Annotated, Any, TypeAlias
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.file import File
@@ -11,6 +11,11 @@ from .types import SegmentType
class Segment(BaseModel):
"""Segment is runtime type used during the execution of workflow.
Note: this class is abstract, you should use subclasses of this class instead.
"""
model_config = ConfigDict(frozen=True)
value_type: SegmentType
@@ -73,7 +78,7 @@ class StringSegment(Segment):
class FloatSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER
value_type: SegmentType = SegmentType.FLOAT
value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass.
@@ -92,7 +97,7 @@ class FloatSegment(Segment):
class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER
value_type: SegmentType = SegmentType.INTEGER
value: int
@@ -181,3 +186,46 @@ class ArrayFileSegment(ArraySegment):
@property
def text(self) -> str:
return ""
def get_segment_discriminator(v: Any) -> SegmentType | None:
if isinstance(v, Segment):
return v.value_type
elif isinstance(v, dict):
value_type = v.get("value_type")
if value_type is None:
return None
try:
seg_type = SegmentType(value_type)
except ValueError:
return None
return seg_type
else:
# return None if the discriminator value isn't found
return None
# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic.
# Use `Segment` for type hinting when serialization is not required.
#
# Note:
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool.
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
SegmentUnion: TypeAlias = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]
| Annotated[StringSegment, Tag(SegmentType.STRING)]
| Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
| Annotated[FileSegment, Tag(SegmentType.FILE)]
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
),
Discriminator(get_segment_discriminator),
]

View File

@@ -1,8 +1,27 @@
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, Optional
from core.file.models import File
class ArrayValidation(StrEnum):
"""Strategy for validating array elements"""
# Skip element validation (only check array container)
NONE = "none"
# Validate the first element (if array is non-empty)
FIRST = "first"
# Validate all elements in the array.
ALL = "all"
class SegmentType(StrEnum):
NUMBER = "number"
INTEGER = "integer"
FLOAT = "float"
STRING = "string"
OBJECT = "object"
SECRET = "secret"
@@ -19,16 +38,141 @@ class SegmentType(StrEnum):
GROUP = "group"
def is_array_type(self):
def is_array_type(self) -> bool:
return self in _ARRAY_TYPES
@classmethod
def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
"""
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
Returns `None` if no appropriate `SegmentType` can be determined for the given `value`.
For example, this may occur if the input is a generic Python object of type `object`.
"""
if isinstance(value, list):
elem_types: set[SegmentType] = set()
for i in value:
segment_type = cls.infer_segment_type(i)
if segment_type is None:
return None
elem_types.add(segment_type)
if len(elem_types) != 1:
if elem_types.issubset(_NUMERICAL_TYPES):
return SegmentType.ARRAY_NUMBER
return SegmentType.ARRAY_ANY
elif all(i.is_array_type() for i in elem_types):
return SegmentType.ARRAY_ANY
match elem_types.pop():
case SegmentType.STRING:
return SegmentType.ARRAY_STRING
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return SegmentType.ARRAY_NUMBER
case SegmentType.OBJECT:
return SegmentType.ARRAY_OBJECT
case SegmentType.FILE:
return SegmentType.ARRAY_FILE
case SegmentType.NONE:
return SegmentType.ARRAY_ANY
case _:
# This should be unreachable.
raise ValueError(f"not supported value {value}")
if value is None:
return SegmentType.NONE
elif isinstance(value, int) and not isinstance(value, bool):
return SegmentType.INTEGER
elif isinstance(value, float):
return SegmentType.FLOAT
elif isinstance(value, str):
return SegmentType.STRING
elif isinstance(value, dict):
return SegmentType.OBJECT
elif isinstance(value, File):
return SegmentType.FILE
elif isinstance(value, str):
return SegmentType.STRING
else:
return None
def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool:
if not isinstance(value, list):
return False
# Skip element validation if array is empty
if len(value) == 0:
return True
if self == SegmentType.ARRAY_ANY:
return True
element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self]
if array_validation == ArrayValidation.NONE:
return True
elif array_validation == ArrayValidation.FIRST:
return element_type.is_valid(value[0])
else:
return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
"""
Check if a value matches the segment type.
Users of `SegmentType` should call this method, instead of using
`isinstance` manually.
Args:
value: The value to validate
array_validation: Validation strategy for array types (ignored for non-array types)
Returns:
True if the value matches the type under the given validation strategy
"""
if self.is_array_type():
return self._validate_array(value, array_validation)
elif self == SegmentType.NUMBER:
return isinstance(value, (int, float))
elif self == SegmentType.STRING:
return isinstance(value, str)
elif self == SegmentType.OBJECT:
return isinstance(value, dict)
elif self == SegmentType.SECRET:
return isinstance(value, str)
elif self == SegmentType.FILE:
return isinstance(value, File)
elif self == SegmentType.NONE:
return value is None
else:
raise AssertionError("this statement should be unreachable.")
def exposed_type(self) -> "SegmentType":
"""Returns the type exposed to the frontend.
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
"""
if self in (SegmentType.INTEGER, SegmentType.FLOAT):
return SegmentType.NUMBER
return self
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
# ARRAY_ANY does not have correpond element type.
SegmentType.ARRAY_STRING: SegmentType.STRING,
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
SegmentType.ARRAY_FILE: SegmentType.FILE,
}
_ARRAY_TYPES = frozenset(
[
list(_ARRAY_ELEMENT_TYPES_MAPPING.keys())
+ [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
]
)
_NUMERICAL_TYPES = frozenset(
[
SegmentType.NUMBER,
SegmentType.INTEGER,
SegmentType.FLOAT,
]
)

View File

@@ -1,8 +1,8 @@
from collections.abc import Sequence
from typing import cast
from typing import Annotated, TypeAlias, cast
from uuid import uuid4
from pydantic import Field
from pydantic import Discriminator, Field, Tag
from core.helper import encrypter
@@ -20,6 +20,7 @@ from .segments import (
ObjectSegment,
Segment,
StringSegment,
get_segment_discriminator,
)
from .types import SegmentType
@@ -27,6 +28,10 @@ from .types import SegmentType
class Variable(Segment):
"""
A variable is a segment that has a name.
It is mainly used to store segments and their selector in VariablePool.
Note: this class is abstract, you should use subclasses of this class instead.
"""
id: str = Field(
@@ -93,3 +98,28 @@ class FileVariable(FileSegment, Variable):
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
# Use `Variable` for type hinting when serialization is not required.
#
# Note:
# - All variants in `VariableUnion` must inherit from the `Variable` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
VariableUnion: TypeAlias = Annotated[
(
Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)]
| Annotated[FloatVariable, Tag(SegmentType.FLOAT)]
| Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
| Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
| Annotated[FileVariable, Tag(SegmentType.FILE)]
| Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
),
Discriminator(get_segment_discriminator),
]