chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -33,28 +33,29 @@ class RelytConfig(BaseModel):
password: str
database: str
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values["host"]:
raise ValueError("config RELYT_HOST is required")
if not values['port']:
if not values["port"]:
raise ValueError("config RELYT_PORT is required")
if not values['user']:
if not values["user"]:
raise ValueError("config RELYT_USER is required")
if not values['password']:
if not values["password"]:
raise ValueError("config RELYT_PASSWORD is required")
if not values['database']:
if not values["database"]:
raise ValueError("config RELYT_DATABASE is required")
return values
class RelytVector(BaseVector):
def __init__(self, collection_name: str, config: RelytConfig, group_id: str):
super().__init__(collection_name)
self.embedding_dimension = 1536
self._client_config = config
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
self._url = (
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self.client = create_engine(self._url)
self._fields = []
self._group_id = group_id
@@ -70,9 +71,9 @@ class RelytVector(BaseVector):
self.add_texts(texts, embeddings)
def create_collection(self, dimension: int):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
@@ -110,7 +111,7 @@ class RelytVector(BaseVector):
ids = [str(uuid.uuid1()) for _ in documents]
metadatas = [d.metadata for d in documents]
for metadata in metadatas:
metadata['group_id'] = self._group_id
metadata["group_id"] = self._group_id
texts = [d.page_content for d in documents]
# Define the table schema
@@ -127,9 +128,7 @@ class RelytVector(BaseVector):
chunks_table_data = []
with self.client.connect() as conn:
with conn.begin():
for document, metadata, chunk_id, embedding in zip(
texts, metadatas, ids, embeddings
):
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
chunks_table_data.append(
{
"id": chunk_id,
@@ -196,15 +195,13 @@ class RelytVector(BaseVector):
return False
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self.delete_by_uuids(ids)
def delete_by_ids(self, ids: list[str]) -> None:
with Session(self.client) as session:
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text(
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
)
@@ -228,38 +225,34 @@ class RelytVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
results = self.similarity_search_with_score_by_vector(
k=int(kwargs.get('top_k')),
embedding=query_vector,
filter=kwargs.get('filter')
k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter")
)
# Organize results.
docs = []
for document, score in results:
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
if 1 - score > score_threshold:
docs.append(document)
return docs
def similarity_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: Optional[dict] = None,
self,
embedding: list[float],
k: int = 4,
filter: Optional[dict] = None,
) -> list[tuple[Document, float]]:
# Add the filter if provided
try:
from sqlalchemy.engine import Row
except ImportError:
raise ImportError(
"Could not import Row from sqlalchemy.engine. "
"Please 'pip install sqlalchemy>=1.4'."
)
raise ImportError("Could not import Row from sqlalchemy.engine. " "Please 'pip install sqlalchemy>=1.4'.")
filter_condition = ""
if filter is not None:
conditions = [
f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
if len(value) > 1
else f"metadata->>{key!r} = {value[0]!r}"
for key, value in filter.items()
]
@@ -305,13 +298,12 @@ class RelytVector(BaseVector):
class RelytVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.RELYT, collection_name))
return RelytVector(
collection_name=collection_name,
@@ -322,5 +314,5 @@ class RelytVectorFactory(AbstractVectorFactory):
password=dify_config.RELYT_PASSWORD,
database=dify_config.RELYT_DATABASE,
),
group_id=dataset.id
group_id=dataset.id,
)