Feat/support multimodal embedding (#29115)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
@@ -5,6 +6,7 @@ from typing import Any
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.model_runtime.entities import LLMMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@@ -32,6 +34,7 @@ class HitTestingService:
|
||||
account: Account,
|
||||
retrieval_model: Any, # FIXME drop this any
|
||||
external_retrieval_model: dict,
|
||||
attachment_ids: list | None = None,
|
||||
limit: int = 10,
|
||||
):
|
||||
start = time.perf_counter()
|
||||
@@ -41,7 +44,7 @@ class HitTestingService:
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
document_ids_filter = None
|
||||
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
|
||||
if metadata_filtering_conditions:
|
||||
if metadata_filtering_conditions and query:
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
from core.app.app_config.entities import MetadataFilteringCondition
|
||||
@@ -66,6 +69,7 @@ class HitTestingService:
|
||||
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
attachment_ids=attachment_ids,
|
||||
top_k=retrieval_model.get("top_k", 4),
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
@@ -80,17 +84,24 @@ class HitTestingService:
|
||||
|
||||
end = time.perf_counter()
|
||||
logger.debug("Hit testing retrieve in %s seconds", end - start)
|
||||
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id,
|
||||
content=query,
|
||||
source="hit_testing",
|
||||
source_app_id=None,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
dataset_queries = []
|
||||
if query:
|
||||
content = {"content_type": QueryType.TEXT_QUERY, "content": query}
|
||||
dataset_queries.append(content)
|
||||
if attachment_ids:
|
||||
for attachment_id in attachment_ids:
|
||||
content = {"content_type": QueryType.IMAGE_QUERY, "content": attachment_id}
|
||||
dataset_queries.append(content)
|
||||
if dataset_queries:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id,
|
||||
content=json.dumps(dataset_queries),
|
||||
source="hit_testing",
|
||||
source_app_id=None,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
return cls.compact_retrieve_response(query, all_documents)
|
||||
@@ -168,9 +179,14 @@ class HitTestingService:
|
||||
@classmethod
|
||||
def hit_testing_args_check(cls, args):
|
||||
query = args["query"]
|
||||
attachment_ids = args["attachment_ids"]
|
||||
|
||||
if not query or len(query) > 250:
|
||||
raise ValueError("Query is required and cannot exceed 250 characters")
|
||||
if not attachment_ids and not query:
|
||||
raise ValueError("Query or attachment_ids is required")
|
||||
if query and len(query) > 250:
|
||||
raise ValueError("Query cannot exceed 250 characters")
|
||||
if attachment_ids and not isinstance(attachment_ids, list):
|
||||
raise ValueError("Attachment_ids must be a list")
|
||||
|
||||
@staticmethod
|
||||
def escape_query_for_search(query: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user