chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -33,28 +33,29 @@ class RelytConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
if not values["host"]:
|
||||
raise ValueError("config RELYT_HOST is required")
|
||||
if not values['port']:
|
||||
if not values["port"]:
|
||||
raise ValueError("config RELYT_PORT is required")
|
||||
if not values['user']:
|
||||
if not values["user"]:
|
||||
raise ValueError("config RELYT_USER is required")
|
||||
if not values['password']:
|
||||
if not values["password"]:
|
||||
raise ValueError("config RELYT_PASSWORD is required")
|
||||
if not values['database']:
|
||||
if not values["database"]:
|
||||
raise ValueError("config RELYT_DATABASE is required")
|
||||
return values
|
||||
|
||||
|
||||
class RelytVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: RelytConfig, group_id: str):
|
||||
super().__init__(collection_name)
|
||||
self.embedding_dimension = 1536
|
||||
self._client_config = config
|
||||
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
self._url = (
|
||||
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
)
|
||||
self.client = create_engine(self._url)
|
||||
self._fields = []
|
||||
self._group_id = group_id
|
||||
@@ -70,9 +71,9 @@ class RelytVector(BaseVector):
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def create_collection(self, dimension: int):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
@@ -110,7 +111,7 @@ class RelytVector(BaseVector):
|
||||
ids = [str(uuid.uuid1()) for _ in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
for metadata in metadatas:
|
||||
metadata['group_id'] = self._group_id
|
||||
metadata["group_id"] = self._group_id
|
||||
texts = [d.page_content for d in documents]
|
||||
|
||||
# Define the table schema
|
||||
@@ -127,9 +128,7 @@ class RelytVector(BaseVector):
|
||||
chunks_table_data = []
|
||||
with self.client.connect() as conn:
|
||||
with conn.begin():
|
||||
for document, metadata, chunk_id, embedding in zip(
|
||||
texts, metadatas, ids, embeddings
|
||||
):
|
||||
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
|
||||
chunks_table_data.append(
|
||||
{
|
||||
"id": chunk_id,
|
||||
@@ -196,15 +195,13 @@ class RelytVector(BaseVector):
|
||||
return False
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self.delete_by_uuids(ids)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
with Session(self.client) as session:
|
||||
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
|
||||
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
|
||||
select_statement = sql_text(
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
|
||||
)
|
||||
@@ -228,38 +225,34 @@ class RelytVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
results = self.similarity_search_with_score_by_vector(
|
||||
k=int(kwargs.get('top_k')),
|
||||
embedding=query_vector,
|
||||
filter=kwargs.get('filter')
|
||||
k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter")
|
||||
)
|
||||
|
||||
# Organize results.
|
||||
docs = []
|
||||
for document, score in results:
|
||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
if 1 - score > score_threshold:
|
||||
docs.append(document)
|
||||
return docs
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: Optional[dict] = None,
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: Optional[dict] = None,
|
||||
) -> list[tuple[Document, float]]:
|
||||
# Add the filter if provided
|
||||
try:
|
||||
from sqlalchemy.engine import Row
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import Row from sqlalchemy.engine. "
|
||||
"Please 'pip install sqlalchemy>=1.4'."
|
||||
)
|
||||
raise ImportError("Could not import Row from sqlalchemy.engine. " "Please 'pip install sqlalchemy>=1.4'.")
|
||||
|
||||
filter_condition = ""
|
||||
if filter is not None:
|
||||
conditions = [
|
||||
f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1
|
||||
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
|
||||
if len(value) > 1
|
||||
else f"metadata->>{key!r} = {value[0]!r}"
|
||||
for key, value in filter.items()
|
||||
]
|
||||
@@ -305,13 +298,12 @@ class RelytVector(BaseVector):
|
||||
class RelytVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.RELYT, collection_name))
|
||||
|
||||
return RelytVector(
|
||||
collection_name=collection_name,
|
||||
@@ -322,5 +314,5 @@ class RelytVectorFactory(AbstractVectorFactory):
|
||||
password=dify_config.RELYT_PASSWORD,
|
||||
database=dify_config.RELYT_DATABASE,
|
||||
),
|
||||
group_id=dataset.id
|
||||
group_id=dataset.id,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user