feat: knowledge pipeline (#25360)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: jyong <718720800@qq.com>
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: quicksand <quicksandzn@gmail.com>
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: Hanqing Zhao <sherry9277@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry <xh001x@hotmail.com>
This commit is contained in:
-LAN-
2025-09-18 12:49:10 +08:00
committed by GitHub
parent 7dadb33003
commit 85cda47c70
1772 changed files with 102407 additions and 31710 deletions

View File

@@ -104,6 +104,7 @@ class TestDatasetServiceUpdateDataset:
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
patch("extensions.ext_database.db.session") as mock_db,
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
patch("services.dataset_service.DatasetService._has_dataset_same_name") as has_dataset_same_name,
):
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_naive_utc_now.return_value = current_time
@@ -114,6 +115,7 @@ class TestDatasetServiceUpdateDataset:
"db_session": mock_db,
"naive_utc_now": mock_naive_utc_now,
"current_time": current_time,
"has_dataset_same_name": has_dataset_same_name,
}
@pytest.fixture
@@ -190,9 +192,9 @@ class TestDatasetServiceUpdateDataset:
"external_knowledge_api_id": "new_api_id",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify permission check was called
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
# Verify dataset and binding updates
@@ -214,6 +216,7 @@ class TestDatasetServiceUpdateDataset:
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
@@ -227,6 +230,7 @@ class TestDatasetServiceUpdateDataset:
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
@@ -250,6 +254,7 @@ class TestDatasetServiceUpdateDataset:
"external_knowledge_id": "knowledge_id",
"external_knowledge_api_id": "api_id",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
@@ -280,6 +285,7 @@ class TestDatasetServiceUpdateDataset:
"embedding_model": "text-embedding-ada-002",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify permission check was called
@@ -320,6 +326,8 @@ class TestDatasetServiceUpdateDataset:
"embedding_model": None, # Should be filtered out
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify database update was called with filtered data
@@ -356,6 +364,7 @@ class TestDatasetServiceUpdateDataset:
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
@@ -402,6 +411,7 @@ class TestDatasetServiceUpdateDataset:
"embedding_model": "text-embedding-ada-002",
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
@@ -453,6 +463,7 @@ class TestDatasetServiceUpdateDataset:
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
@@ -505,6 +516,7 @@ class TestDatasetServiceUpdateDataset:
"embedding_model": "text-embedding-3-small",
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
@@ -558,6 +570,7 @@ class TestDatasetServiceUpdateDataset:
"indexing_technique": "high_quality", # Same as current
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
@@ -588,6 +601,7 @@ class TestDatasetServiceUpdateDataset:
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
@@ -604,6 +618,8 @@ class TestDatasetServiceUpdateDataset:
update_data = {"name": "new_name"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(NoPermissionError):
DatasetService.update_dataset("dataset-123", update_data, user)
@@ -628,6 +644,8 @@ class TestDatasetServiceUpdateDataset:
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(Exception) as context:
DatasetService.update_dataset("dataset-123", update_data, user)

View File

@@ -0,0 +1,590 @@
"""
Comprehensive unit tests for VariableTruncator class based on current implementation.
This test suite covers all functionality of the current VariableTruncator including:
- JSON size calculation for different data types
- String, array, and object truncation logic
- Segment-based truncation interface
- Helper methods for budget-based truncation
- Edge cases and error handling
"""
import functools
import json
import uuid
from typing import Any
from uuid import uuid4
import pytest
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
from core.variables.segments import (
ArrayFileSegment,
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
StringSegment,
)
from services.variable_truncator import (
MaxDepthExceededError,
TruncationResult,
UnknownTypeError,
VariableTruncator,
)
@pytest.fixture
def file() -> File:
return File(
id=str(uuid4()), # Generate new UUID for File.id
tenant_id=str(uuid.uuid4()),
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id=str(uuid.uuid4()),
filename="test_file.txt",
extension=".txt",
mime_type="text/plain",
size=1024,
storage_key="initial_key",
)
_compact_json_dumps = functools.partial(json.dumps, separators=(",", ":"))
class TestCalculateJsonSize:
"""Test calculate_json_size method with different data types."""
@pytest.fixture
def truncator(self):
return VariableTruncator()
def test_string_size_calculation(self):
"""Test JSON size calculation for strings."""
# Simple ASCII string
assert VariableTruncator.calculate_json_size("hello") == 7 # "hello" + 2 quotes
# Empty string
assert VariableTruncator.calculate_json_size("") == 2 # Just quotes
# Unicode string
assert VariableTruncator.calculate_json_size("你好") == 4
def test_number_size_calculation(self, truncator):
"""Test JSON size calculation for numbers."""
assert truncator.calculate_json_size(123) == 3
assert truncator.calculate_json_size(12.34) == 5
assert truncator.calculate_json_size(-456) == 4
assert truncator.calculate_json_size(0) == 1
def test_boolean_size_calculation(self, truncator):
"""Test JSON size calculation for booleans."""
assert truncator.calculate_json_size(True) == 4 # "true"
assert truncator.calculate_json_size(False) == 5 # "false"
def test_null_size_calculation(self, truncator):
"""Test JSON size calculation for None/null."""
assert truncator.calculate_json_size(None) == 4 # "null"
def test_array_size_calculation(self, truncator):
"""Test JSON size calculation for arrays."""
# Empty array
assert truncator.calculate_json_size([]) == 2 # "[]"
# Simple array
simple_array = [1, 2, 3]
# [1,2,3] = 1 + 1 + 1 + 1 + 1 + 2 = 7 (numbers + commas + brackets)
assert truncator.calculate_json_size(simple_array) == 7
# Array with strings
string_array = ["a", "b"]
# ["a","b"] = 3 + 3 + 1 + 2 = 9 (quoted strings + comma + brackets)
assert truncator.calculate_json_size(string_array) == 9
def test_object_size_calculation(self, truncator):
"""Test JSON size calculation for objects."""
# Empty object
assert truncator.calculate_json_size({}) == 2 # "{}"
# Simple object
simple_obj = {"a": 1}
# {"a":1} = 3 + 1 + 1 + 2 = 7 (key + colon + value + brackets)
assert truncator.calculate_json_size(simple_obj) == 7
# Multiple keys
multi_obj = {"a": 1, "b": 2}
# {"a":1,"b":2} = 3 + 1 + 1 + 1 + 3 + 1 + 1 + 2 = 13
assert truncator.calculate_json_size(multi_obj) == 13
def test_nested_structure_size_calculation(self, truncator):
"""Test JSON size calculation for nested structures."""
nested = {"items": [1, 2, {"nested": "value"}]}
size = truncator.calculate_json_size(nested)
assert size > 0 # Should calculate without error
# Verify it matches actual JSON length roughly
actual_json = _compact_json_dumps(nested)
# Should be close but not exact due to UTF-8 encoding considerations
assert abs(size - len(actual_json.encode())) <= 5
def test_calculate_json_size_max_depth_exceeded(self, truncator):
"""Test that calculate_json_size handles deep nesting gracefully."""
# Create deeply nested structure
nested: dict[str, Any] = {"level": 0}
current = nested
for i in range(105): # Create deep nesting
current["next"] = {"level": i + 1}
current = current["next"]
# Should either raise an error or handle gracefully
with pytest.raises(MaxDepthExceededError):
truncator.calculate_json_size(nested)
def test_calculate_json_size_unknown_type(self, truncator):
"""Test that calculate_json_size raises error for unknown types."""
class CustomType:
pass
with pytest.raises(UnknownTypeError):
truncator.calculate_json_size(CustomType())
class TestStringTruncation:
LENGTH_LIMIT = 10
"""Test string truncation functionality."""
@pytest.fixture
def small_truncator(self):
return VariableTruncator(string_length_limit=10)
def test_short_string_no_truncation(self, small_truncator):
"""Test that short strings are not truncated."""
short_str = "hello"
result = small_truncator._truncate_string(short_str, self.LENGTH_LIMIT)
assert result.value == short_str
assert result.truncated is False
assert result.value_size == VariableTruncator.calculate_json_size(short_str)
def test_long_string_truncation(self, small_truncator: VariableTruncator):
"""Test that long strings are truncated with ellipsis."""
long_str = "this is a very long string that exceeds the limit"
result = small_truncator._truncate_string(long_str, self.LENGTH_LIMIT)
assert result.truncated is True
assert result.value == long_str[:5] + "..."
assert result.value_size == 10 # 10 chars + "..."
def test_exact_limit_string(self, small_truncator: VariableTruncator):
"""Test string exactly at limit."""
exact_str = "1234567890" # Exactly 10 chars
result = small_truncator._truncate_string(exact_str, self.LENGTH_LIMIT)
assert result.value == "12345..."
assert result.truncated is True
assert result.value_size == 10
class TestArrayTruncation:
"""Test array truncation functionality."""
@pytest.fixture
def small_truncator(self):
return VariableTruncator(array_element_limit=3, max_size_bytes=100)
def test_small_array_no_truncation(self, small_truncator: VariableTruncator):
"""Test that small arrays are not truncated."""
small_array = [1, 2]
result = small_truncator._truncate_array(small_array, 1000)
assert result.value == small_array
assert result.truncated is False
def test_array_element_limit_truncation(self, small_truncator: VariableTruncator):
"""Test that arrays over element limit are truncated."""
large_array = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3
result = small_truncator._truncate_array(large_array, 1000)
assert result.truncated is True
assert result.value == [1, 2, 3]
def test_array_size_budget_truncation(self, small_truncator: VariableTruncator):
"""Test array truncation due to size budget constraints."""
# Create array with strings that will exceed size budget
large_strings = ["very long string " * 5, "another long string " * 5]
result = small_truncator._truncate_array(large_strings, 50)
assert result.truncated is True
# Should have truncated the strings within the array
for item in result.value:
assert isinstance(item, str)
assert VariableTruncator.calculate_json_size(result.value) <= 50
def test_array_with_nested_objects(self, small_truncator):
"""Test array truncation with nested objects."""
nested_array = [
{"name": "item1", "data": "some data"},
{"name": "item2", "data": "more data"},
{"name": "item3", "data": "even more data"},
]
result = small_truncator._truncate_array(nested_array, 30)
assert isinstance(result.value, list)
assert len(result.value) <= 3
for item in result.value:
assert isinstance(item, dict)
class TestObjectTruncation:
"""Test object truncation functionality."""
@pytest.fixture
def small_truncator(self):
return VariableTruncator(max_size_bytes=100)
def test_small_object_no_truncation(self, small_truncator):
"""Test that small objects are not truncated."""
small_obj = {"a": 1, "b": 2}
result = small_truncator._truncate_object(small_obj, 1000)
assert result.value == small_obj
assert result.truncated is False
def test_empty_object_no_truncation(self, small_truncator):
"""Test that empty objects are not truncated."""
empty_obj = {}
result = small_truncator._truncate_object(empty_obj, 100)
assert result.value == empty_obj
assert result.truncated is False
def test_object_value_truncation(self, small_truncator):
"""Test object truncation where values are truncated to fit budget."""
obj_with_long_values = {
"key1": "very long string " * 10,
"key2": "another long string " * 10,
"key3": "third long string " * 10,
}
result = small_truncator._truncate_object(obj_with_long_values, 80)
assert result.truncated is True
assert isinstance(result.value, dict)
assert set(result.value.keys()).issubset(obj_with_long_values.keys())
# Values should be truncated if they exist
for key, value in result.value.items():
if isinstance(value, str):
original_value = obj_with_long_values[key]
# Value should be same or smaller
assert len(value) <= len(original_value)
def test_object_key_dropping(self, small_truncator):
"""Test object truncation where keys are dropped due to size constraints."""
large_obj = {f"key{i:02d}": f"value{i}" for i in range(20)}
result = small_truncator._truncate_object(large_obj, 50)
assert result.truncated is True
assert len(result.value) < len(large_obj)
# Should maintain sorted key order
result_keys = list(result.value.keys())
assert result_keys == sorted(result_keys)
def test_object_with_nested_structures(self, small_truncator):
"""Test object truncation with nested arrays and objects."""
nested_obj = {"simple": "value", "array": [1, 2, 3, 4, 5], "nested": {"inner": "data", "more": ["a", "b", "c"]}}
result = small_truncator._truncate_object(nested_obj, 60)
assert isinstance(result.value, dict)
class TestSegmentBasedTruncation:
"""Test the main truncate method that works with Segments."""
@pytest.fixture
def truncator(self):
return VariableTruncator()
@pytest.fixture
def small_truncator(self):
return VariableTruncator(string_length_limit=20, array_element_limit=3, max_size_bytes=200)
def test_integer_segment_no_truncation(self, truncator):
"""Test that integer segments are never truncated."""
segment = IntegerSegment(value=12345)
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is False
assert result.result == segment
def test_boolean_as_integer_segment(self, truncator):
"""Test boolean values in IntegerSegment are converted to int."""
segment = IntegerSegment(value=True)
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is False
assert isinstance(result.result, IntegerSegment)
assert result.result.value == 1 # True converted to 1
def test_float_segment_no_truncation(self, truncator):
"""Test that float segments are never truncated."""
segment = FloatSegment(value=123.456)
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is False
assert result.result == segment
def test_none_segment_no_truncation(self, truncator):
"""Test that None segments are never truncated."""
segment = NoneSegment()
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is False
assert result.result == segment
def test_file_segment_no_truncation(self, truncator, file):
"""Test that file segments are never truncated."""
file_segment = FileSegment(value=file)
result = truncator.truncate(file_segment)
assert result.result == file_segment
assert result.truncated is False
def test_array_file_segment_no_truncation(self, truncator, file):
"""Test that array file segments are never truncated."""
array_file_segment = ArrayFileSegment(value=[file] * 20)
result = truncator.truncate(array_file_segment)
assert result.result == array_file_segment
assert result.truncated is False
def test_string_segment_small_no_truncation(self, truncator):
"""Test small string segments are not truncated."""
segment = StringSegment(value="hello world")
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is False
assert result.result == segment
def test_string_segment_large_truncation(self, small_truncator):
"""Test large string segments are truncated."""
long_text = "this is a very long string that will definitely exceed the limit"
segment = StringSegment(value=long_text)
result = small_truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is True
assert isinstance(result.result, StringSegment)
assert len(result.result.value) < len(long_text)
assert result.result.value.endswith("...")
def test_array_segment_small_no_truncation(self, truncator):
"""Test small array segments are not truncated."""
from factories.variable_factory import build_segment
segment = build_segment([1, 2, 3])
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is False
assert result.result == segment
def test_array_segment_large_truncation(self, small_truncator):
"""Test large array segments are truncated."""
from factories.variable_factory import build_segment
large_array = list(range(10)) # Exceeds element limit of 3
segment = build_segment(large_array)
result = small_truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is True
assert isinstance(result.result, ArraySegment)
assert len(result.result.value) <= 3
def test_object_segment_small_no_truncation(self, truncator):
"""Test small object segments are not truncated."""
segment = ObjectSegment(value={"key": "value"})
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is False
assert result.result == segment
def test_object_segment_large_truncation(self, small_truncator):
"""Test large object segments are truncated."""
large_obj = {f"key{i}": f"very long value {i}" * 5 for i in range(5)}
segment = ObjectSegment(value=large_obj)
result = small_truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is True
assert isinstance(result.result, ObjectSegment)
# Object should be smaller or equal than original
original_size = small_truncator.calculate_json_size(large_obj)
result_size = small_truncator.calculate_json_size(result.result.value)
assert result_size <= original_size
def test_final_size_fallback_to_json_string(self, small_truncator):
"""Test final fallback when truncated result still exceeds size limit."""
# Create data that will still be large after initial truncation
large_nested_data = {"data": ["very long string " * 5] * 5, "more": {"nested": "content " * 20}}
segment = ObjectSegment(value=large_nested_data)
# Use very small limit to force JSON string fallback
tiny_truncator = VariableTruncator(max_size_bytes=50)
result = tiny_truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is True
assert isinstance(result.result, StringSegment)
# Should be JSON string with possible truncation
assert len(result.result.value) <= 53 # 50 + "..." = 53
def test_final_size_fallback_string_truncation(self, small_truncator):
"""Test final fallback for string that still exceeds limit."""
# Create very long string that exceeds string length limit
very_long_string = "x" * 6000 # Exceeds default string_length_limit of 5000
segment = StringSegment(value=very_long_string)
# Use small limit to test string fallback path
tiny_truncator = VariableTruncator(string_length_limit=100, max_size_bytes=50)
result = tiny_truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is True
assert isinstance(result.result, StringSegment)
# Should be truncated due to string limit or final size limit
assert len(result.result.value) <= 1000 # Much smaller than original
class TestEdgeCases:
"""Test edge cases and error conditions."""
def test_empty_inputs(self):
"""Test truncator with empty inputs."""
truncator = VariableTruncator()
# Empty string
result = truncator.truncate(StringSegment(value=""))
assert not result.truncated
assert result.result.value == ""
# Empty array
from factories.variable_factory import build_segment
result = truncator.truncate(build_segment([]))
assert not result.truncated
assert result.result.value == []
# Empty object
result = truncator.truncate(ObjectSegment(value={}))
assert not result.truncated
assert result.result.value == {}
def test_zero_and_negative_limits(self):
"""Test truncator behavior with zero or very small limits."""
# Zero string limit
with pytest.raises(ValueError):
truncator = VariableTruncator(string_length_limit=3)
with pytest.raises(ValueError):
truncator = VariableTruncator(array_element_limit=0)
with pytest.raises(ValueError):
truncator = VariableTruncator(max_size_bytes=0)
def test_unicode_and_special_characters(self):
"""Test truncator with unicode and special characters."""
truncator = VariableTruncator(string_length_limit=10)
# Unicode characters
unicode_text = "🌍🚀🌍🚀🌍🚀🌍🚀🌍🚀" # Each emoji counts as 1 character
result = truncator.truncate(StringSegment(value=unicode_text))
if len(unicode_text) > 10:
assert result.truncated is True
# Special JSON characters
special_chars = '{"key": "value with \\"quotes\\" and \\n newlines"}'
result = truncator.truncate(StringSegment(value=special_chars))
assert isinstance(result.result, StringSegment)
class TestIntegrationScenarios:
"""Test realistic integration scenarios."""
def test_workflow_output_scenario(self):
"""Test truncation of typical workflow output data."""
truncator = VariableTruncator()
workflow_data = {
"result": "success",
"data": {
"users": [
{"id": 1, "name": "Alice", "email": "alice@example.com"},
{"id": 2, "name": "Bob", "email": "bob@example.com"},
]
* 3, # Multiply to make it larger
"metadata": {
"count": 6,
"processing_time": "1.23s",
"details": "x" * 200, # Long string but not too long
},
},
}
segment = ObjectSegment(value=workflow_data)
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert isinstance(result.result, (ObjectSegment, StringSegment))
# Should handle complex nested structure appropriately
def test_large_text_processing_scenario(self):
"""Test truncation of large text data."""
truncator = VariableTruncator(string_length_limit=100)
large_text = "This is a very long text document. " * 20 # Make it larger than limit
segment = StringSegment(value=large_text)
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is True
assert isinstance(result.result, StringSegment)
assert len(result.result.value) <= 103 # 100 + "..."
assert result.result.value.endswith("...")
def test_mixed_data_types_scenario(self):
"""Test truncation with mixed data types in complex structure."""
truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300)
mixed_data = {
"strings": ["short", "medium length", "very long string " * 3],
"numbers": [1, 2.5, 999999],
"booleans": [True, False, True],
"nested": {
"more_strings": ["nested string " * 2],
"more_numbers": list(range(5)),
"deep": {"level": 3, "content": "deep content " * 3},
},
"nulls": [None, None],
}
segment = ObjectSegment(value=mixed_data)
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
# Should handle all data types appropriately
if result.truncated:
# Verify the result is smaller or equal than original
original_size = truncator.calculate_json_size(mixed_data)
if isinstance(result.result, ObjectSegment):
result_size = truncator.calculate_json_size(result.result.value)
assert result_size <= original_size

View File

@@ -0,0 +1,377 @@
"""Simplified unit tests for DraftVarLoader focusing on core functionality."""
import json
from unittest.mock import Mock, patch
import pytest
from sqlalchemy import Engine
from core.variables.segments import ObjectSegment, StringSegment
from core.variables.types import SegmentType
from models.model import UploadFile
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
from services.workflow_draft_variable_service import DraftVarLoader
class TestDraftVarLoaderSimple:
"""Simplified unit tests for DraftVarLoader core methods."""
@pytest.fixture
def mock_engine(self) -> Engine:
return Mock(spec=Engine)
@pytest.fixture
def draft_var_loader(self, mock_engine):
"""Create DraftVarLoader instance for testing."""
return DraftVarLoader(
engine=mock_engine, app_id="test-app-id", tenant_id="test-tenant-id", fallback_variables=[]
)
def test_load_offloaded_variable_string_type_unit(self, draft_var_loader):
"""Test _load_offloaded_variable with string type - isolated unit test."""
# Create mock objects
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/test.txt"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.STRING
variable_file.upload_file = upload_file
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.id = "draft-var-id"
draft_var.node_id = "test-node-id"
draft_var.name = "test_variable"
draft_var.description = "test description"
draft_var.get_selector.return_value = ["test-node-id", "test_variable"]
draft_var.variable_file = variable_file
test_content = "This is the full string content"
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = test_content.encode()
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
mock_variable = Mock()
mock_variable.id = "draft-var-id"
mock_variable.name = "test_variable"
mock_variable.value = StringSegment(value=test_content)
mock_segment_to_variable.return_value = mock_variable
# Execute the method
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
# Verify results
assert selector_tuple == ("test-node-id", "test_variable")
assert variable.id == "draft-var-id"
assert variable.name == "test_variable"
assert variable.description == "test description"
assert variable.value == test_content
# Verify storage was called correctly
mock_storage.load.assert_called_once_with("storage/key/test.txt")
def test_load_offloaded_variable_object_type_unit(self, draft_var_loader):
"""Test _load_offloaded_variable with object type - isolated unit test."""
# Create mock objects
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/test.json"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.OBJECT
variable_file.upload_file = upload_file
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.id = "draft-var-id"
draft_var.node_id = "test-node-id"
draft_var.name = "test_object"
draft_var.description = "test description"
draft_var.get_selector.return_value = ["test-node-id", "test_object"]
draft_var.variable_file = variable_file
test_object = {"key1": "value1", "key2": 42}
test_json_content = json.dumps(test_object, ensure_ascii=False, separators=(",", ":"))
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = test_json_content.encode()
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
mock_segment = ObjectSegment(value=test_object)
mock_build_segment.return_value = mock_segment
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
mock_variable = Mock()
mock_variable.id = "draft-var-id"
mock_variable.name = "test_object"
mock_variable.value = mock_segment
mock_segment_to_variable.return_value = mock_variable
# Execute the method
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
# Verify results
assert selector_tuple == ("test-node-id", "test_object")
assert variable.id == "draft-var-id"
assert variable.name == "test_object"
assert variable.description == "test description"
assert variable.value == test_object
# Verify method calls
mock_storage.load.assert_called_once_with("storage/key/test.json")
mock_build_segment.assert_called_once_with(SegmentType.OBJECT, test_object)
def test_load_offloaded_variable_missing_variable_file_unit(self, draft_var_loader):
"""Test that assertion error is raised when variable_file is None."""
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.variable_file = None
with pytest.raises(AssertionError):
draft_var_loader._load_offloaded_variable(draft_var)
def test_load_offloaded_variable_missing_upload_file_unit(self, draft_var_loader):
"""Test that assertion error is raised when upload_file is None."""
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.upload_file = None
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.variable_file = variable_file
with pytest.raises(AssertionError):
draft_var_loader._load_offloaded_variable(draft_var)
def test_load_variables_empty_selectors_unit(self, draft_var_loader):
"""Test load_variables returns empty list for empty selectors."""
result = draft_var_loader.load_variables([])
assert result == []
def test_selector_to_tuple_unit(self, draft_var_loader):
"""Test _selector_to_tuple method."""
selector = ["node_id", "var_name", "extra_field"]
result = draft_var_loader._selector_to_tuple(selector)
assert result == ("node_id", "var_name")
def test_load_offloaded_variable_number_type_unit(self, draft_var_loader):
"""Test _load_offloaded_variable with number type - isolated unit test."""
# Create mock objects
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/test_number.json"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.NUMBER
variable_file.upload_file = upload_file
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.id = "draft-var-id"
draft_var.node_id = "test-node-id"
draft_var.name = "test_number"
draft_var.description = "test number description"
draft_var.get_selector.return_value = ["test-node-id", "test_number"]
draft_var.variable_file = variable_file
test_number = 123.45
test_json_content = json.dumps(test_number)
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = test_json_content.encode()
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
from core.variables.segments import FloatSegment
mock_segment = FloatSegment(value=test_number)
mock_build_segment.return_value = mock_segment
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
mock_variable = Mock()
mock_variable.id = "draft-var-id"
mock_variable.name = "test_number"
mock_variable.value = mock_segment
mock_segment_to_variable.return_value = mock_variable
# Execute the method
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
# Verify results
assert selector_tuple == ("test-node-id", "test_number")
assert variable.id == "draft-var-id"
assert variable.name == "test_number"
assert variable.description == "test number description"
# Verify method calls
mock_storage.load.assert_called_once_with("storage/key/test_number.json")
mock_build_segment.assert_called_once_with(SegmentType.NUMBER, test_number)
def test_load_offloaded_variable_array_type_unit(self, draft_var_loader):
"""Test _load_offloaded_variable with array type - isolated unit test."""
# Create mock objects
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/test_array.json"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.ARRAY_ANY
variable_file.upload_file = upload_file
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.id = "draft-var-id"
draft_var.node_id = "test-node-id"
draft_var.name = "test_array"
draft_var.description = "test array description"
draft_var.get_selector.return_value = ["test-node-id", "test_array"]
draft_var.variable_file = variable_file
test_array = ["item1", "item2", "item3"]
test_json_content = json.dumps(test_array)
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = test_json_content.encode()
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
from core.variables.segments import ArrayAnySegment
mock_segment = ArrayAnySegment(value=test_array)
mock_build_segment.return_value = mock_segment
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
mock_variable = Mock()
mock_variable.id = "draft-var-id"
mock_variable.name = "test_array"
mock_variable.value = mock_segment
mock_segment_to_variable.return_value = mock_variable
# Execute the method
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
# Verify results
assert selector_tuple == ("test-node-id", "test_array")
assert variable.id == "draft-var-id"
assert variable.name == "test_array"
assert variable.description == "test array description"
# Verify method calls
mock_storage.load.assert_called_once_with("storage/key/test_array.json")
mock_build_segment.assert_called_once_with(SegmentType.ARRAY_ANY, test_array)
def test_load_variables_with_offloaded_variables_unit(self, draft_var_loader):
"""Test load_variables method with mix of regular and offloaded variables."""
selectors = [["node1", "regular_var"], ["node2", "offloaded_var"]]
# Mock regular variable
regular_draft_var = Mock(spec=WorkflowDraftVariable)
regular_draft_var.is_truncated.return_value = False
regular_draft_var.node_id = "node1"
regular_draft_var.name = "regular_var"
regular_draft_var.get_value.return_value = StringSegment(value="regular_value")
regular_draft_var.get_selector.return_value = ["node1", "regular_var"]
regular_draft_var.id = "regular-var-id"
regular_draft_var.description = "regular description"
# Mock offloaded variable
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/offloaded.txt"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.STRING
variable_file.upload_file = upload_file
offloaded_draft_var = Mock(spec=WorkflowDraftVariable)
offloaded_draft_var.is_truncated.return_value = True
offloaded_draft_var.node_id = "node2"
offloaded_draft_var.name = "offloaded_var"
offloaded_draft_var.get_selector.return_value = ["node2", "offloaded_var"]
offloaded_draft_var.variable_file = variable_file
offloaded_draft_var.id = "offloaded-var-id"
offloaded_draft_var.description = "offloaded description"
draft_vars = [regular_draft_var, offloaded_draft_var]
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
mock_session = Mock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_service = Mock()
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
with patch(
"services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service
):
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
# Mock regular variable creation
regular_variable = Mock()
regular_variable.selector = ["node1", "regular_var"]
# Mock offloaded variable creation
offloaded_variable = Mock()
offloaded_variable.selector = ["node2", "offloaded_var"]
mock_segment_to_variable.return_value = regular_variable
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = b"offloaded_content"
with patch.object(draft_var_loader, "_load_offloaded_variable") as mock_load_offloaded:
mock_load_offloaded.return_value = (("node2", "offloaded_var"), offloaded_variable)
with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor_cls:
mock_executor = Mock()
mock_executor_cls.return_value.__enter__.return_value = mock_executor
mock_executor.map.return_value = [(("node2", "offloaded_var"), offloaded_variable)]
# Execute the method
result = draft_var_loader.load_variables(selectors)
# Verify results
assert len(result) == 2
# Verify service method was called
mock_service.get_draft_variables_by_selectors.assert_called_once_with(
draft_var_loader._app_id, selectors
)
# Verify offloaded variable loading was called
mock_load_offloaded.assert_called_once_with(offloaded_draft_var)
def test_load_variables_all_offloaded_variables_unit(self, draft_var_loader):
"""Test load_variables method with only offloaded variables."""
selectors = [["node1", "offloaded_var1"], ["node2", "offloaded_var2"]]
# Mock first offloaded variable
offloaded_var1 = Mock(spec=WorkflowDraftVariable)
offloaded_var1.is_truncated.return_value = True
offloaded_var1.node_id = "node1"
offloaded_var1.name = "offloaded_var1"
# Mock second offloaded variable
offloaded_var2 = Mock(spec=WorkflowDraftVariable)
offloaded_var2.is_truncated.return_value = True
offloaded_var2.node_id = "node2"
offloaded_var2.name = "offloaded_var2"
draft_vars = [offloaded_var1, offloaded_var2]
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
mock_session = Mock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_service = Mock()
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
with patch(
"services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service
):
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
with patch("services.workflow_draft_variable_service.ThreadPoolExecutor") as mock_executor_cls:
mock_executor = Mock()
mock_executor_cls.return_value.__enter__.return_value = mock_executor
mock_executor.map.return_value = [
(("node1", "offloaded_var1"), Mock()),
(("node2", "offloaded_var2"), Mock()),
]
# Execute the method
result = draft_var_loader.load_variables(selectors)
# Verify results - since we have only offloaded variables, should have 2 results
assert len(result) == 2
# Verify ThreadPoolExecutor was used
mock_executor_cls.assert_called_once_with(max_workers=10)
mock_executor.map.assert_called_once()

View File

@@ -1,16 +1,26 @@
import dataclasses
import secrets
import uuid
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy import Engine
from sqlalchemy.orm import Session
from core.variables import StringSegment
from core.variables.segments import StringSegment
from core.variables.types import SegmentType
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes.enums import NodeType
from core.workflow.enums import NodeType
from libs.uuid_utils import uuidv7
from models.account import Account
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
from models.workflow import (
Workflow,
WorkflowDraftVariable,
WorkflowDraftVariableFile,
WorkflowNodeExecutionModel,
is_system_variable_editable,
)
from services.workflow_draft_variable_service import (
DraftVariableSaver,
VariableResetError,
@@ -37,6 +47,7 @@ class TestDraftVariableSaver:
def test__should_variable_be_visible(self):
mock_session = MagicMock(spec=Session)
mock_user = Account(id=str(uuid.uuid4()))
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
@@ -44,6 +55,7 @@ class TestDraftVariableSaver:
node_id="test_node_id",
node_type=NodeType.START,
node_execution_id="test_execution_id",
user=mock_user,
)
assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False
assert saver._should_variable_be_visible("123", NodeType.START, "output") == True
@@ -83,6 +95,7 @@ class TestDraftVariableSaver:
]
mock_session = MagicMock(spec=Session)
mock_user = MagicMock()
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
@@ -90,6 +103,7 @@ class TestDraftVariableSaver:
node_id=_NODE_ID,
node_type=NodeType.START,
node_execution_id="test_execution_id",
user=mock_user,
)
for idx, c in enumerate(cases, 1):
fail_msg = f"Test case {c.name} failed, index={idx}"
@@ -97,6 +111,76 @@ class TestDraftVariableSaver:
assert node_id == c.expected_node_id, fail_msg
assert name == c.expected_name, fail_msg
@pytest.fixture
def mock_session(self):
"""Mock SQLAlchemy session."""
from sqlalchemy import Engine
mock_session = MagicMock(spec=Session)
mock_engine = MagicMock(spec=Engine)
mock_session.get_bind.return_value = mock_engine
return mock_session
@pytest.fixture
def draft_saver(self, mock_session):
"""Create DraftVariableSaver instance with user context."""
# Create a mock user
mock_user = MagicMock(spec=Account)
mock_user.id = "test-user-id"
mock_user.tenant_id = "test-tenant-id"
return DraftVariableSaver(
session=mock_session,
app_id="test-app-id",
node_id="test-node-id",
node_type=NodeType.LLM,
node_execution_id="test-execution-id",
user=mock_user,
)
def test_draft_saver_with_small_variables(self, draft_saver, mock_session):
with patch(
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
) as _mock_try_offload:
_mock_try_offload.return_value = None
mock_segment = StringSegment(value="small value")
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
# Should not have large variable metadata
assert draft_var.file_id is None
_mock_try_offload.return_value = None
def test_draft_saver_with_large_variables(self, draft_saver, mock_session):
with patch(
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
) as _mock_try_offload:
mock_segment = StringSegment(value="small value")
mock_draft_var_file = WorkflowDraftVariableFile(
id=str(uuidv7()),
size=1024,
length=10,
value_type=SegmentType.ARRAY_STRING,
upload_file_id=str(uuid.uuid4()),
)
_mock_try_offload.return_value = mock_segment, mock_draft_var_file
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
# Should not have large variable metadata
assert draft_var.file_id == mock_draft_var_file.id
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable")
def test_save_method_integration(self, mock_batch_upsert, draft_saver):
"""Test complete save workflow."""
outputs = {"result": {"data": "test_output"}, "metadata": {"type": "llm_response"}}
draft_saver.save(outputs=outputs)
# Should batch upsert draft variables
mock_batch_upsert.assert_called_once()
draft_vars = mock_batch_upsert.call_args[0][1]
assert len(draft_vars) == 2
class TestWorkflowDraftVariableService:
def _get_test_app_id(self):
@@ -115,6 +199,7 @@ class TestWorkflowDraftVariableService:
created_by="test_user_id",
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
def test_reset_conversation_variable(self, mock_session):
@@ -225,7 +310,7 @@ class TestWorkflowDraftVariableService:
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"test_var": "output_value"}
mock_execution.load_full_outputs.return_value = {"test_var": "output_value"}
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
@@ -298,7 +383,7 @@ class TestWorkflowDraftVariableService:
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"sys.files": "[]"}
mock_execution.load_full_outputs.return_value = {"sys.files": "[]"}
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
@@ -330,7 +415,7 @@ class TestWorkflowDraftVariableService:
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"sys.query": "reset query"}
mock_execution.load_full_outputs.return_value = {"sys.query": "reset query"}
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()