chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -27,21 +27,20 @@ class ChromaConfig(BaseModel):
|
||||
settings = Settings(
|
||||
# auth
|
||||
chroma_client_auth_provider=self.auth_provider,
|
||||
chroma_client_auth_credentials=self.auth_credentials
|
||||
chroma_client_auth_credentials=self.auth_credentials,
|
||||
)
|
||||
|
||||
return {
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'ssl': False,
|
||||
'tenant': self.tenant,
|
||||
'database': self.database,
|
||||
'settings': settings,
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"ssl": False,
|
||||
"tenant": self.tenant,
|
||||
"database": self.database,
|
||||
"settings": settings,
|
||||
}
|
||||
|
||||
|
||||
class ChromaVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: ChromaConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
@@ -58,9 +57,9 @@ class ChromaVector(BaseVector):
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(self, collection_name: str):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
self._client.get_or_create_collection(collection_name)
|
||||
@@ -76,7 +75,7 @@ class ChromaVector(BaseVector):
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.delete(where={key: {'$eq': value}})
|
||||
collection.delete(where={key: {"$eq": value}})
|
||||
|
||||
def delete(self):
|
||||
self._client.delete_collection(self._collection_name)
|
||||
@@ -93,26 +92,26 @@ class ChromaVector(BaseVector):
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
|
||||
ids: list[str] = results['ids'][0]
|
||||
documents: list[str] = results['documents'][0]
|
||||
metadatas: dict[str, Any] = results['metadatas'][0]
|
||||
distances: list[float] = results['distances'][0]
|
||||
ids: list[str] = results["ids"][0]
|
||||
documents: list[str] = results["documents"][0]
|
||||
metadatas: dict[str, Any] = results["metadatas"][0]
|
||||
distances: list[float] = results["distances"][0]
|
||||
|
||||
docs = []
|
||||
for index in range(len(ids)):
|
||||
distance = distances[index]
|
||||
metadata = metadatas[index]
|
||||
if distance >= score_threshold:
|
||||
metadata['score'] = distance
|
||||
metadata["score"] = distance
|
||||
doc = Document(
|
||||
page_content=documents[index],
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -123,15 +122,12 @@ class ChromaVector(BaseVector):
|
||||
class ChromaVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
index_struct_dict = {
|
||||
"type": VectorType.CHROMA,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
return ChromaVector(
|
||||
|
||||
Reference in New Issue
Block a user