feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -32,8 +32,8 @@ class UserToolProvider(BaseModel):
|
||||
original_credentials: Optional[dict] = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
tools: list[UserTool] = None
|
||||
labels: list[str] = None
|
||||
tools: list[UserTool] | None = None
|
||||
labels: list[str] | None = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
# -------------
|
||||
@@ -42,7 +42,7 @@ class UserToolProvider(BaseModel):
|
||||
for tool in tools:
|
||||
if tool.get("parameters"):
|
||||
for parameter in tool.get("parameters"):
|
||||
if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value:
|
||||
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
|
||||
parameter["type"] = "files"
|
||||
# -------------
|
||||
|
||||
|
||||
@@ -104,14 +104,15 @@ class ToolInvokeMessage(BaseModel):
|
||||
BLOB = "blob"
|
||||
JSON = "json"
|
||||
IMAGE_LINK = "image_link"
|
||||
FILE_VAR = "file_var"
|
||||
FILE = "file"
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
"""
|
||||
plain text, image url or link url
|
||||
"""
|
||||
message: str | bytes | dict | None = None
|
||||
meta: dict[str, Any] | None = None
|
||||
# TODO: Use a BaseModel for meta
|
||||
meta: dict[str, Any] = Field(default_factory=dict)
|
||||
save_as: str = ""
|
||||
|
||||
|
||||
@@ -143,6 +144,67 @@ class ToolParameter(BaseModel):
|
||||
SELECT = "select"
|
||||
SECRET_INPUT = "secret-input"
|
||||
FILE = "file"
|
||||
FILES = "files"
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = "systme-files"
|
||||
|
||||
def as_normal_type(self):
|
||||
if self in {
|
||||
ToolParameter.ToolParameterType.SECRET_INPUT,
|
||||
ToolParameter.ToolParameterType.SELECT,
|
||||
}:
|
||||
return "string"
|
||||
return self.value
|
||||
|
||||
def cast_value(self, value: Any, /):
|
||||
try:
|
||||
match self:
|
||||
case (
|
||||
ToolParameter.ToolParameterType.STRING
|
||||
| ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
| ToolParameter.ToolParameterType.SELECT
|
||||
):
|
||||
if value is None:
|
||||
return ""
|
||||
else:
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
case ToolParameter.ToolParameterType.BOOLEAN:
|
||||
if value is None:
|
||||
return False
|
||||
elif isinstance(value, str):
|
||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||
# and also '0' for False and '1' for True
|
||||
match value.lower():
|
||||
case "true" | "yes" | "y" | "1":
|
||||
return True
|
||||
case "false" | "no" | "n" | "0":
|
||||
return False
|
||||
case _:
|
||||
return bool(value)
|
||||
else:
|
||||
return value if isinstance(value, bool) else bool(value)
|
||||
|
||||
case ToolParameter.ToolParameterType.NUMBER:
|
||||
if isinstance(value, int | float):
|
||||
return value
|
||||
elif isinstance(value, str) and value:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
case (
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES
|
||||
| ToolParameter.ToolParameterType.FILE
|
||||
| ToolParameter.ToolParameterType.FILES
|
||||
):
|
||||
return value
|
||||
case _:
|
||||
return str(value)
|
||||
|
||||
except Exception:
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.")
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
|
||||
@@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool):
|
||||
for image in response.data:
|
||||
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
|
||||
blob_message = self.create_blob_message(
|
||||
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
|
||||
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE
|
||||
)
|
||||
result.append(blob_message)
|
||||
return result
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
from core.file.file_obj import FileTransferMethod
|
||||
from core.file.models import FileTransferMethod
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
<svg width="100" height="100" viewBox="0 0 100 100" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="100" height="100" rx="20" fill="#4A90E2" />
|
||||
<path
|
||||
d="M50 25C40.6 25 33 32.6 33 42V58C33 67.4 40.6 75 50 75C59.4 75 67 67.4 67 58V42C67 32.6 59.4 25 50 25ZM61 58C61 64.1 56.1 69 50 69C43.9 69 39 64.1 39 58V42C39 35.9 43.9 31 50 31C56.1 31 61 35.9 61 42V58Z"
|
||||
fill="white" />
|
||||
<path d="M50 37C47.2 37 45 39.2 45 42V58C45 60.8 47.2 63 50 63C52.8 63 55 60.8 55 58V42C55 39.2 52.8 37 50 37Z"
|
||||
fill="white" />
|
||||
<path
|
||||
d="M73 49H69V58C69 68.5 60.5 77 50 77C39.5 77 31 68.5 31 58V49H27V58C27 70.7 37.3 81 50 81C62.7 81 73 70.7 73 58V49Z"
|
||||
fill="white" />
|
||||
<path d="M50 85C51.1 85 52 84.1 52 83V81H48V83C48 84.1 48.9 85 50 85Z" fill="white" />
|
||||
<path
|
||||
d="M35 45C36.1046 45 37 44.1046 37 43C37 41.8954 36.1046 41 35 41C33.8954 41 33 41.8954 33 43C33 44.1046 33.8954 45 35 45Z"
|
||||
fill="white" />
|
||||
<path
|
||||
d="M35 55C36.1046 55 37 54.1046 37 53C37 51.8954 36.1046 51 35 51C33.8954 51 33 51.8954 33 53C33 54.1046 33.8954 55 35 55Z"
|
||||
fill="white" />
|
||||
<path
|
||||
d="M65 45C66.1046 45 67 44.1046 67 43C67 41.8954 66.1046 41 65 41C63.8954 41 63 41.8954 63 43C63 44.1046 63.8954 45 65 45Z"
|
||||
fill="white" />
|
||||
<path
|
||||
d="M65 55C66.1046 55 67 54.1046 67 53C67 51.8954 66.1046 51 65 51C63.8954 51 63 51.8954 63 53C63 54.1046 63.8954 55 65 55Z"
|
||||
fill="white" />
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
@@ -0,0 +1,33 @@
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class PodcastGeneratorProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
tts_service = credentials.get("tts_service")
|
||||
api_key = credentials.get("api_key")
|
||||
|
||||
if not tts_service:
|
||||
raise ToolProviderCredentialValidationError("TTS service is not specified")
|
||||
|
||||
if not api_key:
|
||||
raise ToolProviderCredentialValidationError("API key is missing")
|
||||
|
||||
if tts_service == "openai":
|
||||
self._validate_openai_credentials(api_key)
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}")
|
||||
|
||||
def _validate_openai_credentials(self, api_key: str) -> None:
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
try:
|
||||
# We're using a simple API call to validate the credentials
|
||||
client.models.list()
|
||||
except openai.AuthenticationError:
|
||||
raise ToolProviderCredentialValidationError("Invalid OpenAI API key")
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(f"Error validating OpenAI API key: {str(e)}")
|
||||
@@ -0,0 +1,34 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: podcast_generator
|
||||
label:
|
||||
en_US: Podcast Generator
|
||||
zh_Hans: 播客生成器
|
||||
description:
|
||||
en_US: Generate podcast audio using Text-to-Speech services
|
||||
zh_Hans: 使用文字转语音服务生成播客音频
|
||||
icon: icon.svg
|
||||
credentials_for_provider:
|
||||
tts_service:
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: TTS Service
|
||||
zh_Hans: TTS 服务
|
||||
placeholder:
|
||||
en_US: Select a TTS service
|
||||
zh_Hans: 选择一个 TTS 服务
|
||||
options:
|
||||
- label:
|
||||
en_US: OpenAI TTS
|
||||
zh_Hans: OpenAI TTS
|
||||
value: openai
|
||||
api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: API Key
|
||||
zh_Hans: API 密钥
|
||||
placeholder:
|
||||
en_US: Enter your TTS service API key
|
||||
zh_Hans: 输入您的 TTS 服务 API 密钥
|
||||
@@ -0,0 +1,100 @@
|
||||
import concurrent.futures
|
||||
import io
|
||||
import random
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import openai
|
||||
from pydub import AudioSegment
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class PodcastAudioGeneratorTool(BuiltinTool):
|
||||
@staticmethod
|
||||
def _generate_silence(duration: float):
|
||||
# Generate silent WAV data using pydub
|
||||
silence = AudioSegment.silent(duration=int(duration * 1000)) # pydub uses milliseconds
|
||||
return silence
|
||||
|
||||
@staticmethod
|
||||
def _generate_audio_segment(
|
||||
client: openai.OpenAI,
|
||||
line: str,
|
||||
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"],
|
||||
index: int,
|
||||
) -> tuple[int, Union[AudioSegment, str], Optional[AudioSegment]]:
|
||||
try:
|
||||
response = client.audio.speech.create(model="tts-1", voice=voice, input=line.strip(), response_format="wav")
|
||||
audio = AudioSegment.from_wav(io.BytesIO(response.content))
|
||||
silence_duration = random.uniform(0.1, 1.5)
|
||||
silence = PodcastAudioGeneratorTool._generate_silence(silence_duration)
|
||||
return index, audio, silence
|
||||
except Exception as e:
|
||||
return index, f"Error generating audio: {str(e)}", None
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
# Extract parameters
|
||||
script = tool_parameters.get("script", "")
|
||||
host1_voice = tool_parameters.get("host1_voice")
|
||||
host2_voice = tool_parameters.get("host2_voice")
|
||||
|
||||
# Split the script into lines
|
||||
script_lines = [line for line in script.split("\n") if line.strip()]
|
||||
|
||||
# Ensure voices are provided
|
||||
if not host1_voice or not host2_voice:
|
||||
raise ToolParameterValidationError("Host voices are required")
|
||||
|
||||
# Get OpenAI API key from credentials
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
|
||||
api_key = self.runtime.credentials.get("api_key")
|
||||
if not api_key:
|
||||
raise ToolProviderCredentialValidationError("OpenAI API key is missing")
|
||||
|
||||
# Initialize OpenAI client
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
# Create a thread pool
|
||||
max_workers = 5
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
for i, line in enumerate(script_lines):
|
||||
voice = host1_voice if i % 2 == 0 else host2_voice
|
||||
future = executor.submit(self._generate_audio_segment, client, line, voice, i)
|
||||
futures.append(future)
|
||||
|
||||
# Collect results
|
||||
audio_segments: list[Any] = [None] * len(script_lines)
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
index, audio, silence = future.result()
|
||||
if isinstance(audio, str): # Error occurred
|
||||
return self.create_text_message(audio)
|
||||
audio_segments[index] = (audio, silence)
|
||||
|
||||
# Combine audio segments in the correct order
|
||||
combined_audio = AudioSegment.empty()
|
||||
for i, (audio, silence) in enumerate(audio_segments):
|
||||
if audio:
|
||||
combined_audio += audio
|
||||
if i < len(audio_segments) - 1 and silence:
|
||||
combined_audio += silence
|
||||
|
||||
# Export the combined audio to a WAV file in memory
|
||||
buffer = io.BytesIO()
|
||||
combined_audio.export(buffer, format="wav")
|
||||
wav_bytes = buffer.getvalue()
|
||||
|
||||
# Create a blob message with the combined audio
|
||||
return [
|
||||
self.create_text_message("Audio generated successfully"),
|
||||
self.create_blob_message(
|
||||
blob=wav_bytes,
|
||||
meta={"mime_type": "audio/x-wav"},
|
||||
save_as=self.VariableKey.AUDIO,
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,95 @@
|
||||
identity:
|
||||
name: podcast_audio_generator
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Podcast Audio Generator
|
||||
zh_Hans: 播客音频生成器
|
||||
description:
|
||||
human:
|
||||
en_US: Generate a podcast audio file from a script with two alternating voices using OpenAI's TTS service.
|
||||
zh_Hans: 使用 OpenAI 的 TTS 服务,从包含两个交替声音的脚本生成播客音频文件。
|
||||
llm: This tool converts a prepared podcast script into an audio file using OpenAI's Text-to-Speech service, with two specified voices for alternating hosts.
|
||||
parameters:
|
||||
- name: script
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Podcast Script
|
||||
zh_Hans: 播客脚本
|
||||
human_description:
|
||||
en_US: A string containing alternating lines for two hosts, separated by newline characters.
|
||||
zh_Hans: 包含两位主持人交替台词的字符串,每行用换行符分隔。
|
||||
llm_description: A string representing the script, with alternating lines for two hosts separated by newline characters.
|
||||
form: llm
|
||||
- name: host1_voice
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: Host 1 Voice
|
||||
zh_Hans: 主持人1 音色
|
||||
human_description:
|
||||
en_US: The voice for the first host.
|
||||
zh_Hans: 第一位主持人的音色。
|
||||
llm_description: The voice identifier for the first host's voice.
|
||||
options:
|
||||
- label:
|
||||
en_US: Alloy
|
||||
zh_Hans: Alloy
|
||||
value: alloy
|
||||
- label:
|
||||
en_US: Echo
|
||||
zh_Hans: Echo
|
||||
value: echo
|
||||
- label:
|
||||
en_US: Fable
|
||||
zh_Hans: Fable
|
||||
value: fable
|
||||
- label:
|
||||
en_US: Onyx
|
||||
zh_Hans: Onyx
|
||||
value: onyx
|
||||
- label:
|
||||
en_US: Nova
|
||||
zh_Hans: Nova
|
||||
value: nova
|
||||
- label:
|
||||
en_US: Shimmer
|
||||
zh_Hans: Shimmer
|
||||
value: shimmer
|
||||
form: form
|
||||
- name: host2_voice
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: Host 2 Voice
|
||||
zh_Hans: 主持人2 音色
|
||||
human_description:
|
||||
en_US: The voice for the second host.
|
||||
zh_Hans: 第二位主持人的音色。
|
||||
llm_description: The voice identifier for the second host's voice.
|
||||
options:
|
||||
- label:
|
||||
en_US: Alloy
|
||||
zh_Hans: Alloy
|
||||
value: alloy
|
||||
- label:
|
||||
en_US: Echo
|
||||
zh_Hans: Echo
|
||||
value: echo
|
||||
- label:
|
||||
en_US: Fable
|
||||
zh_Hans: Fable
|
||||
value: fable
|
||||
- label:
|
||||
en_US: Onyx
|
||||
zh_Hans: Onyx
|
||||
value: onyx
|
||||
- label:
|
||||
en_US: Nova
|
||||
zh_Hans: Nova
|
||||
value: nova
|
||||
- label:
|
||||
en_US: Shimmer
|
||||
zh_Hans: Shimmer
|
||||
value: shimmer
|
||||
form: form
|
||||
@@ -13,7 +13,6 @@ from core.tools.errors import (
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
@@ -208,9 +207,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
|
||||
# the parameter is not set currently, set the default value if needed
|
||||
if parameter_schema.default is not None:
|
||||
default_value = ToolParameterConverter.cast_parameter_by_type(
|
||||
parameter_schema.default, parameter_schema.type
|
||||
)
|
||||
default_value = parameter_schema.type.cast_value(parameter_schema.default)
|
||||
tool_parameters[parameter] = default_value
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
|
||||
@@ -11,7 +11,6 @@ from core.tools.entities.tool_entities import (
|
||||
)
|
||||
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
|
||||
|
||||
class ToolProviderController(BaseModel, ABC):
|
||||
@@ -127,9 +126,7 @@ class ToolProviderController(BaseModel, ABC):
|
||||
|
||||
# the parameter is not set currently, set the default value if needed
|
||||
if parameter_schema.default is not None:
|
||||
tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(
|
||||
parameter_schema.default, parameter_schema.type
|
||||
)
|
||||
tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default)
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
@@ -23,6 +23,8 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
||||
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
|
||||
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
|
||||
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
|
||||
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
|
||||
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
|
||||
}
|
||||
|
||||
|
||||
@@ -36,8 +38,8 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
controller = WorkflowToolProviderController(
|
||||
**{
|
||||
controller = WorkflowToolProviderController.model_validate(
|
||||
{
|
||||
"identity": {
|
||||
"author": db_provider.user.name if db_provider.user_id and db_provider.user else "",
|
||||
"name": db_provider.label,
|
||||
@@ -67,7 +69,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
:param app: the app
|
||||
:return: the tool
|
||||
"""
|
||||
workflow: Workflow = (
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||
.first()
|
||||
@@ -76,14 +78,14 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
raise ValueError("workflow not found")
|
||||
|
||||
# fetch start node
|
||||
graph: dict = workflow.graph_dict
|
||||
features_dict: dict = workflow.features_dict
|
||||
graph = workflow.graph_dict
|
||||
features_dict = workflow.features_dict
|
||||
features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW)
|
||||
|
||||
parameters = db_provider.parameter_configurations
|
||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||
|
||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity:
|
||||
def fetch_workflow_variable(variable_name: str):
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||
|
||||
user = db_provider.user
|
||||
@@ -114,7 +116,6 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
llm_description=parameter.description,
|
||||
required=variable.required,
|
||||
options=options,
|
||||
default=variable.default,
|
||||
)
|
||||
)
|
||||
elif features.file_upload:
|
||||
@@ -123,7 +124,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
name=parameter.name,
|
||||
label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name),
|
||||
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
|
||||
type=ToolParameter.ToolParameterType.FILE,
|
||||
type=ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
llm_description=parameter.description,
|
||||
required=False,
|
||||
form=parameter.form,
|
||||
|
||||
@@ -20,10 +20,9 @@ from core.tools.entities.tool_entities import (
|
||||
ToolRuntimeVariablePool,
|
||||
)
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
class Tool(BaseModel, ABC):
|
||||
@@ -63,8 +62,12 @@ class Tool(BaseModel, ABC):
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
|
||||
class VariableKey(Enum):
|
||||
class VariableKey(str, Enum):
|
||||
IMAGE = "image"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
CUSTOM = "custom"
|
||||
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
|
||||
"""
|
||||
@@ -221,9 +224,7 @@ class Tool(BaseModel, ABC):
|
||||
result = deepcopy(tool_parameters)
|
||||
for parameter in self.parameters or []:
|
||||
if parameter.name in tool_parameters:
|
||||
result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(
|
||||
tool_parameters[parameter.name], parameter.type
|
||||
)
|
||||
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
|
||||
|
||||
return result
|
||||
|
||||
@@ -295,10 +296,8 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as)
|
||||
|
||||
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.FILE_VAR, message="", meta={"file_var": file_var}, save_as=""
|
||||
)
|
||||
def create_file_message(self, file: "File") -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE, message="", meta={"file": file}, save_as="")
|
||||
|
||||
def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileVar
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
from extensions.ext_database import db
|
||||
@@ -45,11 +45,13 @@ class WorkflowTool(Tool):
|
||||
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
|
||||
|
||||
# transform the tool parameters
|
||||
tool_parameters, files = self._transform_args(tool_parameters)
|
||||
tool_parameters, files = self._transform_args(tool_parameters=tool_parameters)
|
||||
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
|
||||
generator = WorkflowAppGenerator()
|
||||
assert self.runtime is not None
|
||||
assert self.runtime.invoke_from is not None
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
@@ -74,7 +76,7 @@ class WorkflowTool(Tool):
|
||||
else:
|
||||
outputs, files = self._extract_files(outputs)
|
||||
for file in files:
|
||||
result.append(self.create_file_var_message(file))
|
||||
result.append(self.create_file_message(file))
|
||||
|
||||
result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
|
||||
result.append(self.create_json_message(outputs))
|
||||
@@ -154,22 +156,22 @@ class WorkflowTool(Tool):
|
||||
parameters_result = {}
|
||||
files = []
|
||||
for parameter in parameter_rules:
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
if parameter.type == ToolParameter.ToolParameterType.SYSTEM_FILES:
|
||||
file = tool_parameters.get(parameter.name)
|
||||
if file:
|
||||
try:
|
||||
file_var_list = [FileVar(**f) for f in file]
|
||||
for file_var in file_var_list:
|
||||
file_dict = {
|
||||
"transfer_method": file_var.transfer_method.value,
|
||||
"type": file_var.type.value,
|
||||
file_var_list = [File.model_validate(f) for f in file]
|
||||
for file in file_var_list:
|
||||
file_dict: dict[str, str | None] = {
|
||||
"transfer_method": file.transfer_method.value,
|
||||
"type": file.type.value,
|
||||
}
|
||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file_var.related_id
|
||||
elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = file_var.related_id
|
||||
elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
file_dict["url"] = file_var.preview_url
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file.related_id
|
||||
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = file.related_id
|
||||
elif file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
file_dict["url"] = file.generate_url()
|
||||
|
||||
files.append(file_dict)
|
||||
except Exception as e:
|
||||
@@ -179,7 +181,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
return parameters_result, files
|
||||
|
||||
def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
|
||||
def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]:
|
||||
"""
|
||||
extract files from the result
|
||||
|
||||
@@ -190,17 +192,13 @@ class WorkflowTool(Tool):
|
||||
result = {}
|
||||
for key, value in outputs.items():
|
||||
if isinstance(value, list):
|
||||
has_file = False
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item.get("__variant") == "FileVar":
|
||||
try:
|
||||
files.append(FileVar(**item))
|
||||
has_file = True
|
||||
except Exception as e:
|
||||
pass
|
||||
if has_file:
|
||||
continue
|
||||
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
file = File.model_validate(item)
|
||||
files.append(file)
|
||||
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
file = File.model_validate(value)
|
||||
files.append(file)
|
||||
|
||||
result[key] = value
|
||||
|
||||
return result, files
|
||||
|
||||
@@ -10,7 +10,8 @@ from yarl import URL
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod
|
||||
from core.file import FileType
|
||||
from core.file.models import FileTransferMethod
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
|
||||
from core.tools.errors import (
|
||||
@@ -26,6 +27,7 @@ from core.tools.tool.tool import Tool
|
||||
from core.tools.tool.workflow_tool import WorkflowTool
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import Message, MessageFile
|
||||
|
||||
|
||||
@@ -128,6 +130,7 @@ class ToolEngine:
|
||||
"""
|
||||
try:
|
||||
# hit the callback handler
|
||||
assert tool.identity is not None
|
||||
workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)
|
||||
|
||||
if isinstance(tool, WorkflowTool):
|
||||
@@ -258,7 +261,10 @@ class ToolEngine:
|
||||
|
||||
@staticmethod
|
||||
def _create_message_files(
|
||||
tool_messages: list[ToolInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str
|
||||
tool_messages: list[ToolInvokeMessageBinary],
|
||||
agent_message: Message,
|
||||
invoke_from: InvokeFrom,
|
||||
user_id: str,
|
||||
) -> list[tuple[Any, str]]:
|
||||
"""
|
||||
Create message file
|
||||
@@ -269,29 +275,31 @@ class ToolEngine:
|
||||
result = []
|
||||
|
||||
for message in tool_messages:
|
||||
file_type = "bin"
|
||||
if "image" in message.mimetype:
|
||||
file_type = "image"
|
||||
file_type = FileType.IMAGE
|
||||
elif "video" in message.mimetype:
|
||||
file_type = "video"
|
||||
file_type = FileType.VIDEO
|
||||
elif "audio" in message.mimetype:
|
||||
file_type = "audio"
|
||||
elif "text" in message.mimetype:
|
||||
file_type = "text"
|
||||
elif "pdf" in message.mimetype:
|
||||
file_type = "pdf"
|
||||
elif "zip" in message.mimetype:
|
||||
file_type = "archive"
|
||||
# ...
|
||||
file_type = FileType.AUDIO
|
||||
elif "text" in message.mimetype or "pdf" in message.mimetype:
|
||||
file_type = FileType.DOCUMENT
|
||||
else:
|
||||
file_type = FileType.CUSTOM
|
||||
|
||||
# extract tool file id from url
|
||||
tool_file_id = message.url.split("/")[-1].split(".")[0]
|
||||
message_file = MessageFile(
|
||||
message_id=agent_message.id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE.value,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
belongs_to="assistant",
|
||||
url=message.url,
|
||||
upload_file_id=None,
|
||||
created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"),
|
||||
upload_file_id=tool_file_id,
|
||||
created_by_role=(
|
||||
CreatedByRole.ACCOUNT
|
||||
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else CreatedByRole.END_USER
|
||||
),
|
||||
created_by=user_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
@@ -57,22 +56,32 @@ class ToolFileManager:
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_raw(
|
||||
user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str],
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
) -> ToolFile:
|
||||
"""
|
||||
create file
|
||||
"""
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
filename = f"tools/{tenant_id}/{unique_name}{extension}"
|
||||
storage.save(filename, file_binary)
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, file_binary)
|
||||
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
name=filename,
|
||||
size=len(file_binary),
|
||||
)
|
||||
|
||||
db.session.add(tool_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
@@ -80,29 +89,34 @@ class ToolFileManager:
|
||||
def create_file_by_url(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str,
|
||||
conversation_id: str | None,
|
||||
file_url: str,
|
||||
) -> ToolFile:
|
||||
"""
|
||||
create file
|
||||
"""
|
||||
# try to download image
|
||||
response = get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
try:
|
||||
response = get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download file from {file_url}: {e}")
|
||||
raise
|
||||
|
||||
mimetype = guess_type(file_url)[0] or "octet/stream"
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
filename = f"tools/{tenant_id}/{unique_name}{extension}"
|
||||
storage.save(filename, blob)
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filename,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
original_url=file_url,
|
||||
name=filename,
|
||||
size=len(blob),
|
||||
)
|
||||
|
||||
db.session.add(tool_file)
|
||||
@@ -110,18 +124,6 @@ class ToolFileManager:
|
||||
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_key(
|
||||
user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str
|
||||
) -> ToolFile:
|
||||
"""
|
||||
create file
|
||||
"""
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype
|
||||
)
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
@@ -131,7 +133,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
tool_file: ToolFile = (
|
||||
tool_file = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == id,
|
||||
@@ -155,7 +157,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile = (
|
||||
message_file = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(
|
||||
MessageFile.id == id,
|
||||
@@ -166,13 +168,16 @@ class ToolFileManager:
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
if message_file.url is not None:
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
else:
|
||||
tool_file_id = None
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile = (
|
||||
tool_file = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
@@ -188,7 +193,7 @@ class ToolFileManager:
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
@staticmethod
|
||||
def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generator, str], None]:
|
||||
def get_file_generator_by_tool_file_id(tool_file_id: str):
|
||||
"""
|
||||
get file binary
|
||||
|
||||
@@ -196,7 +201,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
tool_file: ToolFile = (
|
||||
tool_file = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
@@ -205,11 +210,11 @@ class ToolFileManager:
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
return None, None
|
||||
|
||||
generator = storage.load_stream(tool_file.file_key)
|
||||
stream = storage.load_stream(tool_file.file_key)
|
||||
|
||||
return generator, tool_file.mimetype
|
||||
return stream, tool_file
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
|
||||
@@ -24,7 +24,6 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
@@ -203,7 +202,7 @@ class ToolManager:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type} not found")
|
||||
|
||||
@classmethod
|
||||
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
|
||||
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict):
|
||||
"""
|
||||
init runtime parameter
|
||||
"""
|
||||
@@ -222,7 +221,7 @@ class ToolManager:
|
||||
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
|
||||
)
|
||||
|
||||
return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type)
|
||||
return parameter_rule.type.cast_value(parameter_value)
|
||||
|
||||
@classmethod
|
||||
def get_agent_tool_runtime(
|
||||
@@ -243,7 +242,11 @@ class ToolManager:
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
# check file types
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
if parameter.type in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}:
|
||||
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import logging
|
||||
from mimetypes import guess_extension
|
||||
from typing import Optional
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileType
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
@@ -11,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
class ToolFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(
|
||||
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str
|
||||
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str | None
|
||||
) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Transform tool message and handle file download
|
||||
@@ -21,7 +22,7 @@ class ToolFileMessageTransformer:
|
||||
for message in messages:
|
||||
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
|
||||
result.append(message)
|
||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(message.message, str):
|
||||
# try to download image
|
||||
try:
|
||||
file = ToolFileManager.create_file_by_url(
|
||||
@@ -50,11 +51,14 @@ class ToolFileMessageTransformer:
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
assert message.meta is not None
|
||||
mimetype = message.meta.get("mime_type", "octet/stream")
|
||||
# if message is str, encode it to bytes
|
||||
if isinstance(message.message, str):
|
||||
message.message = message.message.encode("utf-8")
|
||||
|
||||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message, bytes)
|
||||
file = ToolFileManager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -63,7 +67,7 @@ class ToolFileMessageTransformer:
|
||||
mimetype=mimetype,
|
||||
)
|
||||
|
||||
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
|
||||
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
@@ -84,12 +88,14 @@ class ToolFileMessageTransformer:
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
||||
file_var = message.meta.get("file_var")
|
||||
if file_var:
|
||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
||||
if file_var.type == FileType.IMAGE:
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
file = message.meta.get("file")
|
||||
if isinstance(file, File):
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file.related_id is not None
|
||||
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
|
||||
if file.type == FileType.IMAGE:
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
@@ -107,11 +113,13 @@ class ToolFileMessageTransformer:
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(message)
|
||||
else:
|
||||
result.append(message)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
|
||||
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
|
||||
|
||||
class ToolParameterConverter:
|
||||
@staticmethod
|
||||
def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str:
|
||||
match parameter_type:
|
||||
case (
|
||||
ToolParameter.ToolParameterType.STRING
|
||||
| ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
| ToolParameter.ToolParameterType.SELECT
|
||||
):
|
||||
return "string"
|
||||
|
||||
case ToolParameter.ToolParameterType.BOOLEAN:
|
||||
return "boolean"
|
||||
|
||||
case ToolParameter.ToolParameterType.NUMBER:
|
||||
return "number"
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unsupported parameter type {parameter_type}")
|
||||
|
||||
@staticmethod
|
||||
def cast_parameter_by_type(value: Any, parameter_type: str) -> Any:
|
||||
# convert tool parameter config to correct type
|
||||
try:
|
||||
match parameter_type:
|
||||
case (
|
||||
ToolParameter.ToolParameterType.STRING
|
||||
| ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
| ToolParameter.ToolParameterType.SELECT
|
||||
):
|
||||
if value is None:
|
||||
return ""
|
||||
else:
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
case ToolParameter.ToolParameterType.BOOLEAN:
|
||||
if value is None:
|
||||
return False
|
||||
elif isinstance(value, str):
|
||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||
# and also '0' for False and '1' for True
|
||||
match value.lower():
|
||||
case "true" | "yes" | "y" | "1":
|
||||
return True
|
||||
case "false" | "no" | "n" | "0":
|
||||
return False
|
||||
case _:
|
||||
return bool(value)
|
||||
else:
|
||||
return value if isinstance(value, bool) else bool(value)
|
||||
|
||||
case ToolParameter.ToolParameterType.NUMBER:
|
||||
if isinstance(value, int) | isinstance(value, float):
|
||||
return value
|
||||
elif isinstance(value, str) and value != "":
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
case ToolParameter.ToolParameterType.FILE:
|
||||
return value
|
||||
case _:
|
||||
return str(value)
|
||||
|
||||
except Exception:
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.")
|
||||
@@ -1,19 +1,18 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[dict]):
|
||||
"""
|
||||
check parameter configurations
|
||||
"""
|
||||
def check_parameter_configurations(cls, configurations: Mapping[str, Any]):
|
||||
for configuration in configurations:
|
||||
if not WorkflowToolParameterConfiguration(**configuration):
|
||||
raise ValueError("invalid parameter configuration")
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
@@ -17,15 +18,18 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any
|
||||
:param default_value: the value returned when errors ignored
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
try:
|
||||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content or default_value
|
||||
except Exception as e:
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}")
|
||||
except Exception as e:
|
||||
if not file_path or not Path(file_path).exists():
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise e
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content or default_value
|
||||
except Exception as e:
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||
|
||||
Reference in New Issue
Block a user