feat/enhance the multi-modal support (#8818)

This commit is contained in:
-LAN-
2024-10-21 10:43:49 +08:00
committed by GitHub
parent 7a1d6fe509
commit e61752bd3a
267 changed files with 6263 additions and 3523 deletions

View File

@@ -32,8 +32,8 @@ class UserToolProvider(BaseModel):
original_credentials: Optional[dict] = None
is_team_authorization: bool = False
allow_delete: bool = True
tools: list[UserTool] = None
labels: list[str] = None
tools: list[UserTool] | None = None
labels: list[str] | None = None
def to_dict(self) -> dict:
# -------------
@@ -42,7 +42,7 @@ class UserToolProvider(BaseModel):
for tool in tools:
if tool.get("parameters"):
for parameter in tool.get("parameters"):
if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value:
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
parameter["type"] = "files"
# -------------

View File

@@ -104,14 +104,15 @@ class ToolInvokeMessage(BaseModel):
BLOB = "blob"
JSON = "json"
IMAGE_LINK = "image_link"
FILE_VAR = "file_var"
FILE = "file"
type: MessageType = MessageType.TEXT
"""
plain text, image url or link url
"""
message: str | bytes | dict | None = None
meta: dict[str, Any] | None = None
# TODO: Use a BaseModel for meta
meta: dict[str, Any] = Field(default_factory=dict)
save_as: str = ""
@@ -143,6 +144,67 @@ class ToolParameter(BaseModel):
SELECT = "select"
SECRET_INPUT = "secret-input"
FILE = "file"
FILES = "files"
# deprecated, should not use.
SYSTEM_FILES = "systme-files"
def as_normal_type(self):
if self in {
ToolParameter.ToolParameterType.SECRET_INPUT,
ToolParameter.ToolParameterType.SELECT,
}:
return "string"
return self.value
def cast_value(self, value: Any, /):
try:
match self:
case (
ToolParameter.ToolParameterType.STRING
| ToolParameter.ToolParameterType.SECRET_INPUT
| ToolParameter.ToolParameterType.SELECT
):
if value is None:
return ""
else:
return value if isinstance(value, str) else str(value)
case ToolParameter.ToolParameterType.BOOLEAN:
if value is None:
return False
elif isinstance(value, str):
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
# and also '0' for False and '1' for True
match value.lower():
case "true" | "yes" | "y" | "1":
return True
case "false" | "no" | "n" | "0":
return False
case _:
return bool(value)
else:
return value if isinstance(value, bool) else bool(value)
case ToolParameter.ToolParameterType.NUMBER:
if isinstance(value, int | float):
return value
elif isinstance(value, str) and value:
if "." in value:
return float(value)
else:
return int(value)
case (
ToolParameter.ToolParameterType.SYSTEM_FILES
| ToolParameter.ToolParameterType.FILE
| ToolParameter.ToolParameterType.FILES
):
return value
case _:
return str(value)
except Exception:
raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.")
class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool

View File

@@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool):
for image in response.data:
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
blob_message = self.create_blob_message(
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE
)
result.append(blob_message)
return result

View File

@@ -2,7 +2,7 @@ from typing import Any
from duckduckgo_search import DDGS
from core.file.file_obj import FileTransferMethod
from core.file.models import FileTransferMethod
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

View File

@@ -0,0 +1,24 @@
<svg width="100" height="100" viewBox="0 0 100 100" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="100" height="100" rx="20" fill="#4A90E2" />
<path
d="M50 25C40.6 25 33 32.6 33 42V58C33 67.4 40.6 75 50 75C59.4 75 67 67.4 67 58V42C67 32.6 59.4 25 50 25ZM61 58C61 64.1 56.1 69 50 69C43.9 69 39 64.1 39 58V42C39 35.9 43.9 31 50 31C56.1 31 61 35.9 61 42V58Z"
fill="white" />
<path d="M50 37C47.2 37 45 39.2 45 42V58C45 60.8 47.2 63 50 63C52.8 63 55 60.8 55 58V42C55 39.2 52.8 37 50 37Z"
fill="white" />
<path
d="M73 49H69V58C69 68.5 60.5 77 50 77C39.5 77 31 68.5 31 58V49H27V58C27 70.7 37.3 81 50 81C62.7 81 73 70.7 73 58V49Z"
fill="white" />
<path d="M50 85C51.1 85 52 84.1 52 83V81H48V83C48 84.1 48.9 85 50 85Z" fill="white" />
<path
d="M35 45C36.1046 45 37 44.1046 37 43C37 41.8954 36.1046 41 35 41C33.8954 41 33 41.8954 33 43C33 44.1046 33.8954 45 35 45Z"
fill="white" />
<path
d="M35 55C36.1046 55 37 54.1046 37 53C37 51.8954 36.1046 51 35 51C33.8954 51 33 51.8954 33 53C33 54.1046 33.8954 55 35 55Z"
fill="white" />
<path
d="M65 45C66.1046 45 67 44.1046 67 43C67 41.8954 66.1046 41 65 41C63.8954 41 63 41.8954 63 43C63 44.1046 63.8954 45 65 45Z"
fill="white" />
<path
d="M65 55C66.1046 55 67 54.1046 67 53C67 51.8954 66.1046 51 65 51C63.8954 51 63 51.8954 63 53C63 54.1046 63.8954 55 65 55Z"
fill="white" />
</svg>

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -0,0 +1,33 @@
from typing import Any
import openai
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class PodcastGeneratorProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
tts_service = credentials.get("tts_service")
api_key = credentials.get("api_key")
if not tts_service:
raise ToolProviderCredentialValidationError("TTS service is not specified")
if not api_key:
raise ToolProviderCredentialValidationError("API key is missing")
if tts_service == "openai":
self._validate_openai_credentials(api_key)
else:
raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}")
def _validate_openai_credentials(self, api_key: str) -> None:
client = openai.OpenAI(api_key=api_key)
try:
# We're using a simple API call to validate the credentials
client.models.list()
except openai.AuthenticationError:
raise ToolProviderCredentialValidationError("Invalid OpenAI API key")
except Exception as e:
raise ToolProviderCredentialValidationError(f"Error validating OpenAI API key: {str(e)}")

View File

@@ -0,0 +1,34 @@
identity:
author: Dify
name: podcast_generator
label:
en_US: Podcast Generator
zh_Hans: 播客生成器
description:
en_US: Generate podcast audio using Text-to-Speech services
zh_Hans: 使用文字转语音服务生成播客音频
icon: icon.svg
credentials_for_provider:
tts_service:
type: select
required: true
label:
en_US: TTS Service
zh_Hans: TTS 服务
placeholder:
en_US: Select a TTS service
zh_Hans: 选择一个 TTS 服务
options:
- label:
en_US: OpenAI TTS
zh_Hans: OpenAI TTS
value: openai
api_key:
type: secret-input
required: true
label:
en_US: API Key
zh_Hans: API 密钥
placeholder:
en_US: Enter your TTS service API key
zh_Hans: 输入您的 TTS 服务 API 密钥

View File

@@ -0,0 +1,100 @@
import concurrent.futures
import io
import random
from typing import Any, Literal, Optional, Union
import openai
from pydub import AudioSegment
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.builtin_tool import BuiltinTool
class PodcastAudioGeneratorTool(BuiltinTool):
@staticmethod
def _generate_silence(duration: float):
# Generate silent WAV data using pydub
silence = AudioSegment.silent(duration=int(duration * 1000)) # pydub uses milliseconds
return silence
@staticmethod
def _generate_audio_segment(
client: openai.OpenAI,
line: str,
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"],
index: int,
) -> tuple[int, Union[AudioSegment, str], Optional[AudioSegment]]:
try:
response = client.audio.speech.create(model="tts-1", voice=voice, input=line.strip(), response_format="wav")
audio = AudioSegment.from_wav(io.BytesIO(response.content))
silence_duration = random.uniform(0.1, 1.5)
silence = PodcastAudioGeneratorTool._generate_silence(silence_duration)
return index, audio, silence
except Exception as e:
return index, f"Error generating audio: {str(e)}", None
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
# Extract parameters
script = tool_parameters.get("script", "")
host1_voice = tool_parameters.get("host1_voice")
host2_voice = tool_parameters.get("host2_voice")
# Split the script into lines
script_lines = [line for line in script.split("\n") if line.strip()]
# Ensure voices are provided
if not host1_voice or not host2_voice:
raise ToolParameterValidationError("Host voices are required")
# Get OpenAI API key from credentials
if not self.runtime or not self.runtime.credentials:
raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
api_key = self.runtime.credentials.get("api_key")
if not api_key:
raise ToolProviderCredentialValidationError("OpenAI API key is missing")
# Initialize OpenAI client
client = openai.OpenAI(api_key=api_key)
# Create a thread pool
max_workers = 5
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for i, line in enumerate(script_lines):
voice = host1_voice if i % 2 == 0 else host2_voice
future = executor.submit(self._generate_audio_segment, client, line, voice, i)
futures.append(future)
# Collect results
audio_segments: list[Any] = [None] * len(script_lines)
for future in concurrent.futures.as_completed(futures):
index, audio, silence = future.result()
if isinstance(audio, str): # Error occurred
return self.create_text_message(audio)
audio_segments[index] = (audio, silence)
# Combine audio segments in the correct order
combined_audio = AudioSegment.empty()
for i, (audio, silence) in enumerate(audio_segments):
if audio:
combined_audio += audio
if i < len(audio_segments) - 1 and silence:
combined_audio += silence
# Export the combined audio to a WAV file in memory
buffer = io.BytesIO()
combined_audio.export(buffer, format="wav")
wav_bytes = buffer.getvalue()
# Create a blob message with the combined audio
return [
self.create_text_message("Audio generated successfully"),
self.create_blob_message(
blob=wav_bytes,
meta={"mime_type": "audio/x-wav"},
save_as=self.VariableKey.AUDIO,
),
]

View File

@@ -0,0 +1,95 @@
identity:
name: podcast_audio_generator
author: Dify
label:
en_US: Podcast Audio Generator
zh_Hans: 播客音频生成器
description:
human:
en_US: Generate a podcast audio file from a script with two alternating voices using OpenAI's TTS service.
zh_Hans: 使用 OpenAI 的 TTS 服务,从包含两个交替声音的脚本生成播客音频文件。
llm: This tool converts a prepared podcast script into an audio file using OpenAI's Text-to-Speech service, with two specified voices for alternating hosts.
parameters:
- name: script
type: string
required: true
label:
en_US: Podcast Script
zh_Hans: 播客脚本
human_description:
en_US: A string containing alternating lines for two hosts, separated by newline characters.
zh_Hans: 包含两位主持人交替台词的字符串,每行用换行符分隔。
llm_description: A string representing the script, with alternating lines for two hosts separated by newline characters.
form: llm
- name: host1_voice
type: select
required: true
label:
en_US: Host 1 Voice
zh_Hans: 主持人1 音色
human_description:
en_US: The voice for the first host.
zh_Hans: 第一位主持人的音色。
llm_description: The voice identifier for the first host's voice.
options:
- label:
en_US: Alloy
zh_Hans: Alloy
value: alloy
- label:
en_US: Echo
zh_Hans: Echo
value: echo
- label:
en_US: Fable
zh_Hans: Fable
value: fable
- label:
en_US: Onyx
zh_Hans: Onyx
value: onyx
- label:
en_US: Nova
zh_Hans: Nova
value: nova
- label:
en_US: Shimmer
zh_Hans: Shimmer
value: shimmer
form: form
- name: host2_voice
type: select
required: true
label:
en_US: Host 2 Voice
zh_Hans: 主持人2 音色
human_description:
en_US: The voice for the second host.
zh_Hans: 第二位主持人的音色。
llm_description: The voice identifier for the second host's voice.
options:
- label:
en_US: Alloy
zh_Hans: Alloy
value: alloy
- label:
en_US: Echo
zh_Hans: Echo
value: echo
- label:
en_US: Fable
zh_Hans: Fable
value: fable
- label:
en_US: Onyx
zh_Hans: Onyx
value: onyx
- label:
en_US: Nova
zh_Hans: Nova
value: nova
- label:
en_US: Shimmer
zh_Hans: Shimmer
value: shimmer
form: form

View File

@@ -13,7 +13,6 @@ from core.tools.errors import (
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.tools.utils.yaml_utils import load_yaml_file
@@ -208,9 +207,7 @@ class BuiltinToolProviderController(ToolProviderController):
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
default_value = ToolParameterConverter.cast_parameter_by_type(
parameter_schema.default, parameter_schema.type
)
default_value = parameter_schema.type.cast_value(parameter_schema.default)
tool_parameters[parameter] = default_value
def validate_credentials(self, credentials: dict[str, Any]) -> None:

View File

@@ -11,7 +11,6 @@ from core.tools.entities.tool_entities import (
)
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.tool import Tool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
class ToolProviderController(BaseModel, ABC):
@@ -127,9 +126,7 @@ class ToolProviderController(BaseModel, ABC):
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(
parameter_schema.default, parameter_schema.type
)
tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default)
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""

View File

@@ -1,6 +1,6 @@
from typing import Optional
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.app_config.entities import VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
@@ -23,6 +23,8 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
}
@@ -36,8 +38,8 @@ class WorkflowToolProviderController(ToolProviderController):
if not app:
raise ValueError("app not found")
controller = WorkflowToolProviderController(
**{
controller = WorkflowToolProviderController.model_validate(
{
"identity": {
"author": db_provider.user.name if db_provider.user_id and db_provider.user else "",
"name": db_provider.label,
@@ -67,7 +69,7 @@ class WorkflowToolProviderController(ToolProviderController):
:param app: the app
:return: the tool
"""
workflow: Workflow = (
workflow = (
db.session.query(Workflow)
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first()
@@ -76,14 +78,14 @@ class WorkflowToolProviderController(ToolProviderController):
raise ValueError("workflow not found")
# fetch start node
graph: dict = workflow.graph_dict
features_dict: dict = workflow.features_dict
graph = workflow.graph_dict
features_dict = workflow.features_dict
features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW)
parameters = db_provider.parameter_configurations
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity:
def fetch_workflow_variable(variable_name: str):
return next(filter(lambda x: x.variable == variable_name, variables), None)
user = db_provider.user
@@ -114,7 +116,6 @@ class WorkflowToolProviderController(ToolProviderController):
llm_description=parameter.description,
required=variable.required,
options=options,
default=variable.default,
)
)
elif features.file_upload:
@@ -123,7 +124,7 @@ class WorkflowToolProviderController(ToolProviderController):
name=parameter.name,
label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name),
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
type=ToolParameter.ToolParameterType.FILE,
type=ToolParameter.ToolParameterType.SYSTEM_FILES,
llm_description=parameter.description,
required=False,
form=parameter.form,

View File

@@ -20,10 +20,9 @@ from core.tools.entities.tool_entities import (
ToolRuntimeVariablePool,
)
from core.tools.tool_file_manager import ToolFileManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
if TYPE_CHECKING:
from core.file.file_obj import FileVar
from core.file.models import File
class Tool(BaseModel, ABC):
@@ -63,8 +62,12 @@ class Tool(BaseModel, ABC):
def __init__(self, **data: Any):
super().__init__(**data)
class VariableKey(Enum):
class VariableKey(str, Enum):
IMAGE = "image"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
CUSTOM = "custom"
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
"""
@@ -221,9 +224,7 @@ class Tool(BaseModel, ABC):
result = deepcopy(tool_parameters)
for parameter in self.parameters or []:
if parameter.name in tool_parameters:
result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(
tool_parameters[parameter.name], parameter.type
)
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
return result
@@ -295,10 +296,8 @@ class Tool(BaseModel, ABC):
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as)
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.FILE_VAR, message="", meta={"file_var": file_var}, save_as=""
)
def create_file_message(self, file: "File") -> ToolInvokeMessage:
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE, message="", meta={"file": file}, save_as="")
def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
"""

View File

@@ -3,7 +3,7 @@ import logging
from copy import deepcopy
from typing import Any, Optional, Union
from core.file.file_obj import FileTransferMethod, FileVar
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.tool.tool import Tool
from extensions.ext_database import db
@@ -45,11 +45,13 @@ class WorkflowTool(Tool):
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
# transform the tool parameters
tool_parameters, files = self._transform_args(tool_parameters)
tool_parameters, files = self._transform_args(tool_parameters=tool_parameters)
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
generator = WorkflowAppGenerator()
assert self.runtime is not None
assert self.runtime.invoke_from is not None
result = generator.generate(
app_model=app,
workflow=workflow,
@@ -74,7 +76,7 @@ class WorkflowTool(Tool):
else:
outputs, files = self._extract_files(outputs)
for file in files:
result.append(self.create_file_var_message(file))
result.append(self.create_file_message(file))
result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
result.append(self.create_json_message(outputs))
@@ -154,22 +156,22 @@ class WorkflowTool(Tool):
parameters_result = {}
files = []
for parameter in parameter_rules:
if parameter.type == ToolParameter.ToolParameterType.FILE:
if parameter.type == ToolParameter.ToolParameterType.SYSTEM_FILES:
file = tool_parameters.get(parameter.name)
if file:
try:
file_var_list = [FileVar(**f) for f in file]
for file_var in file_var_list:
file_dict = {
"transfer_method": file_var.transfer_method.value,
"type": file_var.type.value,
file_var_list = [File.model_validate(f) for f in file]
for file in file_var_list:
file_dict: dict[str, str | None] = {
"transfer_method": file.transfer_method.value,
"type": file.type.value,
}
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_var.related_id
elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
file_dict["upload_file_id"] = file_var.related_id
elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
file_dict["url"] = file_var.preview_url
if file.transfer_method == FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file.related_id
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
file_dict["upload_file_id"] = file.related_id
elif file.transfer_method == FileTransferMethod.REMOTE_URL:
file_dict["url"] = file.generate_url()
files.append(file_dict)
except Exception as e:
@@ -179,7 +181,7 @@ class WorkflowTool(Tool):
return parameters_result, files
def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]:
"""
extract files from the result
@@ -190,17 +192,13 @@ class WorkflowTool(Tool):
result = {}
for key, value in outputs.items():
if isinstance(value, list):
has_file = False
for item in value:
if isinstance(item, dict) and item.get("__variant") == "FileVar":
try:
files.append(FileVar(**item))
has_file = True
except Exception as e:
pass
if has_file:
continue
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
file = File.model_validate(item)
files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
file = File.model_validate(value)
files.append(file)
result[key] = value
return result, files

View File

@@ -10,7 +10,8 @@ from yarl import URL
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod
from core.file import FileType
from core.file.models import FileTransferMethod
from core.ops.ops_trace_manager import TraceQueueManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
from core.tools.errors import (
@@ -26,6 +27,7 @@ from core.tools.tool.tool import Tool
from core.tools.tool.workflow_tool import WorkflowTool
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from extensions.ext_database import db
from models.enums import CreatedByRole
from models.model import Message, MessageFile
@@ -128,6 +130,7 @@ class ToolEngine:
"""
try:
# hit the callback handler
assert tool.identity is not None
workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)
if isinstance(tool, WorkflowTool):
@@ -258,7 +261,10 @@ class ToolEngine:
@staticmethod
def _create_message_files(
tool_messages: list[ToolInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str
tool_messages: list[ToolInvokeMessageBinary],
agent_message: Message,
invoke_from: InvokeFrom,
user_id: str,
) -> list[tuple[Any, str]]:
"""
Create message file
@@ -269,29 +275,31 @@ class ToolEngine:
result = []
for message in tool_messages:
file_type = "bin"
if "image" in message.mimetype:
file_type = "image"
file_type = FileType.IMAGE
elif "video" in message.mimetype:
file_type = "video"
file_type = FileType.VIDEO
elif "audio" in message.mimetype:
file_type = "audio"
elif "text" in message.mimetype:
file_type = "text"
elif "pdf" in message.mimetype:
file_type = "pdf"
elif "zip" in message.mimetype:
file_type = "archive"
# ...
file_type = FileType.AUDIO
elif "text" in message.mimetype or "pdf" in message.mimetype:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
# extract tool file id from url
tool_file_id = message.url.split("/")[-1].split(".")[0]
message_file = MessageFile(
message_id=agent_message.id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE.value,
transfer_method=FileTransferMethod.TOOL_FILE,
belongs_to="assistant",
url=message.url,
upload_file_id=None,
created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"),
upload_file_id=tool_file_id,
created_by_role=(
CreatedByRole.ACCOUNT
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatedByRole.END_USER
),
created_by=user_id,
)

View File

@@ -4,7 +4,6 @@ import hmac
import logging
import os
import time
from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4
@@ -57,22 +56,32 @@ class ToolFileManager:
@staticmethod
def create_file_by_raw(
user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str
*,
user_id: str,
tenant_id: str,
conversation_id: Optional[str],
file_binary: bytes,
mimetype: str,
) -> ToolFile:
"""
create file
"""
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
filename = f"tools/{tenant_id}/{unique_name}{extension}"
storage.save(filename, file_binary)
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, file_binary)
tool_file = ToolFile(
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_key=filepath,
mimetype=mimetype,
name=filename,
size=len(file_binary),
)
db.session.add(tool_file)
db.session.commit()
db.session.refresh(tool_file)
return tool_file
@@ -80,29 +89,34 @@ class ToolFileManager:
def create_file_by_url(
user_id: str,
tenant_id: str,
conversation_id: str,
conversation_id: str | None,
file_url: str,
) -> ToolFile:
"""
create file
"""
# try to download image
response = get(file_url)
response.raise_for_status()
blob = response.content
try:
response = get(file_url)
response.raise_for_status()
blob = response.content
except Exception as e:
logger.error(f"Failed to download file from {file_url}: {e}")
raise
mimetype = guess_type(file_url)[0] or "octet/stream"
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
filename = f"tools/{tenant_id}/{unique_name}{extension}"
storage.save(filename, blob)
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_key=filename,
file_key=filepath,
mimetype=mimetype,
original_url=file_url,
name=filename,
size=len(blob),
)
db.session.add(tool_file)
@@ -110,18 +124,6 @@ class ToolFileManager:
return tool_file
@staticmethod
def create_file_by_key(
user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str
) -> ToolFile:
"""
create file
"""
tool_file = ToolFile(
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype
)
return tool_file
@staticmethod
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
"""
@@ -131,7 +133,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
tool_file: ToolFile = (
tool_file = (
db.session.query(ToolFile)
.filter(
ToolFile.id == id,
@@ -155,7 +157,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
message_file: MessageFile = (
message_file = (
db.session.query(MessageFile)
.filter(
MessageFile.id == id,
@@ -166,13 +168,16 @@ class ToolFileManager:
# Check if message_file is not None
if message_file is not None:
# get tool file id
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
if message_file.url is not None:
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
else:
tool_file_id = None
else:
tool_file_id = None
tool_file: ToolFile = (
tool_file = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
@@ -188,7 +193,7 @@ class ToolFileManager:
return blob, tool_file.mimetype
@staticmethod
def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generator, str], None]:
def get_file_generator_by_tool_file_id(tool_file_id: str):
"""
get file binary
@@ -196,7 +201,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
tool_file: ToolFile = (
tool_file = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
@@ -205,11 +210,11 @@ class ToolFileManager:
)
if not tool_file:
return None
return None, None
generator = storage.load_stream(tool_file.file_key)
stream = storage.load_stream(tool_file.file_key)
return generator, tool_file.mimetype
return stream, tool_file
# init tool_file_parser

View File

@@ -24,7 +24,6 @@ from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
@@ -203,7 +202,7 @@ class ToolManager:
raise ToolProviderNotFoundError(f"provider type {provider_type} not found")
@classmethod
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict):
"""
init runtime parameter
"""
@@ -222,7 +221,7 @@ class ToolManager:
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
)
return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type)
return parameter_rule.type.cast_value(parameter_value)
@classmethod
def get_agent_tool_runtime(
@@ -243,7 +242,11 @@ class ToolManager:
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
# check file types
if parameter.type == ToolParameter.ToolParameterType.FILE:
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
if parameter.form == ToolParameter.ToolParameterForm.FORM:

View File

@@ -1,7 +1,8 @@
import logging
from mimetypes import guess_extension
from typing import Optional
from core.file.file_obj import FileTransferMethod, FileType
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
@@ -11,7 +12,7 @@ logger = logging.getLogger(__name__)
class ToolFileMessageTransformer:
@classmethod
def transform_tool_invoke_messages(
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str | None
) -> list[ToolInvokeMessage]:
"""
Transform tool message and handle file download
@@ -21,7 +22,7 @@ class ToolFileMessageTransformer:
for message in messages:
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
result.append(message)
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(message.message, str):
# try to download image
try:
file = ToolFileManager.create_file_by_url(
@@ -50,11 +51,14 @@ class ToolFileMessageTransformer:
)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage
assert message.meta is not None
mimetype = message.meta.get("mime_type", "octet/stream")
# if message is str, encode it to bytes
if isinstance(message.message, str):
message.message = message.message.encode("utf-8")
# FIXME: should do a type check here.
assert isinstance(message.message, bytes)
file = ToolFileManager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
@@ -63,7 +67,7 @@ class ToolFileMessageTransformer:
mimetype=mimetype,
)
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
# check if file is image
if "image" in mimetype:
@@ -84,12 +88,14 @@ class ToolFileMessageTransformer:
meta=message.meta.copy() if message.meta is not None else {},
)
)
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
file_var = message.meta.get("file_var")
if file_var:
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
if file_var.type == FileType.IMAGE:
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
file = message.meta.get("file")
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE:
result.append(
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
@@ -107,11 +113,13 @@ class ToolFileMessageTransformer:
meta=message.meta.copy() if message.meta is not None else {},
)
)
else:
result.append(message)
else:
result.append(message)
return result
@classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
return f'/files/tools/{tool_file_id}{extension or ".bin"}'

View File

@@ -1,71 +0,0 @@
from typing import Any
from core.tools.entities.tool_entities import ToolParameter
class ToolParameterConverter:
@staticmethod
def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str:
match parameter_type:
case (
ToolParameter.ToolParameterType.STRING
| ToolParameter.ToolParameterType.SECRET_INPUT
| ToolParameter.ToolParameterType.SELECT
):
return "string"
case ToolParameter.ToolParameterType.BOOLEAN:
return "boolean"
case ToolParameter.ToolParameterType.NUMBER:
return "number"
case _:
raise ValueError(f"Unsupported parameter type {parameter_type}")
@staticmethod
def cast_parameter_by_type(value: Any, parameter_type: str) -> Any:
# convert tool parameter config to correct type
try:
match parameter_type:
case (
ToolParameter.ToolParameterType.STRING
| ToolParameter.ToolParameterType.SECRET_INPUT
| ToolParameter.ToolParameterType.SELECT
):
if value is None:
return ""
else:
return value if isinstance(value, str) else str(value)
case ToolParameter.ToolParameterType.BOOLEAN:
if value is None:
return False
elif isinstance(value, str):
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
# and also '0' for False and '1' for True
match value.lower():
case "true" | "yes" | "y" | "1":
return True
case "false" | "no" | "n" | "0":
return False
case _:
return bool(value)
else:
return value if isinstance(value, bool) else bool(value)
case ToolParameter.ToolParameterType.NUMBER:
if isinstance(value, int) | isinstance(value, float):
return value
elif isinstance(value, str) and value != "":
if "." in value:
return float(value)
else:
return int(value)
case ToolParameter.ToolParameterType.FILE:
return value
case _:
return str(value)
except Exception:
raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.")

View File

@@ -1,19 +1,18 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[dict]):
"""
check parameter configurations
"""
def check_parameter_configurations(cls, configurations: Mapping[str, Any]):
for configuration in configurations:
if not WorkflowToolParameterConfiguration(**configuration):
raise ValueError("invalid parameter configuration")
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""
get workflow graph variables
"""

View File

@@ -1,4 +1,5 @@
import logging
from pathlib import Path
from typing import Any
import yaml
@@ -17,15 +18,18 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any
:param default_value: the value returned when errors ignored
:return: an object of the YAML content
"""
try:
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content or default_value
except Exception as e:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}")
except Exception as e:
if not file_path or not Path(file_path).exists():
if ignore_error:
return default_value
else:
raise e
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content or default_value
except Exception as e:
if ignore_error:
return default_value
else:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e