chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -28,11 +28,11 @@ class OpenSearchConfig(BaseModel):
|
||||
password: Optional[str] = None
|
||||
secure: bool = False
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values.get('host'):
|
||||
if not values.get("host"):
|
||||
raise ValueError("config OPENSEARCH_HOST is required")
|
||||
if not values.get('port'):
|
||||
if not values.get("port"):
|
||||
raise ValueError("config OPENSEARCH_PORT is required")
|
||||
return values
|
||||
|
||||
@@ -44,19 +44,18 @@ class OpenSearchConfig(BaseModel):
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
'hosts': [{'host': self.host, 'port': self.port}],
|
||||
'use_ssl': self.secure,
|
||||
'verify_certs': self.secure,
|
||||
"hosts": [{"host": self.host, "port": self.port}],
|
||||
"use_ssl": self.secure,
|
||||
"verify_certs": self.secure,
|
||||
}
|
||||
if self.user and self.password:
|
||||
params['http_auth'] = (self.user, self.password)
|
||||
params["http_auth"] = (self.user, self.password)
|
||||
if self.secure:
|
||||
params['ssl_context'] = self.create_ssl_context()
|
||||
params["ssl_context"] = self.create_ssl_context()
|
||||
return params
|
||||
|
||||
|
||||
class OpenSearchVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: OpenSearchConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
@@ -81,7 +80,7 @@ class OpenSearchVector(BaseVector):
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
}
|
||||
},
|
||||
}
|
||||
actions.append(action)
|
||||
|
||||
@@ -90,8 +89,8 @@ class OpenSearchVector(BaseVector):
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
if response['hits']['hits']:
|
||||
return [hit['_id'] for hit in response['hits']['hits']]
|
||||
if response["hits"]["hits"]:
|
||||
return [hit["_id"] for hit in response["hits"]["hits"]]
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -110,7 +109,7 @@ class OpenSearchVector(BaseVector):
|
||||
actual_ids = []
|
||||
|
||||
for doc_id in ids:
|
||||
es_ids = self.get_ids_by_metadata_field('doc_id', doc_id)
|
||||
es_ids = self.get_ids_by_metadata_field("doc_id", doc_id)
|
||||
if es_ids:
|
||||
actual_ids.extend(es_ids)
|
||||
else:
|
||||
@@ -122,9 +121,9 @@ class OpenSearchVector(BaseVector):
|
||||
helpers.bulk(self._client, actions)
|
||||
except BulkIndexError as e:
|
||||
for error in e.errors:
|
||||
delete_error = error.get('delete', {})
|
||||
status = delete_error.get('status')
|
||||
doc_id = delete_error.get('_id')
|
||||
delete_error = error.get("delete", {})
|
||||
status = delete_error.get("status")
|
||||
doc_id = delete_error.get("_id")
|
||||
|
||||
if status == 404:
|
||||
logger.warning(f"Document not found for deletion: {doc_id}")
|
||||
@@ -151,15 +150,8 @@ class OpenSearchVector(BaseVector):
|
||||
raise ValueError("All elements in query_vector should be floats")
|
||||
|
||||
query = {
|
||||
"size": kwargs.get('top_k', 4),
|
||||
"query": {
|
||||
"knn": {
|
||||
Field.VECTOR.value: {
|
||||
Field.VECTOR.value: query_vector,
|
||||
"k": kwargs.get('top_k', 4)
|
||||
}
|
||||
}
|
||||
}
|
||||
"size": kwargs.get("top_k", 4),
|
||||
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -169,17 +161,17 @@ class OpenSearchVector(BaseVector):
|
||||
raise
|
||||
|
||||
docs = []
|
||||
for hit in response['hits']['hits']:
|
||||
metadata = hit['_source'].get(Field.METADATA_KEY.value, {})
|
||||
for hit in response["hits"]["hits"]:
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY.value, {})
|
||||
|
||||
# Make sure metadata is a dictionary
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
metadata['score'] = hit['_score']
|
||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||
if hit['_score'] > score_threshold:
|
||||
doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
metadata["score"] = hit["_score"]
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
if hit["_score"] > score_threshold:
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
@@ -190,32 +182,28 @@ class OpenSearchVector(BaseVector):
|
||||
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
|
||||
|
||||
docs = []
|
||||
for hit in response['hits']['hits']:
|
||||
metadata = hit['_source'].get(Field.METADATA_KEY.value)
|
||||
vector = hit['_source'].get(Field.VECTOR.value)
|
||||
page_content = hit['_source'].get(Field.CONTENT_KEY.value)
|
||||
for hit in response["hits"]["hits"]:
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY.value)
|
||||
vector = hit["_source"].get(Field.VECTOR.value)
|
||||
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
|
||||
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
):
|
||||
lock_name = f'vector_indexing_lock_{self._collection_name.lower()}'
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name.lower()}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f'vector_indexing_{self._collection_name.lower()}'
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info(f"Collection {self._collection_name.lower()} already exists.")
|
||||
return
|
||||
|
||||
if not self._client.indices.exists(index=self._collection_name.lower()):
|
||||
index_body = {
|
||||
"settings": {
|
||||
"index": {
|
||||
"knn": True
|
||||
}
|
||||
},
|
||||
"settings": {"index": {"knn": True}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {"type": "text"},
|
||||
@@ -226,20 +214,17 @@ class OpenSearchVector(BaseVector):
|
||||
"name": "hnsw",
|
||||
"space_type": "l2",
|
||||
"engine": "faiss",
|
||||
"parameters": {
|
||||
"ef_construction": 64,
|
||||
"m": 8
|
||||
}
|
||||
}
|
||||
"parameters": {"ef_construction": 64, "m": 8},
|
||||
},
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
|
||||
@@ -248,17 +233,14 @@ class OpenSearchVector(BaseVector):
|
||||
|
||||
|
||||
class OpenSearchVectorFactory(AbstractVectorFactory):
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
|
||||
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
|
||||
|
||||
open_search_config = OpenSearchConfig(
|
||||
host=dify_config.OPENSEARCH_HOST,
|
||||
@@ -268,7 +250,4 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
||||
secure=dify_config.OPENSEARCH_SECURE,
|
||||
)
|
||||
|
||||
return OpenSearchVector(
|
||||
collection_name=collection_name,
|
||||
config=open_search_config
|
||||
)
|
||||
return OpenSearchVector(collection_name=collection_name, config=open_search_config)
|
||||
|
||||
Reference in New Issue
Block a user