chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -22,15 +22,14 @@ class WeaviateConfig(BaseModel):
api_key: Optional[str] = None
batch_size: int = 100
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
if not values["endpoint"]:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVector(BaseVector):
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
super().__init__(collection_name)
self._client = self._init_client(config)
@@ -43,10 +42,7 @@ class WeaviateVector(BaseVector):
try:
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
@@ -68,10 +64,10 @@ class WeaviateVector(BaseVector):
def get_collection_name(self, dataset: Dataset) -> str:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
if not class_prefix.endswith("_Node"):
# original class_prefix
class_prefix += '_Node'
class_prefix += "_Node"
return class_prefix
@@ -79,10 +75,7 @@ class WeaviateVector(BaseVector):
return Dataset.gen_collection_name_by_id(dataset_id)
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
# create collection
@@ -91,9 +84,9 @@ class WeaviateVector(BaseVector):
self.add_texts(texts, embeddings)
def _create_collection(self):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
schema = self._default_schema(self._collection_name)
@@ -129,17 +122,9 @@ class WeaviateVector(BaseVector):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
where_filter = {
"operator": "Equal",
"path": [key],
"valueText": value
}
where_filter = {"operator": "Equal", "path": [key], "valueText": value}
self._client.batch.delete_objects(
class_name=self._collection_name,
where=where_filter,
output='minimal'
)
self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal")
def delete(self):
# check whether the index already exists
@@ -154,11 +139,19 @@ class WeaviateVector(BaseVector):
# check whether the index already exists
if not self._client.schema.contains(schema):
return False
result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": id,
}).with_limit(1).do()
result = (
self._client.query.get(collection_name)
.with_additional(["id"])
.with_where(
{
"path": ["doc_id"],
"operator": "Equal",
"valueText": id,
}
)
.with_limit(1)
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
@@ -211,13 +204,13 @@ class WeaviateVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
# check score threshold
if score > score_threshold:
doc.metadata['score'] = score
doc.metadata["score"] = score
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -240,15 +233,15 @@ class WeaviateVector(BaseVector):
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
query_obj = query_obj.with_additional(["vector"])
properties = ['text']
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
properties = ["text"]
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
additional = res.pop('_additional')
docs.append(Document(page_content=text, vector=additional['vector'], metadata=res))
additional = res.pop("_additional")
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
return docs
def _default_schema(self, index_name: str) -> dict:
@@ -271,20 +264,19 @@ class WeaviateVector(BaseVector):
class WeaviateVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
endpoint=dify_config.WEAVIATE_ENDPOINT,
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
),
attributes=attributes
attributes=attributes,
)