chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -27,44 +27,39 @@ class MilvusConfig(BaseModel):
batch_size: int = 100
database: str = "default"
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values.get('uri'):
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
if not values.get('user'):
if not values.get("user"):
raise ValueError("config MILVUS_USER is required")
if not values.get('password'):
if not values.get("password"):
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'uri': self.uri,
'token': self.token,
'user': self.user,
'password': self.password,
'db_name': self.database,
"uri": self.uri,
"token": self.token,
"user": self.user,
"password": self.password,
"db_name": self.database,
}
class MilvusVector(BaseVector):
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = 'Session'
self._consistency_level = "Session"
self._fields = []
def get_type(self) -> str:
return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata for d in texts]
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
@@ -75,7 +70,7 @@ class MilvusVector(BaseVector):
insert_dict = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata
Field.METADATA_KEY.value: documents[i].metadata,
}
insert_dict_list.append(insert_dict)
# Total insert count
@@ -84,22 +79,20 @@ class MilvusVector(BaseVector):
pks: list[str] = []
for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i:i + 1000]
batch_insert_list = insert_dict_list[i : i + 1000]
# Insert into the collection.
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
except MilvusException as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count)
raise e
return pks
def get_ids_by_metadata_field(self, key: str, value: str):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["{key}"] == "{value}"',
output_fields=["id"])
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
@@ -107,17 +100,15 @@ class MilvusVector(BaseVector):
def delete_by_metadata_field(self, key: str, value: str):
if self._client.has_collection(self._collection_name):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, ids: list[str]) -> None:
if self._client.has_collection(self._collection_name):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] in {ids}',
output_fields=["id"])
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
)
if result:
ids = [item["id"] for item in result]
self._client.delete(collection_name=self._collection_name, pks=ids)
@@ -130,29 +121,28 @@ class MilvusVector(BaseVector):
if not self._client.has_collection(self._collection_name):
return False
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] == "{id}"',
output_fields=["id"])
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] == "{id}"', output_fields=["id"]
)
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
results = self._client.search(collection_name=self._collection_name,
data=[query_vector],
limit=kwargs.get('top_k', 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
docs = []
for result in results[0]:
metadata = result['entity'].get(Field.METADATA_KEY.value)
metadata['score'] = result['distance']
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if result['distance'] > score_threshold:
doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
metadata=metadata)
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs
@@ -161,11 +151,11 @@ class MilvusVector(BaseVector):
return []
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
# Grab the existing collection if it exists
@@ -180,19 +170,11 @@ class MilvusVector(BaseVector):
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
)
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
# Create the primary key field
fields.append(
FieldSchema(
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
)
)
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
)
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create the schema for the collection
schema = CollectionSchema(fields)
@@ -208,9 +190,12 @@ class MilvusVector(BaseVector):
# Create the collection
collection_name = self._collection_name
self._client.create_collection(collection_name=collection_name,
schema=schema, index_params=index_params_obj,
consistency_level=self._consistency_level)
self._client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params_obj,
consistency_level=self._consistency_level,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
@@ -221,13 +206,12 @@ class MilvusVector(BaseVector):
class MilvusVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
return MilvusVector(
collection_name=collection_name,
@@ -237,5 +221,5 @@ class MilvusVectorFactory(AbstractVectorFactory):
user=dify_config.MILVUS_USER,
password=dify_config.MILVUS_PASSWORD,
database=dify_config.MILVUS_DATABASE,
)
),
)