refactor: simplify variable pool key structure and improve type safety (#23732)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -69,8 +69,12 @@ def test_get_file_attribute(pool, file):
|
||||
|
||||
|
||||
def test_use_long_selector(pool):
|
||||
pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value"))
|
||||
# The add method now only accepts 2-element selectors (node_id, variable_name)
|
||||
# Store nested data as an ObjectSegment instead
|
||||
nested_data = {"part_2": "test_value"}
|
||||
pool.add(("node_1", "part_1"), ObjectSegment(value=nested_data))
|
||||
|
||||
# The get method supports longer selectors for nested access
|
||||
result = pool.get(("node_1", "part_1", "part_2"))
|
||||
assert result is not None
|
||||
assert result.value == "test_value"
|
||||
@@ -280,8 +284,10 @@ class TestVariablePoolSerialization:
|
||||
pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
|
||||
pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
|
||||
|
||||
# Add nested variables
|
||||
pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value"))
|
||||
# Add nested variables as ObjectSegment
|
||||
# The add method only accepts 2-element selectors
|
||||
nested_obj = {"deep": {"var": "deep_value"}}
|
||||
pool.add((self._NODE3_ID, "nested"), ObjectSegment(value=nested_obj))
|
||||
|
||||
def test_system_variables(self):
|
||||
sys_vars = SystemVariable(
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.variables.segments import ObjectSegment, StringSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.utils.variable_utils import append_variables_recursively
|
||||
|
||||
|
||||
class TestAppendVariablesRecursively:
|
||||
"""Test cases for append_variables_recursively function"""
|
||||
|
||||
def test_append_simple_dict_value(self):
|
||||
"""Test appending a simple dictionary value"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["output"]
|
||||
variable_value = {"name": "John", "age": 30}
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert main_var.value == variable_value
|
||||
|
||||
# Check that nested variables are added recursively
|
||||
name_var = pool.get([node_id] + variable_key_list + ["name"])
|
||||
assert name_var is not None
|
||||
assert name_var.value == "John"
|
||||
|
||||
age_var = pool.get([node_id] + variable_key_list + ["age"])
|
||||
assert age_var is not None
|
||||
assert age_var.value == 30
|
||||
|
||||
def test_append_object_segment_value(self):
|
||||
"""Test appending an ObjectSegment value"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["result"]
|
||||
|
||||
# Create an ObjectSegment
|
||||
obj_data = {"status": "success", "code": 200}
|
||||
variable_value = ObjectSegment(value=obj_data)
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert isinstance(main_var, ObjectSegment)
|
||||
assert main_var.value == obj_data
|
||||
|
||||
# Check that nested variables are added recursively
|
||||
status_var = pool.get([node_id] + variable_key_list + ["status"])
|
||||
assert status_var is not None
|
||||
assert status_var.value == "success"
|
||||
|
||||
code_var = pool.get([node_id] + variable_key_list + ["code"])
|
||||
assert code_var is not None
|
||||
assert code_var.value == 200
|
||||
|
||||
def test_append_nested_dict_value(self):
|
||||
"""Test appending a nested dictionary value"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["data"]
|
||||
|
||||
variable_value = {
|
||||
"user": {
|
||||
"profile": {"name": "Alice", "email": "alice@example.com"},
|
||||
"settings": {"theme": "dark", "notifications": True},
|
||||
},
|
||||
"metadata": {"version": "1.0", "timestamp": 1234567890},
|
||||
}
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check deeply nested variables
|
||||
name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"])
|
||||
assert name_var is not None
|
||||
assert name_var.value == "Alice"
|
||||
|
||||
email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"])
|
||||
assert email_var is not None
|
||||
assert email_var.value == "alice@example.com"
|
||||
|
||||
theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"])
|
||||
assert theme_var is not None
|
||||
assert theme_var.value == "dark"
|
||||
|
||||
notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"])
|
||||
assert notifications_var is not None
|
||||
assert notifications_var.value == 1 # Boolean True is converted to integer 1
|
||||
|
||||
version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"])
|
||||
assert version_var is not None
|
||||
assert version_var.value == "1.0"
|
||||
|
||||
def test_append_non_dict_value(self):
|
||||
"""Test appending a non-dictionary value (should not recurse)"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["simple"]
|
||||
variable_value = "simple_string"
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that only the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert main_var.value == variable_value
|
||||
|
||||
# Ensure no additional variables are created
|
||||
assert len(pool.variable_dictionary[node_id]) == 1
|
||||
|
||||
def test_append_segment_non_object_value(self):
|
||||
"""Test appending a Segment that is not ObjectSegment (should not recurse)"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["text"]
|
||||
variable_value = StringSegment(value="Hello World")
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that only the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert isinstance(main_var, StringSegment)
|
||||
assert main_var.value == "Hello World"
|
||||
|
||||
# Ensure no additional variables are created
|
||||
assert len(pool.variable_dictionary[node_id]) == 1
|
||||
|
||||
def test_append_empty_dict_value(self):
|
||||
"""Test appending an empty dictionary value"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["empty"]
|
||||
variable_value: dict[str, Any] = {}
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert main_var.value == {}
|
||||
|
||||
# Ensure only the main variable is created (no recursion for empty dict)
|
||||
assert len(pool.variable_dictionary[node_id]) == 1
|
||||
Reference in New Issue
Block a user