chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -22,15 +22,14 @@ class WeaviateConfig(BaseModel):
|
||||
api_key: Optional[str] = None
|
||||
batch_size: int = 100
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
if not values["endpoint"]:
|
||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||
return values
|
||||
|
||||
|
||||
class WeaviateVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
|
||||
super().__init__(collection_name)
|
||||
self._client = self._init_client(config)
|
||||
@@ -43,10 +42,7 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
try:
|
||||
client = weaviate.Client(
|
||||
url=config.endpoint,
|
||||
auth_client_secret=auth_config,
|
||||
timeout_config=(5, 60),
|
||||
startup_period=None
|
||||
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise ConnectionError("Vector database connection error")
|
||||
@@ -68,10 +64,10 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
def get_collection_name(self, dataset: Dataset) -> str:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
if not class_prefix.endswith("_Node"):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
class_prefix += "_Node"
|
||||
|
||||
return class_prefix
|
||||
|
||||
@@ -79,10 +75,7 @@ class WeaviateVector(BaseVector):
|
||||
return Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name}
|
||||
}
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
# create collection
|
||||
@@ -91,9 +84,9 @@ class WeaviateVector(BaseVector):
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def _create_collection(self):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
schema = self._default_schema(self._collection_name)
|
||||
@@ -129,17 +122,9 @@ class WeaviateVector(BaseVector):
|
||||
# check whether the index already exists
|
||||
schema = self._default_schema(self._collection_name)
|
||||
if self._client.schema.contains(schema):
|
||||
where_filter = {
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
}
|
||||
where_filter = {"operator": "Equal", "path": [key], "valueText": value}
|
||||
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._collection_name,
|
||||
where=where_filter,
|
||||
output='minimal'
|
||||
)
|
||||
self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal")
|
||||
|
||||
def delete(self):
|
||||
# check whether the index already exists
|
||||
@@ -154,11 +139,19 @@ class WeaviateVector(BaseVector):
|
||||
# check whether the index already exists
|
||||
if not self._client.schema.contains(schema):
|
||||
return False
|
||||
result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueText": id,
|
||||
}).with_limit(1).do()
|
||||
result = (
|
||||
self._client.query.get(collection_name)
|
||||
.with_additional(["id"])
|
||||
.with_where(
|
||||
{
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueText": id,
|
||||
}
|
||||
)
|
||||
.with_limit(1)
|
||||
.do()
|
||||
)
|
||||
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
@@ -211,13 +204,13 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
# check score threshold
|
||||
if score > score_threshold:
|
||||
doc.metadata['score'] = score
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -240,15 +233,15 @@ class WeaviateVector(BaseVector):
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
query_obj = query_obj.with_additional(["vector"])
|
||||
properties = ['text']
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
|
||||
properties = ["text"]
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][collection_name]:
|
||||
text = res.pop(Field.TEXT_KEY.value)
|
||||
additional = res.pop('_additional')
|
||||
docs.append(Document(page_content=text, vector=additional['vector'], metadata=res))
|
||||
additional = res.pop("_additional")
|
||||
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
|
||||
return docs
|
||||
|
||||
def _default_schema(self, index_name: str) -> dict:
|
||||
@@ -271,20 +264,19 @@ class WeaviateVector(BaseVector):
|
||||
class WeaviateVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||
|
||||
return WeaviateVector(
|
||||
collection_name=collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint=dify_config.WEAVIATE_ENDPOINT,
|
||||
api_key=dify_config.WEAVIATE_API_KEY,
|
||||
batch_size=dify_config.WEAVIATE_BATCH_SIZE
|
||||
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
|
||||
),
|
||||
attributes=attributes
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user