Feat/support multimodal embedding (#29115)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -12,10 +13,13 @@ from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from models.dataset import Dataset, Whitelist
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -203,6 +207,47 @@ class Vector:
|
||||
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
|
||||
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
|
||||
|
||||
def create_multimodal(self, file_documents: list | None = None, **kwargs):
|
||||
if file_documents:
|
||||
start = time.time()
|
||||
logger.info("start embedding %s files %s", len(file_documents), start)
|
||||
batch_size = 1000
|
||||
total_batches = len(file_documents) + batch_size - 1
|
||||
for i in range(0, len(file_documents), batch_size):
|
||||
batch = file_documents[i : i + batch_size]
|
||||
batch_start = time.time()
|
||||
logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch))
|
||||
|
||||
# Batch query all upload files to avoid N+1 queries
|
||||
attachment_ids = [doc.metadata["doc_id"] for doc in batch]
|
||||
stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||
upload_files = db.session.scalars(stmt).all()
|
||||
upload_file_map = {str(f.id): f for f in upload_files}
|
||||
|
||||
file_base64_list = []
|
||||
real_batch = []
|
||||
for document in batch:
|
||||
attachment_id = document.metadata["doc_id"]
|
||||
doc_type = document.metadata["doc_type"]
|
||||
upload_file = upload_file_map.get(attachment_id)
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
file_base64_str = base64.b64encode(blob).decode()
|
||||
file_base64_list.append(
|
||||
{
|
||||
"content": file_base64_str,
|
||||
"content_type": doc_type,
|
||||
"file_id": attachment_id,
|
||||
}
|
||||
)
|
||||
real_batch.append(document)
|
||||
batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list)
|
||||
logger.info(
|
||||
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
|
||||
)
|
||||
self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs)
|
||||
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
|
||||
|
||||
def add_texts(self, documents: list[Document], **kwargs):
|
||||
if kwargs.get("duplicate_check", False):
|
||||
documents = self._filter_duplicate_texts(documents)
|
||||
@@ -223,6 +268,22 @@ class Vector:
|
||||
query_vector = self._embeddings.embed_query(query)
|
||||
return self._vector_processor.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
|
||||
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
return []
|
||||
blob = storage.load_once(upload_file.key)
|
||||
file_base64_str = base64.b64encode(blob).decode()
|
||||
multimodal_vector = self._embeddings.embed_multimodal_query(
|
||||
{
|
||||
"content": file_base64_str,
|
||||
"content_type": DocType.IMAGE,
|
||||
"file_id": file_id,
|
||||
}
|
||||
)
|
||||
return self._vector_processor.search_by_vector(multimodal_vector, **kwargs)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._vector_processor.search_by_full_text(query, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user