chore: adopt StrEnum and auto() for some string-typed enums (#25129)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Krito.
2025-09-12 21:14:26 +08:00
committed by GitHub
parent 635e7d3e70
commit a13d7987e0
68 changed files with 558 additions and 559 deletions

View File

@@ -1,4 +1,4 @@
import enum
from enum import StrEnum
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
@@ -26,25 +26,25 @@ class AgentStrategyProviderIdentity(ToolProviderIdentity):
class AgentStrategyParameter(PluginParameter):
class AgentStrategyParameterType(enum.StrEnum):
class AgentStrategyParameterType(StrEnum):
"""
Keep all the types from PluginParameterType
"""
STRING = CommonParameterType.STRING.value
NUMBER = CommonParameterType.NUMBER.value
BOOLEAN = CommonParameterType.BOOLEAN.value
SELECT = CommonParameterType.SELECT.value
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
FILE = CommonParameterType.FILE.value
FILES = CommonParameterType.FILES.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
ANY = CommonParameterType.ANY.value
STRING = CommonParameterType.STRING
NUMBER = CommonParameterType.NUMBER
BOOLEAN = CommonParameterType.BOOLEAN
SELECT = CommonParameterType.SELECT
SECRET_INPUT = CommonParameterType.SECRET_INPUT
FILE = CommonParameterType.FILE
FILES = CommonParameterType.FILES
APP_SELECTOR = CommonParameterType.APP_SELECTOR
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
ANY = CommonParameterType.ANY
# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES
def as_normal_type(self):
return as_normal_type(self)
@@ -72,7 +72,7 @@ class AgentStrategyIdentity(ToolIdentity):
pass
class AgentFeature(enum.StrEnum):
class AgentFeature(StrEnum):
"""
Agent Feature, used to describe the features of the agent strategy.
"""

View File

@@ -70,7 +70,7 @@ class PromptTemplateConfigManager:
:param config: app model config args
"""
if not config.get("prompt_type"):
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
if config["prompt_type"] not in prompt_type_vals:
@@ -90,7 +90,7 @@ class PromptTemplateConfigManager:
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED:
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
raise ValueError(
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"

View File

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
@@ -61,14 +61,14 @@ class PromptTemplateEntity(BaseModel):
Prompt Template Entity.
"""
class PromptType(Enum):
class PromptType(StrEnum):
"""
Prompt Type.
'simple', 'advanced'
"""
SIMPLE = "simple"
ADVANCED = "advanced"
SIMPLE = auto()
ADVANCED = auto()
@classmethod
def value_of(cls, value: str):
@@ -195,14 +195,14 @@ class DatasetRetrieveConfigEntity(BaseModel):
Dataset Retrieve Config Entity.
"""
class RetrieveStrategy(Enum):
class RetrieveStrategy(StrEnum):
"""
Dataset Retrieve Strategy.
'single' or 'multiple'
"""
SINGLE = "single"
MULTIPLE = "multiple"
SINGLE = auto()
MULTIPLE = auto()
@classmethod
def value_of(cls, value: str):
@@ -293,12 +293,12 @@ class AppConfig(BaseModel):
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
class EasyUIBasedAppModelConfigFrom(Enum):
class EasyUIBasedAppModelConfigFrom(StrEnum):
"""
App Model Config From.
"""
ARGS = "args"
ARGS = auto()
APP_LATEST_CONFIG = "app-latest-config"
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"

View File

@@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Any, Optional
from pydantic import BaseModel
@@ -626,15 +626,15 @@ class QueueStopEvent(AppQueueEvent):
QueueStopEvent entity
"""
class StopBy(Enum):
class StopBy(StrEnum):
"""
Stop by enum
"""
USER_MANUAL = "user-manual"
ANNOTATION_REPLY = "annotation-reply"
OUTPUT_MODERATION = "output-moderation"
INPUT_MODERATION = "input-moderation"
USER_MANUAL = auto()
ANNOTATION_REPLY = auto()
OUTPUT_MODERATION = auto()
INPUT_MODERATION = auto()
event: QueueEvent = QueueEvent.STOP
stopped_by: StopBy

View File

@@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from enum import StrEnum, auto
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field
@@ -50,37 +50,37 @@ class WorkflowTaskState(TaskState):
answer: str = ""
class StreamEvent(Enum):
class StreamEvent(StrEnum):
"""
Stream event
"""
PING = "ping"
ERROR = "error"
MESSAGE = "message"
MESSAGE_END = "message_end"
TTS_MESSAGE = "tts_message"
TTS_MESSAGE_END = "tts_message_end"
MESSAGE_FILE = "message_file"
MESSAGE_REPLACE = "message_replace"
AGENT_THOUGHT = "agent_thought"
AGENT_MESSAGE = "agent_message"
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
NODE_RETRY = "node_retry"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
LOOP_STARTED = "loop_started"
LOOP_NEXT = "loop_next"
LOOP_COMPLETED = "loop_completed"
TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace"
AGENT_LOG = "agent_log"
PING = auto()
ERROR = auto()
MESSAGE = auto()
MESSAGE_END = auto()
TTS_MESSAGE = auto()
TTS_MESSAGE_END = auto()
MESSAGE_FILE = auto()
MESSAGE_REPLACE = auto()
AGENT_THOUGHT = auto()
AGENT_MESSAGE = auto()
WORKFLOW_STARTED = auto()
WORKFLOW_FINISHED = auto()
NODE_STARTED = auto()
NODE_FINISHED = auto()
NODE_RETRY = auto()
PARALLEL_BRANCH_STARTED = auto()
PARALLEL_BRANCH_FINISHED = auto()
ITERATION_STARTED = auto()
ITERATION_NEXT = auto()
ITERATION_COMPLETED = auto()
LOOP_STARTED = auto()
LOOP_NEXT = auto()
LOOP_COMPLETED = auto()
TEXT_CHUNK = auto()
TEXT_REPLACE = auto()
AGENT_LOG = auto()
class StreamResponse(BaseModel):

View File

@@ -145,7 +145,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation_mode == AppMode.COMPLETION.value:
if self._conversation_mode == AppMode.COMPLETION:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=CompletionAppBlockingResponse.Data(

View File

@@ -92,7 +92,7 @@ class MessageCycleManager:
if not conversation:
return
if conversation.mode != AppMode.COMPLETION.value:
if conversation.mode != AppMode.COMPLETION:
app_model = conversation.app
if not app_model:
return

View File

@@ -1,8 +1,8 @@
from enum import Enum
from enum import StrEnum, auto
class PlanningStrategy(Enum):
ROUTER = "router"
REACT_ROUTER = "react_router"
REACT = "react"
FUNCTION_CALL = "function_call"
class PlanningStrategy(StrEnum):
ROUTER = auto()
REACT_ROUTER = auto()
REACT = auto()
FUNCTION_CALL = auto()

View File

@@ -1,10 +1,10 @@
from enum import Enum
from enum import StrEnum, auto
class EmbeddingInputType(Enum):
class EmbeddingInputType(StrEnum):
"""
Enum for embedding input type.
"""
DOCUMENT = "document"
QUERY = "query"
DOCUMENT = auto()
QUERY = auto()

View File

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import StrEnum, auto
from typing import Optional
from pydantic import BaseModel, ConfigDict
@@ -9,16 +9,16 @@ from core.model_runtime.entities.model_entities import ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ProviderEntity
class ModelStatus(Enum):
class ModelStatus(StrEnum):
"""
Enum class for model status.
"""
ACTIVE = "active"
ACTIVE = auto()
NO_CONFIGURE = "no-configure"
QUOTA_EXCEEDED = "quota-exceeded"
NO_PERMISSION = "no-permission"
DISABLED = "disabled"
DISABLED = auto()
CREDENTIAL_REMOVED = "credential-removed"

View File

@@ -1,20 +1,20 @@
from enum import StrEnum
from enum import StrEnum, auto
class CommonParameterType(StrEnum):
SECRET_INPUT = "secret-input"
TEXT_INPUT = "text-input"
SELECT = "select"
STRING = "string"
NUMBER = "number"
FILE = "file"
FILES = "files"
SELECT = auto()
STRING = auto()
NUMBER = auto()
FILE = auto()
FILES = auto()
SYSTEM_FILES = "system-files"
BOOLEAN = "boolean"
BOOLEAN = auto()
APP_SELECTOR = "app-selector"
MODEL_SELECTOR = "model-selector"
TOOLS_SELECTOR = "array[tools]"
ANY = "any"
ANY = auto()
# Dynamic select parameter
# Once you are not sure about the available options until authorization is done
@@ -23,29 +23,29 @@ class CommonParameterType(StrEnum):
# TOOL_SELECTOR = "tool-selector"
# MCP object and array type parameters
ARRAY = "array"
OBJECT = "object"
ARRAY = auto()
OBJECT = auto()
class AppSelectorScope(StrEnum):
ALL = "all"
CHAT = "chat"
WORKFLOW = "workflow"
COMPLETION = "completion"
ALL = auto()
CHAT = auto()
WORKFLOW = auto()
COMPLETION = auto()
class ModelSelectorScope(StrEnum):
LLM = "llm"
LLM = auto()
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
TTS = "tts"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
VISION = "vision"
RERANK = auto()
TTS = auto()
SPEECH2TEXT = auto()
MODERATION = auto()
VISION = auto()
class ToolSelectorScope(StrEnum):
ALL = "all"
CUSTOM = "custom"
BUILTIN = "builtin"
WORKFLOW = "workflow"
ALL = auto()
CUSTOM = auto()
BUILTIN = auto()
WORKFLOW = auto()

View File

@@ -1,4 +1,4 @@
from enum import Enum
from enum import StrEnum, auto
from typing import Optional, Union
from pydantic import BaseModel, ConfigDict, Field
@@ -13,14 +13,14 @@ from core.model_runtime.entities.model_entities import ModelType
from core.tools.entities.common_entities import I18nObject
class ProviderQuotaType(Enum):
PAID = "paid"
class ProviderQuotaType(StrEnum):
PAID = auto()
"""hosted paid quota"""
FREE = "free"
FREE = auto()
"""third-party free quota"""
TRIAL = "trial"
TRIAL = auto()
"""hosted trial quota"""
@staticmethod
@@ -31,20 +31,20 @@ class ProviderQuotaType(Enum):
raise ValueError(f"No matching enum found for value '{value}'")
class QuotaUnit(Enum):
TIMES = "times"
TOKENS = "tokens"
CREDITS = "credits"
class QuotaUnit(StrEnum):
TIMES = auto()
TOKENS = auto()
CREDITS = auto()
class SystemConfigurationStatus(Enum):
class SystemConfigurationStatus(StrEnum):
"""
Enum class for system configuration status.
"""
ACTIVE = "active"
ACTIVE = auto()
QUOTA_EXCEEDED = "quota-exceeded"
UNSUPPORTED = "unsupported"
UNSUPPORTED = auto()
class RestrictModel(BaseModel):
@@ -168,14 +168,14 @@ class BasicProviderConfig(BaseModel):
Base model class for common provider settings like credentials
"""
class Type(Enum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
SELECT = CommonParameterType.SELECT.value
BOOLEAN = CommonParameterType.BOOLEAN.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
class Type(StrEnum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT
TEXT_INPUT = CommonParameterType.TEXT_INPUT
SELECT = CommonParameterType.SELECT
BOOLEAN = CommonParameterType.BOOLEAN
APP_SELECTOR = CommonParameterType.APP_SELECTOR
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":

View File

@@ -1,8 +1,8 @@
import enum
import importlib.util
import json
import logging
import os
from enum import StrEnum, auto
from pathlib import Path
from typing import Any, Optional
@@ -13,9 +13,9 @@ from core.helper.position_helper import sort_to_dict_by_position_map
logger = logging.getLogger(__name__)
class ExtensionModule(enum.Enum):
MODERATION = "moderation"
EXTERNAL_DATA_TOOL = "external_data_tool"
class ExtensionModule(StrEnum):
MODERATION = auto()
EXTERNAL_DATA_TOOL = auto()
class ModuleExtension(BaseModel):

View File

@@ -1,12 +1,12 @@
import json
from enum import Enum
from enum import StrEnum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ProviderCredentialsCacheType(Enum):
class ProviderCredentialsCacheType(StrEnum):
PROVIDER = "provider"
MODEL = "provider_model"
LOAD_BALANCING_MODEL = "load_balancing_provider_model"
@@ -14,7 +14,7 @@ class ProviderCredentialsCacheType(Enum):
class ProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
"""

View File

@@ -1,12 +1,12 @@
import json
from enum import Enum
from enum import StrEnum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolParameterCacheType(Enum):
class ToolParameterCacheType(StrEnum):
PARAMETER = "tool_parameter"
@@ -15,7 +15,7 @@ class ToolParameterCache:
self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str
):
self.cache_key = (
f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
f"{cache_type}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
f":identity_id:{identity_id}"
)

View File

@@ -142,7 +142,7 @@ def handle_call_tool(
end_user,
args,
InvokeFrom.SERVICE_API,
streaming=app.mode == AppMode.AGENT_CHAT.value,
streaming=app.mode == AppMode.AGENT_CHAT,
)
answer = extract_answer_from_response(app, response)
@@ -157,7 +157,7 @@ def build_parameter_schema(
"""Build parameter schema for the tool"""
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
if app_mode in {AppMode.COMPLETION, AppMode.WORKFLOW}:
return {
"type": "object",
"properties": parameters,
@@ -175,9 +175,9 @@ def build_parameter_schema(
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
"""Prepare arguments based on app mode"""
if app.mode == AppMode.WORKFLOW.value:
if app.mode == AppMode.WORKFLOW:
return {"inputs": arguments}
elif app.mode == AppMode.COMPLETION.value:
elif app.mode == AppMode.COMPLETION:
return {"query": "", "inputs": arguments}
else:
# Chat modes - create a copy to avoid modifying original dict
@@ -218,13 +218,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str:
def process_mapping_response(app: App, response: Mapping) -> str:
"""Process mapping response based on app mode"""
if app.mode in {
AppMode.ADVANCED_CHAT.value,
AppMode.COMPLETION.value,
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
AppMode.ADVANCED_CHAT,
AppMode.COMPLETION,
AppMode.CHAT,
AppMode.AGENT_CHAT,
}:
return response.get("answer", "")
elif app.mode == AppMode.WORKFLOW.value:
elif app.mode == AppMode.WORKFLOW:
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode: " + str(app.mode))

View File

@@ -1,20 +1,20 @@
from abc import ABC
from collections.abc import Mapping, Sequence
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Annotated, Any, Literal, Optional, Union
from pydantic import BaseModel, Field, field_serializer, field_validator
class PromptMessageRole(Enum):
class PromptMessageRole(StrEnum):
"""
Enum class for prompt message.
"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
SYSTEM = auto()
USER = auto()
ASSISTANT = auto()
TOOL = auto()
@classmethod
def value_of(cls, value: str) -> "PromptMessageRole":
@@ -54,11 +54,11 @@ class PromptMessageContentType(StrEnum):
Enum class for prompt message content type.
"""
TEXT = "text"
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
DOCUMENT = "document"
TEXT = auto()
IMAGE = auto()
AUDIO = auto()
VIDEO = auto()
DOCUMENT = auto()
class PromptMessageContent(ABC, BaseModel):
@@ -108,8 +108,8 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
"""
class DETAIL(StrEnum):
LOW = "low"
HIGH = "high"
LOW = auto()
HIGH = auto()
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW

View File

@@ -1,5 +1,5 @@
from decimal import Decimal
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, model_validator
@@ -7,17 +7,17 @@ from pydantic import BaseModel, ConfigDict, model_validator
from core.model_runtime.entities.common_entities import I18nObject
class ModelType(Enum):
class ModelType(StrEnum):
"""
Enum class for model type.
"""
LLM = "llm"
LLM = auto()
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
TTS = "tts"
RERANK = auto()
SPEECH2TEXT = auto()
MODERATION = auto()
TTS = auto()
@classmethod
def value_of(cls, origin_model_type: str) -> "ModelType":
@@ -26,17 +26,17 @@ class ModelType(Enum):
:return: model type
"""
if origin_model_type in {"text-generation", cls.LLM.value}:
if origin_model_type in {"text-generation", cls.LLM}:
return cls.LLM
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}:
return cls.TEXT_EMBEDDING
elif origin_model_type in {"reranking", cls.RERANK.value}:
elif origin_model_type in {"reranking", cls.RERANK}:
return cls.RERANK
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}:
return cls.SPEECH2TEXT
elif origin_model_type in {"tts", cls.TTS.value}:
elif origin_model_type in {"tts", cls.TTS}:
return cls.TTS
elif origin_model_type == cls.MODERATION.value:
elif origin_model_type == cls.MODERATION:
return cls.MODERATION
else:
raise ValueError(f"invalid origin model type {origin_model_type}")
@@ -63,7 +63,7 @@ class ModelType(Enum):
raise ValueError(f"invalid model type {self}")
class FetchFrom(Enum):
class FetchFrom(StrEnum):
"""
Enum class for fetch from.
"""
@@ -72,7 +72,7 @@ class FetchFrom(Enum):
CUSTOMIZABLE_MODEL = "customizable-model"
class ModelFeature(Enum):
class ModelFeature(StrEnum):
"""
Enum class for llm feature.
"""
@@ -80,11 +80,11 @@ class ModelFeature(Enum):
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()
STRUCTURED_OUTPUT = "structured-output"
@@ -93,14 +93,14 @@ class DefaultParameterName(StrEnum):
Enum class for parameter template variable.
"""
TEMPERATURE = "temperature"
TOP_P = "top_p"
TOP_K = "top_k"
PRESENCE_PENALTY = "presence_penalty"
FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens"
RESPONSE_FORMAT = "response_format"
JSON_SCHEMA = "json_schema"
TEMPERATURE = auto()
TOP_P = auto()
TOP_K = auto()
PRESENCE_PENALTY = auto()
FREQUENCY_PENALTY = auto()
MAX_TOKENS = auto()
RESPONSE_FORMAT = auto()
JSON_SCHEMA = auto()
@classmethod
def value_of(cls, value: Any) -> "DefaultParameterName":
@@ -116,34 +116,34 @@ class DefaultParameterName(StrEnum):
raise ValueError(f"invalid parameter name {value}")
class ParameterType(Enum):
class ParameterType(StrEnum):
"""
Enum class for parameter type.
"""
FLOAT = "float"
INT = "int"
STRING = "string"
BOOLEAN = "boolean"
TEXT = "text"
FLOAT = auto()
INT = auto()
STRING = auto()
BOOLEAN = auto()
TEXT = auto()
class ModelPropertyKey(Enum):
class ModelPropertyKey(StrEnum):
"""
Enum class for model property key.
"""
MODE = "mode"
CONTEXT_SIZE = "context_size"
MAX_CHUNKS = "max_chunks"
FILE_UPLOAD_LIMIT = "file_upload_limit"
SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
DEFAULT_VOICE = "default_voice"
VOICES = "voices"
WORD_LIMIT = "word_limit"
AUDIO_TYPE = "audio_type"
MAX_WORKERS = "max_workers"
MODE = auto()
CONTEXT_SIZE = auto()
MAX_CHUNKS = auto()
FILE_UPLOAD_LIMIT = auto()
SUPPORTED_FILE_EXTENSIONS = auto()
MAX_CHARACTERS_PER_CHUNK = auto()
DEFAULT_VOICE = auto()
VOICES = auto()
WORD_LIMIT = auto()
AUDIO_TYPE = auto()
MAX_WORKERS = auto()
class ProviderModel(BaseModel):
@@ -220,13 +220,13 @@ class ModelUsage(BaseModel):
pass
class PriceType(Enum):
class PriceType(StrEnum):
"""
Enum class for price type.
"""
INPUT = "input"
OUTPUT = "output"
INPUT = auto()
OUTPUT = auto()
class PriceInfo(BaseModel):

View File

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import Enum, StrEnum, auto
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -17,16 +17,16 @@ class ConfigurateMethod(Enum):
CUSTOMIZABLE_MODEL = "customizable-model"
class FormType(Enum):
class FormType(StrEnum):
"""
Enum class for form type.
"""
TEXT_INPUT = "text-input"
SECRET_INPUT = "secret-input"
SELECT = "select"
RADIO = "radio"
SWITCH = "switch"
SELECT = auto()
RADIO = auto()
SWITCH = auto()
class FormShowOnObject(BaseModel):

View File

@@ -47,7 +47,7 @@ class TextEmbeddingModel(AIModel):
model=model,
credentials=credentials,
texts=texts,
input_type=input_type.value,
input_type=input_type,
)
except Exception as e:
raise self._transform_invoke_error(e)

View File

@@ -18,7 +18,7 @@ from pydantic_core import Url
from pydantic_extra_types.color import Color
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any):
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any:
return model.model_dump(mode=mode, **kwargs)
@@ -100,7 +100,7 @@ def jsonable_encoder(
exclude_none: bool = False,
custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
sqlalchemy_safe: bool = True,
):
) -> Any:
custom_encoder = custom_encoder or {}
if custom_encoder:
if type(obj) in custom_encoder:

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from enum import Enum
from enum import StrEnum, auto
from typing import Optional
from pydantic import BaseModel, Field
@@ -7,9 +7,9 @@ from pydantic import BaseModel, Field
from core.extension.extensible import Extensible, ExtensionModule
class ModerationAction(Enum):
DIRECT_OUTPUT = "direct_output"
OVERRIDDEN = "overridden"
class ModerationAction(StrEnum):
DIRECT_OUTPUT = auto()
OVERRIDDEN = auto()
class ModerationInputsResult(BaseModel):

View File

@@ -1,4 +1,4 @@
from enum import Enum
from enum import StrEnum
# public
GEN_AI_SESSION_ID = "gen_ai.session.id"
@@ -53,7 +53,7 @@ TOOL_DESCRIPTION = "tool.description"
TOOL_PARAMETERS = "tool.parameters"
class GenAISpanKind(Enum):
class GenAISpanKind(StrEnum):
CHAIN = "CHAIN"
RETRIEVER = "RETRIEVER"
RERANKER = "RERANKER"

View File

@@ -27,7 +27,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
app = cls._get_app(app_id, tenant_id)
"""Retrieve app parameters."""
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app.workflow
if workflow is None:
raise ValueError("unexpected app type")
@@ -70,7 +70,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
conversation_id = conversation_id or ""
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value}:
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}:
if not query:
raise ValueError("missing query")
@@ -96,7 +96,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke chat app
"""
if app.mode == AppMode.ADVANCED_CHAT.value:
if app.mode == AppMode.ADVANCED_CHAT:
workflow = app.workflow
if not workflow:
raise ValueError("unexpected app type")
@@ -114,7 +114,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
elif app.mode == AppMode.AGENT_CHAT.value:
elif app.mode == AppMode.AGENT_CHAT:
return AgentChatAppGenerator().generate(
app_model=app,
user=user,
@@ -127,7 +127,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
elif app.mode == AppMode.CHAT.value:
elif app.mode == AppMode.CHAT:
return ChatAppGenerator().generate(
app_model=app,
user=user,

View File

@@ -1,5 +1,5 @@
import enum
import json
from enum import StrEnum, auto
from typing import Any, Optional, Union
from pydantic import BaseModel, Field, field_validator
@@ -25,44 +25,44 @@ class PluginParameterOption(BaseModel):
return value
class PluginParameterType(enum.StrEnum):
class PluginParameterType(StrEnum):
"""
all available parameter types
"""
STRING = CommonParameterType.STRING.value
NUMBER = CommonParameterType.NUMBER.value
BOOLEAN = CommonParameterType.BOOLEAN.value
SELECT = CommonParameterType.SELECT.value
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
FILE = CommonParameterType.FILE.value
FILES = CommonParameterType.FILES.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
ANY = CommonParameterType.ANY.value
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value
STRING = CommonParameterType.STRING
NUMBER = CommonParameterType.NUMBER
BOOLEAN = CommonParameterType.BOOLEAN
SELECT = CommonParameterType.SELECT
SECRET_INPUT = CommonParameterType.SECRET_INPUT
FILE = CommonParameterType.FILE
FILES = CommonParameterType.FILES
APP_SELECTOR = CommonParameterType.APP_SELECTOR
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
ANY = CommonParameterType.ANY
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT
# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES
# MCP object and array type parameters
ARRAY = CommonParameterType.ARRAY.value
OBJECT = CommonParameterType.OBJECT.value
ARRAY = CommonParameterType.ARRAY
OBJECT = CommonParameterType.OBJECT
class MCPServerParameterType(enum.StrEnum):
class MCPServerParameterType(StrEnum):
"""
MCP server got complex parameter types
"""
ARRAY = "array"
OBJECT = "object"
ARRAY = auto()
OBJECT = auto()
class PluginParameterAutoGenerate(BaseModel):
class Type(enum.StrEnum):
PROMPT_INSTRUCTION = "prompt_instruction"
class Type(StrEnum):
PROMPT_INSTRUCTION = auto()
type: Type
@@ -93,7 +93,7 @@ class PluginParameter(BaseModel):
return v
def as_normal_type(typ: enum.StrEnum):
def as_normal_type(typ: StrEnum):
if typ.value in {
PluginParameterType.SECRET_INPUT,
PluginParameterType.SELECT,
@@ -102,7 +102,7 @@ def as_normal_type(typ: enum.StrEnum):
return typ.value
def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
def cast_parameter_value(typ: StrEnum, value: Any, /):
try:
match typ.value:
case PluginParameterType.STRING | PluginParameterType.SECRET_INPUT | PluginParameterType.SELECT:
@@ -190,7 +190,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: Any):
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):
"""
init frontend parameter by rule
"""

View File

@@ -1,7 +1,7 @@
import datetime
import enum
import re
from collections.abc import Mapping
from enum import StrEnum, auto
from typing import Any, Optional
from packaging.version import InvalidVersion, Version
@@ -16,11 +16,11 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntity
class PluginInstallationSource(enum.StrEnum):
Github = "github"
Marketplace = "marketplace"
Package = "package"
Remote = "remote"
class PluginInstallationSource(StrEnum):
Github = auto()
Marketplace = auto()
Package = auto()
Remote = auto()
class PluginResourceRequirements(BaseModel):
@@ -58,10 +58,10 @@ class PluginResourceRequirements(BaseModel):
permission: Optional[Permission] = Field(default=None)
class PluginCategory(enum.StrEnum):
Tool = "tool"
Model = "model"
Extension = "extension"
class PluginCategory(StrEnum):
Tool = auto()
Model = auto()
Extension = auto()
AgentStrategy = "agent-strategy"
@@ -206,10 +206,10 @@ class ToolProviderID(GenericProviderID):
class PluginDependency(BaseModel):
class Type(enum.StrEnum):
Github = PluginInstallationSource.Github.value
Marketplace = PluginInstallationSource.Marketplace.value
Package = PluginInstallationSource.Package.value
class Type(StrEnum):
Github = PluginInstallationSource.Github
Marketplace = PluginInstallationSource.Marketplace
Package = PluginInstallationSource.Package
class Github(BaseModel):
repo: str

View File

@@ -1,7 +1,7 @@
import enum
import json
import os
from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.app_config.entities import PromptTemplateEntity
@@ -25,9 +25,9 @@ if TYPE_CHECKING:
from core.file.models import File
class ModelMode(enum.StrEnum):
COMPLETION = "completion"
CHAT = "chat"
class ModelMode(StrEnum):
COMPLETION = auto()
CHAT = auto()
prompt_file_contents: dict[str, Any] = {}

View File

@@ -1,13 +1,13 @@
from enum import Enum
from enum import StrEnum, auto
class Field(Enum):
class Field(StrEnum):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = "vector"
VECTOR = auto()
# Sparse Vector aims to support full text search
SPARSE_VECTOR = "sparse_vector"
SPARSE_VECTOR = auto()
TEXT_KEY = "text"
PRIMARY_KEY = "id"
DOC_ID = "metadata.doc_id"

View File

@@ -1,7 +1,7 @@
import json
import logging
import uuid
from enum import Enum
from enum import StrEnum
from typing import Any
from clickhouse_connect import get_client
@@ -27,7 +27,7 @@ class MyScaleConfig(BaseModel):
fts_params: str
class SortOrder(Enum):
class SortOrder(StrEnum):
ASC = "ASC"
DESC = "DESC"

View File

@@ -1,7 +1,7 @@
from enum import Enum
from enum import StrEnum
class DatasourceType(Enum):
class DatasourceType(StrEnum):
FILE = "upload_file"
NOTION = "notion_import"
WEBSITE = "website_crawl"

View File

@@ -1,15 +1,15 @@
from enum import Enum, StrEnum
from enum import StrEnum, auto
class BuiltInField(StrEnum):
document_name = "document_name"
uploader = "uploader"
upload_date = "upload_date"
last_update_date = "last_update_date"
source = "source"
document_name = auto()
uploader = auto()
upload_date = auto()
last_update_date = auto()
source = auto()
class MetadataDataSource(Enum):
class MetadataDataSource(StrEnum):
upload_file = "file_upload"
website_crawl = "website"
notion_import = "notion"

View File

@@ -1,8 +1,7 @@
import base64
import contextlib
import enum
from collections.abc import Mapping
from enum import Enum
from enum import StrEnum, auto
from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
@@ -22,37 +21,37 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
class ToolLabelEnum(Enum):
SEARCH = "search"
IMAGE = "image"
VIDEOS = "videos"
WEATHER = "weather"
FINANCE = "finance"
DESIGN = "design"
TRAVEL = "travel"
SOCIAL = "social"
NEWS = "news"
MEDICAL = "medical"
PRODUCTIVITY = "productivity"
EDUCATION = "education"
BUSINESS = "business"
ENTERTAINMENT = "entertainment"
UTILITIES = "utilities"
OTHER = "other"
class ToolLabelEnum(StrEnum):
SEARCH = auto()
IMAGE = auto()
VIDEOS = auto()
WEATHER = auto()
FINANCE = auto()
DESIGN = auto()
TRAVEL = auto()
SOCIAL = auto()
NEWS = auto()
MEDICAL = auto()
PRODUCTIVITY = auto()
EDUCATION = auto()
BUSINESS = auto()
ENTERTAINMENT = auto()
UTILITIES = auto()
OTHER = auto()
class ToolProviderType(enum.StrEnum):
class ToolProviderType(StrEnum):
"""
Enum class for tool provider
"""
PLUGIN = "plugin"
PLUGIN = auto()
BUILT_IN = "builtin"
WORKFLOW = "workflow"
API = "api"
APP = "app"
WORKFLOW = auto()
API = auto()
APP = auto()
DATASET_RETRIEVAL = "dataset-retrieval"
MCP = "mcp"
MCP = auto()
@classmethod
def value_of(cls, value: str) -> "ToolProviderType":
@@ -68,15 +67,15 @@ class ToolProviderType(enum.StrEnum):
raise ValueError(f"invalid mode value {value}")
class ApiProviderSchemaType(Enum):
class ApiProviderSchemaType(StrEnum):
"""
Enum class for api provider schema type.
"""
OPENAPI = "openapi"
SWAGGER = "swagger"
OPENAI_PLUGIN = "openai_plugin"
OPENAI_ACTIONS = "openai_actions"
OPENAPI = auto()
SWAGGER = auto()
OPENAI_PLUGIN = auto()
OPENAI_ACTIONS = auto()
@classmethod
def value_of(cls, value: str) -> "ApiProviderSchemaType":
@@ -92,14 +91,14 @@ class ApiProviderSchemaType(Enum):
raise ValueError(f"invalid mode value {value}")
class ApiProviderAuthType(Enum):
class ApiProviderAuthType(StrEnum):
"""
Enum class for api provider auth type.
"""
NONE = "none"
API_KEY_HEADER = "api_key_header"
API_KEY_QUERY = "api_key_query"
NONE = auto()
API_KEY_HEADER = auto()
API_KEY_QUERY = auto()
@classmethod
def value_of(cls, value: str) -> "ApiProviderAuthType":
@@ -176,10 +175,10 @@ class ToolInvokeMessage(BaseModel):
return value
class LogMessage(BaseModel):
class LogStatus(Enum):
START = "start"
ERROR = "error"
SUCCESS = "success"
class LogStatus(StrEnum):
START = auto()
ERROR = auto()
SUCCESS = auto()
id: str
label: str = Field(..., description="The label of the log")
@@ -193,19 +192,19 @@ class ToolInvokeMessage(BaseModel):
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class MessageType(Enum):
TEXT = "text"
IMAGE = "image"
LINK = "link"
BLOB = "blob"
JSON = "json"
IMAGE_LINK = "image_link"
BINARY_LINK = "binary_link"
VARIABLE = "variable"
FILE = "file"
LOG = "log"
BLOB_CHUNK = "blob_chunk"
RETRIEVER_RESOURCES = "retriever_resources"
class MessageType(StrEnum):
TEXT = auto()
IMAGE = auto()
LINK = auto()
BLOB = auto()
JSON = auto()
IMAGE_LINK = auto()
BINARY_LINK = auto()
VARIABLE = auto()
FILE = auto()
LOG = auto()
BLOB_CHUNK = auto()
RETRIEVER_RESOURCES = auto()
type: MessageType = MessageType.TEXT
"""
@@ -250,29 +249,29 @@ class ToolParameter(PluginParameter):
Overrides type
"""
class ToolParameterType(enum.StrEnum):
class ToolParameterType(StrEnum):
"""
removes TOOLS_SELECTOR from PluginParameterType
"""
STRING = PluginParameterType.STRING.value
NUMBER = PluginParameterType.NUMBER.value
BOOLEAN = PluginParameterType.BOOLEAN.value
SELECT = PluginParameterType.SELECT.value
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
FILE = PluginParameterType.FILE.value
FILES = PluginParameterType.FILES.value
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
ANY = PluginParameterType.ANY.value
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
STRING = PluginParameterType.STRING
NUMBER = PluginParameterType.NUMBER
BOOLEAN = PluginParameterType.BOOLEAN
SELECT = PluginParameterType.SELECT
SECRET_INPUT = PluginParameterType.SECRET_INPUT
FILE = PluginParameterType.FILE
FILES = PluginParameterType.FILES
APP_SELECTOR = PluginParameterType.APP_SELECTOR
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR
ANY = PluginParameterType.ANY
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT
# MCP object and array type parameters
ARRAY = MCPServerParameterType.ARRAY.value
OBJECT = MCPServerParameterType.OBJECT.value
ARRAY = MCPServerParameterType.ARRAY
OBJECT = MCPServerParameterType.OBJECT
# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES
def as_normal_type(self):
return as_normal_type(self)
@@ -280,10 +279,10 @@ class ToolParameter(PluginParameter):
def cast_value(self, value: Any):
return cast_parameter_value(self, value)
class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool
FORM = "form" # should be set before invoking tool
LLM = "llm" # will be set by LLM
class ToolParameterForm(StrEnum):
SCHEMA = auto() # should be set while adding tool
FORM = auto() # should be set before invoking tool
LLM = auto() # will be set by LLM
type: ToolParameterType = Field(..., description="The type of the parameter")
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
@@ -446,14 +445,14 @@ class ToolLabel(BaseModel):
icon: str = Field(..., description="The icon of the tool")
class ToolInvokeFrom(Enum):
class ToolInvokeFrom(StrEnum):
"""
Enum class for tool invoke
"""
WORKFLOW = "workflow"
AGENT = "agent"
PLUGIN = "plugin"
WORKFLOW = auto()
AGENT = auto()
PLUGIN = auto()
class ToolSelector(BaseModel):
@@ -478,9 +477,9 @@ class ToolSelector(BaseModel):
return self.model_dump()
class CredentialType(enum.StrEnum):
class CredentialType(StrEnum):
API_KEY = "api-key"
OAUTH2 = "oauth2"
OAUTH2 = auto()
def get_name(self):
if self == CredentialType.API_KEY:

View File

@@ -1,6 +1,6 @@
import uuid
from datetime import datetime
from enum import Enum
from enum import StrEnum, auto
from typing import Optional
from pydantic import BaseModel, Field
@@ -11,12 +11,12 @@ from libs.datetime_utils import naive_utc_now
class RouteNodeState(BaseModel):
class Status(Enum):
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
EXCEPTION = "exception"
class Status(StrEnum):
RUNNING = auto()
SUCCESS = auto()
FAILED = auto()
PAUSED = auto()
EXCEPTION = auto()
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""node state id"""

View File

@@ -1,4 +1,4 @@
from enum import Enum, StrEnum
from enum import IntEnum, StrEnum, auto
from typing import Any, Literal, Union
from pydantic import BaseModel
@@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData):
agent_parameters: dict[str, AgentInput]
class ParamsAutoGenerated(Enum):
CLOSE = 0
OPEN = 1
class ParamsAutoGenerated(IntEnum):
CLOSE = auto()
OPEN = auto()
class AgentOldVersionModelFeatures(StrEnum):
@@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum):
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()

View File

@@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import StrEnum, auto
from pydantic import BaseModel, Field
@@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel):
Generate Route Chunk.
"""
class ChunkType(Enum):
VAR = "var"
TEXT = "text"
class ChunkType(StrEnum):
VAR = auto()
TEXT = auto()
type: ChunkType = Field(..., description="generate route chunk type")

View File

@@ -259,7 +259,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# fetch model config
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
@@ -291,7 +291,7 @@ class KnowledgeRetrievalNode(BaseNode):
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":