improve: generalize vector factory classes and vector type (#5033)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
@@ -5,6 +6,7 @@ from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
import qdrant_client
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client.http import models as rest
|
||||
from qdrant_client.http.models import (
|
||||
@@ -17,10 +19,15 @@ from qdrant_client.http.models import (
|
||||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
@@ -69,7 +76,7 @@ class QdrantVector(BaseVector):
|
||||
self._group_id = group_id
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'qdrant'
|
||||
return VectorType.QDRANT
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
@@ -408,3 +415,40 @@ class QdrantVector(BaseVector):
|
||||
page_content=scored_point.payload.get(content_payload_key),
|
||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||
)
|
||||
|
||||
|
||||
class QdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
||||
if not dataset.index_struct_dict:
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return QdrantVector(
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=QdrantConfig(
|
||||
endpoint=config.get('QDRANT_URL'),
|
||||
api_key=config.get('QDRANT_API_KEY'),
|
||||
root_path=config.root_path,
|
||||
timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
|
||||
grpc_port=config.get('QDRANT_GRPC_PORT'),
|
||||
prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user