refactor: implement tenant self queue for rag tasks (#27559)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
301
api/tests/unit_tests/core/rag/pipeline/test_queue.py
Normal file
301
api/tests/unit_tests/core/rag/pipeline/test_queue.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Unit tests for TenantIsolatedTaskQueue.
|
||||
|
||||
These tests verify the Redis-based task queue functionality for tenant-specific
|
||||
task management with proper serialization and deserialization.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
|
||||
|
||||
|
||||
class TestTaskWrapper:
|
||||
"""Test cases for TaskWrapper serialization/deserialization."""
|
||||
|
||||
def test_serialize_simple_data(self):
|
||||
"""Test serialization of simple data types."""
|
||||
data = {"key": "value", "number": 42, "list": [1, 2, 3]}
|
||||
wrapper = TaskWrapper(data=data)
|
||||
|
||||
serialized = wrapper.serialize()
|
||||
assert isinstance(serialized, str)
|
||||
|
||||
# Verify it's valid JSON
|
||||
parsed = json.loads(serialized)
|
||||
assert parsed["data"] == data
|
||||
|
||||
def test_serialize_complex_data(self):
|
||||
"""Test serialization of complex nested data."""
|
||||
data = {
|
||||
"nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5]}},
|
||||
"unicode": "测试中文",
|
||||
"special_chars": "!@#$%^&*()",
|
||||
}
|
||||
wrapper = TaskWrapper(data=data)
|
||||
|
||||
serialized = wrapper.serialize()
|
||||
parsed = json.loads(serialized)
|
||||
assert parsed["data"] == data
|
||||
|
||||
def test_deserialize_valid_data(self):
|
||||
"""Test deserialization of valid JSON data."""
|
||||
original_data = {"key": "value", "number": 42}
|
||||
# Serialize using TaskWrapper to get the correct format
|
||||
wrapper = TaskWrapper(data=original_data)
|
||||
serialized = wrapper.serialize()
|
||||
|
||||
wrapper = TaskWrapper.deserialize(serialized)
|
||||
assert wrapper.data == original_data
|
||||
|
||||
def test_deserialize_invalid_json(self):
|
||||
"""Test deserialization handles invalid JSON gracefully."""
|
||||
invalid_json = "{invalid json}"
|
||||
|
||||
# Pydantic will raise ValidationError for invalid JSON
|
||||
with pytest.raises(ValidationError):
|
||||
TaskWrapper.deserialize(invalid_json)
|
||||
|
||||
def test_serialize_ensure_ascii_false(self):
|
||||
"""Test that serialization preserves Unicode characters."""
|
||||
data = {"chinese": "中文测试", "emoji": "🚀"}
|
||||
wrapper = TaskWrapper(data=data)
|
||||
|
||||
serialized = wrapper.serialize()
|
||||
assert "中文测试" in serialized
|
||||
assert "🚀" in serialized
|
||||
|
||||
|
||||
class TestTenantIsolatedTaskQueue:
|
||||
"""Test cases for TenantIsolatedTaskQueue functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self):
|
||||
"""Mock Redis client for testing."""
|
||||
mock_redis = MagicMock()
|
||||
return mock_redis
|
||||
|
||||
@pytest.fixture
|
||||
def sample_queue(self, mock_redis_client):
|
||||
"""Create a sample TenantIsolatedTaskQueue instance."""
|
||||
return TenantIsolatedTaskQueue("tenant-123", "test-key")
|
||||
|
||||
def test_initialization(self, sample_queue):
|
||||
"""Test queue initialization with correct key generation."""
|
||||
assert sample_queue._tenant_id == "tenant-123"
|
||||
assert sample_queue._unique_key == "test-key"
|
||||
assert sample_queue._queue == "tenant_self_test-key_task_queue:tenant-123"
|
||||
assert sample_queue._task_key == "tenant_test-key_task:tenant-123"
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_get_task_key_exists(self, mock_redis, sample_queue):
|
||||
"""Test getting task key when it exists."""
|
||||
mock_redis.get.return_value = "1"
|
||||
|
||||
result = sample_queue.get_task_key()
|
||||
|
||||
assert result == "1"
|
||||
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_get_task_key_not_exists(self, mock_redis, sample_queue):
|
||||
"""Test getting task key when it doesn't exist."""
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = sample_queue.get_task_key()
|
||||
|
||||
assert result is None
|
||||
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_set_task_waiting_time_default_ttl(self, mock_redis, sample_queue):
|
||||
"""Test setting task waiting flag with default TTL."""
|
||||
sample_queue.set_task_waiting_time()
|
||||
|
||||
mock_redis.setex.assert_called_once_with(
|
||||
"tenant_test-key_task:tenant-123",
|
||||
3600, # DEFAULT_TASK_TTL
|
||||
1,
|
||||
)
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_set_task_waiting_time_custom_ttl(self, mock_redis, sample_queue):
|
||||
"""Test setting task waiting flag with custom TTL."""
|
||||
custom_ttl = 1800
|
||||
sample_queue.set_task_waiting_time(custom_ttl)
|
||||
|
||||
mock_redis.setex.assert_called_once_with("tenant_test-key_task:tenant-123", custom_ttl, 1)
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_delete_task_key(self, mock_redis, sample_queue):
|
||||
"""Test deleting task key."""
|
||||
sample_queue.delete_task_key()
|
||||
|
||||
mock_redis.delete.assert_called_once_with("tenant_test-key_task:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_push_tasks_string_list(self, mock_redis, sample_queue):
|
||||
"""Test pushing string tasks directly."""
|
||||
tasks = ["task1", "task2", "task3"]
|
||||
|
||||
sample_queue.push_tasks(tasks)
|
||||
|
||||
mock_redis.lpush.assert_called_once_with(
|
||||
"tenant_self_test-key_task_queue:tenant-123", "task1", "task2", "task3"
|
||||
)
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_push_tasks_mixed_types(self, mock_redis, sample_queue):
|
||||
"""Test pushing mixed string and object tasks."""
|
||||
tasks = ["string_task", {"object_task": "data", "id": 123}, "another_string"]
|
||||
|
||||
sample_queue.push_tasks(tasks)
|
||||
|
||||
# Verify lpush was called
|
||||
mock_redis.lpush.assert_called_once()
|
||||
call_args = mock_redis.lpush.call_args
|
||||
|
||||
# Check queue name
|
||||
assert call_args[0][0] == "tenant_self_test-key_task_queue:tenant-123"
|
||||
|
||||
# Check serialized tasks
|
||||
serialized_tasks = call_args[0][1:]
|
||||
assert len(serialized_tasks) == 3
|
||||
assert serialized_tasks[0] == "string_task"
|
||||
assert serialized_tasks[2] == "another_string"
|
||||
|
||||
# Check object task is serialized as TaskWrapper JSON (without prefix)
|
||||
# It should be a valid JSON string that can be deserialized by TaskWrapper
|
||||
wrapper = TaskWrapper.deserialize(serialized_tasks[1])
|
||||
assert wrapper.data == {"object_task": "data", "id": 123}
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_push_tasks_empty_list(self, mock_redis, sample_queue):
|
||||
"""Test pushing empty task list."""
|
||||
sample_queue.push_tasks([])
|
||||
|
||||
mock_redis.lpush.assert_called_once_with("tenant_self_test-key_task_queue:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_default_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with default count (1)."""
|
||||
mock_redis.rpop.side_effect = ["task1", None]
|
||||
|
||||
result = sample_queue.pull_tasks()
|
||||
|
||||
assert result == ["task1"]
|
||||
assert mock_redis.rpop.call_count == 1
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_custom_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with custom count."""
|
||||
# First test: pull 3 tasks
|
||||
mock_redis.rpop.side_effect = ["task1", "task2", "task3", None]
|
||||
|
||||
result = sample_queue.pull_tasks(3)
|
||||
|
||||
assert result == ["task1", "task2", "task3"]
|
||||
assert mock_redis.rpop.call_count == 3
|
||||
|
||||
# Reset mock for second test
|
||||
mock_redis.reset_mock()
|
||||
mock_redis.rpop.side_effect = ["task1", "task2", None]
|
||||
|
||||
result = sample_queue.pull_tasks(3)
|
||||
|
||||
assert result == ["task1", "task2"]
|
||||
assert mock_redis.rpop.call_count == 3
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_zero_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with zero count returns empty list."""
|
||||
result = sample_queue.pull_tasks(0)
|
||||
|
||||
assert result == []
|
||||
mock_redis.rpop.assert_not_called()
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_negative_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with negative count returns empty list."""
|
||||
result = sample_queue.pull_tasks(-1)
|
||||
|
||||
assert result == []
|
||||
mock_redis.rpop.assert_not_called()
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_with_wrapped_objects(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks that include wrapped objects."""
|
||||
# Create a wrapped task
|
||||
task_data = {"task_id": 123, "data": "test"}
|
||||
wrapper = TaskWrapper(data=task_data)
|
||||
wrapped_task = wrapper.serialize()
|
||||
|
||||
mock_redis.rpop.side_effect = [
|
||||
"string_task",
|
||||
wrapped_task.encode("utf-8"), # Simulate bytes from Redis
|
||||
None,
|
||||
]
|
||||
|
||||
result = sample_queue.pull_tasks(2)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0] == "string_task"
|
||||
assert result[1] == {"task_id": 123, "data": "test"}
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_with_invalid_wrapped_data(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with invalid JSON falls back to string."""
|
||||
# Invalid JSON string that cannot be deserialized
|
||||
invalid_json = "invalid json data"
|
||||
mock_redis.rpop.side_effect = [invalid_json, None]
|
||||
|
||||
result = sample_queue.pull_tasks(1)
|
||||
|
||||
assert result == [invalid_json]
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_bytes_decoding(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks handles bytes from Redis correctly."""
|
||||
mock_redis.rpop.side_effect = [
|
||||
b"task1", # bytes
|
||||
"task2", # string
|
||||
None,
|
||||
]
|
||||
|
||||
result = sample_queue.pull_tasks(2)
|
||||
|
||||
assert result == ["task1", "task2"]
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_complex_object_serialization_roundtrip(self, mock_redis, sample_queue):
|
||||
"""Test complex object serialization and deserialization roundtrip."""
|
||||
complex_task = {
|
||||
"id": uuid4().hex,
|
||||
"data": {"nested": {"deep": [1, 2, 3], "unicode": "测试中文", "special": "!@#$%^&*()"}},
|
||||
"metadata": {"created_at": "2024-01-01T00:00:00Z", "tags": ["tag1", "tag2", "tag3"]},
|
||||
}
|
||||
|
||||
# Push the complex task
|
||||
sample_queue.push_tasks([complex_task])
|
||||
|
||||
# Verify it was serialized as TaskWrapper JSON
|
||||
call_args = mock_redis.lpush.call_args
|
||||
wrapped_task = call_args[0][1]
|
||||
# Verify it's a valid TaskWrapper JSON (starts with {"data":)
|
||||
assert wrapped_task.startswith('{"data":')
|
||||
|
||||
# Verify it can be deserialized
|
||||
wrapper = TaskWrapper.deserialize(wrapped_task)
|
||||
assert wrapper.data == complex_task
|
||||
|
||||
# Simulate pulling it back
|
||||
mock_redis.rpop.return_value = wrapped_task
|
||||
result = sample_queue.pull_tasks(1)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] == complex_task
|
||||
@@ -0,0 +1,317 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
|
||||
|
||||
class DocumentIndexingTaskProxyTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
|
||||
"""Create mock features with billing configuration."""
|
||||
features = Mock()
|
||||
features.billing = Mock()
|
||||
features.billing.enabled = billing_enabled
|
||||
features.billing.subscription = Mock()
|
||||
features.billing.subscription.plan = plan
|
||||
return features
|
||||
|
||||
@staticmethod
|
||||
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
|
||||
"""Create mock TenantIsolatedTaskQueue."""
|
||||
queue = Mock(spec=TenantIsolatedTaskQueue)
|
||||
queue.get_task_key.return_value = "task_key" if has_task_key else None
|
||||
queue.push_tasks = Mock()
|
||||
queue.set_task_waiting_time = Mock()
|
||||
return queue
|
||||
|
||||
@staticmethod
|
||||
def create_document_task_proxy(
|
||||
tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None
|
||||
) -> DocumentIndexingTaskProxy:
|
||||
"""Create DocumentIndexingTaskProxy instance for testing."""
|
||||
if document_ids is None:
|
||||
document_ids = ["doc-1", "doc-2", "doc-3"]
|
||||
return DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
|
||||
class TestDocumentIndexingTaskProxy:
|
||||
"""Test cases for DocumentIndexingTaskProxy class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test DocumentIndexingTaskProxy initialization."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1", "doc-2", "doc-3"]
|
||||
|
||||
# Act
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
|
||||
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
|
||||
assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing"
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
"""Test cached_property features."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features()
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
|
||||
# Act
|
||||
features1 = proxy.features
|
||||
features2 = proxy.features # Second call should use cached property
|
||||
|
||||
# Assert
|
||||
assert features1 == mock_features
|
||||
assert features2 == mock_features
|
||||
assert features1 is features2 # Should be the same instance due to caching
|
||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
"""Test _send_to_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_direct_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
mock_task.delay.assert_called_once_with(
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when task key exists."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=True
|
||||
)
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once()
|
||||
pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0]
|
||||
assert len(pushed_tasks) == 1
|
||||
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
|
||||
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
|
||||
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
|
||||
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when no task key exists."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=False
|
||||
)
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
|
||||
mock_task.delay.assert_called_once_with(
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_default_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_default_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_default_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_priority_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_direct_queue(self, mock_task):
|
||||
"""Test _send_to_priority_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_direct_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_direct_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.TEAM
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing enabled with non sandbox plan, should send to priority tenant queue
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is disabled."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_direct_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
||||
proxy._send_to_priority_direct_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_delay_method(self, mock_feature_service):
|
||||
"""Test delay method integration."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
# If billing enabled with sandbox plan, should send to default tenant queue
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
def test_document_task_dataclass(self):
|
||||
"""Test DocumentTask dataclass."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1", "doc-2"]
|
||||
|
||||
# Act
|
||||
task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
|
||||
|
||||
# Assert
|
||||
assert task.tenant_id == tenant_id
|
||||
assert task.dataset_id == dataset_id
|
||||
assert task.document_ids == document_ids
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with empty plan string."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with None plan."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
def test_initialization_with_empty_document_ids(self):
|
||||
"""Test initialization with empty document_ids list."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = []
|
||||
|
||||
# Act
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
|
||||
def test_initialization_with_single_document_id(self):
|
||||
"""Test initialization with single document_id."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1"]
|
||||
|
||||
# Act
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
483
api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py
Normal file
483
api/tests/unit_tests/services/test_rag_pipeline_task_proxy.py
Normal file
@@ -0,0 +1,483 @@
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
|
||||
|
||||
|
||||
class RagPipelineTaskProxyTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for RagPipelineTaskProxy tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
|
||||
"""Create mock features with billing configuration."""
|
||||
features = Mock()
|
||||
features.billing = Mock()
|
||||
features.billing.enabled = billing_enabled
|
||||
features.billing.subscription = Mock()
|
||||
features.billing.subscription.plan = plan
|
||||
return features
|
||||
|
||||
@staticmethod
|
||||
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
|
||||
"""Create mock TenantIsolatedTaskQueue."""
|
||||
queue = Mock(spec=TenantIsolatedTaskQueue)
|
||||
queue.get_task_key.return_value = "task_key" if has_task_key else None
|
||||
queue.push_tasks = Mock()
|
||||
queue.set_task_waiting_time = Mock()
|
||||
return queue
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_invoke_entity(
|
||||
pipeline_id: str = "pipeline-123",
|
||||
user_id: str = "user-456",
|
||||
tenant_id: str = "tenant-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
streaming: bool = True,
|
||||
workflow_execution_id: str | None = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> RagPipelineInvokeEntity:
|
||||
"""Create RagPipelineInvokeEntity instance for testing."""
|
||||
return RagPipelineInvokeEntity(
|
||||
pipeline_id=pipeline_id,
|
||||
application_generate_entity={"key": "value"},
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_id=workflow_id,
|
||||
streaming=streaming,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_task_proxy(
|
||||
dataset_tenant_id: str = "tenant-123",
|
||||
user_id: str = "user-456",
|
||||
rag_pipeline_invoke_entities: list[RagPipelineInvokeEntity] | None = None,
|
||||
) -> RagPipelineTaskProxy:
|
||||
"""Create RagPipelineTaskProxy instance for testing."""
|
||||
if rag_pipeline_invoke_entities is None:
|
||||
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
|
||||
return RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
@staticmethod
|
||||
def create_mock_upload_file(file_id: str = "file-123") -> Mock:
|
||||
"""Create mock upload file."""
|
||||
upload_file = Mock()
|
||||
upload_file.id = file_id
|
||||
return upload_file
|
||||
|
||||
|
||||
class TestRagPipelineTaskProxy:
|
||||
"""Test cases for RagPipelineTaskProxy class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test RagPipelineTaskProxy initialization."""
|
||||
# Arrange
|
||||
dataset_tenant_id = "tenant-123"
|
||||
user_id = "user-456"
|
||||
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
|
||||
|
||||
# Act
|
||||
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
# Assert
|
||||
assert proxy._dataset_tenant_id == dataset_tenant_id
|
||||
assert proxy._user_id == user_id
|
||||
assert proxy._rag_pipeline_invoke_entities == rag_pipeline_invoke_entities
|
||||
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
|
||||
assert proxy._tenant_isolated_task_queue._tenant_id == dataset_tenant_id
|
||||
assert proxy._tenant_isolated_task_queue._unique_key == "pipeline"
|
||||
|
||||
def test_initialization_with_empty_entities(self):
|
||||
"""Test initialization with empty rag_pipeline_invoke_entities."""
|
||||
# Arrange
|
||||
dataset_tenant_id = "tenant-123"
|
||||
user_id = "user-456"
|
||||
rag_pipeline_invoke_entities = []
|
||||
|
||||
# Act
|
||||
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
# Assert
|
||||
assert proxy._dataset_tenant_id == dataset_tenant_id
|
||||
assert proxy._user_id == user_id
|
||||
assert proxy._rag_pipeline_invoke_entities == []
|
||||
|
||||
def test_initialization_with_multiple_entities(self):
|
||||
"""Test initialization with multiple rag_pipeline_invoke_entities."""
|
||||
# Arrange
|
||||
dataset_tenant_id = "tenant-123"
|
||||
user_id = "user-456"
|
||||
rag_pipeline_invoke_entities = [
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-3"),
|
||||
]
|
||||
|
||||
# Act
|
||||
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
# Assert
|
||||
assert len(proxy._rag_pipeline_invoke_entities) == 3
|
||||
assert proxy._rag_pipeline_invoke_entities[0].pipeline_id == "pipeline-1"
|
||||
assert proxy._rag_pipeline_invoke_entities[1].pipeline_id == "pipeline-2"
|
||||
assert proxy._rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3"
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
"""Test cached_property features."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features()
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
|
||||
# Act
|
||||
features1 = proxy.features
|
||||
features2 = proxy.features # Second call should use cached property
|
||||
|
||||
# Assert
|
||||
assert features1 == mock_features
|
||||
assert features2 == mock_features
|
||||
assert features1 is features2 # Should be the same instance due to caching
|
||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_upload_invoke_entities(self, mock_db, mock_file_service_class):
|
||||
"""Test _upload_invoke_entities method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
result = proxy._upload_invoke_entities()
|
||||
|
||||
# Assert
|
||||
assert result == "file-123"
|
||||
mock_file_service_class.assert_called_once_with(mock_db.engine)
|
||||
|
||||
# Verify upload_text was called with correct parameters
|
||||
mock_file_service.upload_text.assert_called_once()
|
||||
call_args = mock_file_service.upload_text.call_args
|
||||
json_text, name, user_id, tenant_id = call_args[0]
|
||||
|
||||
assert name == "rag_pipeline_invoke_entities.json"
|
||||
assert user_id == "user-456"
|
||||
assert tenant_id == "tenant-123"
|
||||
|
||||
# Verify JSON content
|
||||
parsed_json = json.loads(json_text)
|
||||
assert len(parsed_json) == 1
|
||||
assert parsed_json[0]["pipeline_id"] == "pipeline-123"
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class):
|
||||
"""Test _upload_invoke_entities method with multiple entities."""
|
||||
# Arrange
|
||||
entities = [
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
|
||||
]
|
||||
proxy = RagPipelineTaskProxy("tenant-123", "user-456", entities)
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-456")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
result = proxy._upload_invoke_entities()
|
||||
|
||||
# Assert
|
||||
assert result == "file-456"
|
||||
|
||||
# Verify JSON content contains both entities
|
||||
call_args = mock_file_service.upload_text.call_args
|
||||
json_text = call_args[0][0]
|
||||
parsed_json = json.loads(json_text)
|
||||
assert len(parsed_json) == 2
|
||||
assert parsed_json[0]["pipeline_id"] == "pipeline-1"
|
||||
assert parsed_json[1]["pipeline_id"] == "pipeline-2"
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
"""Test _send_to_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue()
|
||||
upload_file_id = "file-123"
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_direct_queue(upload_file_id, mock_task)
|
||||
|
||||
# If sent to direct queue, tenant_isolated_task_queue should not be called
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
# Celery should be called directly
|
||||
mock_task.delay.assert_called_once_with(
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
|
||||
)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when task key exists."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=True
|
||||
)
|
||||
upload_file_id = "file-123"
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(upload_file_id, mock_task)
|
||||
|
||||
# If task key exists, should push tasks to the queue
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once_with([upload_file_id])
|
||||
# Celery should not be called directly
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when no task key exists."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=False
|
||||
)
|
||||
upload_file_id = "file-123"
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(upload_file_id, mock_task)
|
||||
|
||||
# If no task key, should set task waiting time key first
|
||||
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
|
||||
mock_task.delay.assert_called_once_with(
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
|
||||
)
|
||||
|
||||
# The first task should be sent to celery directly, so push tasks should not be called
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_default_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_default_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
upload_file_id = "file-123"
|
||||
|
||||
# Act
|
||||
proxy._send_to_default_tenant_queue(upload_file_id)
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
|
||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_priority_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
upload_file_id = "file-123"
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_tenant_queue(upload_file_id)
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
|
||||
def test_send_to_priority_direct_queue(self, mock_task):
|
||||
"""Test _send_to_priority_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_direct_queue = Mock()
|
||||
upload_file_id = "file-123"
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_direct_queue(upload_file_id)
|
||||
|
||||
# Assert
|
||||
proxy._send_to_direct_queue.assert_called_once_with(upload_file_id, mock_task)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing is enabled with sandbox plan, should send to default tenant queue
|
||||
proxy._send_to_default_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_billing_enabled_non_sandbox_plan(
|
||||
self, mock_db, mock_file_service_class, mock_feature_service
|
||||
):
|
||||
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.TEAM
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing is enabled with non-sandbox plan, should send to priority tenant queue
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method when billing is disabled."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_direct_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing is disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
||||
proxy._send_to_priority_direct_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class):
|
||||
"""Test _dispatch method when upload_file_id is empty."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = Mock()
|
||||
mock_upload_file.id = "" # Empty file ID
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="upload_file_id is empty"):
|
||||
proxy._dispatch()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method with empty plan string."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method with None plan."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test delay method integration."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._dispatch = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
proxy._dispatch.assert_called_once()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.logger")
|
||||
def test_delay_method_with_empty_entities(self, mock_logger):
|
||||
"""Test delay method with empty rag_pipeline_invoke_entities."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxy("tenant-123", "user-456", [])
|
||||
|
||||
# Act
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s", "tenant-123", "user-456"
|
||||
)
|
||||
Reference in New Issue
Block a user