refactor: implement tenant self queue for rag tasks (#27559)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
hj24
2025-11-06 21:25:50 +08:00
committed by GitHub
parent f4c82d0010
commit 37903722fe
24 changed files with 3433 additions and 94 deletions

View File

@@ -0,0 +1,301 @@
"""
Unit tests for TenantIsolatedTaskQueue.
These tests verify the Redis-based task queue functionality for tenant-specific
task management with proper serialization and deserialization.
"""
import json
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from pydantic import ValidationError
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
class TestTaskWrapper:
"""Test cases for TaskWrapper serialization/deserialization."""
def test_serialize_simple_data(self):
"""Test serialization of simple data types."""
data = {"key": "value", "number": 42, "list": [1, 2, 3]}
wrapper = TaskWrapper(data=data)
serialized = wrapper.serialize()
assert isinstance(serialized, str)
# Verify it's valid JSON
parsed = json.loads(serialized)
assert parsed["data"] == data
def test_serialize_complex_data(self):
"""Test serialization of complex nested data."""
data = {
"nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5]}},
"unicode": "测试中文",
"special_chars": "!@#$%^&*()",
}
wrapper = TaskWrapper(data=data)
serialized = wrapper.serialize()
parsed = json.loads(serialized)
assert parsed["data"] == data
def test_deserialize_valid_data(self):
"""Test deserialization of valid JSON data."""
original_data = {"key": "value", "number": 42}
# Serialize using TaskWrapper to get the correct format
wrapper = TaskWrapper(data=original_data)
serialized = wrapper.serialize()
wrapper = TaskWrapper.deserialize(serialized)
assert wrapper.data == original_data
def test_deserialize_invalid_json(self):
"""Test deserialization handles invalid JSON gracefully."""
invalid_json = "{invalid json}"
# Pydantic will raise ValidationError for invalid JSON
with pytest.raises(ValidationError):
TaskWrapper.deserialize(invalid_json)
def test_serialize_ensure_ascii_false(self):
"""Test that serialization preserves Unicode characters."""
data = {"chinese": "中文测试", "emoji": "🚀"}
wrapper = TaskWrapper(data=data)
serialized = wrapper.serialize()
assert "中文测试" in serialized
assert "🚀" in serialized
class TestTenantIsolatedTaskQueue:
"""Test cases for TenantIsolatedTaskQueue functionality."""
@pytest.fixture
def mock_redis_client(self):
"""Mock Redis client for testing."""
mock_redis = MagicMock()
return mock_redis
@pytest.fixture
def sample_queue(self, mock_redis_client):
"""Create a sample TenantIsolatedTaskQueue instance."""
return TenantIsolatedTaskQueue("tenant-123", "test-key")
def test_initialization(self, sample_queue):
"""Test queue initialization with correct key generation."""
assert sample_queue._tenant_id == "tenant-123"
assert sample_queue._unique_key == "test-key"
assert sample_queue._queue == "tenant_self_test-key_task_queue:tenant-123"
assert sample_queue._task_key == "tenant_test-key_task:tenant-123"
@patch("core.rag.pipeline.queue.redis_client")
def test_get_task_key_exists(self, mock_redis, sample_queue):
"""Test getting task key when it exists."""
mock_redis.get.return_value = "1"
result = sample_queue.get_task_key()
assert result == "1"
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_get_task_key_not_exists(self, mock_redis, sample_queue):
"""Test getting task key when it doesn't exist."""
mock_redis.get.return_value = None
result = sample_queue.get_task_key()
assert result is None
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_set_task_waiting_time_default_ttl(self, mock_redis, sample_queue):
"""Test setting task waiting flag with default TTL."""
sample_queue.set_task_waiting_time()
mock_redis.setex.assert_called_once_with(
"tenant_test-key_task:tenant-123",
3600, # DEFAULT_TASK_TTL
1,
)
@patch("core.rag.pipeline.queue.redis_client")
def test_set_task_waiting_time_custom_ttl(self, mock_redis, sample_queue):
"""Test setting task waiting flag with custom TTL."""
custom_ttl = 1800
sample_queue.set_task_waiting_time(custom_ttl)
mock_redis.setex.assert_called_once_with("tenant_test-key_task:tenant-123", custom_ttl, 1)
@patch("core.rag.pipeline.queue.redis_client")
def test_delete_task_key(self, mock_redis, sample_queue):
"""Test deleting task key."""
sample_queue.delete_task_key()
mock_redis.delete.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_push_tasks_string_list(self, mock_redis, sample_queue):
"""Test pushing string tasks directly."""
tasks = ["task1", "task2", "task3"]
sample_queue.push_tasks(tasks)
mock_redis.lpush.assert_called_once_with(
"tenant_self_test-key_task_queue:tenant-123", "task1", "task2", "task3"
)
@patch("core.rag.pipeline.queue.redis_client")
def test_push_tasks_mixed_types(self, mock_redis, sample_queue):
"""Test pushing mixed string and object tasks."""
tasks = ["string_task", {"object_task": "data", "id": 123}, "another_string"]
sample_queue.push_tasks(tasks)
# Verify lpush was called
mock_redis.lpush.assert_called_once()
call_args = mock_redis.lpush.call_args
# Check queue name
assert call_args[0][0] == "tenant_self_test-key_task_queue:tenant-123"
# Check serialized tasks
serialized_tasks = call_args[0][1:]
assert len(serialized_tasks) == 3
assert serialized_tasks[0] == "string_task"
assert serialized_tasks[2] == "another_string"
# Check object task is serialized as TaskWrapper JSON (without prefix)
# It should be a valid JSON string that can be deserialized by TaskWrapper
wrapper = TaskWrapper.deserialize(serialized_tasks[1])
assert wrapper.data == {"object_task": "data", "id": 123}
@patch("core.rag.pipeline.queue.redis_client")
def test_push_tasks_empty_list(self, mock_redis, sample_queue):
"""Test pushing empty task list."""
sample_queue.push_tasks([])
mock_redis.lpush.assert_called_once_with("tenant_self_test-key_task_queue:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_default_count(self, mock_redis, sample_queue):
"""Test pulling tasks with default count (1)."""
mock_redis.rpop.side_effect = ["task1", None]
result = sample_queue.pull_tasks()
assert result == ["task1"]
assert mock_redis.rpop.call_count == 1
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_custom_count(self, mock_redis, sample_queue):
"""Test pulling tasks with custom count."""
# First test: pull 3 tasks
mock_redis.rpop.side_effect = ["task1", "task2", "task3", None]
result = sample_queue.pull_tasks(3)
assert result == ["task1", "task2", "task3"]
assert mock_redis.rpop.call_count == 3
# Reset mock for second test
mock_redis.reset_mock()
mock_redis.rpop.side_effect = ["task1", "task2", None]
result = sample_queue.pull_tasks(3)
assert result == ["task1", "task2"]
assert mock_redis.rpop.call_count == 3
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_zero_count(self, mock_redis, sample_queue):
"""Test pulling tasks with zero count returns empty list."""
result = sample_queue.pull_tasks(0)
assert result == []
mock_redis.rpop.assert_not_called()
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_negative_count(self, mock_redis, sample_queue):
"""Test pulling tasks with negative count returns empty list."""
result = sample_queue.pull_tasks(-1)
assert result == []
mock_redis.rpop.assert_not_called()
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_with_wrapped_objects(self, mock_redis, sample_queue):
"""Test pulling tasks that include wrapped objects."""
# Create a wrapped task
task_data = {"task_id": 123, "data": "test"}
wrapper = TaskWrapper(data=task_data)
wrapped_task = wrapper.serialize()
mock_redis.rpop.side_effect = [
"string_task",
wrapped_task.encode("utf-8"), # Simulate bytes from Redis
None,
]
result = sample_queue.pull_tasks(2)
assert len(result) == 2
assert result[0] == "string_task"
assert result[1] == {"task_id": 123, "data": "test"}
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_with_invalid_wrapped_data(self, mock_redis, sample_queue):
"""Test pulling tasks with invalid JSON falls back to string."""
# Invalid JSON string that cannot be deserialized
invalid_json = "invalid json data"
mock_redis.rpop.side_effect = [invalid_json, None]
result = sample_queue.pull_tasks(1)
assert result == [invalid_json]
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_bytes_decoding(self, mock_redis, sample_queue):
"""Test pulling tasks handles bytes from Redis correctly."""
mock_redis.rpop.side_effect = [
b"task1", # bytes
"task2", # string
None,
]
result = sample_queue.pull_tasks(2)
assert result == ["task1", "task2"]
@patch("core.rag.pipeline.queue.redis_client")
def test_complex_object_serialization_roundtrip(self, mock_redis, sample_queue):
"""Test complex object serialization and deserialization roundtrip."""
complex_task = {
"id": uuid4().hex,
"data": {"nested": {"deep": [1, 2, 3], "unicode": "测试中文", "special": "!@#$%^&*()"}},
"metadata": {"created_at": "2024-01-01T00:00:00Z", "tags": ["tag1", "tag2", "tag3"]},
}
# Push the complex task
sample_queue.push_tasks([complex_task])
# Verify it was serialized as TaskWrapper JSON
call_args = mock_redis.lpush.call_args
wrapped_task = call_args[0][1]
# Verify it's a valid TaskWrapper JSON (starts with {"data":)
assert wrapped_task.startswith('{"data":')
# Verify it can be deserialized
wrapper = TaskWrapper.deserialize(wrapped_task)
assert wrapper.data == complex_task
# Simulate pulling it back
mock_redis.rpop.return_value = wrapped_task
result = sample_queue.pull_tasks(1)
assert len(result) == 1
assert result[0] == complex_task

View File

@@ -0,0 +1,317 @@
from unittest.mock import Mock, patch
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
class DocumentIndexingTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests."""
@staticmethod
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()
features.billing.enabled = billing_enabled
features.billing.subscription = Mock()
features.billing.subscription.plan = plan
return features
@staticmethod
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
"""Create mock TenantIsolatedTaskQueue."""
queue = Mock(spec=TenantIsolatedTaskQueue)
queue.get_task_key.return_value = "task_key" if has_task_key else None
queue.push_tasks = Mock()
queue.set_task_waiting_time = Mock()
return queue
@staticmethod
def create_document_task_proxy(
tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None
) -> DocumentIndexingTaskProxy:
"""Create DocumentIndexingTaskProxy instance for testing."""
if document_ids is None:
document_ids = ["doc-1", "doc-2", "doc-3"]
return DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
class TestDocumentIndexingTaskProxy:
"""Test cases for DocumentIndexingTaskProxy class."""
def test_initialization(self):
"""Test DocumentIndexingTaskProxy initialization."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1", "doc-2", "doc-3"]
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing"
@patch("services.document_indexing_task_proxy.FeatureService")
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features()
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
# Act
features1 = proxy.features
features2 = proxy.features # Second call should use cached property
# Assert
assert features1 == mock_features
assert features2 == mock_features
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
mock_task.delay = Mock()
# Act
proxy._send_to_direct_queue(mock_task)
# Assert
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=True
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once()
pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0]
assert len(pushed_tasks) == 1
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
mock_task.delay.assert_not_called()
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=False
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_default_tenant_queue(self, mock_task):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_default_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_tenant_queue(self, mock_task):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_priority_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_direct_queue(self, mock_task):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_direct_queue = Mock()
# Act
proxy._send_to_priority_direct_queue()
# Assert
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_default_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.TEAM
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# If billing enabled with non sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_disabled(self, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_direct_queue = Mock()
# Act
proxy._dispatch()
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_delay_method(self, mock_feature_service):
"""Test delay method integration."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy.delay()
# Assert
# If billing enabled with sandbox plan, should send to default tenant queue
proxy._send_to_default_tenant_queue.assert_called_once()
def test_document_task_dataclass(self):
"""Test DocumentTask dataclass."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1", "doc-2"]
# Act
task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
# Assert
assert task.tenant_id == tenant_id
assert task.dataset_id == dataset_id
assert task.document_ids == document_ids
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
def test_initialization_with_empty_document_ids(self):
"""Test initialization with empty document_ids list."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = []
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
def test_initialization_with_single_document_id(self):
"""Test initialization with single document_id."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1"]
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids

View File

@@ -0,0 +1,483 @@
import json
from unittest.mock import Mock, patch
import pytest
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
class RagPipelineTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for RagPipelineTaskProxy tests."""
@staticmethod
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()
features.billing.enabled = billing_enabled
features.billing.subscription = Mock()
features.billing.subscription.plan = plan
return features
@staticmethod
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
"""Create mock TenantIsolatedTaskQueue."""
queue = Mock(spec=TenantIsolatedTaskQueue)
queue.get_task_key.return_value = "task_key" if has_task_key else None
queue.push_tasks = Mock()
queue.set_task_waiting_time = Mock()
return queue
@staticmethod
def create_rag_pipeline_invoke_entity(
pipeline_id: str = "pipeline-123",
user_id: str = "user-456",
tenant_id: str = "tenant-789",
workflow_id: str = "workflow-101",
streaming: bool = True,
workflow_execution_id: str | None = None,
workflow_thread_pool_id: str | None = None,
) -> RagPipelineInvokeEntity:
"""Create RagPipelineInvokeEntity instance for testing."""
return RagPipelineInvokeEntity(
pipeline_id=pipeline_id,
application_generate_entity={"key": "value"},
user_id=user_id,
tenant_id=tenant_id,
workflow_id=workflow_id,
streaming=streaming,
workflow_execution_id=workflow_execution_id,
workflow_thread_pool_id=workflow_thread_pool_id,
)
@staticmethod
def create_rag_pipeline_task_proxy(
dataset_tenant_id: str = "tenant-123",
user_id: str = "user-456",
rag_pipeline_invoke_entities: list[RagPipelineInvokeEntity] | None = None,
) -> RagPipelineTaskProxy:
"""Create RagPipelineTaskProxy instance for testing."""
if rag_pipeline_invoke_entities is None:
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
return RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
@staticmethod
def create_mock_upload_file(file_id: str = "file-123") -> Mock:
"""Create mock upload file."""
upload_file = Mock()
upload_file.id = file_id
return upload_file
class TestRagPipelineTaskProxy:
"""Test cases for RagPipelineTaskProxy class."""
def test_initialization(self):
"""Test RagPipelineTaskProxy initialization."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert proxy._dataset_tenant_id == dataset_tenant_id
assert proxy._user_id == user_id
assert proxy._rag_pipeline_invoke_entities == rag_pipeline_invoke_entities
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
assert proxy._tenant_isolated_task_queue._tenant_id == dataset_tenant_id
assert proxy._tenant_isolated_task_queue._unique_key == "pipeline"
def test_initialization_with_empty_entities(self):
"""Test initialization with empty rag_pipeline_invoke_entities."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = []
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert proxy._dataset_tenant_id == dataset_tenant_id
assert proxy._user_id == user_id
assert proxy._rag_pipeline_invoke_entities == []
def test_initialization_with_multiple_entities(self):
"""Test initialization with multiple rag_pipeline_invoke_entities."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-3"),
]
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert len(proxy._rag_pipeline_invoke_entities) == 3
assert proxy._rag_pipeline_invoke_entities[0].pipeline_id == "pipeline-1"
assert proxy._rag_pipeline_invoke_entities[1].pipeline_id == "pipeline-2"
assert proxy._rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3"
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features()
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
# Act
features1 = proxy.features
features2 = proxy.features # Second call should use cached property
# Assert
assert features1 == mock_features
assert features2 == mock_features
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_upload_invoke_entities(self, mock_db, mock_file_service_class):
"""Test _upload_invoke_entities method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
result = proxy._upload_invoke_entities()
# Assert
assert result == "file-123"
mock_file_service_class.assert_called_once_with(mock_db.engine)
# Verify upload_text was called with correct parameters
mock_file_service.upload_text.assert_called_once()
call_args = mock_file_service.upload_text.call_args
json_text, name, user_id, tenant_id = call_args[0]
assert name == "rag_pipeline_invoke_entities.json"
assert user_id == "user-456"
assert tenant_id == "tenant-123"
# Verify JSON content
parsed_json = json.loads(json_text)
assert len(parsed_json) == 1
assert parsed_json[0]["pipeline_id"] == "pipeline-123"
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class):
"""Test _upload_invoke_entities method with multiple entities."""
# Arrange
entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
]
proxy = RagPipelineTaskProxy("tenant-123", "user-456", entities)
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-456")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
result = proxy._upload_invoke_entities()
# Assert
assert result == "file-456"
# Verify JSON content contains both entities
call_args = mock_file_service.upload_text.call_args
json_text = call_args[0][0]
parsed_json = json.loads(json_text)
assert len(parsed_json) == 2
assert parsed_json[0]["pipeline_id"] == "pipeline-1"
assert parsed_json[1]["pipeline_id"] == "pipeline-2"
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue()
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_direct_queue(upload_file_id, mock_task)
# If sent to direct queue, tenant_isolated_task_queue should not be called
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
# Celery should be called directly
mock_task.delay.assert_called_once_with(
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=True
)
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(upload_file_id, mock_task)
# If task key exists, should push tasks to the queue
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once_with([upload_file_id])
# Celery should not be called directly
mock_task.delay.assert_not_called()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=False
)
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(upload_file_id, mock_task)
# If no task key, should set task waiting time key first
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
mock_task.delay.assert_called_once_with(
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
)
# The first task should be sent to celery directly, so push tasks should not be called
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_default_tenant_queue(self, mock_task):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_tenant_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_default_tenant_queue(upload_file_id)
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
def test_send_to_priority_tenant_queue(self, mock_task):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_tenant_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_priority_tenant_queue(upload_file_id)
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
def test_send_to_priority_direct_queue(self, mock_task):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_direct_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_priority_direct_queue(upload_file_id)
# Assert
proxy._send_to_direct_queue.assert_called_once_with(upload_file_id, mock_task)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is enabled with sandbox plan, should send to default tenant queue
proxy._send_to_default_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_billing_enabled_non_sandbox_plan(
self, mock_db, mock_file_service_class, mock_feature_service
):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.TEAM
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is enabled with non-sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_direct_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class):
"""Test _dispatch method when upload_file_id is empty."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = Mock()
mock_upload_file.id = "" # Empty file ID
mock_file_service.upload_text.return_value = mock_upload_file
# Act & Assert
with pytest.raises(ValueError, match="upload_file_id is empty"):
proxy._dispatch()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test delay method integration."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._dispatch = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy.delay()
# Assert
proxy._dispatch.assert_called_once()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.logger")
def test_delay_method_with_empty_entities(self, mock_logger):
"""Test delay method with empty rag_pipeline_invoke_entities."""
# Arrange
proxy = RagPipelineTaskProxy("tenant-123", "user-456", [])
# Act
proxy.delay()
# Assert
mock_logger.warning.assert_called_once_with(
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s", "tenant-123", "user-456"
)