chore: add ast-grep rule to convert Optional[T] to T | None (#25560)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@@ -13,7 +12,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
tenant_id: str
|
||||
top_k: int = 4
|
||||
score_threshold: Optional[float] = None
|
||||
score_threshold: float | None = None
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
@@ -37,7 +37,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
dataset_id: str
|
||||
user_id: Optional[str] = None
|
||||
user_id: str | None = None
|
||||
retrieve_config: DatasetRetrieveConfigEntity
|
||||
inputs: dict
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@@ -87,9 +87,9 @@ class DatasetRetrieverTool(Tool):
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> list[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
@@ -112,9 +112,9 @@ class DatasetRetrieverTool(Tool):
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke dataset retriever tool
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Protocol
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
@@ -13,7 +13,7 @@ class ProviderConfigCache(Protocol):
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
def get(self) -> dict | None:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from collections.abc import Generator
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from mimetypes import guess_extension
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
@@ -60,7 +59,7 @@ class ToolFileMessageTransformer:
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
conversation_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Transform tool message and handle file download
|
||||
@@ -165,5 +164,5 @@ class ToolFileMessageTransformer:
|
||||
yield message
|
||||
|
||||
@classmethod
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str:
|
||||
return f"/files/tools/{tool_file_id}{extension or '.bin'}"
|
||||
|
||||
@@ -5,7 +5,7 @@ Therefore, a model manager is needed to list/invoke/validate models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
@@ -51,7 +51,7 @@ class ModelInvocationUtils:
|
||||
if not schema:
|
||||
raise InvokeModelError("No model schema found")
|
||||
|
||||
max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
|
||||
max_tokens: int | None = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
|
||||
if max_tokens is None:
|
||||
return 2048
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import re
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
@@ -198,9 +197,9 @@ class ApiBasedToolSchemaParser:
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
||||
parameter = parameter or {}
|
||||
typ: Optional[str] = None
|
||||
typ: str | None = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import base64
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Random import get_random_bytes
|
||||
@@ -28,7 +28,7 @@ class SystemOAuthEncrypter:
|
||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: Optional[str] = None):
|
||||
def __init__(self, secret_key: str | None = None):
|
||||
"""
|
||||
Initialize the OAuth encrypter.
|
||||
|
||||
@@ -130,7 +130,7 @@ class SystemOAuthEncrypter:
|
||||
|
||||
|
||||
# Factory function for creating encrypter instances
|
||||
def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter:
|
||||
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
|
||||
"""
|
||||
Create an OAuth encrypter instance.
|
||||
|
||||
@@ -144,7 +144,7 @@ def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAu
|
||||
|
||||
|
||||
# Global encrypter instance (for backward compatibility)
|
||||
_oauth_encrypter: Optional[SystemOAuthEncrypter] = None
|
||||
_oauth_encrypter: SystemOAuthEncrypter | None = None
|
||||
|
||||
|
||||
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
||||
|
||||
@@ -2,7 +2,7 @@ import mimetypes
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
@@ -27,7 +27,7 @@ def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||
return text[cursor : cursor + max_length]
|
||||
|
||||
|
||||
def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
||||
def get_url(url: str, user_agent: str | None = None) -> str:
|
||||
"""Fetch URL and return the contents as a string."""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
||||
|
||||
Reference in New Issue
Block a user