chore: add ast-grep rule to convert Optional[T] to T | None (#25560)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
-LAN-
2025-09-15 13:06:33 +08:00
committed by GitHub
parent 2e44ebe98d
commit bab4975809
394 changed files with 2555 additions and 2792 deletions

View File

@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Optional
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
@@ -31,10 +30,10 @@ class Callback(ABC):
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
user: str | None = None,
):
"""
Before invoke callback
@@ -60,10 +59,10 @@ class Callback(ABC):
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
user: str | None = None,
):
"""
On new chunk callback
@@ -90,10 +89,10 @@ class Callback(ABC):
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
user: str | None = None,
):
"""
After invoke callback
@@ -120,10 +119,10 @@ class Callback(ABC):
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
user: str | None = None,
):
"""
Invoke error callback
@@ -141,7 +140,7 @@ class Callback(ABC):
"""
raise NotImplementedError()
def print_text(self, text: str, color: Optional[str] = None, end: str = ""):
def print_text(self, text: str, color: str | None = None, end: str = ""):
"""Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text
print(text_to_print, end=end)

View File

@@ -2,7 +2,7 @@ import json
import logging
import sys
from collections.abc import Sequence
from typing import Optional, cast
from typing import cast
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@@ -20,10 +20,10 @@ class LoggingCallback(Callback):
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
user: str | None = None,
):
"""
Before invoke callback
@@ -76,10 +76,10 @@ class LoggingCallback(Callback):
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
user: str | None = None,
):
"""
On new chunk callback
@@ -106,10 +106,10 @@ class LoggingCallback(Callback):
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
user: str | None = None,
):
"""
After invoke callback
@@ -147,10 +147,10 @@ class LoggingCallback(Callback):
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
user: str | None = None,
):
"""
Invoke error callback

View File

@@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel
@@ -8,7 +6,7 @@ class I18nObject(BaseModel):
Model class for i18n object.
"""
zh_Hans: Optional[str] = None
zh_Hans: str | None = None
en_US: str
def __init__(self, **data):

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from decimal import Decimal
from enum import StrEnum
from typing import Any, Optional, TypedDict, Union
from typing import Any, TypedDict, Union
from pydantic import BaseModel, Field
@@ -150,13 +150,13 @@ class LLMResult(BaseModel):
Model class for llm result.
"""
id: Optional[str] = None
id: str | None = None
model: str
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
message: AssistantPromptMessage
usage: LLMUsage
system_fingerprint: Optional[str] = None
reasoning_content: Optional[str] = None
system_fingerprint: str | None = None
reasoning_content: str | None = None
class LLMStructuredOutput(BaseModel):
@@ -164,7 +164,7 @@ class LLMStructuredOutput(BaseModel):
Model class for llm structured output.
"""
structured_output: Optional[Mapping[str, Any]] = None
structured_output: Mapping[str, Any] | None = None
class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput):
@@ -180,8 +180,8 @@ class LLMResultChunkDelta(BaseModel):
index: int
message: AssistantPromptMessage
usage: Optional[LLMUsage] = None
finish_reason: Optional[str] = None
usage: LLMUsage | None = None
finish_reason: str | None = None
class LLMResultChunk(BaseModel):
@@ -191,7 +191,7 @@ class LLMResultChunk(BaseModel):
model: str
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
system_fingerprint: Optional[str] = None
system_fingerprint: str | None = None
delta: LLMResultChunkDelta

View File

@@ -1,7 +1,7 @@
from abc import ABC
from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
from typing import Annotated, Any, Literal, Optional, Union
from typing import Annotated, Any, Literal, Union
from pydantic import BaseModel, Field, field_serializer, field_validator
@@ -146,8 +146,8 @@ class PromptMessage(ABC, BaseModel):
"""
role: PromptMessageRole
content: Optional[str | list[PromptMessageContentUnionTypes]] = None
name: Optional[str] = None
content: str | list[PromptMessageContentUnionTypes] | None = None
name: str | None = None
def is_empty(self) -> bool:
"""
@@ -193,8 +193,8 @@ class PromptMessage(ABC, BaseModel):
@field_serializer("content")
def serialize_content(
self, content: Optional[Union[str, Sequence[PromptMessageContent]]]
) -> Optional[str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent]]:
self, content: Union[str, Sequence[PromptMessageContent]] | None
) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None:
if content is None or isinstance(content, str):
return content
if isinstance(content, list):

View File

@@ -1,6 +1,6 @@
from decimal import Decimal
from enum import StrEnum, auto
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, ConfigDict, model_validator
@@ -154,7 +154,7 @@ class ProviderModel(BaseModel):
model: str
label: I18nObject
model_type: ModelType
features: Optional[list[ModelFeature]] = None
features: list[ModelFeature] | None = None
fetch_from: FetchFrom
model_properties: dict[ModelPropertyKey, Any]
deprecated: bool = False
@@ -171,15 +171,15 @@ class ParameterRule(BaseModel):
"""
name: str
use_template: Optional[str] = None
use_template: str | None = None
label: I18nObject
type: ParameterType
help: Optional[I18nObject] = None
help: I18nObject | None = None
required: bool = False
default: Optional[Any] = None
min: Optional[float] = None
max: Optional[float] = None
precision: Optional[int] = None
default: Any | None = None
min: float | None = None
max: float | None = None
precision: int | None = None
options: list[str] = []
@@ -189,7 +189,7 @@ class PriceConfig(BaseModel):
"""
input: Decimal
output: Optional[Decimal] = None
output: Decimal | None = None
unit: Decimal
currency: str
@@ -200,7 +200,7 @@ class AIModelEntity(ProviderModel):
"""
parameter_rules: list[ParameterRule] = []
pricing: Optional[PriceConfig] = None
pricing: PriceConfig | None = None
@model_validator(mode="after")
def validate_model(self):

View File

@@ -1,6 +1,5 @@
from collections.abc import Sequence
from enum import Enum, StrEnum, auto
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -62,9 +61,9 @@ class CredentialFormSchema(BaseModel):
label: I18nObject
type: FormType
required: bool = True
default: Optional[str] = None
options: Optional[list[FormOption]] = None
placeholder: Optional[I18nObject] = None
default: str | None = None
options: list[FormOption] | None = None
placeholder: I18nObject | None = None
max_length: int = 0
show_on: list[FormShowOnObject] = []
@@ -79,7 +78,7 @@ class ProviderCredentialSchema(BaseModel):
class FieldModelSchema(BaseModel):
label: I18nObject
placeholder: Optional[I18nObject] = None
placeholder: I18nObject | None = None
class ModelCredentialSchema(BaseModel):
@@ -98,8 +97,8 @@ class SimpleProviderEntity(BaseModel):
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType]
models: list[AIModelEntity] = []
@@ -120,24 +119,24 @@ class ProviderEntity(BaseModel):
provider: str
label: I18nObject
description: Optional[I18nObject] = None
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
icon_small_dark: Optional[I18nObject] = None
icon_large_dark: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large_dark: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: Sequence[ModelType]
configurate_methods: list[ConfigurateMethod]
models: list[AIModelEntity] = Field(default_factory=list)
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
provider_credential_schema: ProviderCredentialSchema | None = None
model_credential_schema: ModelCredentialSchema | None = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
# position from plugin _position.yaml
position: Optional[dict[str, list[str]]] = {}
position: dict[str, list[str]] | None = {}
@field_validator("models", mode="before")
@classmethod

View File

@@ -1,12 +1,9 @@
from typing import Optional
class InvokeError(ValueError):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
description: str | None = None
def __init__(self, description: Optional[str] = None):
def __init__(self, description: str | None = None):
self.description = description
def __str__(self):

View File

@@ -1,7 +1,6 @@
import decimal
import hashlib
from threading import Lock
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
@@ -99,7 +98,7 @@ class AIModel(BaseModel):
model_schema = self.get_model_schema(model, credentials)
# get price info from predefined model schema
price_config: Optional[PriceConfig] = None
price_config: PriceConfig | None = None
if model_schema and model_schema.pricing:
price_config = model_schema.pricing
@@ -132,7 +131,7 @@ class AIModel(BaseModel):
currency=price_config.currency,
)
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None:
"""
Get model schema by model name and credentials
@@ -171,7 +170,7 @@ class AIModel(BaseModel):
return schema
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get customizable model schema from credentials
@@ -229,7 +228,7 @@ class AIModel(BaseModel):
return schema
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get customizable model schema

View File

@@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Sequence
from typing import Optional, Union
from typing import Union
from pydantic import ConfigDict
@@ -94,12 +94,12 @@ class LargeLanguageModel(AIModel):
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
model_parameters: dict | None = None,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
"""
Invoke large language model
@@ -243,11 +243,11 @@ class LargeLanguageModel(AIModel):
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Invoke result generator
@@ -328,7 +328,7 @@ class LargeLanguageModel(AIModel):
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
tools: list[PromptMessageTool] | None = None,
) -> int:
"""
Get number of tokens for given prompt messages
@@ -403,11 +403,11 @@ class LargeLanguageModel(AIModel):
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
):
"""
Trigger before invoke callbacks
@@ -451,11 +451,11 @@ class LargeLanguageModel(AIModel):
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
):
"""
Trigger new chunk callbacks
@@ -498,11 +498,11 @@ class LargeLanguageModel(AIModel):
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
):
"""
Trigger after invoke callbacks
@@ -548,11 +548,11 @@ class LargeLanguageModel(AIModel):
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
):
"""
Trigger invoke error callbacks

View File

@@ -1,5 +1,4 @@
import time
from typing import Optional
from pydantic import ConfigDict
@@ -18,7 +17,7 @@ class ModerationModel(AIModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool:
"""
Invoke moderation model

View File

@@ -1,5 +1,3 @@
from typing import Optional
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
@@ -19,9 +17,9 @@ class RerankModel(AIModel):
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> RerankResult:
"""
Invoke rerank model

View File

@@ -1,4 +1,4 @@
from typing import IO, Optional
from typing import IO
from pydantic import ConfigDict
@@ -17,7 +17,7 @@ class Speech2TextModel(AIModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str:
"""
Invoke speech to text model

View File

@@ -1,5 +1,3 @@
from typing import Optional
from pydantic import ConfigDict
from core.entities.embedding_type import EmbeddingInputType
@@ -24,7 +22,7 @@ class TextEmbeddingModel(AIModel):
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""

View File

@@ -1,10 +1,10 @@
import logging
from threading import Lock
from typing import Any, Optional
from typing import Any
logger = logging.getLogger(__name__)
_tokenizer: Optional[Any] = None
_tokenizer: Any | None = None
_lock = Lock()

View File

@@ -1,6 +1,5 @@
import logging
from collections.abc import Iterable
from typing import Optional
from pydantic import ConfigDict
@@ -28,7 +27,7 @@ class TTSModel(AIModel):
credentials: dict,
content_text: str,
voice: str,
user: Optional[str] = None,
user: str | None = None,
) -> Iterable[bytes]:
"""
Invoke large language model
@@ -56,7 +55,7 @@ class TTSModel(AIModel):
except Exception as e:
raise self._transform_invoke_error(e)
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None):
def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None):
"""
Retrieves the list of voices supported by a given text-to-speech (TTS) model.

View File

@@ -2,7 +2,6 @@ import hashlib
import logging
from collections.abc import Sequence
from threading import Lock
from typing import Optional
import contexts
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
@@ -206,9 +205,9 @@ class ModelProviderFactory:
def get_models(
self,
*,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
provider_configs: Optional[list[ProviderConfig]] = None,
provider: str | None = None,
model_type: ModelType | None = None,
provider_configs: list[ProviderConfig] | None = None,
) -> list[SimpleProviderEntity]:
"""
Get all models for given model type

View File

@@ -8,7 +8,7 @@ from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6
from pathlib import Path, PurePath
from re import Pattern
from types import GeneratorType
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Union
from uuid import UUID
from pydantic import BaseModel
@@ -98,7 +98,7 @@ def jsonable_encoder(
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
custom_encoder: dict[Any, Callable[[Any], Any]] | None = None,
sqlalchemy_safe: bool = True,
) -> Any:
custom_encoder = custom_encoder or {}