improve: generalize vector factory classes and vector type (#5033)
This commit is contained in:
@@ -1,12 +1,19 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, root_validator
|
||||
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from models.dataset import Dataset
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
except ImportError:
|
||||
@@ -53,7 +60,7 @@ class RelytVector(BaseVector):
|
||||
self._group_id = group_id
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'relyt'
|
||||
return VectorType.RELYT
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
index_params = {}
|
||||
@@ -240,10 +247,10 @@ class RelytVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: Optional[dict] = None,
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: Optional[dict] = None,
|
||||
) -> list[tuple[Document, float]]:
|
||||
# Add the filter if provided
|
||||
try:
|
||||
@@ -298,3 +305,28 @@ class RelytVector(BaseVector):
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# milvus/zilliz/relyt doesn't support bm25 search
|
||||
return []
|
||||
|
||||
|
||||
class RelytVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return RelytVector(
|
||||
collection_name=collection_name,
|
||||
config=RelytConfig(
|
||||
host=config.get('RELYT_HOST'),
|
||||
port=config.get('RELYT_PORT'),
|
||||
user=config.get('RELYT_USER'),
|
||||
password=config.get('RELYT_PASSWORD'),
|
||||
database=config.get('RELYT_DATABASE'),
|
||||
),
|
||||
group_id=dataset.id
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user