refactor: simplify variable pool key structure and improve type safety (#23732)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-08-11 18:10:04 +08:00
committed by GitHub
parent 223c1a8089
commit 577062b93a
10 changed files with 102 additions and 259 deletions

View File

@@ -69,8 +69,12 @@ def test_get_file_attribute(pool, file):
def test_use_long_selector(pool):
pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value"))
# The add method now only accepts 2-element selectors (node_id, variable_name)
# Store nested data as an ObjectSegment instead
nested_data = {"part_2": "test_value"}
pool.add(("node_1", "part_1"), ObjectSegment(value=nested_data))
# The get method supports longer selectors for nested access
result = pool.get(("node_1", "part_1", "part_2"))
assert result is not None
assert result.value == "test_value"
@@ -280,8 +284,10 @@ class TestVariablePoolSerialization:
pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
# Add nested variables
pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value"))
# Add nested variables as ObjectSegment
# The add method only accepts 2-element selectors
nested_obj = {"deep": {"var": "deep_value"}}
pool.add((self._NODE3_ID, "nested"), ObjectSegment(value=nested_obj))
def test_system_variables(self):
sys_vars = SystemVariable(

View File

@@ -1,148 +0,0 @@
from typing import Any
from core.variables.segments import ObjectSegment, StringSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.utils.variable_utils import append_variables_recursively
class TestAppendVariablesRecursively:
"""Test cases for append_variables_recursively function"""
def test_append_simple_dict_value(self):
"""Test appending a simple dictionary value"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["output"]
variable_value = {"name": "John", "age": 30}
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert main_var.value == variable_value
# Check that nested variables are added recursively
name_var = pool.get([node_id] + variable_key_list + ["name"])
assert name_var is not None
assert name_var.value == "John"
age_var = pool.get([node_id] + variable_key_list + ["age"])
assert age_var is not None
assert age_var.value == 30
def test_append_object_segment_value(self):
"""Test appending an ObjectSegment value"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["result"]
# Create an ObjectSegment
obj_data = {"status": "success", "code": 200}
variable_value = ObjectSegment(value=obj_data)
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert isinstance(main_var, ObjectSegment)
assert main_var.value == obj_data
# Check that nested variables are added recursively
status_var = pool.get([node_id] + variable_key_list + ["status"])
assert status_var is not None
assert status_var.value == "success"
code_var = pool.get([node_id] + variable_key_list + ["code"])
assert code_var is not None
assert code_var.value == 200
def test_append_nested_dict_value(self):
"""Test appending a nested dictionary value"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["data"]
variable_value = {
"user": {
"profile": {"name": "Alice", "email": "alice@example.com"},
"settings": {"theme": "dark", "notifications": True},
},
"metadata": {"version": "1.0", "timestamp": 1234567890},
}
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check deeply nested variables
name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"])
assert name_var is not None
assert name_var.value == "Alice"
email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"])
assert email_var is not None
assert email_var.value == "alice@example.com"
theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"])
assert theme_var is not None
assert theme_var.value == "dark"
notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"])
assert notifications_var is not None
assert notifications_var.value == 1 # Boolean True is converted to integer 1
version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"])
assert version_var is not None
assert version_var.value == "1.0"
def test_append_non_dict_value(self):
"""Test appending a non-dictionary value (should not recurse)"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["simple"]
variable_value = "simple_string"
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that only the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert main_var.value == variable_value
# Ensure no additional variables are created
assert len(pool.variable_dictionary[node_id]) == 1
def test_append_segment_non_object_value(self):
"""Test appending a Segment that is not ObjectSegment (should not recurse)"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["text"]
variable_value = StringSegment(value="Hello World")
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that only the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert isinstance(main_var, StringSegment)
assert main_var.value == "Hello World"
# Ensure no additional variables are created
assert len(pool.variable_dictionary[node_id]) == 1
def test_append_empty_dict_value(self):
"""Test appending an empty dictionary value"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["empty"]
variable_value: dict[str, Any] = {}
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert main_var.value == {}
# Ensure only the main variable is created (no recursion for empty dict)
assert len(pool.variable_dictionary[node_id]) == 1