chore: adopt StrEnum and auto() for some string-typed enums (#25129)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
@@ -1,20 +1,20 @@
|
||||
from abc import ABC
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(Enum):
|
||||
class PromptMessageRole(StrEnum):
|
||||
"""
|
||||
Enum class for prompt message.
|
||||
"""
|
||||
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
SYSTEM = auto()
|
||||
USER = auto()
|
||||
ASSISTANT = auto()
|
||||
TOOL = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "PromptMessageRole":
|
||||
@@ -54,11 +54,11 @@ class PromptMessageContentType(StrEnum):
|
||||
Enum class for prompt message content type.
|
||||
"""
|
||||
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
DOCUMENT = "document"
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
AUDIO = auto()
|
||||
VIDEO = auto()
|
||||
DOCUMENT = auto()
|
||||
|
||||
|
||||
class PromptMessageContent(ABC, BaseModel):
|
||||
@@ -108,8 +108,8 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
"""
|
||||
|
||||
class DETAIL(StrEnum):
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
LOW = auto()
|
||||
HIGH = auto()
|
||||
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from decimal import Decimal
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
@@ -7,17 +7,17 @@ from pydantic import BaseModel, ConfigDict, model_validator
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
class ModelType(StrEnum):
|
||||
"""
|
||||
Enum class for model type.
|
||||
"""
|
||||
|
||||
LLM = "llm"
|
||||
LLM = auto()
|
||||
TEXT_EMBEDDING = "text-embedding"
|
||||
RERANK = "rerank"
|
||||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
TTS = "tts"
|
||||
RERANK = auto()
|
||||
SPEECH2TEXT = auto()
|
||||
MODERATION = auto()
|
||||
TTS = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_model_type: str) -> "ModelType":
|
||||
@@ -26,17 +26,17 @@ class ModelType(Enum):
|
||||
|
||||
:return: model type
|
||||
"""
|
||||
if origin_model_type in {"text-generation", cls.LLM.value}:
|
||||
if origin_model_type in {"text-generation", cls.LLM}:
|
||||
return cls.LLM
|
||||
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
|
||||
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}:
|
||||
return cls.TEXT_EMBEDDING
|
||||
elif origin_model_type in {"reranking", cls.RERANK.value}:
|
||||
elif origin_model_type in {"reranking", cls.RERANK}:
|
||||
return cls.RERANK
|
||||
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
|
||||
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}:
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type in {"tts", cls.TTS.value}:
|
||||
elif origin_model_type in {"tts", cls.TTS}:
|
||||
return cls.TTS
|
||||
elif origin_model_type == cls.MODERATION.value:
|
||||
elif origin_model_type == cls.MODERATION:
|
||||
return cls.MODERATION
|
||||
else:
|
||||
raise ValueError(f"invalid origin model type {origin_model_type}")
|
||||
@@ -63,7 +63,7 @@ class ModelType(Enum):
|
||||
raise ValueError(f"invalid model type {self}")
|
||||
|
||||
|
||||
class FetchFrom(Enum):
|
||||
class FetchFrom(StrEnum):
|
||||
"""
|
||||
Enum class for fetch from.
|
||||
"""
|
||||
@@ -72,7 +72,7 @@ class FetchFrom(Enum):
|
||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
|
||||
class ModelFeature(Enum):
|
||||
class ModelFeature(StrEnum):
|
||||
"""
|
||||
Enum class for llm feature.
|
||||
"""
|
||||
@@ -80,11 +80,11 @@ class ModelFeature(Enum):
|
||||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = "vision"
|
||||
VISION = auto()
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
DOCUMENT = auto()
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
STRUCTURED_OUTPUT = "structured-output"
|
||||
|
||||
|
||||
@@ -93,14 +93,14 @@ class DefaultParameterName(StrEnum):
|
||||
Enum class for parameter template variable.
|
||||
"""
|
||||
|
||||
TEMPERATURE = "temperature"
|
||||
TOP_P = "top_p"
|
||||
TOP_K = "top_k"
|
||||
PRESENCE_PENALTY = "presence_penalty"
|
||||
FREQUENCY_PENALTY = "frequency_penalty"
|
||||
MAX_TOKENS = "max_tokens"
|
||||
RESPONSE_FORMAT = "response_format"
|
||||
JSON_SCHEMA = "json_schema"
|
||||
TEMPERATURE = auto()
|
||||
TOP_P = auto()
|
||||
TOP_K = auto()
|
||||
PRESENCE_PENALTY = auto()
|
||||
FREQUENCY_PENALTY = auto()
|
||||
MAX_TOKENS = auto()
|
||||
RESPONSE_FORMAT = auto()
|
||||
JSON_SCHEMA = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: Any) -> "DefaultParameterName":
|
||||
@@ -116,34 +116,34 @@ class DefaultParameterName(StrEnum):
|
||||
raise ValueError(f"invalid parameter name {value}")
|
||||
|
||||
|
||||
class ParameterType(Enum):
|
||||
class ParameterType(StrEnum):
|
||||
"""
|
||||
Enum class for parameter type.
|
||||
"""
|
||||
|
||||
FLOAT = "float"
|
||||
INT = "int"
|
||||
STRING = "string"
|
||||
BOOLEAN = "boolean"
|
||||
TEXT = "text"
|
||||
FLOAT = auto()
|
||||
INT = auto()
|
||||
STRING = auto()
|
||||
BOOLEAN = auto()
|
||||
TEXT = auto()
|
||||
|
||||
|
||||
class ModelPropertyKey(Enum):
|
||||
class ModelPropertyKey(StrEnum):
|
||||
"""
|
||||
Enum class for model property key.
|
||||
"""
|
||||
|
||||
MODE = "mode"
|
||||
CONTEXT_SIZE = "context_size"
|
||||
MAX_CHUNKS = "max_chunks"
|
||||
FILE_UPLOAD_LIMIT = "file_upload_limit"
|
||||
SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
|
||||
MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
|
||||
DEFAULT_VOICE = "default_voice"
|
||||
VOICES = "voices"
|
||||
WORD_LIMIT = "word_limit"
|
||||
AUDIO_TYPE = "audio_type"
|
||||
MAX_WORKERS = "max_workers"
|
||||
MODE = auto()
|
||||
CONTEXT_SIZE = auto()
|
||||
MAX_CHUNKS = auto()
|
||||
FILE_UPLOAD_LIMIT = auto()
|
||||
SUPPORTED_FILE_EXTENSIONS = auto()
|
||||
MAX_CHARACTERS_PER_CHUNK = auto()
|
||||
DEFAULT_VOICE = auto()
|
||||
VOICES = auto()
|
||||
WORD_LIMIT = auto()
|
||||
AUDIO_TYPE = auto()
|
||||
MAX_WORKERS = auto()
|
||||
|
||||
|
||||
class ProviderModel(BaseModel):
|
||||
@@ -220,13 +220,13 @@ class ModelUsage(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class PriceType(Enum):
|
||||
class PriceType(StrEnum):
|
||||
"""
|
||||
Enum class for price type.
|
||||
"""
|
||||
|
||||
INPUT = "input"
|
||||
OUTPUT = "output"
|
||||
INPUT = auto()
|
||||
OUTPUT = auto()
|
||||
|
||||
|
||||
class PriceInfo(BaseModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum, auto
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
@@ -17,16 +17,16 @@ class ConfigurateMethod(Enum):
|
||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
|
||||
class FormType(Enum):
|
||||
class FormType(StrEnum):
|
||||
"""
|
||||
Enum class for form type.
|
||||
"""
|
||||
|
||||
TEXT_INPUT = "text-input"
|
||||
SECRET_INPUT = "secret-input"
|
||||
SELECT = "select"
|
||||
RADIO = "radio"
|
||||
SWITCH = "switch"
|
||||
SELECT = auto()
|
||||
RADIO = auto()
|
||||
SWITCH = auto()
|
||||
|
||||
|
||||
class FormShowOnObject(BaseModel):
|
||||
|
||||
@@ -47,7 +47,7 @@ class TextEmbeddingModel(AIModel):
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=input_type.value,
|
||||
input_type=input_type,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@@ -18,7 +18,7 @@ from pydantic_core import Url
|
||||
from pydantic_extra_types.color import Color
|
||||
|
||||
|
||||
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any):
|
||||
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any:
|
||||
return model.model_dump(mode=mode, **kwargs)
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ def jsonable_encoder(
|
||||
exclude_none: bool = False,
|
||||
custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
|
||||
sqlalchemy_safe: bool = True,
|
||||
):
|
||||
) -> Any:
|
||||
custom_encoder = custom_encoder or {}
|
||||
if custom_encoder:
|
||||
if type(obj) in custom_encoder:
|
||||
|
||||
Reference in New Issue
Block a user