refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)

refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)

This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization.

Key changes:
- Introduce distinct `value_type` tags for Integer and Float segments/variables
- Add `VariableUnion` and `SegmentUnion` types for proper type discrimination
- Leverage Pydantic's discriminated union feature for seamless serialization/deserialization
- Enable accurate serialization of data structures containing these types

Closes #22024.
This commit is contained in:
QuantumGhost
2025-07-16 12:31:37 +08:00
committed by GitHub
parent 229b4d621e
commit 2c1ab4879f
58 changed files with 2325 additions and 328 deletions

View File

@@ -1,11 +1,29 @@
from collections.abc import Mapping
from typing import Any, Literal, Optional
from typing import Annotated, Any, Literal, Optional
from pydantic import BaseModel, Field
from pydantic import AfterValidator, BaseModel, Field
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from core.workflow.utils.condition.entities import Condition
_VALID_VAR_TYPE = frozenset(
[
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
]
)
def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
if seg_type not in _VALID_VAR_TYPE:
raise ValueError(...)
return seg_type
class LoopVariableData(BaseModel):
"""
@@ -13,7 +31,7 @@ class LoopVariableData(BaseModel):
"""
label: str
var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: Optional[Any | list[str]] = None

View File

@@ -7,14 +7,9 @@ from typing import TYPE_CHECKING, Any, Literal, cast
from configs import dify_config
from core.variables import (
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
IntegerSegment,
ObjectSegment,
Segment,
SegmentType,
StringSegment,
)
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@@ -39,6 +34,7 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
from factories.variable_factory import TypeMismatchError, build_segment_with_type
if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
@@ -505,23 +501,21 @@ class LoopNode(BaseNode[LoopNodeData]):
return variable_mapping
@staticmethod
def _get_segment_for_constant(var_type: str, value: Any) -> Segment:
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
"string": (StringSegment, SegmentType.STRING),
"number": (IntegerSegment, SegmentType.NUMBER),
"object": (ObjectSegment, SegmentType.OBJECT),
"array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
"array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
"array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
}
if var_type in ["array[string]", "array[number]", "array[object]"]:
if value:
if value and isinstance(value, str):
value = json.loads(value)
else:
value = []
segment_info = segment_mapping.get(var_type)
if not segment_info:
raise ValueError(f"Invalid variable type: {var_type}")
segment_class, value_type = segment_info
return segment_class(value=value, value_type=value_type)
try:
return build_segment_with_type(var_type, value)
except TypeMismatchError as type_exc:
# Attempt to parse the value as a JSON-encoded string, if applicable.
if not isinstance(value, str):
raise
try:
value = json.loads(value)
except ValueError:
raise type_exc
return build_segment_with_type(var_type, value)