feat/enhance the multi-modal support (#8818)

This commit is contained in:
-LAN-
2024-10-21 10:43:49 +08:00
committed by GitHub
parent 7a1d6fe509
commit e61752bd3a
267 changed files with 6263 additions and 3523 deletions

View File

@@ -1,7 +1,8 @@
import logging
from mimetypes import guess_extension
from typing import Optional
from core.file.file_obj import FileTransferMethod, FileType
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
@@ -11,7 +12,7 @@ logger = logging.getLogger(__name__)
class ToolFileMessageTransformer:
@classmethod
def transform_tool_invoke_messages(
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str | None
) -> list[ToolInvokeMessage]:
"""
Transform tool message and handle file download
@@ -21,7 +22,7 @@ class ToolFileMessageTransformer:
for message in messages:
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
result.append(message)
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(message.message, str):
# try to download image
try:
file = ToolFileManager.create_file_by_url(
@@ -50,11 +51,14 @@ class ToolFileMessageTransformer:
)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage
assert message.meta is not None
mimetype = message.meta.get("mime_type", "octet/stream")
# if message is str, encode it to bytes
if isinstance(message.message, str):
message.message = message.message.encode("utf-8")
# FIXME: should do a type check here.
assert isinstance(message.message, bytes)
file = ToolFileManager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
@@ -63,7 +67,7 @@ class ToolFileMessageTransformer:
mimetype=mimetype,
)
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
# check if file is image
if "image" in mimetype:
@@ -84,12 +88,14 @@ class ToolFileMessageTransformer:
meta=message.meta.copy() if message.meta is not None else {},
)
)
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
file_var = message.meta.get("file_var")
if file_var:
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
if file_var.type == FileType.IMAGE:
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
file = message.meta.get("file")
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE:
result.append(
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
@@ -107,11 +113,13 @@ class ToolFileMessageTransformer:
meta=message.meta.copy() if message.meta is not None else {},
)
)
else:
result.append(message)
else:
result.append(message)
return result
@classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
return f'/files/tools/{tool_file_id}{extension or ".bin"}'

View File

@@ -1,71 +0,0 @@
from typing import Any
from core.tools.entities.tool_entities import ToolParameter
class ToolParameterConverter:
@staticmethod
def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str:
match parameter_type:
case (
ToolParameter.ToolParameterType.STRING
| ToolParameter.ToolParameterType.SECRET_INPUT
| ToolParameter.ToolParameterType.SELECT
):
return "string"
case ToolParameter.ToolParameterType.BOOLEAN:
return "boolean"
case ToolParameter.ToolParameterType.NUMBER:
return "number"
case _:
raise ValueError(f"Unsupported parameter type {parameter_type}")
@staticmethod
def cast_parameter_by_type(value: Any, parameter_type: str) -> Any:
# convert tool parameter config to correct type
try:
match parameter_type:
case (
ToolParameter.ToolParameterType.STRING
| ToolParameter.ToolParameterType.SECRET_INPUT
| ToolParameter.ToolParameterType.SELECT
):
if value is None:
return ""
else:
return value if isinstance(value, str) else str(value)
case ToolParameter.ToolParameterType.BOOLEAN:
if value is None:
return False
elif isinstance(value, str):
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
# and also '0' for False and '1' for True
match value.lower():
case "true" | "yes" | "y" | "1":
return True
case "false" | "no" | "n" | "0":
return False
case _:
return bool(value)
else:
return value if isinstance(value, bool) else bool(value)
case ToolParameter.ToolParameterType.NUMBER:
if isinstance(value, int) | isinstance(value, float):
return value
elif isinstance(value, str) and value != "":
if "." in value:
return float(value)
else:
return int(value)
case ToolParameter.ToolParameterType.FILE:
return value
case _:
return str(value)
except Exception:
raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.")

View File

@@ -1,19 +1,18 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[dict]):
"""
check parameter configurations
"""
def check_parameter_configurations(cls, configurations: Mapping[str, Any]):
for configuration in configurations:
if not WorkflowToolParameterConfiguration(**configuration):
raise ValueError("invalid parameter configuration")
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""
get workflow graph variables
"""

View File

@@ -1,4 +1,5 @@
import logging
from pathlib import Path
from typing import Any
import yaml
@@ -17,15 +18,18 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any
:param default_value: the value returned when errors ignored
:return: an object of the YAML content
"""
try:
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content or default_value
except Exception as e:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}")
except Exception as e:
if not file_path or not Path(file_path).exists():
if ignore_error:
return default_value
else:
raise e
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content or default_value
except Exception as e:
if ignore_error:
return default_value
else:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e