improve: generalize vector factory classes and vector type (#5033)
This commit is contained in:
@@ -1,14 +1,20 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pymilvus import MilvusClient, MilvusException, connections
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,7 +61,7 @@ class MilvusVector(BaseVector):
|
||||
self._fields = []
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'milvus'
|
||||
return VectorType.MILVUS
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
index_params = {
|
||||
@@ -254,10 +260,36 @@ class MilvusVector(BaseVector):
|
||||
schema=schema, index_param=index_params,
|
||||
consistency_level=self._consistency_level)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _init_client(self, config) -> MilvusClient:
|
||||
if config.secure:
|
||||
uri = "https://" + str(config.host) + ":" + str(config.port)
|
||||
else:
|
||||
uri = "http://" + str(config.host) + ":" + str(config.port)
|
||||
client = MilvusClient(uri=uri, user=config.user, password=config.password,db_name=config.database)
|
||||
client = MilvusClient(uri=uri, user=config.user, password=config.password, db_name=config.database)
|
||||
return client
|
||||
|
||||
|
||||
class MilvusVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return MilvusVector(
|
||||
collection_name=collection_name,
|
||||
config=MilvusConfig(
|
||||
host=config.get('MILVUS_HOST'),
|
||||
port=config.get('MILVUS_PORT'),
|
||||
user=config.get('MILVUS_USER'),
|
||||
password=config.get('MILVUS_PASSWORD'),
|
||||
secure=config.get('MILVUS_SECURE'),
|
||||
database=config.get('MILVUS_DATABASE'),
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user