refactor: Improve model status handling and structured output (#20586)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-06-04 19:56:54 +08:00
committed by GitHub
parent 92614765ff
commit 5ccfb1f4ba
9 changed files with 450 additions and 344 deletions

View File

@@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData):
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: dict | None = None
structured_output_enabled: bool = False
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
@field_validator("prompt_config", mode="before")
@classmethod
@@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData):
if v is None:
return PromptConfig()
return v
@property
def structured_output_enabled(self) -> bool:
return self.structured_output_switch_on and self.structured_output is not None

View File

@@ -12,9 +12,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -74,7 +72,6 @@ from core.workflow.nodes.event import (
from core.workflow.utils.structured_output.entities import (
ResponseFormat,
SpecialModelType,
SupportStructuredOutputStatus,
)
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@@ -277,7 +274,7 @@ class LLMNode(BaseNode[LLMNodeData]):
llm_usage=usage,
)
)
except LLMNodeError as e:
except ValueError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@@ -527,65 +524,53 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model_name = node_data_model.name
provider_name = node_data_model.provider
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
model = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=node_data_model.provider,
model=node_data_model.name,
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name, model_type=ModelType.LLM
provider_model = model.provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name, model_type=ModelType.LLM
)
if provider_model is None:
raise ModelNotExistError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
# model config
completion_params = node_data_model.completion_params
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = node_data_model.mode
if not model_mode:
raise LLMModeRequiredError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema:
raise ModelNotExistError(f"Model {model_name} not exist.")
support_structured_output = self._check_model_structured_output_support()
if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
# Set appropriate response format based on model capabilities
self._set_response_format(completion_params, model_schema.parameter_rules)
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
if self.node_data.structured_output_enabled:
if model_schema.support_structure_output:
node_data_model.completion_params = self._handle_native_json_schema(
node_data_model.completion_params, model_schema.parameter_rules
)
else:
# Set appropriate response format based on model capabilities
self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules)
return model, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
mode=node_data_model.mode,
provider_model_bundle=model.provider_model_bundle,
credentials=model.credentials,
parameters=node_data_model.completion_params,
stop=stop,
)
@@ -786,13 +771,25 @@ class LLMNode(BaseNode[LLMNodeData]):
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
support_structured_output = self._check_model_structured_output_support()
if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
filtered_prompt_messages = self._handle_prompt_based_schema(
prompt_messages=filtered_prompt_messages,
)
stop = model_config.stop
return filtered_prompt_messages, stop
model = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=self.node_data.model.provider,
model=self.node_data.model.name,
)
model_schema = model.model_type_instance.get_model_schema(
model=self.node_data.model.name,
credentials=model.credentials,
)
if not model_schema:
raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.")
if self.node_data.structured_output_enabled:
if not model_schema.support_structure_output:
filtered_prompt_messages = self._handle_prompt_based_schema(
prompt_messages=filtered_prompt_messages,
)
return filtered_prompt_messages, model_config.stop
def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
structured_output: dict[str, Any] = {}
@@ -1185,32 +1182,6 @@ class LLMNode(BaseNode[LLMNodeData]):
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
"""
Check if the current model supports structured output.
Returns:
SupportStructuredOutput: The support status of structured output
"""
# Early return if structured output is disabled
if (
not isinstance(self.node_data, LLMNodeData)
or not self.node_data.structured_output_enabled
or not self.node_data.structured_output
):
return SupportStructuredOutputStatus.DISABLED
# Get model schema and check if it exists
model_schema = self._fetch_model_schema(self.node_data.model.provider)
if not model_schema:
return SupportStructuredOutputStatus.DISABLED
# Check if model supports structured output feature
return (
SupportStructuredOutputStatus.SUPPORTED
if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
else SupportStructuredOutputStatus.UNSUPPORTED
)
def _save_multimodal_output_and_convert_result_to_markdown(
self,
contents: str | list[PromptMessageContentUnionTypes] | None,

View File

@@ -14,11 +14,3 @@ class SpecialModelType(StrEnum):
GEMINI = "gemini"
OLLAMA = "ollama"
class SupportStructuredOutputStatus(StrEnum):
"""Constants for structured output support status"""
SUPPORTED = "supported"
UNSUPPORTED = "unsupported"
DISABLED = "disabled"