improve: generalize vector factory classes and vector type (#5033)
This commit is contained in:
@@ -1,12 +1,17 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
import weaviate
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
@@ -59,7 +64,7 @@ class WeaviateVector(BaseVector):
|
||||
return client
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'weaviate'
|
||||
return VectorType.WEAVIATE
|
||||
|
||||
def get_collection_name(self, dataset: Dataset) -> str:
|
||||
if dataset.index_struct_dict:
|
||||
@@ -255,3 +260,25 @@ class WeaviateVector(BaseVector):
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value.isoformat()
|
||||
return value
|
||||
|
||||
|
||||
class WeaviateVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||
|
||||
return WeaviateVector(
|
||||
collection_name=collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint=current_app.config.get('WEAVIATE_ENDPOINT'),
|
||||
api_key=current_app.config.get('WEAVIATE_API_KEY'),
|
||||
batch_size=int(current_app.config.get('WEAVIATE_BATCH_SIZE'))
|
||||
),
|
||||
attributes=attributes
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user