chore: add ast-grep rule to convert Optional[T] to T | None (#25560)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = "dify"
|
||||
namespace_password: Optional[str] = None
|
||||
namespace_password: str | None = None
|
||||
metrics: str = "cosine"
|
||||
read_timeout: int = 60000
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import chromadb
|
||||
from chromadb import QueryResult, Settings
|
||||
@@ -20,8 +20,8 @@ class ChromaConfig(BaseModel):
|
||||
port: int
|
||||
tenant: str
|
||||
database: str
|
||||
auth_provider: Optional[str] = None
|
||||
auth_credentials: Optional[str] = None
|
||||
auth_provider: str | None = None
|
||||
auth_credentials: str | None = None
|
||||
|
||||
def to_chroma_params(self):
|
||||
settings = Settings(
|
||||
|
||||
@@ -84,7 +84,7 @@ class ClickzettaConnectionPool:
|
||||
self._pool_locks: dict[str, threading.Lock] = {}
|
||||
self._max_pool_size = 5 # Maximum connections per configuration
|
||||
self._connection_timeout = 300 # 5 minutes timeout
|
||||
self._cleanup_thread: Optional[threading.Thread] = None
|
||||
self._cleanup_thread: threading.Thread | None = None
|
||||
self._shutdown = False
|
||||
self._start_cleanup_thread()
|
||||
|
||||
@@ -303,8 +303,8 @@ class ClickzettaVector(BaseVector):
|
||||
"""
|
||||
|
||||
# Class-level write queue and lock for serializing writes
|
||||
_write_queue: Optional[queue.Queue] = None
|
||||
_write_thread: Optional[threading.Thread] = None
|
||||
_write_queue: queue.Queue | None = None
|
||||
_write_thread: threading.Thread | None = None
|
||||
_write_lock = threading.Lock()
|
||||
_shutdown = False
|
||||
|
||||
@@ -328,7 +328,7 @@ class ClickzettaVector(BaseVector):
|
||||
|
||||
def __init__(self, vector_instance: "ClickzettaVector"):
|
||||
self.vector = vector_instance
|
||||
self.connection: Optional[Connection] = None
|
||||
self.connection: Connection | None = None
|
||||
|
||||
def __enter__(self) -> "Connection":
|
||||
self.connection = self.vector._get_connection()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
|
||||
@@ -22,8 +22,8 @@ class ElasticSearchJaVector(ElasticSearchVector):
|
||||
def create_collection(
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict[Any, Any]]] = None,
|
||||
index_params: Optional[dict] = None,
|
||||
metadatas: list[dict[Any, Any]] | None = None,
|
||||
index_params: dict | None = None,
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
@@ -24,18 +24,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ElasticSearchConfig(BaseModel):
|
||||
# Regular Elasticsearch config
|
||||
host: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
host: str | None = None
|
||||
port: int | None = None
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
|
||||
# Elastic Cloud specific config
|
||||
cloud_url: Optional[str] = None # Cloud URL for Elasticsearch Cloud
|
||||
api_key: Optional[str] = None
|
||||
cloud_url: str | None = None # Cloud URL for Elasticsearch Cloud
|
||||
api_key: str | None = None
|
||||
|
||||
# Common config
|
||||
use_cloud: bool = False
|
||||
ca_certs: Optional[str] = None
|
||||
ca_certs: str | None = None
|
||||
verify_certs: bool = False
|
||||
request_timeout: int = 100000
|
||||
retry_on_timeout: bool = True
|
||||
@@ -256,8 +256,8 @@ class ElasticSearchVector(BaseVector):
|
||||
def create_collection(
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict[Any, Any]]] = None,
|
||||
index_params: Optional[dict] = None,
|
||||
metadatas: list[dict[Any, Any]] | None = None,
|
||||
index_params: dict | None = None,
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -157,8 +157,8 @@ class HuaweiCloudVector(BaseVector):
|
||||
def create_collection(
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict[Any, Any]]] = None,
|
||||
index_params: Optional[dict] = None,
|
||||
metadatas: list[dict[Any, Any]] | None = None,
|
||||
index_params: dict | None = None,
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
|
||||
@@ -2,7 +2,7 @@ import copy
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from opensearchpy import OpenSearch, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
@@ -29,10 +29,10 @@ UGC_INDEX_PREFIX = "ugc_index"
|
||||
|
||||
class LindormVectorStoreConfig(BaseModel):
|
||||
hosts: str
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
using_ugc: Optional[bool] = False
|
||||
request_timeout: Optional[float] = 1.0 # timeout units: s
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
using_ugc: bool | None = False
|
||||
request_timeout: float | None = 1.0 # timeout units: s
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -448,13 +448,13 @@ def default_text_search_query(
|
||||
query_text: str,
|
||||
k: int = 4,
|
||||
text_field: str = Field.CONTENT_KEY.value,
|
||||
must: Optional[list[dict]] = None,
|
||||
must_not: Optional[list[dict]] = None,
|
||||
should: Optional[list[dict]] = None,
|
||||
must: list[dict] | None = None,
|
||||
must_not: list[dict] | None = None,
|
||||
should: list[dict] | None = None,
|
||||
minimum_should_match: int = 0,
|
||||
filters: Optional[list[dict]] = None,
|
||||
routing: Optional[str] = None,
|
||||
routing_field: Optional[str] = None,
|
||||
filters: list[dict] | None = None,
|
||||
routing: str | None = None,
|
||||
routing_field: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
query_clause: dict[str, Any] = {}
|
||||
@@ -505,13 +505,13 @@ def default_vector_search_query(
|
||||
query_vector: list[float],
|
||||
k: int = 4,
|
||||
min_score: str = "0.0",
|
||||
ef_search: Optional[str] = None, # only for hnsw
|
||||
nprobe: Optional[str] = None, # "2000"
|
||||
reorder_factor: Optional[str] = None, # "20"
|
||||
client_refactor: Optional[str] = None, # "true"
|
||||
ef_search: str | None = None, # only for hnsw
|
||||
nprobe: str | None = None, # "2000"
|
||||
reorder_factor: str | None = None, # "20"
|
||||
client_refactor: str | None = None, # "true"
|
||||
vector_field: str = Field.VECTOR.value,
|
||||
filters: Optional[list[dict]] = None,
|
||||
filter_type: Optional[str] = None,
|
||||
filters: list[dict] | None = None,
|
||||
filter_type: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if filters is not None:
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, Concatenate, Optional, ParamSpec, TypeVar
|
||||
from typing import Any, Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from mo_vector.client import MoVectorClient # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -74,7 +74,7 @@ class MatrixoneVector(BaseVector):
|
||||
self.client = self._get_client(len(embeddings[0]), True)
|
||||
return self.add_texts(texts, embeddings)
|
||||
|
||||
def _get_client(self, dimension: Optional[int] = None, create_table: bool = False) -> MoVectorClient:
|
||||
def _get_client(self, dimension: int | None = None, create_table: bool = False) -> MoVectorClient:
|
||||
"""
|
||||
Create a new client for the collection.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -26,13 +26,13 @@ class MilvusConfig(BaseModel):
|
||||
"""
|
||||
|
||||
uri: str # Milvus server URI
|
||||
token: Optional[str] = None # Optional token for authentication
|
||||
user: Optional[str] = None # Username for authentication
|
||||
password: Optional[str] = None # Password for authentication
|
||||
token: str | None = None # Optional token for authentication
|
||||
user: str | None = None # Username for authentication
|
||||
password: str | None = None # Password for authentication
|
||||
batch_size: int = 100 # Batch size for operations
|
||||
database: str = "default" # Database name
|
||||
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||
analyzer_params: Optional[str] = None # Analyzer params
|
||||
analyzer_params: str | None = None # Analyzer params
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -79,7 +79,7 @@ class MilvusVector(BaseVector):
|
||||
self._load_collection_fields()
|
||||
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
|
||||
|
||||
def _load_collection_fields(self, fields: Optional[list[str]] = None):
|
||||
def _load_collection_fields(self, fields: list[str] | None = None):
|
||||
if fields is None:
|
||||
# Load collection fields from remote server
|
||||
collection_info = self._client.describe_collection(self._collection_name)
|
||||
@@ -292,7 +292,7 @@ class MilvusVector(BaseVector):
|
||||
)
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
|
||||
):
|
||||
"""
|
||||
Create a new collection in Milvus with the specified schema and index parameters.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
@@ -26,10 +26,10 @@ class OpenSearchConfig(BaseModel):
|
||||
secure: bool = False # use_ssl
|
||||
verify_certs: bool = True
|
||||
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
|
||||
user: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
aws_region: Optional[str] = None
|
||||
aws_service: Optional[str] = None
|
||||
user: str | None = None
|
||||
password: str | None = None
|
||||
aws_region: str | None = None
|
||||
aws_service: str | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -236,7 +236,7 @@ class OpenSearchVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name.lower()}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import qdrant_client
|
||||
from flask import current_app
|
||||
@@ -46,7 +46,7 @@ class PathQdrantParams(BaseModel):
|
||||
|
||||
class UrlQdrantParams(BaseModel):
|
||||
url: str
|
||||
api_key: Optional[str]
|
||||
api_key: str | None
|
||||
timeout: float
|
||||
verify: bool
|
||||
grpc_port: int
|
||||
@@ -55,9 +55,9 @@ class UrlQdrantParams(BaseModel):
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str] = None
|
||||
api_key: str | None = None
|
||||
timeout: float = 20
|
||||
root_path: Optional[str] = None
|
||||
root_path: str | None = None
|
||||
grpc_port: int = 6334
|
||||
prefer_grpc: bool = False
|
||||
replication_factor: int = 1
|
||||
@@ -189,10 +189,10 @@ class QdrantVector(BaseVector):
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
metadatas: list[dict] | None = None,
|
||||
ids: Sequence[str] | None = None,
|
||||
batch_size: int = 64,
|
||||
group_id: Optional[str] = None,
|
||||
group_id: str | None = None,
|
||||
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
@@ -234,7 +234,7 @@ class QdrantVector(BaseVector):
|
||||
def _build_payloads(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]],
|
||||
metadatas: list[dict] | None,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
group_id: str,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Column, String, Table, create_engine, insert
|
||||
@@ -160,7 +160,7 @@ class RelytVector(BaseVector):
|
||||
else:
|
||||
return None
|
||||
|
||||
def delete_by_uuids(self, ids: Optional[list[str]] = None):
|
||||
def delete_by_uuids(self, ids: list[str] | None = None):
|
||||
"""Delete by vector IDs.
|
||||
|
||||
Args:
|
||||
@@ -241,7 +241,7 @@ class RelytVector(BaseVector):
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict | None = None,
|
||||
) -> list[tuple[Document, float]]:
|
||||
# Add the filter if provided
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import tablestore # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -22,11 +22,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TableStoreConfig(BaseModel):
|
||||
access_key_id: Optional[str] = None
|
||||
access_key_secret: Optional[str] = None
|
||||
instance_name: Optional[str] = None
|
||||
endpoint: Optional[str] = None
|
||||
normalize_full_text_bm25_score: Optional[bool] = False
|
||||
access_key_id: str | None = None
|
||||
access_key_secret: str | None = None
|
||||
instance_name: str | None = None
|
||||
endpoint: str | None = None
|
||||
normalize_full_text_bm25_score: bool | None = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tcvdb_text.encoder import BM25Encoder # type: ignore
|
||||
@@ -24,10 +24,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class TencentConfig(BaseModel):
|
||||
url: str
|
||||
api_key: Optional[str] = None
|
||||
api_key: str | None = None
|
||||
timeout: float = 30
|
||||
username: Optional[str] = None
|
||||
database: Optional[str] = None
|
||||
username: str | None = None
|
||||
database: str | None = None
|
||||
index_type: str = "HNSW"
|
||||
metric_type: str = "IP"
|
||||
shard: int = 1
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import qdrant_client
|
||||
import requests
|
||||
@@ -45,9 +45,9 @@ if TYPE_CHECKING:
|
||||
|
||||
class TidbOnQdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str] = None
|
||||
api_key: str | None = None
|
||||
timeout: float = 20
|
||||
root_path: Optional[str] = None
|
||||
root_path: str | None = None
|
||||
grpc_port: int = 6334
|
||||
prefer_grpc: bool = False
|
||||
replication_factor: int = 1
|
||||
@@ -180,10 +180,10 @@ class TidbOnQdrantVector(BaseVector):
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
metadatas: list[dict] | None = None,
|
||||
ids: Sequence[str] | None = None,
|
||||
batch_size: int = 64,
|
||||
group_id: Optional[str] = None,
|
||||
group_id: str | None = None,
|
||||
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
@@ -225,7 +225,7 @@ class TidbOnQdrantVector(BaseVector):
|
||||
def _build_payloads(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]],
|
||||
metadatas: list[dict] | None,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
group_id: str,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -32,7 +32,7 @@ class AbstractVectorFactory(ABC):
|
||||
|
||||
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: Optional[list] = None):
|
||||
def __init__(self, dataset: Dataset, attributes: list | None = None):
|
||||
if attributes is None:
|
||||
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
|
||||
self._dataset = dataset
|
||||
@@ -180,7 +180,7 @@ class Vector:
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
def create(self, texts: Optional[list] = None, **kwargs):
|
||||
def create(self, texts: list | None = None, **kwargs):
|
||||
if texts:
|
||||
start = time.time()
|
||||
logger.info("start embedding %s texts %s", len(texts), start)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import weaviate # type: ignore
|
||||
@@ -19,7 +19,7 @@ from models.dataset import Dataset
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str] = None
|
||||
api_key: str | None = None
|
||||
batch_size: int = 100
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
Reference in New Issue
Block a user