Feat/support multimodal embedding (#29115)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jyong
2025-12-09 14:41:46 +08:00
committed by GitHub
parent 77cf8f6c27
commit 9affc546c6
78 changed files with 3230 additions and 713 deletions

View File

@@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@@ -95,7 +95,7 @@ class TestAddDocumentToIndexTask:
created_by=account.id,
indexing_status="completed",
enabled=True,
doc_form=IndexType.PARAGRAPH_INDEX,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db.session.add(document)
db.session.commit()
@@ -172,7 +172,9 @@ class TestAddDocumentToIndexTask:
# Assert: Verify the expected outcomes
# Verify index processor was called correctly
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify database state changes
@@ -204,7 +206,7 @@ class TestAddDocumentToIndexTask:
)
# Update document to use different index type
document.doc_form = IndexType.QA_INDEX
document.doc_form = IndexStructureType.QA_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
@@ -221,7 +223,9 @@ class TestAddDocumentToIndexTask:
add_document_to_index_task(document.id)
# Assert: Verify different index type handling
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.QA_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with correct parameters
@@ -360,7 +364,7 @@ class TestAddDocumentToIndexTask:
)
# Update document to use parent-child index type
document.doc_form = IndexType.PARENT_CHILD_INDEX
document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
@@ -391,7 +395,7 @@ class TestAddDocumentToIndexTask:
# Assert: Verify parent-child index processing
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexType.PARENT_CHILD_INDEX
IndexStructureType.PARENT_CHILD_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
@@ -465,8 +469,10 @@ class TestAddDocumentToIndexTask:
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify index processing occurred with all completed segments
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
# Assert: Verify index processing occurred but with empty documents list
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with all completed segments
@@ -532,7 +538,9 @@ class TestAddDocumentToIndexTask:
assert len(remaining_logs) == 0
# Verify index processing occurred normally
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify segments were enabled
@@ -699,7 +707,9 @@ class TestAddDocumentToIndexTask:
add_document_to_index_task(document.id)
# Assert: Verify only eligible segments were processed
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with correct parameters

View File

@@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from models import Account, Dataset, Document, DocumentSegment, Tenant
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
@@ -164,7 +164,7 @@ class TestDeleteSegmentFromIndexTask:
document.updated_at = fake.date_time_this_year()
document.doc_type = kwargs.get("doc_type", "text")
document.doc_metadata = kwargs.get("doc_metadata", {})
document.doc_form = kwargs.get("doc_form", IndexType.PARAGRAPH_INDEX)
document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX)
document.doc_language = kwargs.get("doc_language", "en")
db_session_with_containers.add(document)
@@ -244,8 +244,11 @@ class TestDeleteSegmentFromIndexTask:
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Extract segment IDs for the task
segment_ids = [segment.id for segment in segments]
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed successfully
assert result is None # Task should return None on success
@@ -279,7 +282,7 @@ class TestDeleteSegmentFromIndexTask:
index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)]
# Execute the task with non-existent dataset
result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id)
result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id, [])
# Verify the task completed without exceptions
assert result is None # Task should return None when dataset not found
@@ -305,7 +308,7 @@ class TestDeleteSegmentFromIndexTask:
index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)]
# Execute the task with non-existent document
result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id, [])
# Verify the task completed without exceptions
assert result is None # Task should return None when document not found
@@ -330,9 +333,10 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Execute the task with disabled document
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed without exceptions
assert result is None # Task should return None when document is disabled
@@ -357,9 +361,10 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Execute the task with archived document
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed without exceptions
assert result is None # Task should return None when document is archived
@@ -386,9 +391,10 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Execute the task with incomplete indexing
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed without exceptions
assert result is None # Task should return None when indexing is not completed
@@ -409,7 +415,11 @@ class TestDeleteSegmentFromIndexTask:
fake = Faker()
# Test different document forms
document_forms = [IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX, IndexType.PARENT_CHILD_INDEX]
document_forms = [
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.QA_INDEX,
IndexStructureType.PARENT_CHILD_INDEX,
]
for doc_form in document_forms:
# Create test data for each document form
@@ -420,13 +430,14 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Mock the index processor
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed successfully
assert result is None
@@ -469,6 +480,7 @@ class TestDeleteSegmentFromIndexTask:
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Mock the index processor to raise an exception
mock_processor = MagicMock()
@@ -476,7 +488,7 @@ class TestDeleteSegmentFromIndexTask:
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task - should not raise exception
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed without raising exceptions
assert result is None # Task should return None even when exceptions occur
@@ -518,7 +530,7 @@ class TestDeleteSegmentFromIndexTask:
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, [])
# Verify the task completed successfully
assert result is None
@@ -555,13 +567,14 @@ class TestDeleteSegmentFromIndexTask:
# Create large number of segments
segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake)
index_node_ids = [segment.index_node_id for segment in segments]
segment_ids = [segment.id for segment in segments]
# Mock the index processor
mock_processor = MagicMock()
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor
# Execute the task
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id)
result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids)
# Verify the task completed successfully
assert result is None

View File

@@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@@ -95,7 +95,7 @@ class TestEnableSegmentsToIndexTask:
created_by=account.id,
indexing_status="completed",
enabled=True,
doc_form=IndexType.PARAGRAPH_INDEX,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db.session.add(document)
db.session.commit()
@@ -166,7 +166,7 @@ class TestEnableSegmentsToIndexTask:
)
# Update document to use different index type
document.doc_form = IndexType.QA_INDEX
document.doc_form = IndexStructureType.QA_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
@@ -185,7 +185,9 @@ class TestEnableSegmentsToIndexTask:
enable_segments_to_index_task(segment_ids, dataset.id, document.id)
# Assert: Verify different index type handling
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.QA_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with correct parameters
@@ -328,7 +330,9 @@ class TestEnableSegmentsToIndexTask:
enable_segments_to_index_task(non_existent_segment_ids, dataset.id, document.id)
# Assert: Verify index processor was created but load was not called
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexStructureType.PARAGRAPH_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_enable_segments_to_index_with_parent_child_structure(
@@ -350,7 +354,7 @@ class TestEnableSegmentsToIndexTask:
)
# Update document to use parent-child index type
document.doc_form = IndexType.PARENT_CHILD_INDEX
document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
db.session.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
@@ -383,7 +387,7 @@ class TestEnableSegmentsToIndexTask:
# Assert: Verify parent-child index processing
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(
IndexType.PARENT_CHILD_INDEX
IndexStructureType.PARENT_CHILD_INDEX
)
mock_external_service_dependencies["index_processor"].load.assert_called_once()

View File

@@ -53,7 +53,7 @@ from sqlalchemy.exc import IntegrityError
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeConnectionError,
@@ -99,10 +99,10 @@ class TestCacheEmbeddingDocuments:
@pytest.fixture
def sample_embedding_result(self):
"""Create a sample TextEmbeddingResult for testing.
"""Create a sample EmbeddingResult for testing.
Returns:
TextEmbeddingResult: Mock embedding result with proper structure
EmbeddingResult: Mock embedding result with proper structure
"""
# Create normalized embedding vectors (dimension 1536 for ada-002)
embedding_vector = np.random.randn(1536)
@@ -118,7 +118,7 @@ class TestCacheEmbeddingDocuments:
latency=0.5,
)
return TextEmbeddingResult(
return EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_vector],
usage=usage,
@@ -197,7 +197,7 @@ class TestCacheEmbeddingDocuments:
latency=0.8,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -296,7 +296,7 @@ class TestCacheEmbeddingDocuments:
latency=0.6,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=new_embeddings,
usage=usage,
@@ -386,7 +386,7 @@ class TestCacheEmbeddingDocuments:
latency=0.5,
)
return TextEmbeddingResult(
return EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -449,7 +449,7 @@ class TestCacheEmbeddingDocuments:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[valid_vector.tolist(), nan_vector],
usage=usage,
@@ -629,7 +629,7 @@ class TestCacheEmbeddingQuery:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@@ -728,7 +728,7 @@ class TestCacheEmbeddingQuery:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[nan_vector],
usage=usage,
@@ -793,7 +793,7 @@ class TestCacheEmbeddingQuery:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@@ -873,13 +873,13 @@ class TestEmbeddingModelSwitching:
latency=0.3,
)
result_ada = TextEmbeddingResult(
result_ada = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_ada],
usage=usage,
)
result_3_small = TextEmbeddingResult(
result_3_small = EmbeddingResult(
model="text-embedding-3-small",
embeddings=[normalized_3_small],
usage=usage,
@@ -953,13 +953,13 @@ class TestEmbeddingModelSwitching:
latency=0.4,
)
result_openai = TextEmbeddingResult(
result_openai = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_openai],
usage=usage_openai,
)
result_cohere = TextEmbeddingResult(
result_cohere = EmbeddingResult(
model="embed-english-v3.0",
embeddings=[normalized_cohere],
usage=usage_cohere,
@@ -1042,7 +1042,7 @@ class TestEmbeddingDimensionValidation:
latency=0.7,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1095,7 +1095,7 @@ class TestEmbeddingDimensionValidation:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1148,7 +1148,7 @@ class TestEmbeddingDimensionValidation:
latency=0.3,
)
result_ada = TextEmbeddingResult(
result_ada = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized_ada],
usage=usage_ada,
@@ -1181,7 +1181,7 @@ class TestEmbeddingDimensionValidation:
latency=0.4,
)
result_cohere = TextEmbeddingResult(
result_cohere = EmbeddingResult(
model="embed-english-v3.0",
embeddings=[normalized_cohere],
usage=usage_cohere,
@@ -1279,7 +1279,7 @@ class TestEmbeddingEdgeCases:
latency=0.1,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@@ -1322,7 +1322,7 @@ class TestEmbeddingEdgeCases:
latency=1.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@@ -1370,7 +1370,7 @@ class TestEmbeddingEdgeCases:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1422,7 +1422,7 @@ class TestEmbeddingEdgeCases:
latency=0.2,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1478,7 +1478,7 @@ class TestEmbeddingEdgeCases:
)
# Model returns embeddings for all texts
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1546,7 +1546,7 @@ class TestEmbeddingEdgeCases:
latency=0.8,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1603,7 +1603,7 @@ class TestEmbeddingEdgeCases:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@@ -1657,7 +1657,7 @@ class TestEmbeddingEdgeCases:
latency=0.5,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1757,7 +1757,7 @@ class TestEmbeddingCachePerformance:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,
@@ -1826,7 +1826,7 @@ class TestEmbeddingCachePerformance:
latency=0.5,
)
return TextEmbeddingResult(
return EmbeddingResult(
model="text-embedding-ada-002",
embeddings=embeddings,
usage=usage,
@@ -1888,7 +1888,7 @@ class TestEmbeddingCachePerformance:
latency=0.3,
)
embedding_result = TextEmbeddingResult(
embedding_result = EmbeddingResult(
model="text-embedding-ada-002",
embeddings=[normalized],
usage=usage,

View File

@@ -62,7 +62,7 @@ from core.indexing_runner import (
IndexingRunner,
)
from core.model_runtime.entities.model_entities import ModelType
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import ChildDocument, Document
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, DatasetProcessRule
@@ -112,7 +112,7 @@ def create_mock_dataset_document(
document_id: str | None = None,
dataset_id: str | None = None,
tenant_id: str | None = None,
doc_form: str = IndexType.PARAGRAPH_INDEX,
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
data_source_type: str = "upload_file",
doc_language: str = "English",
) -> Mock:
@@ -133,8 +133,8 @@ def create_mock_dataset_document(
Mock: A configured mock DatasetDocument object with all required attributes.
Example:
>>> doc = create_mock_dataset_document(doc_form=IndexType.QA_INDEX)
>>> assert doc.doc_form == IndexType.QA_INDEX
>>> doc = create_mock_dataset_document(doc_form=IndexStructureType.QA_INDEX)
>>> assert doc.doc_form == IndexStructureType.QA_INDEX
"""
doc = Mock(spec=DatasetDocument)
doc.id = document_id or str(uuid.uuid4())
@@ -276,7 +276,7 @@ class TestIndexingRunnerExtract:
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.tenant_id = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
doc.data_source_type = "upload_file"
doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())}
return doc
@@ -616,7 +616,7 @@ class TestIndexingRunnerLoad:
doc = Mock(spec=DatasetDocument)
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
return doc
@pytest.fixture
@@ -700,7 +700,7 @@ class TestIndexingRunnerLoad:
"""Test loading with parent-child index structure."""
# Arrange
runner = IndexingRunner()
sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
sample_dataset.indexing_technique = "high_quality"
# Add child documents
@@ -775,7 +775,7 @@ class TestIndexingRunnerRun:
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.tenant_id = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
doc.doc_language = "English"
doc.data_source_type = "upload_file"
doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())}
@@ -802,6 +802,21 @@ class TestIndexingRunnerRun:
mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}}
mock_dependencies["db"].session.scalar.return_value = mock_process_rule
# Mock current_user (Account) for _transform
mock_current_user = MagicMock()
mock_current_user.set_tenant_id = MagicMock()
# Setup db.session.query to return different results based on the model
def mock_query_side_effect(model):
mock_query_result = MagicMock()
if model.__name__ == "Dataset":
mock_query_result.filter_by.return_value.first.return_value = mock_dataset
elif model.__name__ == "Account":
mock_query_result.filter_by.return_value.first.return_value = mock_current_user
return mock_query_result
mock_dependencies["db"].session.query.side_effect = mock_query_side_effect
# Mock processor
mock_processor = MagicMock()
mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor
@@ -1268,7 +1283,7 @@ class TestIndexingRunnerLoadSegments:
doc.id = str(uuid.uuid4())
doc.dataset_id = str(uuid.uuid4())
doc.created_by = str(uuid.uuid4())
doc.doc_form = IndexType.PARAGRAPH_INDEX
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
return doc
@pytest.fixture
@@ -1316,7 +1331,7 @@ class TestIndexingRunnerLoadSegments:
"""Test loading segments for parent-child index."""
# Arrange
runner = IndexingRunner()
sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX
sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
# Add child documents
for doc in sample_documents:
@@ -1413,7 +1428,7 @@ class TestIndexingRunnerEstimate:
tenant_id=tenant_id,
extract_settings=extract_settings,
tmp_processing_rule={"mode": "automatic", "rules": {}},
doc_form=IndexType.PARAGRAPH_INDEX,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)

View File

@@ -26,6 +26,18 @@ from core.rag.rerank.rerank_type import RerankMode
from core.rag.rerank.weight_rerank import WeightRerankRunner
def create_mock_model_instance():
"""Create a properly configured mock ModelInstance for reranking tests."""
mock_instance = Mock(spec=ModelInstance)
# Setup provider_model_bundle chain for check_model_support_vision
mock_instance.provider_model_bundle = Mock()
mock_instance.provider_model_bundle.configuration = Mock()
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
mock_instance.provider = "test-provider"
mock_instance.model = "test-model"
return mock_instance
class TestRerankModelRunner:
"""Unit tests for RerankModelRunner.
@@ -37,10 +49,23 @@ class TestRerankModelRunner:
- Metadata preservation and score injection
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
@pytest.fixture
def mock_model_instance(self):
"""Create a mock ModelInstance for reranking."""
mock_instance = Mock(spec=ModelInstance)
# Setup provider_model_bundle chain for check_model_support_vision
mock_instance.provider_model_bundle = Mock()
mock_instance.provider_model_bundle.configuration = Mock()
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
mock_instance.provider = "test-provider"
mock_instance.model = "test-model"
return mock_instance
@pytest.fixture
@@ -803,7 +828,7 @@ class TestRerankRunnerFactory:
- Parameters are forwarded to runner constructor
"""
# Arrange: Mock model instance
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
# Act: Create runner via factory
runner = RerankRunnerFactory.create_rerank_runner(
@@ -865,7 +890,7 @@ class TestRerankRunnerFactory:
- String values are properly matched
"""
# Arrange: Mock model instance
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
# Act: Create runner using enum value
runner = RerankRunnerFactory.create_rerank_runner(
@@ -886,6 +911,13 @@ class TestRerankIntegration:
- Real-world usage scenarios
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_model_reranking_full_workflow(self):
"""Test complete model-based reranking workflow.
@@ -895,7 +927,7 @@ class TestRerankIntegration:
- Top results are returned correctly
"""
# Arrange: Create mock model and documents
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@@ -951,7 +983,7 @@ class TestRerankIntegration:
- Normalization is consistent
"""
# Arrange: Create mock model with various scores
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@@ -991,6 +1023,13 @@ class TestRerankEdgeCases:
- Concurrent reranking scenarios
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_rerank_with_empty_metadata(self):
"""Test reranking when documents have empty metadata.
@@ -1000,7 +1039,7 @@ class TestRerankEdgeCases:
- Empty metadata documents are processed correctly
"""
# Arrange: Create documents with empty metadata
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@@ -1046,7 +1085,7 @@ class TestRerankEdgeCases:
- Score comparison logic works at boundary
"""
# Arrange: Create mock with various scores including negatives
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@@ -1082,7 +1121,7 @@ class TestRerankEdgeCases:
- No overflow or precision issues
"""
# Arrange: All documents with perfect scores
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@@ -1117,7 +1156,7 @@ class TestRerankEdgeCases:
- Content encoding is preserved
"""
# Arrange: Documents with special characters
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@@ -1159,7 +1198,7 @@ class TestRerankEdgeCases:
- Content is not truncated unexpectedly
"""
# Arrange: Documents with very long content
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
long_content = "This is a very long document. " * 1000 # ~30,000 characters
mock_rerank_result = RerankResult(
@@ -1196,7 +1235,7 @@ class TestRerankEdgeCases:
- All documents are processed correctly
"""
# Arrange: Create 100 documents
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
num_docs = 100
# Create rerank results for all documents
@@ -1287,7 +1326,7 @@ class TestRerankEdgeCases:
- Documents can still be ranked
"""
# Arrange: Empty query
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[
@@ -1325,6 +1364,13 @@ class TestRerankPerformance:
- Score calculation optimization
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_rerank_batch_processing(self):
"""Test that documents are processed in a single batch.
@@ -1334,7 +1380,7 @@ class TestRerankPerformance:
- Efficient batch processing
"""
# Arrange: Multiple documents
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[RerankDocument(index=i, text=f"Doc {i}", score=0.9 - i * 0.1) for i in range(5)],
@@ -1435,6 +1481,13 @@ class TestRerankErrorHandling:
- Error propagation
"""
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
def test_rerank_model_invocation_error(self):
"""Test handling of model invocation errors.
@@ -1444,7 +1497,7 @@ class TestRerankErrorHandling:
- Error context is preserved
"""
# Arrange: Mock model that raises exception
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_model_instance.invoke_rerank.side_effect = RuntimeError("Model invocation failed")
documents = [
@@ -1470,7 +1523,7 @@ class TestRerankErrorHandling:
- Invalid results don't corrupt output
"""
# Arrange: Rerank result with invalid index
mock_model_instance = Mock(spec=ModelInstance)
mock_model_instance = create_mock_model_instance()
mock_rerank_result = RerankResult(
model="bge-reranker-base",
docs=[

View File

@@ -425,15 +425,15 @@ class TestRetrievalService:
# ==================== Vector Search Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_basic(self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents):
def test_vector_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test basic vector/semantic search functionality.
This test validates the core vector search flow:
1. Dataset is retrieved from database
2. embedding_search is called via ThreadPoolExecutor
2. _retrieve is called via ThreadPoolExecutor
3. Documents are added to shared all_documents list
4. Results are returned to caller
@@ -447,28 +447,28 @@ class TestRetrievalService:
# Set up the mock dataset that will be "retrieved" from database
mock_get_dataset.return_value = mock_dataset
# Create a side effect function that simulates embedding_search behavior
# In the real implementation, embedding_search:
# 1. Gets the dataset
# 2. Creates a Vector instance
# 3. Calls search_by_vector with embeddings
# 4. Extends all_documents with results
def side_effect_embedding_search(
# Create a side effect function that simulates _retrieve behavior
# _retrieve modifies the all_documents list in place
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
"""Simulate embedding_search adding documents to the shared list."""
all_documents.extend(sample_documents)
"""Simulate _retrieve adding documents to the shared list."""
if all_documents is not None:
all_documents.extend(sample_documents)
mock_embedding_search.side_effect = side_effect_embedding_search
mock_retrieve.side_effect = side_effect_retrieve
# Define test parameters
query = "What is Python?" # Natural language query
@@ -481,7 +481,7 @@ class TestRetrievalService:
# 1. Check if query is empty (early return if so)
# 2. Get the dataset using _get_dataset
# 3. Create ThreadPoolExecutor
# 4. Submit embedding_search task
# 4. Submit _retrieve task
# 5. Wait for completion
# 6. Return all_documents list
results = RetrievalService.retrieve(
@@ -502,15 +502,13 @@ class TestRetrievalService:
# Verify documents maintain their scores (highest score first in sample_documents)
assert results[0].metadata["score"] == 0.95, "First document should have highest score from sample_documents"
# Verify embedding_search was called exactly once
# Verify _retrieve was called exactly once
# This confirms the search method was invoked by ThreadPoolExecutor
mock_embedding_search.assert_called_once()
mock_retrieve.assert_called_once()
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_with_document_filter(
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
):
def test_vector_search_with_document_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test vector search with document ID filtering.
@@ -522,21 +520,25 @@ class TestRetrievalService:
mock_get_dataset.return_value = mock_dataset
filtered_docs = [sample_documents[0]]
def side_effect_embedding_search(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.extend(filtered_docs)
if all_documents is not None:
all_documents.extend(filtered_docs)
mock_embedding_search.side_effect = side_effect_embedding_search
mock_retrieve.side_effect = side_effect_retrieve
document_ids_filter = [sample_documents[0].metadata["document_id"]]
# Act
@@ -552,12 +554,12 @@ class TestRetrievalService:
assert len(results) == 1
assert results[0].metadata["doc_id"] == "doc1"
# Verify document_ids_filter was passed
call_kwargs = mock_embedding_search.call_args.kwargs
call_kwargs = mock_retrieve.call_args.kwargs
assert call_kwargs["document_ids_filter"] == document_ids_filter
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_empty_results(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_vector_search_empty_results(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test vector search when no results match the query.
@@ -567,8 +569,8 @@ class TestRetrievalService:
"""
# Arrange
mock_get_dataset.return_value = mock_dataset
# embedding_search doesn't add anything to all_documents
mock_embedding_search.side_effect = lambda *args, **kwargs: None
# _retrieve doesn't add anything to all_documents
mock_retrieve.side_effect = lambda *args, **kwargs: None
# Act
results = RetrievalService.retrieve(
@@ -583,9 +585,9 @@ class TestRetrievalService:
# ==================== Keyword Search Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_keyword_search_basic(self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents):
def test_keyword_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test basic keyword search functionality.
@@ -597,12 +599,25 @@ class TestRetrievalService:
# Arrange
mock_get_dataset.return_value = mock_dataset
def side_effect_keyword_search(
flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None
def side_effect_retrieve(
flask_app,
retrieval_method,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.extend(sample_documents)
if all_documents is not None:
all_documents.extend(sample_documents)
mock_keyword_search.side_effect = side_effect_keyword_search
mock_retrieve.side_effect = side_effect_retrieve
query = "Python programming"
top_k = 3
@@ -618,7 +633,7 @@ class TestRetrievalService:
# Assert
assert len(results) == 3
assert all(isinstance(doc, Document) for doc in results)
mock_keyword_search.assert_called_once()
mock_retrieve.assert_called_once()
@patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
@@ -1147,11 +1162,9 @@ class TestRetrievalService:
# ==================== Metadata Filtering Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_with_metadata_filter(
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
):
def test_vector_search_with_metadata_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test vector search with metadata-based document filtering.
@@ -1166,21 +1179,25 @@ class TestRetrievalService:
filtered_doc = sample_documents[0]
filtered_doc.metadata["category"] = "programming"
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.append(filtered_doc)
if all_documents is not None:
all_documents.append(filtered_doc)
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
# Act
results = RetrievalService.retrieve(
@@ -1243,9 +1260,9 @@ class TestRetrievalService:
# Assert
assert results == []
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test that exceptions during retrieval are properly handled.
@@ -1256,22 +1273,26 @@ class TestRetrievalService:
# Arrange
mock_get_dataset.return_value = mock_dataset
# Make embedding_search add an exception to the exceptions list
# Make _retrieve add an exception to the exceptions list
def side_effect_with_exception(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
exceptions.append("Search failed")
if exceptions is not None:
exceptions.append("Search failed")
mock_embedding_search.side_effect = side_effect_with_exception
mock_retrieve.side_effect = side_effect_with_exception
# Act & Assert
with pytest.raises(ValueError) as exc_info:
@@ -1286,9 +1307,9 @@ class TestRetrievalService:
# ==================== Score Threshold Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test vector search with score threshold filtering.
@@ -1306,21 +1327,25 @@ class TestRetrievalService:
provider="dify",
)
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
all_documents.append(high_score_doc)
if all_documents is not None:
all_documents.append(high_score_doc)
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
score_threshold = 0.8
@@ -1339,9 +1364,9 @@ class TestRetrievalService:
# ==================== Top-K Limiting Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_embedding_search, mock_dataset):
def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_retrieve, mock_dataset):
"""
Test that retrieval respects top_k parameter.
@@ -1362,22 +1387,26 @@ class TestRetrievalService:
for i in range(10)
]
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
# Return only top_k documents
all_documents.extend(many_docs[:top_k])
if all_documents is not None:
all_documents.extend(many_docs[:top_k])
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
top_k = 3
@@ -1390,9 +1419,9 @@ class TestRetrievalService:
)
# Assert
# Verify top_k was passed to embedding_search
assert mock_embedding_search.called
call_kwargs = mock_embedding_search.call_args.kwargs
# Verify _retrieve was called
assert mock_retrieve.called
call_kwargs = mock_retrieve.call_args.kwargs
assert call_kwargs["top_k"] == top_k
# Verify we got the right number of results
assert len(results) == top_k
@@ -1421,11 +1450,9 @@ class TestRetrievalService:
# ==================== Reranking Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_semantic_search_with_reranking(
self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents
):
def test_semantic_search_with_reranking(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents):
"""
Test semantic search with reranking model.
@@ -1439,22 +1466,26 @@ class TestRetrievalService:
# Simulate reranking changing order
reranked_docs = list(reversed(sample_documents))
def side_effect_embedding(
def side_effect_retrieve(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
dataset,
query=None,
top_k=4,
score_threshold=None,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
attachment_id=None,
all_documents=None,
exceptions=None,
):
# embedding_search handles reranking internally
all_documents.extend(reranked_docs)
# _retrieve handles reranking internally
if all_documents is not None:
all_documents.extend(reranked_docs)
mock_embedding_search.side_effect = side_effect_embedding
mock_retrieve.side_effect = side_effect_retrieve
reranking_model = {
"reranking_provider_name": "cohere",
@@ -1473,7 +1504,7 @@ class TestRetrievalService:
# Assert
# For semantic search with reranking, reranking_model should be passed
assert len(results) == 3
call_kwargs = mock_embedding_search.call_args.kwargs
call_kwargs = mock_retrieve.call_args.kwargs
assert call_kwargs["reranking_model"] == reranking_model

View File

@@ -8,7 +8,9 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols
[
("...Hello, World!", "Hello, World!"),
("。测试中文标点", "测试中文标点"),
("!@#Test symbols", "Test symbols"),
# Note: ! is not in the removal pattern, only @# are removed, leaving "!Test symbols"
# The pattern intentionally excludes ! as per #11868 fix
("@#Test symbols", "Test symbols"),
("Hello, World!", "Hello, World!"),
("", ""),
(" ", " "),