chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -29,6 +29,7 @@ class AnalyticdbConfig(BaseModel):
|
||||
namespace_password: str = (None,)
|
||||
metrics: str = ("cosine",)
|
||||
read_timeout: int = 60000
|
||||
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
@@ -37,6 +38,7 @@ class AnalyticdbConfig(BaseModel):
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
_instance = None
|
||||
_init = False
|
||||
@@ -57,9 +59,7 @@ class AnalyticdbVector(BaseVector):
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(
|
||||
user_agent="dify", **config.to_analyticdb_client_params()
|
||||
)
|
||||
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
AnalyticdbVector._init = True
|
||||
@@ -77,6 +77,7 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@@ -88,6 +89,7 @@ class AnalyticdbVector(BaseVector):
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
@@ -109,13 +111,12 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"failed to create namespace {self.config.namespace}: {e}"
|
||||
)
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
@@ -149,9 +150,7 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"failed to create collection {self._collection_name}: {e}"
|
||||
)
|
||||
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def get_type(self) -> str:
|
||||
@@ -162,10 +161,9 @@ class AnalyticdbVector(BaseVector):
|
||||
self._create_collection_if_not_exists(dimension)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(
|
||||
self, documents: list[Document], embeddings: list[list[float]], **kwargs
|
||||
):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
@@ -191,6 +189,7 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@@ -202,13 +201,14 @@ class AnalyticdbVector(BaseVector):
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'"
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
@@ -224,6 +224,7 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@@ -235,15 +236,10 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def search_by_vector(
|
||||
self, query_vector: list[float], **kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
score_threshold = (
|
||||
kwargs.get("score_threshold", 0.0)
|
||||
if kwargs.get("score_threshold", 0.0)
|
||||
else 0.0
|
||||
)
|
||||
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@@ -270,11 +266,8 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
score_threshold = (
|
||||
kwargs.get("score_threshold", 0.0)
|
||||
if kwargs.get("score_threshold", 0.0)
|
||||
else 0.0
|
||||
)
|
||||
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@@ -304,6 +297,7 @@ class AnalyticdbVector(BaseVector):
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
@@ -315,19 +309,16 @@ class AnalyticdbVector(BaseVector):
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"][
|
||||
"class_prefix"
|
||||
]
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
|
||||
)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
|
||||
|
||||
# handle optional params
|
||||
if dify_config.ANALYTICDB_KEY_ID is None:
|
||||
|
||||
Reference in New Issue
Block a user