improve: generalize vector factory classes and vector type (#5033)

This commit is contained in:
Bowen Liang
2024-06-08 22:29:24 +08:00
committed by GitHub
parent 3b62ab564a
commit bdad993901
12 changed files with 343 additions and 233 deletions

View File

@@ -1,12 +1,19 @@
import json
import uuid
from typing import Any, Optional
from flask import current_app
from pydantic import BaseModel, root_validator
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from models.dataset import Dataset
try:
from sqlalchemy.orm import declarative_base
except ImportError:
@@ -53,7 +60,7 @@ class RelytVector(BaseVector):
self._group_id = group_id
def get_type(self) -> str:
return 'relyt'
return VectorType.RELYT
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {}
@@ -240,10 +247,10 @@ class RelytVector(BaseVector):
return docs
def similarity_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: Optional[dict] = None,
self,
embedding: list[float],
k: int = 4,
filter: Optional[dict] = None,
) -> list[tuple[Document, float]]:
# Add the filter if provided
try:
@@ -298,3 +305,28 @@ class RelytVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz/relyt doesn't support bm25 search
return []
class RelytVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
config = current_app.config
return RelytVector(
collection_name=collection_name,
config=RelytConfig(
host=config.get('RELYT_HOST'),
port=config.get('RELYT_PORT'),
user=config.get('RELYT_USER'),
password=config.get('RELYT_PASSWORD'),
database=config.get('RELYT_DATABASE'),
),
group_id=dataset.id
)