chore: adopt StrEnum and auto() for some string-typed enums (#25129)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Krito.
2025-09-12 21:14:26 +08:00
committed by GitHub
parent 635e7d3e70
commit a13d7987e0
68 changed files with 558 additions and 559 deletions

View File

@@ -1,20 +1,20 @@
from abc import ABC
from collections.abc import Mapping, Sequence
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Annotated, Any, Literal, Optional, Union
from pydantic import BaseModel, Field, field_serializer, field_validator
class PromptMessageRole(Enum):
class PromptMessageRole(StrEnum):
"""
Enum class for prompt message.
"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
SYSTEM = auto()
USER = auto()
ASSISTANT = auto()
TOOL = auto()
@classmethod
def value_of(cls, value: str) -> "PromptMessageRole":
@@ -54,11 +54,11 @@ class PromptMessageContentType(StrEnum):
Enum class for prompt message content type.
"""
TEXT = "text"
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
DOCUMENT = "document"
TEXT = auto()
IMAGE = auto()
AUDIO = auto()
VIDEO = auto()
DOCUMENT = auto()
class PromptMessageContent(ABC, BaseModel):
@@ -108,8 +108,8 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
"""
class DETAIL(StrEnum):
LOW = "low"
HIGH = "high"
LOW = auto()
HIGH = auto()
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW

View File

@@ -1,5 +1,5 @@
from decimal import Decimal
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, model_validator
@@ -7,17 +7,17 @@ from pydantic import BaseModel, ConfigDict, model_validator
from core.model_runtime.entities.common_entities import I18nObject
class ModelType(Enum):
class ModelType(StrEnum):
"""
Enum class for model type.
"""
LLM = "llm"
LLM = auto()
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
TTS = "tts"
RERANK = auto()
SPEECH2TEXT = auto()
MODERATION = auto()
TTS = auto()
@classmethod
def value_of(cls, origin_model_type: str) -> "ModelType":
@@ -26,17 +26,17 @@ class ModelType(Enum):
:return: model type
"""
if origin_model_type in {"text-generation", cls.LLM.value}:
if origin_model_type in {"text-generation", cls.LLM}:
return cls.LLM
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}:
return cls.TEXT_EMBEDDING
elif origin_model_type in {"reranking", cls.RERANK.value}:
elif origin_model_type in {"reranking", cls.RERANK}:
return cls.RERANK
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}:
return cls.SPEECH2TEXT
elif origin_model_type in {"tts", cls.TTS.value}:
elif origin_model_type in {"tts", cls.TTS}:
return cls.TTS
elif origin_model_type == cls.MODERATION.value:
elif origin_model_type == cls.MODERATION:
return cls.MODERATION
else:
raise ValueError(f"invalid origin model type {origin_model_type}")
@@ -63,7 +63,7 @@ class ModelType(Enum):
raise ValueError(f"invalid model type {self}")
class FetchFrom(Enum):
class FetchFrom(StrEnum):
"""
Enum class for fetch from.
"""
@@ -72,7 +72,7 @@ class FetchFrom(Enum):
CUSTOMIZABLE_MODEL = "customizable-model"
class ModelFeature(Enum):
class ModelFeature(StrEnum):
"""
Enum class for llm feature.
"""
@@ -80,11 +80,11 @@ class ModelFeature(Enum):
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()
STRUCTURED_OUTPUT = "structured-output"
@@ -93,14 +93,14 @@ class DefaultParameterName(StrEnum):
Enum class for parameter template variable.
"""
TEMPERATURE = "temperature"
TOP_P = "top_p"
TOP_K = "top_k"
PRESENCE_PENALTY = "presence_penalty"
FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens"
RESPONSE_FORMAT = "response_format"
JSON_SCHEMA = "json_schema"
TEMPERATURE = auto()
TOP_P = auto()
TOP_K = auto()
PRESENCE_PENALTY = auto()
FREQUENCY_PENALTY = auto()
MAX_TOKENS = auto()
RESPONSE_FORMAT = auto()
JSON_SCHEMA = auto()
@classmethod
def value_of(cls, value: Any) -> "DefaultParameterName":
@@ -116,34 +116,34 @@ class DefaultParameterName(StrEnum):
raise ValueError(f"invalid parameter name {value}")
class ParameterType(Enum):
class ParameterType(StrEnum):
"""
Enum class for parameter type.
"""
FLOAT = "float"
INT = "int"
STRING = "string"
BOOLEAN = "boolean"
TEXT = "text"
FLOAT = auto()
INT = auto()
STRING = auto()
BOOLEAN = auto()
TEXT = auto()
class ModelPropertyKey(Enum):
class ModelPropertyKey(StrEnum):
"""
Enum class for model property key.
"""
MODE = "mode"
CONTEXT_SIZE = "context_size"
MAX_CHUNKS = "max_chunks"
FILE_UPLOAD_LIMIT = "file_upload_limit"
SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
DEFAULT_VOICE = "default_voice"
VOICES = "voices"
WORD_LIMIT = "word_limit"
AUDIO_TYPE = "audio_type"
MAX_WORKERS = "max_workers"
MODE = auto()
CONTEXT_SIZE = auto()
MAX_CHUNKS = auto()
FILE_UPLOAD_LIMIT = auto()
SUPPORTED_FILE_EXTENSIONS = auto()
MAX_CHARACTERS_PER_CHUNK = auto()
DEFAULT_VOICE = auto()
VOICES = auto()
WORD_LIMIT = auto()
AUDIO_TYPE = auto()
MAX_WORKERS = auto()
class ProviderModel(BaseModel):
@@ -220,13 +220,13 @@ class ModelUsage(BaseModel):
pass
class PriceType(Enum):
class PriceType(StrEnum):
"""
Enum class for price type.
"""
INPUT = "input"
OUTPUT = "output"
INPUT = auto()
OUTPUT = auto()
class PriceInfo(BaseModel):

View File

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import Enum, StrEnum, auto
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -17,16 +17,16 @@ class ConfigurateMethod(Enum):
CUSTOMIZABLE_MODEL = "customizable-model"
class FormType(Enum):
class FormType(StrEnum):
"""
Enum class for form type.
"""
TEXT_INPUT = "text-input"
SECRET_INPUT = "secret-input"
SELECT = "select"
RADIO = "radio"
SWITCH = "switch"
SELECT = auto()
RADIO = auto()
SWITCH = auto()
class FormShowOnObject(BaseModel):

View File

@@ -47,7 +47,7 @@ class TextEmbeddingModel(AIModel):
model=model,
credentials=credentials,
texts=texts,
input_type=input_type.value,
input_type=input_type,
)
except Exception as e:
raise self._transform_invoke_error(e)

View File

@@ -18,7 +18,7 @@ from pydantic_core import Url
from pydantic_extra_types.color import Color
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any):
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any:
return model.model_dump(mode=mode, **kwargs)
@@ -100,7 +100,7 @@ def jsonable_encoder(
exclude_none: bool = False,
custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
sqlalchemy_safe: bool = True,
):
) -> Any:
custom_encoder = custom_encoder or {}
if custom_encoder:
if type(obj) in custom_encoder: