improve: generalize vector factory classes and vector type (#5033)

This commit is contained in:
Bowen Liang
2024-06-08 22:29:24 +08:00
committed by GitHub
parent 3b62ab564a
commit bdad993901
12 changed files with 343 additions and 233 deletions

View File

@@ -1,3 +1,4 @@
import json
import os
import uuid
from collections.abc import Generator, Iterable, Sequence
@@ -5,6 +6,7 @@ from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
import qdrant_client
from flask import current_app
from pydantic import BaseModel
from qdrant_client.http import models as rest
from qdrant_client.http.models import (
@@ -17,10 +19,15 @@ from qdrant_client.http.models import (
)
from qdrant_client.local.qdrant_local import QdrantLocal
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetCollectionBinding
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
@@ -69,7 +76,7 @@ class QdrantVector(BaseVector):
self._group_id = group_id
def get_type(self) -> str:
return 'qdrant'
return VectorType.QDRANT
def to_index_struct(self) -> dict:
return {
@@ -408,3 +415,40 @@ class QdrantVector(BaseVector):
page_content=scored_point.payload.get(content_payload_key),
metadata=scored_point.payload.get(metadata_payload_key) or {},
)
class QdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
if not dataset.index_struct_dict:
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
config = current_app.config
return QdrantVector(
collection_name=collection_name,
group_id=dataset.id,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=config.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
grpc_port=config.get('QDRANT_GRPC_PORT'),
prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
)
)