refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)

refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)

This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization.

Key changes:
- Introduce distinct `value_type` tags for Integer and Float segments/variables
- Add `VariableUnion` and `SegmentUnion` types for proper type discrimination
- Leverage Pydantic's discriminated union feature for seamless serialization/deserialization
- Enable accurate serialization of data structures containing these types

Closes #22024.
This commit is contained in:
QuantumGhost
2025-07-16 12:31:37 +08:00
committed by GitHub
parent 229b4d621e
commit 2c1ab4879f
58 changed files with 2325 additions and 328 deletions

View File

@@ -1,7 +1,7 @@
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union
from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
@@ -9,8 +9,9 @@ from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import SystemVariableKey
from core.workflow.system_variable import SystemVariable
from factories import variable_factory
VariableValue = Union[str, int, float, dict, list, File]
@@ -23,31 +24,31 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: dict[str, dict[int, Segment]] = Field(
variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
# TODO: This user inputs is not used for pool.
# The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere.
user_inputs: Mapping[str, Any] = Field(
description="User inputs",
default_factory=dict,
)
system_variables: Mapping[SystemVariableKey, Any] = Field(
system_variables: SystemVariable = Field(
description="System variables",
default_factory=dict,
)
environment_variables: Sequence[Variable] = Field(
environment_variables: Sequence[VariableUnion] = Field(
description="Environment variables.",
default_factory=list,
)
conversation_variables: Sequence[Variable] = Field(
conversation_variables: Sequence[VariableUnion] = Field(
description="Conversation variables.",
default_factory=list,
)
def model_post_init(self, context: Any, /) -> None:
for key, value in self.system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Create a mapping from field names to SystemVariableKey enum values
self._add_system_variables(self.system_variables)
# Add environment variables to the variable pool
for var in self.environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
@@ -83,8 +84,22 @@ class VariablePool(BaseModel):
segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]][hash_key] = variable
key, hash_key = self._selector_to_keys(selector)
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
return selector[0], hash(tuple(selector[1:]))
def _has(self, selector: Sequence[str]) -> bool:
key, hash_key = self._selector_to_keys(selector)
if key not in self.variable_dictionary:
return False
if hash_key not in self.variable_dictionary[key]:
return False
return True
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
@@ -102,8 +117,8 @@ class VariablePool(BaseModel):
if len(selector) < MIN_SELECTORS_LENGTH:
return None
hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key)
key, hash_key = self._selector_to_keys(selector)
value: Segment | None = self.variable_dictionary[key].get(hash_key)
if value is None:
selector, attr = selector[:-1], selector[-1]
@@ -136,8 +151,9 @@ class VariablePool(BaseModel):
if len(selector) == 1:
self.variable_dictionary[selector[0]] = {}
return
key, hash_key = self._selector_to_keys(selector)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]].pop(hash_key, None)
self.variable_dictionary[key].pop(hash_key, None)
def convert_template(self, template: str, /):
parts = VARIABLE_PATTERN.split(template)
@@ -154,3 +170,20 @@ class VariablePool(BaseModel):
if isinstance(segment, FileSegment):
return segment
return None
def _add_system_variables(self, system_variable: SystemVariable):
sys_var_mapping = system_variable.to_dict()
for key, value in sys_var_mapping.items():
if value is None:
continue
selector = (SYSTEM_VARIABLE_NODE_ID, key)
# If the system variable already exists, do not add it again.
# This ensures that we can keep the id of the system variables intact.
if self._has(selector):
continue
self.add(selector, value) # type: ignore
@classmethod
def empty(cls) -> "VariablePool":
"""Create an empty variable pool."""
return cls(system_variables=SystemVariable.empty())