refactor: Improve model status handling and structured output (#20586)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-06-04 19:56:54 +08:00
committed by GitHub
parent 92614765ff
commit 5ccfb1f4ba
9 changed files with 450 additions and 344 deletions

View File

@@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel):
status: ModelStatus
load_balancing_enabled: bool = False
def raise_for_status(self) -> None:
"""
Check model status and raise ValueError if not active.
:raises ValueError: When model status is not active, with a descriptive message
"""
if self.status == ModelStatus.ACTIVE:
return
error_messages = {
ModelStatus.NO_CONFIGURE: "Model is not configured",
ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
ModelStatus.NO_PERMISSION: "No permission to use this model",
ModelStatus.DISABLED: "Model is disabled",
}
if self.status in error_messages:
raise ValueError(error_messages[self.status])
class ModelWithProviderEntity(ProviderModelWithStatusEntity):
"""

View File

@@ -41,45 +41,53 @@ class Extensible:
extensions = []
position_map: dict[str, int] = {}
# get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
current_dir_path = os.path.dirname(current_path)
# Get the package name from the module path
package_name = ".".join(cls.__module__.split(".")[:-1])
# traverse subdirectories
for subdir_name in os.listdir(current_dir_path):
if subdir_name.startswith("__"):
continue
try:
# Get package directory path
package_spec = importlib.util.find_spec(package_name)
if not package_spec or not package_spec.origin:
raise ImportError(f"Could not find package {package_name}")
subdir_path = os.path.join(current_dir_path, subdir_name)
extension_name = subdir_name
if os.path.isdir(subdir_path):
package_dir = os.path.dirname(package_spec.origin)
# Traverse subdirectories
for subdir_name in os.listdir(package_dir):
if subdir_name.startswith("__"):
continue
subdir_path = os.path.join(package_dir, subdir_name)
if not os.path.isdir(subdir_path):
continue
extension_name = subdir_name
file_names = os.listdir(subdir_path)
# is builtin extension, builtin extension
# in the front-end page and business logic, there are special treatments.
# Check for extension module file
if (extension_name + ".py") not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
continue
# Check for builtin flag and position
builtin = False
# default position is 0 can not be None for sort_to_dict_by_position_map
position = 0
if "__builtin__" in file_names:
builtin = True
builtin_file_path = os.path.join(subdir_path, "__builtin__")
if os.path.exists(builtin_file_path):
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
position_map[extension_name] = position
if (extension_name + ".py") not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
continue
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + ".py")
spec = importlib.util.spec_from_file_location(extension_name, py_path)
# Import the extension module
module_name = f"{package_name}.{extension_name}.{extension_name}"
spec = importlib.util.find_spec(module_name)
if not spec or not spec.loader:
raise Exception(f"Failed to load module {extension_name} from {py_path}")
raise ImportError(f"Failed to load module {module_name}")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
# Find extension class
extension_class = None
for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
@@ -87,21 +95,21 @@ class Extensible:
break
if not extension_class:
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.")
continue
# Load schema if not builtin
json_data: dict[str, Any] = {}
if not builtin:
if "schema.json" not in file_names:
json_path = os.path.join(subdir_path, "schema.json")
if not os.path.exists(json_path):
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
continue
json_path = os.path.join(subdir_path, "schema.json")
json_data = {}
if os.path.exists(json_path):
with open(json_path, encoding="utf-8") as f:
json_data = json.load(f)
with open(json_path, encoding="utf-8") as f:
json_data = json.load(f)
# Create extension
extensions.append(
ModuleExtension(
extension_class=extension_class,
@@ -113,6 +121,11 @@ class Extensible:
)
)
except Exception as e:
logging.exception("Error scanning extensions")
raise
# Sort extensions by position
sorted_extensions = sort_to_dict_by_position_map(
position_map=position_map, data=extensions, name_func=lambda x: x.name
)

View File

@@ -160,6 +160,10 @@ class ProviderModel(BaseModel):
deprecated: bool = False
model_config = ConfigDict(protected_namespaces=())
@property
def support_structure_output(self) -> bool:
return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features
class ParameterRule(BaseModel):
"""

View File

@@ -3,7 +3,9 @@ from collections import defaultdict
from json import JSONDecodeError
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
@@ -393,19 +395,13 @@ class ProviderManager:
@staticmethod
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
"""
Get all provider records of the workspace.
:param tenant_id: workspace id
:return:
"""
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
provider_name_to_provider_records_dict = defaultdict(list)
for provider in providers:
# TODO: Use provider name with prefix after the data migration
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
providers = session.scalars(stmt)
for provider in providers:
# Use provider name with prefix after the data migration
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
return provider_name_to_provider_records_dict
@staticmethod
@@ -416,17 +412,12 @@ class ProviderManager:
:param tenant_id: workspace id
:return:
"""
# Get all provider model records of the workspace
provider_models = (
db.session.query(ProviderModel)
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
.all()
)
provider_name_to_provider_model_records_dict = defaultdict(list)
for provider_model in provider_models:
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
provider_models = session.scalars(stmt)
for provider_model in provider_models:
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
return provider_name_to_provider_model_records_dict
@staticmethod
@@ -437,17 +428,14 @@ class ProviderManager:
:param tenant_id: workspace id
:return:
"""
preferred_provider_types = (
db.session.query(TenantPreferredModelProvider)
.filter(TenantPreferredModelProvider.tenant_id == tenant_id)
.all()
)
provider_name_to_preferred_provider_type_records_dict = {
preferred_provider_type.provider_name: preferred_provider_type
for preferred_provider_type in preferred_provider_types
}
provider_name_to_preferred_provider_type_records_dict = {}
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
preferred_provider_types = session.scalars(stmt)
provider_name_to_preferred_provider_type_records_dict = {
preferred_provider_type.provider_name: preferred_provider_type
for preferred_provider_type in preferred_provider_types
}
return provider_name_to_preferred_provider_type_records_dict
@staticmethod
@@ -458,18 +446,14 @@ class ProviderManager:
:param tenant_id: workspace id
:return:
"""
provider_model_settings = (
db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all()
)
provider_name_to_provider_model_settings_dict = defaultdict(list)
for provider_model_setting in provider_model_settings:
(
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
provider_model_settings = session.scalars(stmt)
for provider_model_setting in provider_model_settings:
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
provider_model_setting
)
)
return provider_name_to_provider_model_settings_dict
@staticmethod
@@ -492,15 +476,14 @@ class ProviderManager:
if not model_load_balancing_enabled:
return {}
provider_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all()
)
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
for provider_load_balancing_config in provider_load_balancing_configs:
provider_name_to_provider_load_balancing_model_configs_dict[
provider_load_balancing_config.provider_name
].append(provider_load_balancing_config)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
provider_load_balancing_configs = session.scalars(stmt)
for provider_load_balancing_config in provider_load_balancing_configs:
provider_name_to_provider_load_balancing_model_configs_dict[
provider_load_balancing_config.provider_name
].append(provider_load_balancing_config)
return provider_name_to_provider_load_balancing_model_configs_dict
@@ -626,10 +609,9 @@ class ProviderManager:
if not cached_provider_credentials:
try:
# fix origin data
if (
custom_provider_record.encrypted_config
and not custom_provider_record.encrypted_config.startswith("{")
):
if custom_provider_record.encrypted_config is None:
raise ValueError("No credentials found")
if not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
@@ -733,7 +715,7 @@ class ProviderManager:
return SystemConfiguration(enabled=False)
# Convert provider_records to dict
quota_type_to_provider_records_dict = {}
quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value:
continue
@@ -758,6 +740,11 @@ class ProviderManager:
else:
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]
if provider_record.quota_used is None:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
@@ -791,10 +778,9 @@ class ProviderManager:
cached_provider_credentials = provider_credentials_cache.get()
if not cached_provider_credentials:
try:
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
provider_credentials: dict[str, Any] = {}
if provider_records and provider_records[0].encrypted_config:
provider_credentials = json.loads(provider_records[0].encrypted_config)
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(

View File

@@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData):
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: dict | None = None
structured_output_enabled: bool = False
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
@field_validator("prompt_config", mode="before")
@classmethod
@@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData):
if v is None:
return PromptConfig()
return v
@property
def structured_output_enabled(self) -> bool:
return self.structured_output_switch_on and self.structured_output is not None

View File

@@ -12,9 +12,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -74,7 +72,6 @@ from core.workflow.nodes.event import (
from core.workflow.utils.structured_output.entities import (
ResponseFormat,
SpecialModelType,
SupportStructuredOutputStatus,
)
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@@ -277,7 +274,7 @@ class LLMNode(BaseNode[LLMNodeData]):
llm_usage=usage,
)
)
except LLMNodeError as e:
except ValueError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@@ -527,65 +524,53 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model_name = node_data_model.name
provider_name = node_data_model.provider
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
model = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=node_data_model.provider,
model=node_data_model.name,
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name, model_type=ModelType.LLM
provider_model = model.provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name, model_type=ModelType.LLM
)
if provider_model is None:
raise ModelNotExistError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
# model config
completion_params = node_data_model.completion_params
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = node_data_model.mode
if not model_mode:
raise LLMModeRequiredError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema:
raise ModelNotExistError(f"Model {model_name} not exist.")
support_structured_output = self._check_model_structured_output_support()
if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
# Set appropriate response format based on model capabilities
self._set_response_format(completion_params, model_schema.parameter_rules)
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
if self.node_data.structured_output_enabled:
if model_schema.support_structure_output:
node_data_model.completion_params = self._handle_native_json_schema(
node_data_model.completion_params, model_schema.parameter_rules
)
else:
# Set appropriate response format based on model capabilities
self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules)
return model, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
mode=node_data_model.mode,
provider_model_bundle=model.provider_model_bundle,
credentials=model.credentials,
parameters=node_data_model.completion_params,
stop=stop,
)
@@ -786,13 +771,25 @@ class LLMNode(BaseNode[LLMNodeData]):
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
support_structured_output = self._check_model_structured_output_support()
if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
filtered_prompt_messages = self._handle_prompt_based_schema(
prompt_messages=filtered_prompt_messages,
)
stop = model_config.stop
return filtered_prompt_messages, stop
model = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=self.node_data.model.provider,
model=self.node_data.model.name,
)
model_schema = model.model_type_instance.get_model_schema(
model=self.node_data.model.name,
credentials=model.credentials,
)
if not model_schema:
raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.")
if self.node_data.structured_output_enabled:
if not model_schema.support_structure_output:
filtered_prompt_messages = self._handle_prompt_based_schema(
prompt_messages=filtered_prompt_messages,
)
return filtered_prompt_messages, model_config.stop
def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
structured_output: dict[str, Any] = {}
@@ -1185,32 +1182,6 @@ class LLMNode(BaseNode[LLMNodeData]):
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
"""
Check if the current model supports structured output.
Returns:
SupportStructuredOutput: The support status of structured output
"""
# Early return if structured output is disabled
if (
not isinstance(self.node_data, LLMNodeData)
or not self.node_data.structured_output_enabled
or not self.node_data.structured_output
):
return SupportStructuredOutputStatus.DISABLED
# Get model schema and check if it exists
model_schema = self._fetch_model_schema(self.node_data.model.provider)
if not model_schema:
return SupportStructuredOutputStatus.DISABLED
# Check if model supports structured output feature
return (
SupportStructuredOutputStatus.SUPPORTED
if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
else SupportStructuredOutputStatus.UNSUPPORTED
)
def _save_multimodal_output_and_convert_result_to_markdown(
self,
contents: str | list[PromptMessageContentUnionTypes] | None,

View File

@@ -14,11 +14,3 @@ class SpecialModelType(StrEnum):
GEMINI = "gemini"
OLLAMA = "ollama"
class SupportStructuredOutputStatus(StrEnum):
"""Constants for structured output support status"""
SUPPORTED = "supported"
UNSUPPORTED = "unsupported"
DISABLED = "disabled"