improve: generalize vector factory classes and vector type (#5033)

This commit is contained in:
Bowen Liang
2024-06-08 22:29:24 +08:00
committed by GitHub
parent 3b62ab564a
commit bdad993901
12 changed files with 343 additions and 233 deletions

View File

@@ -1,14 +1,20 @@
import json
import logging
from typing import Any, Optional
from uuid import uuid4
from flask import current_app
from pydantic import BaseModel, root_validator
from pymilvus import MilvusClient, MilvusException, connections
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
@@ -55,7 +61,7 @@ class MilvusVector(BaseVector):
self._fields = []
def get_type(self) -> str:
return 'milvus'
return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {
@@ -254,10 +260,36 @@ class MilvusVector(BaseVector):
schema=schema, index_param=index_params,
consistency_level=self._consistency_level)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
if config.secure:
uri = "https://" + str(config.host) + ":" + str(config.port)
else:
uri = "http://" + str(config.host) + ":" + str(config.port)
client = MilvusClient(uri=uri, user=config.user, password=config.password,db_name=config.database)
client = MilvusClient(uri=uri, user=config.user, password=config.password, db_name=config.database)
return client
class MilvusVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
config = current_app.config
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
database=config.get('MILVUS_DATABASE'),
)
)