chore: add ast-grep rule to convert Optional[T] to T | None (#25560)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
-LAN-
2025-09-15 13:06:33 +08:00
committed by GitHub
parent 2e44ebe98d
commit bab4975809
394 changed files with 2555 additions and 2792 deletions

View File

@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, Optional
from typing import Any
import orjson
from pydantic import BaseModel
@@ -143,7 +143,7 @@ class Jieba(BaseKeyword):
storage.delete(file_key)
storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8"))
def _get_dataset_keyword_table(self) -> Optional[dict]:
def _get_dataset_keyword_table(self) -> dict | None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict

View File

@@ -1,5 +1,5 @@
import re
from typing import Optional, cast
from typing import cast
class JiebaKeywordTableHandler:
@@ -10,7 +10,7 @@ class JiebaKeywordTableHandler:
jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
def extract_keywords(self, text: str, max_keywords_per_chunk: int | None = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
import jieba.analyse # type: ignore

View File

@@ -1,6 +1,5 @@
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from flask import Flask, current_app
from sqlalchemy import select
@@ -39,11 +38,11 @@ class RetrievalService:
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float] = 0.0,
reranking_model: Optional[dict] = None,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_mode: str = "reranking_model",
weights: Optional[dict] = None,
document_ids_filter: Optional[list[str]] = None,
weights: dict | None = None,
document_ids_filter: list[str] | None = None,
):
if not query:
return []
@@ -125,8 +124,8 @@ class RetrievalService:
cls,
dataset_id: str,
query: str,
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
external_retrieval_model: dict | None = None,
metadata_filtering_conditions: dict | None = None,
):
stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
@@ -145,7 +144,7 @@ class RetrievalService:
return all_documents
@classmethod
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
with Session(db.engine) as session:
return session.query(Dataset).where(Dataset.id == dataset_id).first()
@@ -158,7 +157,7 @@ class RetrievalService:
top_k: int,
all_documents: list,
exceptions: list,
document_ids_filter: Optional[list[str]] = None,
document_ids_filter: list[str] | None = None,
):
with flask_app.app_context():
try:
@@ -182,12 +181,12 @@ class RetrievalService:
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
score_threshold: float | None,
reranking_model: dict | None,
all_documents: list,
retrieval_method: str,
exceptions: list,
document_ids_filter: Optional[list[str]] = None,
document_ids_filter: list[str] | None = None,
):
with flask_app.app_context():
try:
@@ -235,12 +234,12 @@ class RetrievalService:
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
score_threshold: float | None,
reranking_model: dict | None,
all_documents: list,
retrieval_method: str,
exceptions: list,
document_ids_filter: Optional[list[str]] = None,
document_ids_filter: list[str] | None = None,
):
with flask_app.app_context():
try:

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, model_validator
@@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
account: str
account_password: str
namespace: str = "dify"
namespace_password: Optional[str] = None
namespace_password: str | None = None
metrics: str = "cosine"
read_timeout: int = 60000

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Optional
from typing import Any
import chromadb
from chromadb import QueryResult, Settings
@@ -20,8 +20,8 @@ class ChromaConfig(BaseModel):
port: int
tenant: str
database: str
auth_provider: Optional[str] = None
auth_credentials: Optional[str] = None
auth_provider: str | None = None
auth_credentials: str | None = None
def to_chroma_params(self):
settings = Settings(

View File

@@ -84,7 +84,7 @@ class ClickzettaConnectionPool:
self._pool_locks: dict[str, threading.Lock] = {}
self._max_pool_size = 5 # Maximum connections per configuration
self._connection_timeout = 300 # 5 minutes timeout
self._cleanup_thread: Optional[threading.Thread] = None
self._cleanup_thread: threading.Thread | None = None
self._shutdown = False
self._start_cleanup_thread()
@@ -303,8 +303,8 @@ class ClickzettaVector(BaseVector):
"""
# Class-level write queue and lock for serializing writes
_write_queue: Optional[queue.Queue] = None
_write_thread: Optional[threading.Thread] = None
_write_queue: queue.Queue | None = None
_write_thread: threading.Thread | None = None
_write_lock = threading.Lock()
_shutdown = False
@@ -328,7 +328,7 @@ class ClickzettaVector(BaseVector):
def __init__(self, vector_instance: "ClickzettaVector"):
self.vector = vector_instance
self.connection: Optional[Connection] = None
self.connection: Connection | None = None
def __enter__(self) -> "Connection":
self.connection = self.vector._get_connection()

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Optional
from typing import Any
from flask import current_app
@@ -22,8 +22,8 @@ class ElasticSearchJaVector(ElasticSearchVector):
def create_collection(
self,
embeddings: list[list[float]],
metadatas: Optional[list[dict[Any, Any]]] = None,
index_params: Optional[dict] = None,
metadatas: list[dict[Any, Any]] | None = None,
index_params: dict | None = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):

View File

@@ -1,7 +1,7 @@
import json
import logging
import math
from typing import Any, Optional, cast
from typing import Any, cast
from urllib.parse import urlparse
import requests
@@ -24,18 +24,18 @@ logger = logging.getLogger(__name__)
class ElasticSearchConfig(BaseModel):
# Regular Elasticsearch config
host: Optional[str] = None
port: Optional[int] = None
username: Optional[str] = None
password: Optional[str] = None
host: str | None = None
port: int | None = None
username: str | None = None
password: str | None = None
# Elastic Cloud specific config
cloud_url: Optional[str] = None # Cloud URL for Elasticsearch Cloud
api_key: Optional[str] = None
cloud_url: str | None = None # Cloud URL for Elasticsearch Cloud
api_key: str | None = None
# Common config
use_cloud: bool = False
ca_certs: Optional[str] = None
ca_certs: str | None = None
verify_certs: bool = False
request_timeout: int = 100000
retry_on_timeout: bool = True
@@ -256,8 +256,8 @@ class ElasticSearchVector(BaseVector):
def create_collection(
self,
embeddings: list[list[float]],
metadatas: Optional[list[dict[Any, Any]]] = None,
index_params: Optional[dict] = None,
metadatas: list[dict[Any, Any]] | None = None,
index_params: dict | None = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):

View File

@@ -1,7 +1,7 @@
import json
import logging
import ssl
from typing import Any, Optional
from typing import Any
from elasticsearch import Elasticsearch
from pydantic import BaseModel, model_validator
@@ -157,8 +157,8 @@ class HuaweiCloudVector(BaseVector):
def create_collection(
self,
embeddings: list[list[float]],
metadatas: Optional[list[dict[Any, Any]]] = None,
index_params: Optional[dict] = None,
metadatas: list[dict[Any, Any]] | None = None,
index_params: dict | None = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):

View File

@@ -2,7 +2,7 @@ import copy
import json
import logging
import time
from typing import Any, Optional
from typing import Any
from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError
@@ -29,10 +29,10 @@ UGC_INDEX_PREFIX = "ugc_index"
class LindormVectorStoreConfig(BaseModel):
hosts: str
username: Optional[str] = None
password: Optional[str] = None
using_ugc: Optional[bool] = False
request_timeout: Optional[float] = 1.0 # timeout units: s
username: str | None = None
password: str | None = None
using_ugc: bool | None = False
request_timeout: float | None = 1.0 # timeout units: s
@model_validator(mode="before")
@classmethod
@@ -448,13 +448,13 @@ def default_text_search_query(
query_text: str,
k: int = 4,
text_field: str = Field.CONTENT_KEY.value,
must: Optional[list[dict]] = None,
must_not: Optional[list[dict]] = None,
should: Optional[list[dict]] = None,
must: list[dict] | None = None,
must_not: list[dict] | None = None,
should: list[dict] | None = None,
minimum_should_match: int = 0,
filters: Optional[list[dict]] = None,
routing: Optional[str] = None,
routing_field: Optional[str] = None,
filters: list[dict] | None = None,
routing: str | None = None,
routing_field: str | None = None,
**kwargs,
):
query_clause: dict[str, Any] = {}
@@ -505,13 +505,13 @@ def default_vector_search_query(
query_vector: list[float],
k: int = 4,
min_score: str = "0.0",
ef_search: Optional[str] = None, # only for hnsw
nprobe: Optional[str] = None, # "2000"
reorder_factor: Optional[str] = None, # "20"
client_refactor: Optional[str] = None, # "true"
ef_search: str | None = None, # only for hnsw
nprobe: str | None = None, # "2000"
reorder_factor: str | None = None, # "20"
client_refactor: str | None = None, # "true"
vector_field: str = Field.VECTOR.value,
filters: Optional[list[dict]] = None,
filter_type: Optional[str] = None,
filters: list[dict] | None = None,
filter_type: str | None = None,
**kwargs,
):
if filters is not None:

View File

@@ -3,7 +3,7 @@ import logging
import uuid
from collections.abc import Callable
from functools import wraps
from typing import Any, Concatenate, Optional, ParamSpec, TypeVar
from typing import Any, Concatenate, ParamSpec, TypeVar
from mo_vector.client import MoVectorClient # type: ignore
from pydantic import BaseModel, model_validator
@@ -74,7 +74,7 @@ class MatrixoneVector(BaseVector):
self.client = self._get_client(len(embeddings[0]), True)
return self.add_texts(texts, embeddings)
def _get_client(self, dimension: Optional[int] = None, create_table: bool = False) -> MoVectorClient:
def _get_client(self, dimension: int | None = None, create_table: bool = False) -> MoVectorClient:
"""
Create a new client for the collection.

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Optional
from typing import Any
from packaging import version
from pydantic import BaseModel, model_validator
@@ -26,13 +26,13 @@ class MilvusConfig(BaseModel):
"""
uri: str # Milvus server URI
token: Optional[str] = None # Optional token for authentication
user: Optional[str] = None # Username for authentication
password: Optional[str] = None # Password for authentication
token: str | None = None # Optional token for authentication
user: str | None = None # Username for authentication
password: str | None = None # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
analyzer_params: Optional[str] = None # Analyzer params
analyzer_params: str | None = None # Analyzer params
@model_validator(mode="before")
@classmethod
@@ -79,7 +79,7 @@ class MilvusVector(BaseVector):
self._load_collection_fields()
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def _load_collection_fields(self, fields: Optional[list[str]] = None):
def _load_collection_fields(self, fields: list[str] | None = None):
if fields is None:
# Load collection fields from remote server
collection_info = self._client.describe_collection(self._collection_name)
@@ -292,7 +292,7 @@ class MilvusVector(BaseVector):
)
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
):
"""
Create a new collection in Milvus with the specified schema and index parameters.

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Literal, Optional
from typing import Any, Literal
from uuid import uuid4
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
@@ -26,10 +26,10 @@ class OpenSearchConfig(BaseModel):
secure: bool = False # use_ssl
verify_certs: bool = True
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
user: Optional[str] = None
password: Optional[str] = None
aws_region: Optional[str] = None
aws_service: Optional[str] = None
user: str | None = None
password: str | None = None
aws_region: str | None = None
aws_service: str | None = None
@model_validator(mode="before")
@classmethod
@@ -236,7 +236,7 @@ class OpenSearchVector(BaseVector):
return docs
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
):
lock_name = f"vector_indexing_lock_{self._collection_name.lower()}"
with redis_client.lock(lock_name, timeout=20):

View File

@@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Union
import qdrant_client
from flask import current_app
@@ -46,7 +46,7 @@ class PathQdrantParams(BaseModel):
class UrlQdrantParams(BaseModel):
url: str
api_key: Optional[str]
api_key: str | None
timeout: float
verify: bool
grpc_port: int
@@ -55,9 +55,9 @@ class UrlQdrantParams(BaseModel):
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str] = None
api_key: str | None = None
timeout: float = 20
root_path: Optional[str] = None
root_path: str | None = None
grpc_port: int = 6334
prefer_grpc: bool = False
replication_factor: int = 1
@@ -189,10 +189,10 @@ class QdrantVector(BaseVector):
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
metadatas: list[dict] | None = None,
ids: Sequence[str] | None = None,
batch_size: int = 64,
group_id: Optional[str] = None,
group_id: str | None = None,
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
@@ -234,7 +234,7 @@ class QdrantVector(BaseVector):
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
metadatas: list[dict] | None,
content_payload_key: str,
metadata_payload_key: str,
group_id: str,

View File

@@ -1,6 +1,6 @@
import json
import uuid
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, model_validator
from sqlalchemy import Column, String, Table, create_engine, insert
@@ -160,7 +160,7 @@ class RelytVector(BaseVector):
else:
return None
def delete_by_uuids(self, ids: Optional[list[str]] = None):
def delete_by_uuids(self, ids: list[str] | None = None):
"""Delete by vector IDs.
Args:
@@ -241,7 +241,7 @@ class RelytVector(BaseVector):
self,
embedding: list[float],
k: int = 4,
filter: Optional[dict] = None,
filter: dict | None = None,
) -> list[tuple[Document, float]]:
# Add the filter if provided

View File

@@ -2,7 +2,7 @@ import json
import logging
import math
from collections.abc import Iterable
from typing import Any, Optional
from typing import Any
import tablestore # type: ignore
from pydantic import BaseModel, model_validator
@@ -22,11 +22,11 @@ logger = logging.getLogger(__name__)
class TableStoreConfig(BaseModel):
access_key_id: Optional[str] = None
access_key_secret: Optional[str] = None
instance_name: Optional[str] = None
endpoint: Optional[str] = None
normalize_full_text_bm25_score: Optional[bool] = False
access_key_id: str | None = None
access_key_secret: str | None = None
instance_name: str | None = None
endpoint: str | None = None
normalize_full_text_bm25_score: bool | None = False
@model_validator(mode="before")
@classmethod

View File

@@ -1,7 +1,7 @@
import json
import logging
import math
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel
from tcvdb_text.encoder import BM25Encoder # type: ignore
@@ -24,10 +24,10 @@ logger = logging.getLogger(__name__)
class TencentConfig(BaseModel):
url: str
api_key: Optional[str] = None
api_key: str | None = None
timeout: float = 30
username: Optional[str] = None
database: Optional[str] = None
username: str | None = None
database: str | None = None
index_type: str = "HNSW"
metric_type: str = "IP"
shard: int = 1

View File

@@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Union
import qdrant_client
import requests
@@ -45,9 +45,9 @@ if TYPE_CHECKING:
class TidbOnQdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str] = None
api_key: str | None = None
timeout: float = 20
root_path: Optional[str] = None
root_path: str | None = None
grpc_port: int = 6334
prefer_grpc: bool = False
replication_factor: int = 1
@@ -180,10 +180,10 @@ class TidbOnQdrantVector(BaseVector):
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
metadatas: list[dict] | None = None,
ids: Sequence[str] | None = None,
batch_size: int = 64,
group_id: Optional[str] = None,
group_id: str | None = None,
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
@@ -225,7 +225,7 @@ class TidbOnQdrantVector(BaseVector):
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
metadatas: list[dict] | None,
content_payload_key: str,
metadata_payload_key: str,
group_id: str,

View File

@@ -1,7 +1,7 @@
import logging
import time
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any
from sqlalchemy import select
@@ -32,7 +32,7 @@ class AbstractVectorFactory(ABC):
class Vector:
def __init__(self, dataset: Dataset, attributes: Optional[list] = None):
def __init__(self, dataset: Dataset, attributes: list | None = None):
if attributes is None:
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
self._dataset = dataset
@@ -180,7 +180,7 @@ class Vector:
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
def create(self, texts: Optional[list] = None, **kwargs):
def create(self, texts: list | None = None, **kwargs):
if texts:
start = time.time()
logger.info("start embedding %s texts %s", len(texts), start)

View File

@@ -1,6 +1,6 @@
import datetime
import json
from typing import Any, Optional
from typing import Any
import requests
import weaviate # type: ignore
@@ -19,7 +19,7 @@ from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str] = None
api_key: str | None = None
batch_size: int = 100
@model_validator(mode="before")