feat: knowledge pipeline (#25360)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: jyong <718720800@qq.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: quicksand <quicksandzn@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Hanqing Zhao <sherry9277@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry <xh001x@hotmail.com>
This commit is contained in:
@@ -104,6 +104,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
||||
patch("services.dataset_service.DatasetService._has_dataset_same_name") as has_dataset_same_name,
|
||||
):
|
||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
mock_naive_utc_now.return_value = current_time
|
||||
@@ -114,6 +115,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
"db_session": mock_db,
|
||||
"naive_utc_now": mock_naive_utc_now,
|
||||
"current_time": current_time,
|
||||
"has_dataset_same_name": has_dataset_same_name,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
@@ -190,9 +192,9 @@ class TestDatasetServiceUpdateDataset:
|
||||
"external_knowledge_api_id": "new_api_id",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify permission check was called
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
|
||||
# Verify dataset and binding updates
|
||||
@@ -214,6 +216,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
@@ -227,6 +230,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
@@ -250,6 +254,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
"external_knowledge_id": "knowledge_id",
|
||||
"external_knowledge_api_id": "api_id",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
@@ -280,6 +285,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify permission check was called
|
||||
@@ -320,6 +326,8 @@ class TestDatasetServiceUpdateDataset:
|
||||
"embedding_model": None, # Should be filtered out
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with filtered data
|
||||
@@ -356,6 +364,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
@@ -402,6 +411,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
@@ -453,6 +463,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
@@ -505,6 +516,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
@@ -558,6 +570,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
"indexing_technique": "high_quality", # Same as current
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
@@ -588,6 +601,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
@@ -604,6 +618,8 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
update_data = {"name": "new_name"}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
@@ -628,6 +644,8 @@ class TestDatasetServiceUpdateDataset:
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(Exception) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
|
||||
590
api/tests/unit_tests/services/test_variable_truncator.py
Normal file
590
api/tests/unit_tests/services/test_variable_truncator.py
Normal file
@@ -0,0 +1,590 @@
|
||||
"""
|
||||
Comprehensive unit tests for VariableTruncator class based on current implementation.
|
||||
|
||||
This test suite covers all functionality of the current VariableTruncator including:
|
||||
- JSON size calculation for different data types
|
||||
- String, array, and object truncation logic
|
||||
- Segment-based truncation interface
|
||||
- Helper methods for budget-based truncation
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
import functools
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from services.variable_truncator import (
|
||||
MaxDepthExceededError,
|
||||
TruncationResult,
|
||||
UnknownTypeError,
|
||||
VariableTruncator,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file() -> File:
|
||||
return File(
|
||||
id=str(uuid4()), # Generate new UUID for File.id
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id=str(uuid.uuid4()),
|
||||
filename="test_file.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=1024,
|
||||
storage_key="initial_key",
|
||||
)
|
||||
|
||||
|
||||
_compact_json_dumps = functools.partial(json.dumps, separators=(",", ":"))
|
||||
|
||||
|
||||
class TestCalculateJsonSize:
|
||||
"""Test calculate_json_size method with different data types."""
|
||||
|
||||
@pytest.fixture
|
||||
def truncator(self):
|
||||
return VariableTruncator()
|
||||
|
||||
def test_string_size_calculation(self):
|
||||
"""Test JSON size calculation for strings."""
|
||||
# Simple ASCII string
|
||||
assert VariableTruncator.calculate_json_size("hello") == 7 # "hello" + 2 quotes
|
||||
|
||||
# Empty string
|
||||
assert VariableTruncator.calculate_json_size("") == 2 # Just quotes
|
||||
|
||||
# Unicode string
|
||||
assert VariableTruncator.calculate_json_size("你好") == 4
|
||||
|
||||
def test_number_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for numbers."""
|
||||
assert truncator.calculate_json_size(123) == 3
|
||||
assert truncator.calculate_json_size(12.34) == 5
|
||||
assert truncator.calculate_json_size(-456) == 4
|
||||
assert truncator.calculate_json_size(0) == 1
|
||||
|
||||
def test_boolean_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for booleans."""
|
||||
assert truncator.calculate_json_size(True) == 4 # "true"
|
||||
assert truncator.calculate_json_size(False) == 5 # "false"
|
||||
|
||||
def test_null_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for None/null."""
|
||||
assert truncator.calculate_json_size(None) == 4 # "null"
|
||||
|
||||
def test_array_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for arrays."""
|
||||
# Empty array
|
||||
assert truncator.calculate_json_size([]) == 2 # "[]"
|
||||
|
||||
# Simple array
|
||||
simple_array = [1, 2, 3]
|
||||
# [1,2,3] = 1 + 1 + 1 + 1 + 1 + 2 = 7 (numbers + commas + brackets)
|
||||
assert truncator.calculate_json_size(simple_array) == 7
|
||||
|
||||
# Array with strings
|
||||
string_array = ["a", "b"]
|
||||
# ["a","b"] = 3 + 3 + 1 + 2 = 9 (quoted strings + comma + brackets)
|
||||
assert truncator.calculate_json_size(string_array) == 9
|
||||
|
||||
def test_object_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for objects."""
|
||||
# Empty object
|
||||
assert truncator.calculate_json_size({}) == 2 # "{}"
|
||||
|
||||
# Simple object
|
||||
simple_obj = {"a": 1}
|
||||
# {"a":1} = 3 + 1 + 1 + 2 = 7 (key + colon + value + brackets)
|
||||
assert truncator.calculate_json_size(simple_obj) == 7
|
||||
|
||||
# Multiple keys
|
||||
multi_obj = {"a": 1, "b": 2}
|
||||
# {"a":1,"b":2} = 3 + 1 + 1 + 1 + 3 + 1 + 1 + 2 = 13
|
||||
assert truncator.calculate_json_size(multi_obj) == 13
|
||||
|
||||
def test_nested_structure_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for nested structures."""
|
||||
nested = {"items": [1, 2, {"nested": "value"}]}
|
||||
size = truncator.calculate_json_size(nested)
|
||||
assert size > 0 # Should calculate without error
|
||||
|
||||
# Verify it matches actual JSON length roughly
|
||||
|
||||
actual_json = _compact_json_dumps(nested)
|
||||
# Should be close but not exact due to UTF-8 encoding considerations
|
||||
assert abs(size - len(actual_json.encode())) <= 5
|
||||
|
||||
def test_calculate_json_size_max_depth_exceeded(self, truncator):
|
||||
"""Test that calculate_json_size handles deep nesting gracefully."""
|
||||
# Create deeply nested structure
|
||||
nested: dict[str, Any] = {"level": 0}
|
||||
current = nested
|
||||
for i in range(105): # Create deep nesting
|
||||
current["next"] = {"level": i + 1}
|
||||
current = current["next"]
|
||||
|
||||
# Should either raise an error or handle gracefully
|
||||
with pytest.raises(MaxDepthExceededError):
|
||||
truncator.calculate_json_size(nested)
|
||||
|
||||
def test_calculate_json_size_unknown_type(self, truncator):
|
||||
"""Test that calculate_json_size raises error for unknown types."""
|
||||
|
||||
class CustomType:
|
||||
pass
|
||||
|
||||
with pytest.raises(UnknownTypeError):
|
||||
truncator.calculate_json_size(CustomType())
|
||||
|
||||
|
||||
class TestStringTruncation:
|
||||
LENGTH_LIMIT = 10
|
||||
"""Test string truncation functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(string_length_limit=10)
|
||||
|
||||
def test_short_string_no_truncation(self, small_truncator):
|
||||
"""Test that short strings are not truncated."""
|
||||
short_str = "hello"
|
||||
result = small_truncator._truncate_string(short_str, self.LENGTH_LIMIT)
|
||||
assert result.value == short_str
|
||||
assert result.truncated is False
|
||||
assert result.value_size == VariableTruncator.calculate_json_size(short_str)
|
||||
|
||||
def test_long_string_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test that long strings are truncated with ellipsis."""
|
||||
long_str = "this is a very long string that exceeds the limit"
|
||||
result = small_truncator._truncate_string(long_str, self.LENGTH_LIMIT)
|
||||
|
||||
assert result.truncated is True
|
||||
assert result.value == long_str[:5] + "..."
|
||||
assert result.value_size == 10 # 10 chars + "..."
|
||||
|
||||
def test_exact_limit_string(self, small_truncator: VariableTruncator):
|
||||
"""Test string exactly at limit."""
|
||||
exact_str = "1234567890" # Exactly 10 chars
|
||||
result = small_truncator._truncate_string(exact_str, self.LENGTH_LIMIT)
|
||||
assert result.value == "12345..."
|
||||
assert result.truncated is True
|
||||
assert result.value_size == 10
|
||||
|
||||
|
||||
class TestArrayTruncation:
|
||||
"""Test array truncation functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(array_element_limit=3, max_size_bytes=100)
|
||||
|
||||
def test_small_array_no_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test that small arrays are not truncated."""
|
||||
small_array = [1, 2]
|
||||
result = small_truncator._truncate_array(small_array, 1000)
|
||||
assert result.value == small_array
|
||||
assert result.truncated is False
|
||||
|
||||
def test_array_element_limit_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test that arrays over element limit are truncated."""
|
||||
large_array = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3
|
||||
result = small_truncator._truncate_array(large_array, 1000)
|
||||
|
||||
assert result.truncated is True
|
||||
assert result.value == [1, 2, 3]
|
||||
|
||||
def test_array_size_budget_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test array truncation due to size budget constraints."""
|
||||
# Create array with strings that will exceed size budget
|
||||
large_strings = ["very long string " * 5, "another long string " * 5]
|
||||
result = small_truncator._truncate_array(large_strings, 50)
|
||||
|
||||
assert result.truncated is True
|
||||
# Should have truncated the strings within the array
|
||||
for item in result.value:
|
||||
assert isinstance(item, str)
|
||||
assert VariableTruncator.calculate_json_size(result.value) <= 50
|
||||
|
||||
def test_array_with_nested_objects(self, small_truncator):
|
||||
"""Test array truncation with nested objects."""
|
||||
nested_array = [
|
||||
{"name": "item1", "data": "some data"},
|
||||
{"name": "item2", "data": "more data"},
|
||||
{"name": "item3", "data": "even more data"},
|
||||
]
|
||||
result = small_truncator._truncate_array(nested_array, 30)
|
||||
|
||||
assert isinstance(result.value, list)
|
||||
assert len(result.value) <= 3
|
||||
for item in result.value:
|
||||
assert isinstance(item, dict)
|
||||
|
||||
|
||||
class TestObjectTruncation:
|
||||
"""Test object truncation functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(max_size_bytes=100)
|
||||
|
||||
def test_small_object_no_truncation(self, small_truncator):
|
||||
"""Test that small objects are not truncated."""
|
||||
small_obj = {"a": 1, "b": 2}
|
||||
result = small_truncator._truncate_object(small_obj, 1000)
|
||||
assert result.value == small_obj
|
||||
assert result.truncated is False
|
||||
|
||||
def test_empty_object_no_truncation(self, small_truncator):
|
||||
"""Test that empty objects are not truncated."""
|
||||
empty_obj = {}
|
||||
result = small_truncator._truncate_object(empty_obj, 100)
|
||||
assert result.value == empty_obj
|
||||
assert result.truncated is False
|
||||
|
||||
def test_object_value_truncation(self, small_truncator):
|
||||
"""Test object truncation where values are truncated to fit budget."""
|
||||
obj_with_long_values = {
|
||||
"key1": "very long string " * 10,
|
||||
"key2": "another long string " * 10,
|
||||
"key3": "third long string " * 10,
|
||||
}
|
||||
result = small_truncator._truncate_object(obj_with_long_values, 80)
|
||||
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.value, dict)
|
||||
|
||||
assert set(result.value.keys()).issubset(obj_with_long_values.keys())
|
||||
|
||||
# Values should be truncated if they exist
|
||||
for key, value in result.value.items():
|
||||
if isinstance(value, str):
|
||||
original_value = obj_with_long_values[key]
|
||||
# Value should be same or smaller
|
||||
assert len(value) <= len(original_value)
|
||||
|
||||
def test_object_key_dropping(self, small_truncator):
|
||||
"""Test object truncation where keys are dropped due to size constraints."""
|
||||
large_obj = {f"key{i:02d}": f"value{i}" for i in range(20)}
|
||||
result = small_truncator._truncate_object(large_obj, 50)
|
||||
|
||||
assert result.truncated is True
|
||||
assert len(result.value) < len(large_obj)
|
||||
|
||||
# Should maintain sorted key order
|
||||
result_keys = list(result.value.keys())
|
||||
assert result_keys == sorted(result_keys)
|
||||
|
||||
def test_object_with_nested_structures(self, small_truncator):
|
||||
"""Test object truncation with nested arrays and objects."""
|
||||
nested_obj = {"simple": "value", "array": [1, 2, 3, 4, 5], "nested": {"inner": "data", "more": ["a", "b", "c"]}}
|
||||
result = small_truncator._truncate_object(nested_obj, 60)
|
||||
|
||||
assert isinstance(result.value, dict)
|
||||
|
||||
|
||||
class TestSegmentBasedTruncation:
|
||||
"""Test the main truncate method that works with Segments."""
|
||||
|
||||
@pytest.fixture
|
||||
def truncator(self):
|
||||
return VariableTruncator()
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(string_length_limit=20, array_element_limit=3, max_size_bytes=200)
|
||||
|
||||
def test_integer_segment_no_truncation(self, truncator):
|
||||
"""Test that integer segments are never truncated."""
|
||||
segment = IntegerSegment(value=12345)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_boolean_as_integer_segment(self, truncator):
|
||||
"""Test boolean values in IntegerSegment are converted to int."""
|
||||
segment = IntegerSegment(value=True)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert isinstance(result.result, IntegerSegment)
|
||||
assert result.result.value == 1 # True converted to 1
|
||||
|
||||
def test_float_segment_no_truncation(self, truncator):
|
||||
"""Test that float segments are never truncated."""
|
||||
segment = FloatSegment(value=123.456)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_none_segment_no_truncation(self, truncator):
|
||||
"""Test that None segments are never truncated."""
|
||||
segment = NoneSegment()
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_file_segment_no_truncation(self, truncator, file):
|
||||
"""Test that file segments are never truncated."""
|
||||
file_segment = FileSegment(value=file)
|
||||
result = truncator.truncate(file_segment)
|
||||
assert result.result == file_segment
|
||||
assert result.truncated is False
|
||||
|
||||
def test_array_file_segment_no_truncation(self, truncator, file):
|
||||
"""Test that array file segments are never truncated."""
|
||||
|
||||
array_file_segment = ArrayFileSegment(value=[file] * 20)
|
||||
result = truncator.truncate(array_file_segment)
|
||||
assert result.result == array_file_segment
|
||||
assert result.truncated is False
|
||||
|
||||
def test_string_segment_small_no_truncation(self, truncator):
|
||||
"""Test small string segments are not truncated."""
|
||||
segment = StringSegment(value="hello world")
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_string_segment_large_truncation(self, small_truncator):
|
||||
"""Test large string segments are truncated."""
|
||||
long_text = "this is a very long string that will definitely exceed the limit"
|
||||
segment = StringSegment(value=long_text)
|
||||
result = small_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
assert len(result.result.value) < len(long_text)
|
||||
assert result.result.value.endswith("...")
|
||||
|
||||
def test_array_segment_small_no_truncation(self, truncator):
|
||||
"""Test small array segments are not truncated."""
|
||||
from factories.variable_factory import build_segment
|
||||
|
||||
segment = build_segment([1, 2, 3])
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_array_segment_large_truncation(self, small_truncator):
|
||||
"""Test large array segments are truncated."""
|
||||
from factories.variable_factory import build_segment
|
||||
|
||||
large_array = list(range(10)) # Exceeds element limit of 3
|
||||
segment = build_segment(large_array)
|
||||
result = small_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, ArraySegment)
|
||||
assert len(result.result.value) <= 3
|
||||
|
||||
def test_object_segment_small_no_truncation(self, truncator):
|
||||
"""Test small object segments are not truncated."""
|
||||
segment = ObjectSegment(value={"key": "value"})
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_object_segment_large_truncation(self, small_truncator):
|
||||
"""Test large object segments are truncated."""
|
||||
large_obj = {f"key{i}": f"very long value {i}" * 5 for i in range(5)}
|
||||
segment = ObjectSegment(value=large_obj)
|
||||
result = small_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, ObjectSegment)
|
||||
# Object should be smaller or equal than original
|
||||
original_size = small_truncator.calculate_json_size(large_obj)
|
||||
result_size = small_truncator.calculate_json_size(result.result.value)
|
||||
assert result_size <= original_size
|
||||
|
||||
def test_final_size_fallback_to_json_string(self, small_truncator):
|
||||
"""Test final fallback when truncated result still exceeds size limit."""
|
||||
# Create data that will still be large after initial truncation
|
||||
large_nested_data = {"data": ["very long string " * 5] * 5, "more": {"nested": "content " * 20}}
|
||||
segment = ObjectSegment(value=large_nested_data)
|
||||
|
||||
# Use very small limit to force JSON string fallback
|
||||
tiny_truncator = VariableTruncator(max_size_bytes=50)
|
||||
result = tiny_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
# Should be JSON string with possible truncation
|
||||
assert len(result.result.value) <= 53 # 50 + "..." = 53
|
||||
|
||||
def test_final_size_fallback_string_truncation(self, small_truncator):
|
||||
"""Test final fallback for string that still exceeds limit."""
|
||||
# Create very long string that exceeds string length limit
|
||||
very_long_string = "x" * 6000 # Exceeds default string_length_limit of 5000
|
||||
segment = StringSegment(value=very_long_string)
|
||||
|
||||
# Use small limit to test string fallback path
|
||||
tiny_truncator = VariableTruncator(string_length_limit=100, max_size_bytes=50)
|
||||
result = tiny_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
# Should be truncated due to string limit or final size limit
|
||||
assert len(result.result.value) <= 1000 # Much smaller than original
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error conditions."""
|
||||
|
||||
def test_empty_inputs(self):
|
||||
"""Test truncator with empty inputs."""
|
||||
truncator = VariableTruncator()
|
||||
|
||||
# Empty string
|
||||
result = truncator.truncate(StringSegment(value=""))
|
||||
assert not result.truncated
|
||||
assert result.result.value == ""
|
||||
|
||||
# Empty array
|
||||
from factories.variable_factory import build_segment
|
||||
|
||||
result = truncator.truncate(build_segment([]))
|
||||
assert not result.truncated
|
||||
assert result.result.value == []
|
||||
|
||||
# Empty object
|
||||
result = truncator.truncate(ObjectSegment(value={}))
|
||||
assert not result.truncated
|
||||
assert result.result.value == {}
|
||||
|
||||
def test_zero_and_negative_limits(self):
|
||||
"""Test truncator behavior with zero or very small limits."""
|
||||
# Zero string limit
|
||||
with pytest.raises(ValueError):
|
||||
truncator = VariableTruncator(string_length_limit=3)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
truncator = VariableTruncator(array_element_limit=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
truncator = VariableTruncator(max_size_bytes=0)
|
||||
|
||||
def test_unicode_and_special_characters(self):
|
||||
"""Test truncator with unicode and special characters."""
|
||||
truncator = VariableTruncator(string_length_limit=10)
|
||||
|
||||
# Unicode characters
|
||||
unicode_text = "🌍🚀🌍🚀🌍🚀🌍🚀🌍🚀" # Each emoji counts as 1 character
|
||||
result = truncator.truncate(StringSegment(value=unicode_text))
|
||||
if len(unicode_text) > 10:
|
||||
assert result.truncated is True
|
||||
|
||||
# Special JSON characters
|
||||
special_chars = '{"key": "value with \\"quotes\\" and \\n newlines"}'
|
||||
result = truncator.truncate(StringSegment(value=special_chars))
|
||||
assert isinstance(result.result, StringSegment)
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test realistic integration scenarios."""
|
||||
|
||||
def test_workflow_output_scenario(self):
|
||||
"""Test truncation of typical workflow output data."""
|
||||
truncator = VariableTruncator()
|
||||
|
||||
workflow_data = {
|
||||
"result": "success",
|
||||
"data": {
|
||||
"users": [
|
||||
{"id": 1, "name": "Alice", "email": "alice@example.com"},
|
||||
{"id": 2, "name": "Bob", "email": "bob@example.com"},
|
||||
]
|
||||
* 3, # Multiply to make it larger
|
||||
"metadata": {
|
||||
"count": 6,
|
||||
"processing_time": "1.23s",
|
||||
"details": "x" * 200, # Long string but not too long
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
segment = ObjectSegment(value=workflow_data)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert isinstance(result.result, (ObjectSegment, StringSegment))
|
||||
# Should handle complex nested structure appropriately
|
||||
|
||||
def test_large_text_processing_scenario(self):
|
||||
"""Test truncation of large text data."""
|
||||
truncator = VariableTruncator(string_length_limit=100)
|
||||
|
||||
large_text = "This is a very long text document. " * 20 # Make it larger than limit
|
||||
|
||||
segment = StringSegment(value=large_text)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
assert len(result.result.value) <= 103 # 100 + "..."
|
||||
assert result.result.value.endswith("...")
|
||||
|
||||
def test_mixed_data_types_scenario(self):
|
||||
"""Test truncation with mixed data types in complex structure."""
|
||||
truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300)
|
||||
|
||||
mixed_data = {
|
||||
"strings": ["short", "medium length", "very long string " * 3],
|
||||
"numbers": [1, 2.5, 999999],
|
||||
"booleans": [True, False, True],
|
||||
"nested": {
|
||||
"more_strings": ["nested string " * 2],
|
||||
"more_numbers": list(range(5)),
|
||||
"deep": {"level": 3, "content": "deep content " * 3},
|
||||
},
|
||||
"nulls": [None, None],
|
||||
}
|
||||
|
||||
segment = ObjectSegment(value=mixed_data)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
# Should handle all data types appropriately
|
||||
if result.truncated:
|
||||
# Verify the result is smaller or equal than original
|
||||
original_size = truncator.calculate_json_size(mixed_data)
|
||||
if isinstance(result.result, ObjectSegment):
|
||||
result_size = truncator.calculate_json_size(result.result.value)
|
||||
assert result_size <= original_size
|
||||
@@ -0,0 +1,377 @@
|
||||
"""Simplified unit tests for DraftVarLoader focusing on core functionality."""
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.variables.segments import ObjectSegment, StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from services.workflow_draft_variable_service import DraftVarLoader
|
||||
|
||||
|
||||
class TestDraftVarLoaderSimple:
|
||||
"""Simplified unit tests for DraftVarLoader core methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine(self) -> Engine:
|
||||
return Mock(spec=Engine)
|
||||
|
||||
@pytest.fixture
|
||||
def draft_var_loader(self, mock_engine):
|
||||
"""Create DraftVarLoader instance for testing."""
|
||||
return DraftVarLoader(
|
||||
engine=mock_engine, app_id="test-app-id", tenant_id="test-tenant-id", fallback_variables=[]
|
||||
)
|
||||
|
||||
def test_load_offloaded_variable_string_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with string type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test.txt"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.STRING
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_variable"
|
||||
draft_var.description = "test description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_variable"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_content = "This is the full string content"
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_content.encode()
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_variable"
|
||||
mock_variable.value = StringSegment(value=test_content)
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_variable")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_variable"
|
||||
assert variable.description == "test description"
|
||||
assert variable.value == test_content
|
||||
|
||||
# Verify storage was called correctly
|
||||
mock_storage.load.assert_called_once_with("storage/key/test.txt")
|
||||
|
||||
def test_load_offloaded_variable_object_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with object type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.OBJECT
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_object"
|
||||
draft_var.description = "test description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_object"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_object = {"key1": "value1", "key2": 42}
|
||||
test_json_content = json.dumps(test_object, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
mock_segment = ObjectSegment(value=test_object)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_object"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_object")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_object"
|
||||
assert variable.description == "test description"
|
||||
assert variable.value == test_object
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.OBJECT, test_object)
|
||||
|
||||
def test_load_offloaded_variable_missing_variable_file_unit(self, draft_var_loader):
|
||||
"""Test that assertion error is raised when variable_file is None."""
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.variable_file = None
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
def test_load_offloaded_variable_missing_upload_file_unit(self, draft_var_loader):
|
||||
"""Test that assertion error is raised when upload_file is None."""
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.upload_file = None
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
def test_load_variables_empty_selectors_unit(self, draft_var_loader):
|
||||
"""Test load_variables returns empty list for empty selectors."""
|
||||
result = draft_var_loader.load_variables([])
|
||||
assert result == []
|
||||
|
||||
def test_selector_to_tuple_unit(self, draft_var_loader):
|
||||
"""Test _selector_to_tuple method."""
|
||||
selector = ["node_id", "var_name", "extra_field"]
|
||||
result = draft_var_loader._selector_to_tuple(selector)
|
||||
assert result == ("node_id", "var_name")
|
||||
|
||||
def test_load_offloaded_variable_number_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with number type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test_number.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.NUMBER
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_number"
|
||||
draft_var.description = "test number description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_number"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_number = 123.45
|
||||
test_json_content = json.dumps(test_number)
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
from core.variables.segments import FloatSegment
|
||||
|
||||
mock_segment = FloatSegment(value=test_number)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_number"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_number")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_number"
|
||||
assert variable.description == "test number description"
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test_number.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.NUMBER, test_number)
|
||||
|
||||
def test_load_offloaded_variable_array_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with array type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test_array.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.ARRAY_ANY
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_array"
|
||||
draft_var.description = "test array description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_array"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_array = ["item1", "item2", "item3"]
|
||||
test_json_content = json.dumps(test_array)
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
|
||||
mock_segment = ArrayAnySegment(value=test_array)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_array"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_array")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_array"
|
||||
assert variable.description == "test array description"
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test_array.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.ARRAY_ANY, test_array)
|
||||
|
||||
def test_load_variables_with_offloaded_variables_unit(self, draft_var_loader):
|
||||
"""Test load_variables method with mix of regular and offloaded variables."""
|
||||
selectors = [["node1", "regular_var"], ["node2", "offloaded_var"]]
|
||||
|
||||
# Mock regular variable
|
||||
regular_draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
regular_draft_var.is_truncated.return_value = False
|
||||
regular_draft_var.node_id = "node1"
|
||||
regular_draft_var.name = "regular_var"
|
||||
regular_draft_var.get_value.return_value = StringSegment(value="regular_value")
|
||||
regular_draft_var.get_selector.return_value = ["node1", "regular_var"]
|
||||
regular_draft_var.id = "regular-var-id"
|
||||
regular_draft_var.description = "regular description"
|
||||
|
||||
# Mock offloaded variable
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/offloaded.txt"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.STRING
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
offloaded_draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_draft_var.is_truncated.return_value = True
|
||||
offloaded_draft_var.node_id = "node2"
|
||||
offloaded_draft_var.name = "offloaded_var"
|
||||
offloaded_draft_var.get_selector.return_value = ["node2", "offloaded_var"]
|
||||
offloaded_draft_var.variable_file = variable_file
|
||||
offloaded_draft_var.id = "offloaded-var-id"
|
||||
offloaded_draft_var.description = "offloaded description"
|
||||
|
||||
draft_vars = [regular_draft_var, offloaded_draft_var]
|
||||
|
||||
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
|
||||
mock_session = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
|
||||
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service
|
||||
):
|
||||
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
# Mock regular variable creation
|
||||
regular_variable = Mock()
|
||||
regular_variable.selector = ["node1", "regular_var"]
|
||||
|
||||
# Mock offloaded variable creation
|
||||
offloaded_variable = Mock()
|
||||
offloaded_variable.selector = ["node2", "offloaded_var"]
|
||||
|
||||
mock_segment_to_variable.return_value = regular_variable
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = b"offloaded_content"
|
||||
|
||||
with patch.object(draft_var_loader, "_load_offloaded_variable") as mock_load_offloaded:
|
||||
mock_load_offloaded.return_value = (("node2", "offloaded_var"), offloaded_variable)
|
||||
|
||||
with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor_cls:
|
||||
mock_executor = Mock()
|
||||
mock_executor_cls.return_value.__enter__.return_value = mock_executor
|
||||
mock_executor.map.return_value = [(("node2", "offloaded_var"), offloaded_variable)]
|
||||
|
||||
# Execute the method
|
||||
result = draft_var_loader.load_variables(selectors)
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify service method was called
|
||||
mock_service.get_draft_variables_by_selectors.assert_called_once_with(
|
||||
draft_var_loader._app_id, selectors
|
||||
)
|
||||
|
||||
# Verify offloaded variable loading was called
|
||||
mock_load_offloaded.assert_called_once_with(offloaded_draft_var)
|
||||
|
||||
def test_load_variables_all_offloaded_variables_unit(self, draft_var_loader):
|
||||
"""Test load_variables method with only offloaded variables."""
|
||||
selectors = [["node1", "offloaded_var1"], ["node2", "offloaded_var2"]]
|
||||
|
||||
# Mock first offloaded variable
|
||||
offloaded_var1 = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_var1.is_truncated.return_value = True
|
||||
offloaded_var1.node_id = "node1"
|
||||
offloaded_var1.name = "offloaded_var1"
|
||||
|
||||
# Mock second offloaded variable
|
||||
offloaded_var2 = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_var2.is_truncated.return_value = True
|
||||
offloaded_var2.node_id = "node2"
|
||||
offloaded_var2.name = "offloaded_var2"
|
||||
|
||||
draft_vars = [offloaded_var1, offloaded_var2]
|
||||
|
||||
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
|
||||
mock_session = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
|
||||
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service
|
||||
):
|
||||
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
|
||||
with patch("services.workflow_draft_variable_service.ThreadPoolExecutor") as mock_executor_cls:
|
||||
mock_executor = Mock()
|
||||
mock_executor_cls.return_value.__enter__.return_value = mock_executor
|
||||
mock_executor.map.return_value = [
|
||||
(("node1", "offloaded_var1"), Mock()),
|
||||
(("node2", "offloaded_var2"), Mock()),
|
||||
]
|
||||
|
||||
# Execute the method
|
||||
result = draft_var_loader.load_variables(selectors)
|
||||
|
||||
# Verify results - since we have only offloaded variables, should have 2 results
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify ThreadPoolExecutor was used
|
||||
mock_executor_cls.assert_called_once_with(max_workers=10)
|
||||
mock_executor.map.assert_called_once()
|
||||
@@ -1,16 +1,26 @@
|
||||
import dataclasses
|
||||
import secrets
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables import StringSegment
|
||||
from core.variables.segments import StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.account import Account
|
||||
from models.enums import DraftVariableType
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowDraftVariable,
|
||||
WorkflowDraftVariableFile,
|
||||
WorkflowNodeExecutionModel,
|
||||
is_system_variable_editable,
|
||||
)
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
VariableResetError,
|
||||
@@ -37,6 +47,7 @@ class TestDraftVariableSaver:
|
||||
|
||||
def test__should_variable_be_visible(self):
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_user = Account(id=str(uuid.uuid4()))
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
@@ -44,6 +55,7 @@ class TestDraftVariableSaver:
|
||||
node_id="test_node_id",
|
||||
node_type=NodeType.START,
|
||||
node_execution_id="test_execution_id",
|
||||
user=mock_user,
|
||||
)
|
||||
assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False
|
||||
assert saver._should_variable_be_visible("123", NodeType.START, "output") == True
|
||||
@@ -83,6 +95,7 @@ class TestDraftVariableSaver:
|
||||
]
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_user = MagicMock()
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
@@ -90,6 +103,7 @@ class TestDraftVariableSaver:
|
||||
node_id=_NODE_ID,
|
||||
node_type=NodeType.START,
|
||||
node_execution_id="test_execution_id",
|
||||
user=mock_user,
|
||||
)
|
||||
for idx, c in enumerate(cases, 1):
|
||||
fail_msg = f"Test case {c.name} failed, index={idx}"
|
||||
@@ -97,6 +111,76 @@ class TestDraftVariableSaver:
|
||||
assert node_id == c.expected_node_id, fail_msg
|
||||
assert name == c.expected_name, fail_msg
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Mock SQLAlchemy session."""
|
||||
from sqlalchemy import Engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_engine = MagicMock(spec=Engine)
|
||||
mock_session.get_bind.return_value = mock_engine
|
||||
return mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def draft_saver(self, mock_session):
|
||||
"""Create DraftVariableSaver instance with user context."""
|
||||
# Create a mock user
|
||||
mock_user = MagicMock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.tenant_id = "test-tenant-id"
|
||||
|
||||
return DraftVariableSaver(
|
||||
session=mock_session,
|
||||
app_id="test-app-id",
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
node_execution_id="test-execution-id",
|
||||
user=mock_user,
|
||||
)
|
||||
|
||||
def test_draft_saver_with_small_variables(self, draft_saver, mock_session):
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
|
||||
) as _mock_try_offload:
|
||||
_mock_try_offload.return_value = None
|
||||
mock_segment = StringSegment(value="small value")
|
||||
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
|
||||
|
||||
# Should not have large variable metadata
|
||||
assert draft_var.file_id is None
|
||||
_mock_try_offload.return_value = None
|
||||
|
||||
def test_draft_saver_with_large_variables(self, draft_saver, mock_session):
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
|
||||
) as _mock_try_offload:
|
||||
mock_segment = StringSegment(value="small value")
|
||||
mock_draft_var_file = WorkflowDraftVariableFile(
|
||||
id=str(uuidv7()),
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.ARRAY_STRING,
|
||||
upload_file_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
_mock_try_offload.return_value = mock_segment, mock_draft_var_file
|
||||
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
|
||||
|
||||
# Should not have large variable metadata
|
||||
assert draft_var.file_id == mock_draft_var_file.id
|
||||
|
||||
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable")
|
||||
def test_save_method_integration(self, mock_batch_upsert, draft_saver):
|
||||
"""Test complete save workflow."""
|
||||
outputs = {"result": {"data": "test_output"}, "metadata": {"type": "llm_response"}}
|
||||
|
||||
draft_saver.save(outputs=outputs)
|
||||
|
||||
# Should batch upsert draft variables
|
||||
mock_batch_upsert.assert_called_once()
|
||||
draft_vars = mock_batch_upsert.call_args[0][1]
|
||||
assert len(draft_vars) == 2
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableService:
|
||||
def _get_test_app_id(self):
|
||||
@@ -115,6 +199,7 @@ class TestWorkflowDraftVariableService:
|
||||
created_by="test_user_id",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
|
||||
def test_reset_conversation_variable(self, mock_session):
|
||||
@@ -225,7 +310,7 @@ class TestWorkflowDraftVariableService:
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"test_var": "output_value"}
|
||||
mock_execution.load_full_outputs.return_value = {"test_var": "output_value"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
@@ -298,7 +383,7 @@ class TestWorkflowDraftVariableService:
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"sys.files": "[]"}
|
||||
mock_execution.load_full_outputs.return_value = {"sys.files": "[]"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
@@ -330,7 +415,7 @@ class TestWorkflowDraftVariableService:
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"sys.query": "reset query"}
|
||||
mock_execution.load_full_outputs.return_value = {"sys.query": "reset query"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
|
||||
Reference in New Issue
Block a user