refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)
refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025) This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization. Key changes: - Introduce distinct `value_type` tags for Integer and Float segments/variables - Add `VariableUnion` and `SegmentUnion` types for proper type discrimination - Leverage Pydantic's discriminated union feature for seamless serialization/deserialization - Enable accurate serialization of data structures containing these types Closes #22024.
This commit is contained in:
@@ -1,11 +1,29 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import AfterValidator, BaseModel, Field
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
_VALID_VAR_TYPE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
|
||||
if seg_type not in _VALID_VAR_TYPE:
|
||||
raise ValueError(...)
|
||||
return seg_type
|
||||
|
||||
|
||||
class LoopVariableData(BaseModel):
|
||||
"""
|
||||
@@ -13,7 +31,7 @@ class LoopVariableData(BaseModel):
|
||||
"""
|
||||
|
||||
label: str
|
||||
var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
||||
value_type: Literal["variable", "constant"]
|
||||
value: Optional[Any | list[str]] = None
|
||||
|
||||
|
||||
@@ -7,14 +7,9 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import (
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
IntegerSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
SegmentType,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
@@ -39,6 +34,7 @@ from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@@ -505,23 +501,21 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
return variable_mapping
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: str, value: Any) -> Segment:
|
||||
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
|
||||
"string": (StringSegment, SegmentType.STRING),
|
||||
"number": (IntegerSegment, SegmentType.NUMBER),
|
||||
"object": (ObjectSegment, SegmentType.OBJECT),
|
||||
"array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
|
||||
"array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
|
||||
"array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
|
||||
}
|
||||
if var_type in ["array[string]", "array[number]", "array[object]"]:
|
||||
if value:
|
||||
if value and isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
else:
|
||||
value = []
|
||||
segment_info = segment_mapping.get(var_type)
|
||||
if not segment_info:
|
||||
raise ValueError(f"Invalid variable type: {var_type}")
|
||||
segment_class, value_type = segment_info
|
||||
return segment_class(value=value, value_type=value_type)
|
||||
try:
|
||||
return build_segment_with_type(var_type, value)
|
||||
except TypeMismatchError as type_exc:
|
||||
# Attempt to parse the value as a JSON-encoded string, if applicable.
|
||||
if not isinstance(value, str):
|
||||
raise
|
||||
try:
|
||||
value = json.loads(value)
|
||||
except ValueError:
|
||||
raise type_exc
|
||||
return build_segment_with_type(var_type, value)
|
||||
|
||||
Reference in New Issue
Block a user