Feat/workflow phase2 (#4687)
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentials
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentials, ToolProviderType
|
||||
from core.tools.tool.tool import ToolParameter
|
||||
|
||||
|
||||
@@ -14,27 +14,38 @@ class UserTool(BaseModel):
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[ToolParameter]]
|
||||
labels: list[str] = None
|
||||
|
||||
UserToolProviderTypeLiteral = Optional[Literal[
|
||||
'builtin', 'api', 'workflow'
|
||||
]]
|
||||
|
||||
class UserToolProvider(BaseModel):
|
||||
class ProviderType(Enum):
|
||||
BUILTIN = "builtin"
|
||||
APP = "app"
|
||||
API = "api"
|
||||
|
||||
id: str
|
||||
author: str
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str
|
||||
label: I18nObject # label
|
||||
type: ProviderType
|
||||
type: ToolProviderType
|
||||
masked_credentials: dict = None
|
||||
original_credentials: dict = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
tools: list[UserTool] = None
|
||||
labels: list[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
# -------------
|
||||
# overwrite tool parameter types for temp fix
|
||||
tools = jsonable_encoder(self.tools)
|
||||
for tool in tools:
|
||||
if tool.get('parameters'):
|
||||
for parameter in tool.get('parameters'):
|
||||
if parameter.get('type') == ToolParameter.ToolParameterType.FILE.value:
|
||||
parameter['type'] = 'files'
|
||||
# -------------
|
||||
|
||||
return {
|
||||
'id': self.id,
|
||||
'author': self.author,
|
||||
@@ -46,7 +57,8 @@ class UserToolProvider(BaseModel):
|
||||
'team_credentials': self.masked_credentials,
|
||||
'is_team_authorization': self.is_team_authorization,
|
||||
'allow_delete': self.allow_delete,
|
||||
'tools': self.tools
|
||||
'tools': tools,
|
||||
'labels': self.labels,
|
||||
}
|
||||
|
||||
class UserToolProviderCredentials(BaseModel):
|
||||
@@ -1,3 +0,0 @@
|
||||
class DEFAULT_PROVIDERS:
|
||||
API_BASED = '__api_based'
|
||||
APP_BASED = '__app_based'
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
|
||||
|
||||
class ApiBasedToolBundle(BaseModel):
|
||||
class ApiToolBundle(BaseModel):
|
||||
"""
|
||||
This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc.
|
||||
"""
|
||||
@@ -25,12 +25,3 @@ class ApiBasedToolBundle(BaseModel):
|
||||
icon: Optional[str] = None
|
||||
# openapi operation
|
||||
openapi: dict
|
||||
|
||||
class AppToolBundle(BaseModel):
|
||||
"""
|
||||
This class is used to store the schema information of an tool for an app.
|
||||
"""
|
||||
type: ToolProviderType
|
||||
credential: Optional[dict[str, Any]] = None
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
@@ -10,10 +10,11 @@ class ToolProviderType(Enum):
|
||||
"""
|
||||
Enum class for tool provider
|
||||
"""
|
||||
BUILT_IN = "built-in"
|
||||
BUILT_IN = "builtin"
|
||||
WORKFLOW = "workflow"
|
||||
API = "api"
|
||||
APP = "app"
|
||||
DATASET_RETRIEVAL = "dataset-retrieval"
|
||||
APP_BASED = "app-based"
|
||||
API_BASED = "api-based"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'ToolProviderType':
|
||||
@@ -77,6 +78,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
LINK = "link"
|
||||
BLOB = "blob"
|
||||
IMAGE_LINK = "image_link"
|
||||
FILE_VAR = "file_var"
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
"""
|
||||
@@ -90,6 +92,7 @@ class ToolInvokeMessageBinary(BaseModel):
|
||||
mimetype: str = Field(..., description="The mimetype of the binary")
|
||||
url: str = Field(..., description="The url of the binary")
|
||||
save_as: str = ''
|
||||
file_var: Optional[dict[str, Any]] = None
|
||||
|
||||
class ToolParameterOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
@@ -102,6 +105,7 @@ class ToolParameter(BaseModel):
|
||||
BOOLEAN = "boolean"
|
||||
SELECT = "select"
|
||||
SECRET_INPUT = "secret-input"
|
||||
FILE = "file"
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
@@ -331,6 +335,15 @@ class ModelToolProviderConfiguration(BaseModel):
|
||||
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
|
||||
label: I18nObject = Field(..., description="The label of the model tool")
|
||||
|
||||
|
||||
class WorkflowToolParameterConfiguration(BaseModel):
|
||||
"""
|
||||
Workflow tool configuration
|
||||
"""
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
description: str = Field(..., description="The description of the parameter")
|
||||
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
|
||||
|
||||
class ToolInvokeMeta(BaseModel):
|
||||
"""
|
||||
Tool invoke meta
|
||||
@@ -358,4 +371,19 @@ class ToolInvokeMeta(BaseModel):
|
||||
'time_cost': self.time_cost,
|
||||
'error': self.error,
|
||||
'tool_config': self.tool_config,
|
||||
}
|
||||
}
|
||||
|
||||
class ToolLabel(BaseModel):
|
||||
"""
|
||||
Tool label
|
||||
"""
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
class ToolInvokeFrom(Enum):
|
||||
"""
|
||||
Enum class for tool invoke
|
||||
"""
|
||||
WORKFLOW = "workflow"
|
||||
AGENT = "agent"
|
||||
96
api/core/tools/entities/values.py
Normal file
96
api/core/tools/entities/values.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from enum import Enum
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolLabel
|
||||
|
||||
|
||||
class ToolLabelEnum(Enum):
|
||||
SEARCH = 'search'
|
||||
IMAGE = 'image'
|
||||
VIDEOS = 'videos'
|
||||
WEATHER = 'weather'
|
||||
FINANCE = 'finance'
|
||||
DESIGN = 'design'
|
||||
TRAVEL = 'travel'
|
||||
SOCIAL = 'social'
|
||||
NEWS = 'news'
|
||||
MEDICAL = 'medical'
|
||||
PRODUCTIVITY = 'productivity'
|
||||
EDUCATION = 'education'
|
||||
BUSINESS = 'business'
|
||||
ENTERTAINMENT = 'entertainment'
|
||||
UTILITIES = 'utilities'
|
||||
OTHER = 'other'
|
||||
|
||||
ICONS = {
|
||||
ToolLabelEnum.SEARCH: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M7.33398 1.3335C10.646 1.3335 13.334 4.0215 13.334 7.3335C13.334 10.6455 10.646 13.3335 7.33398 13.3335C4.02198 13.3335 1.33398 10.6455 1.33398 7.3335C1.33398 4.0215 4.02198 1.3335 7.33398 1.3335ZM7.33398 12.0002C9.91232 12.0002 12.0007 9.91183 12.0007 7.3335C12.0007 4.75516 9.91232 2.66683 7.33398 2.66683C4.75565 2.66683 2.66732 4.75516 2.66732 7.3335C2.66732 9.91183 4.75565 12.0002 7.33398 12.0002ZM12.9909 12.0476L14.8764 13.9332L13.9337 14.876L12.0481 12.9904L12.9909 12.0476Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.IMAGE: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M13.0514 9.71752L10.4718 7.13792C10.2115 6.87752 9.78932 6.87752 9.52898 7.13792L4.57721 12.0897C3.4097 11.1113 2.66732 9.64232 2.66732 7.99992C2.66732 5.0544 5.05513 2.66659 8.00065 2.66659C10.9462 2.66659 13.334 5.0544 13.334 7.99992C13.334 8.60085 13.2346 9.17852 13.0514 9.71752ZM5.72683 12.8257L10.0004 8.55212L12.4259 10.9777C11.4668 12.4001 9.84152 13.3331 8.00038 13.3331C7.18632 13.3331 6.41628 13.1511 5.72683 12.8257ZM8.00065 14.6666C11.6825 14.6666 14.6673 11.6818 14.6673 7.99992C14.6673 4.31802 11.6825 1.33325 8.00065 1.33325C4.31875 1.33325 1.33398 4.31802 1.33398 7.99992C1.33398 11.6818 4.31875 14.6666 8.00065 14.6666ZM7.33398 6.66658C7.33398 7.40299 6.73705 7.99992 6.00065 7.99992C5.26427 7.99992 4.66732 7.40299 4.66732 6.66658C4.66732 5.9302 5.26427 5.33325 6.00065 5.33325C6.73705 5.33325 7.33398 5.9302 7.33398 6.66658Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.VIDEOS: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00065 13.3333H13.334V14.6666H8.00065C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992C14.6673 9.50072 14.1714 10.8857 13.3345 11.9999H11.5284C12.6356 11.0227 13.334 9.59285 13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333ZM8.00065 6.66658C7.26425 6.66658 6.66732 6.06963 6.66732 5.33325C6.66732 4.59687 7.26425 3.99992 8.00065 3.99992C8.73705 3.99992 9.33398 4.59687 9.33398 5.33325C9.33398 6.06963 8.73705 6.66658 8.00065 6.66658ZM5.33398 9.33325C4.5976 9.33325 4.00065 8.73632 4.00065 7.99992C4.00065 7.26352 4.5976 6.66658 5.33398 6.66658C6.07036 6.66658 6.66732 7.26352 6.66732 7.99992C6.66732 8.73632 6.07036 9.33325 5.33398 9.33325ZM10.6673 9.33325C9.93092 9.33325 9.33398 8.73632 9.33398 7.99992C9.33398 7.26352 9.93092 6.66658 10.6673 6.66658C11.4037 6.66658 12.0007 7.26352 12.0007 7.99992C12.0007 8.73632 11.4037 9.33325 10.6673 9.33325ZM8.00065 11.9999C7.26425 11.9999 6.66732 11.403 6.66732 10.6666C6.66732 9.93018 7.26425 9.33325 8.00065 9.33325C8.73705 9.33325 9.33398 9.93018 9.33398 10.6666C9.33398 11.403 8.73705 11.9999 8.00065 11.9999Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.WEATHER: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M6.6553 3.37344C7.42088 2.1484 8.78162 1.3335 10.3327 1.3335C12.7259 1.3335 14.666 3.2736 14.666 5.66683C14.666 6.38704 14.4903 7.06623 14.1794 7.66383C14.8894 8.3325 15.3327 9.28123 15.3327 10.3335C15.3327 12.3586 13.6911 14.0002 11.666 14.0002H5.99935C3.05383 14.0002 0.666016 11.6124 0.666016 8.66683C0.666016 5.72131 3.05383 3.3335 5.99935 3.3335C6.22143 3.3335 6.44034 3.34707 6.6553 3.37344ZM8.03628 3.73629C9.37768 4.29108 10.4435 5.37735 10.9711 6.73256C11.1961 6.68943 11.4284 6.66683 11.666 6.66683C12.1561 6.66683 12.6237 6.76296 13.0511 6.93743C13.2317 6.55162 13.3327 6.12102 13.3327 5.66683C13.3327 4.00998 11.9895 2.66683 10.3327 2.66683C9.41115 2.66683 8.58662 3.08236 8.03628 3.73629ZM11.666 12.6668C12.9547 12.6668 13.9993 11.6222 13.9993 10.3335C13.9993 9.04483 12.9547 8.00016 11.666 8.00016C11.013 8.00016 10.4227 8.26836 9.99922 8.70063C9.99928 8.68936 9.99935 8.6781 9.99935 8.66683C9.99935 6.45769 8.20848 4.66683 5.99935 4.66683C3.79021 4.66683 1.99935 6.45769 1.99935 8.66683C1.99935 10.876 3.79021 12.6668 5.99935 12.6668H11.666Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.FINANCE: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00262 14.6685C4.32071 14.6685 1.33594 11.6838 1.33594 8.00184C1.33594 4.31997 4.32071 1.33521 8.00262 1.33521C11.6845 1.33521 14.6693 4.31997 14.6693 8.00184C14.6693 11.6838 11.6845 14.6685 8.00262 14.6685ZM8.00262 13.3352C10.9482 13.3352 13.336 10.9474 13.336 8.00184C13.336 5.05635 10.9482 2.66854 8.00262 2.66854C5.05708 2.66854 2.66927 5.05635 2.66927 8.00184C2.66927 10.9474 5.05708 13.3352 8.00262 13.3352ZM5.66927 9.33517H9.33595C9.52002 9.33517 9.66928 9.18597 9.66928 9.00184C9.66928 8.81777 9.52002 8.66851 9.33595 8.66851H6.66928C5.7488 8.66851 5.0026 7.92237 5.0026 7.00184C5.0026 6.08139 5.7488 5.33521 6.66928 5.33521H7.33595V4.00187H8.66928V5.33521H10.336V6.66851H6.66928C6.48518 6.66851 6.33594 6.81777 6.33594 7.00184C6.33594 7.18597 6.48518 7.33517 6.66928 7.33517H9.33595C10.2564 7.33517 11.0026 8.08137 11.0026 9.00184C11.0026 9.92237 10.2564 10.6685 9.33595 10.6685H8.66928V12.0018H7.33595V10.6685H5.66927V9.33517Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.DESIGN: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M4.70152 9.41416L3.2873 10.8284L5.17292 12.714L12.7154 5.17154L10.8298 3.28592L9.41557 4.70013L10.3584 5.64295L9.41557 6.58575L8.47277 5.64295L7.52997 6.58575L8.47277 7.52856L7.52997 8.47136L6.58713 7.52856L5.64433 8.47136L6.58713 9.41416L5.64433 10.357L4.70152 9.41416ZM11.3012 1.87171L14.1296 4.70013C14.39 4.96049 14.39 5.38259 14.1296 5.64295L5.64433 14.1282C5.38397 14.3886 4.96187 14.3886 4.70152 14.1282L1.87309 11.2998C1.61274 11.0394 1.61274 10.6174 1.87309 10.357L10.3584 1.87171C10.6187 1.61136 11.0408 1.61136 11.3012 1.87171ZM9.41557 12.2423L10.3584 11.2995L11.8534 12.7945H12.7962V11.8517L11.3012 10.3567L12.244 9.41383L14.0011 11.171V13.9999H11.1732L9.41557 12.2423ZM3.75861 6.58533L1.87299 4.69971C1.61265 4.43937 1.61265 4.01725 1.87299 3.75691L3.75861 1.87129C4.01896 1.61094 4.44107 1.61094 4.70142 1.87129L6.58704 3.75691L5.64423 4.69971L4.23002 3.2855L3.28721 4.22831L4.70142 5.64253L3.75861 6.58533Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.TRAVEL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M9.44839 2C9.80198 2 10.1411 2.14047 10.3912 2.39053L13.6101 5.60947C13.8602 5.85953 14.0007 6.19866 14.0007 6.55229V11.3333H15.334V12.6667L9.91652 12.6672C9.62032 13.8171 8.57638 14.6667 7.33398 14.6667C6.0916 14.6667 5.04766 13.8171 4.75146 12.6672L2.00065 12.6667C1.63246 12.6667 1.33398 12.3682 1.33398 12V3.33333C1.33398 2.59695 1.93094 2 2.66732 2H9.44839ZM7.33398 10.6667C6.5976 10.6667 6.00065 11.2636 6.00065 12C6.00065 12.7364 6.5976 13.3333 7.33398 13.3333C8.07038 13.3333 8.66732 12.7364 8.66732 12C8.66732 11.2636 8.07038 10.6667 7.33398 10.6667ZM9.44839 3.33333H2.66732V11.3333L4.75128 11.3335C5.04726 10.1833 6.09136 9.33333 7.33398 9.33333C8.57658 9.33333 9.62072 10.1833 9.91665 11.3335L12.6673 11.3333V6.55229L9.44839 3.33333ZM9.33398 4.66667V8.66667H4.00065V4.66667H9.33398ZM8.00065 6H5.33398V7.33333H8.00065V6Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.SOCIAL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333C9.09518 13.3333 10.1127 13.0035 10.9594 12.438L11.699 13.5475C10.6408 14.2545 9.36885 14.6666 8.00065 14.6666C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992V8.99992C14.6673 10.2886 13.6227 11.3333 12.334 11.3333C11.5312 11.3333 10.8231 10.9278 10.4032 10.3105C9.79678 10.9409 8.94452 11.3333 8.00065 11.3333C6.1597 11.3333 4.66732 9.84085 4.66732 7.99992C4.66732 6.15897 6.1597 4.66658 8.00065 4.66658C8.75118 4.66658 9.44378 4.91464 10.001 5.33325H11.334V8.99992C11.334 9.55219 11.7817 9.99992 12.334 9.99992C12.8863 9.99992 13.334 9.55219 13.334 8.99992V7.99992ZM8.00065 5.99992C6.89605 5.99992 6.00065 6.89532 6.00065 7.99992C6.00065 9.10452 6.89605 9.99992 8.00065 9.99992C9.10525 9.99992 10.0007 9.10452 10.0007 7.99992C10.0007 6.89532 9.10525 5.99992 8.00065 5.99992Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.NEWS: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M10.6673 13.3335V2.66683H2.66732V12.6668C2.66732 13.035 2.9658 13.3335 3.33398 13.3335H10.6673ZM12.6673 14.6668H3.33398C2.22942 14.6668 1.33398 13.7714 1.33398 12.6668V2.00016C1.33398 1.63198 1.63246 1.3335 2.00065 1.3335H11.334C11.7022 1.3335 12.0007 1.63198 12.0007 2.00016V6.66683H14.6673V12.6668C14.6673 13.7714 13.7719 14.6668 12.6673 14.6668ZM12.0007 8.00016V12.6668C12.0007 13.035 12.2991 13.3335 12.6673 13.3335C13.0355 13.3335 13.334 13.035 13.334 12.6668V8.00016H12.0007ZM4.00065 4.00016H8.00065V8.00016H4.00065V4.00016ZM5.33398 5.3335V6.66683H6.66732V5.3335H5.33398ZM4.00065 8.66683H9.33398V10.0002H4.00065V8.66683ZM4.00065 10.6668H9.33398V12.0002H4.00065V10.6668Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.MEDICAL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.79747 1.51186L10.9641 5.26464C11.1482 5.5835 11.0389 5.99122 10.7201 6.17532L9.85373 6.67474L10.5207 7.83001L9.366 8.49668L8.699 7.34141L7.83333 7.84201C7.51447 8.02608 7.10673 7.91681 6.92267 7.59794L5.69747 5.47632C4.32922 5.89145 3.33333 7.16268 3.33333 8.66654C3.33333 9.08348 3.40987 9.48248 3.54965 9.85034C4.06613 9.52254 4.67762 9.33321 5.33333 9.33321C6.45605 9.33321 7.44913 9.88828 8.05313 10.7389L13.1787 7.78014L13.8454 8.93488L8.5932 11.9672C8.64133 12.1927 8.66667 12.4267 8.66667 12.6665C8.66667 12.895 8.64367 13.1181 8.59993 13.3337L14 13.3332V14.6665L2.66703 14.6673C2.2482 14.1101 2 13.4173 2 12.6665C2 11.9951 2.19855 11.3699 2.54014 10.8467C2.19517 10.1964 2 9.45428 2 8.66654C2 6.66968 3.25421 4.96575 5.01785 4.29953L4.75598 3.84519C4.38779 3.20747 4.60629 2.39202 5.24402 2.02382L6.97607 1.02382C7.6138 0.655637 8.42927 0.874138 8.79747 1.51186ZM5.33333 10.6665C4.22877 10.6665 3.33333 11.562 3.33333 12.6665C3.33333 12.9003 3.37343 13.1247 3.44711 13.3331H7.21953C7.29327 13.1247 7.33333 12.9003 7.33333 12.6665C7.33333 11.562 6.4379 10.6665 5.33333 10.6665ZM7.64273 2.17852L5.91068 3.17852L7.744 6.35395L9.47607 5.35395L7.64273 2.17852Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.PRODUCTIVITY: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M6.64807 11.9999H9.35062C9.43862 11.1989 9.84742 10.5376 10.5111 9.81499C10.5858 9.73365 11.0652 9.23752 11.1221 9.16665C11.6872 8.46199 11.9993 7.58992 11.9993 6.66659C11.9993 4.45745 10.2085 2.66659 7.99935 2.66659C5.79021 2.66659 3.99935 4.45745 3.99935 6.66659C3.99935 7.58945 4.31118 8.46105 4.87576 9.16552C4.93271 9.23659 5.41322 9.73405 5.48704 9.81445C6.15112 10.5375 6.56004 11.1989 6.64807 11.9999ZM9.33268 13.3333H6.66602V13.9999H9.33268V13.3333ZM3.83532 9.99939C3.10365 9.08639 2.66602 7.92759 2.66602 6.66659C2.66602 3.72107 5.05383 1.33325 7.99935 1.33325C10.9449 1.33325 13.3327 3.72107 13.3327 6.66659C13.3327 7.92825 12.8945 9.08759 12.1622 10.0009C11.7487 10.5165 10.666 11.3333 10.666 12.3333V13.9999C10.666 14.7363 10.0691 15.3333 9.33268 15.3333H6.66602C5.92964 15.3333 5.33268 14.7363 5.33268 13.9999V12.3333C5.33268 11.3333 4.24907 10.5157 3.83532 9.99939ZM8.66602 6.66979H10.3327L7.33268 10.6698V8.00312H5.66602L8.66602 3.99992V6.66979Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.EDUCATION: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M14 2.66683H4.66667C3.93029 2.66683 3.33333 3.26378 3.33333 4.00016C3.33333 4.73654 3.93029 5.3335 4.66667 5.3335H14V14.0002C14 14.3684 13.7015 14.6668 13.3333 14.6668H4.66667C3.19391 14.6668 2 13.4729 2 12.0002V4.00016C2 2.5274 3.19391 1.3335 4.66667 1.3335H13.3333C13.7015 1.3335 14 1.63198 14 2.00016V2.66683ZM3.33333 12.0002C3.33333 12.7366 3.93029 13.3335 4.66667 13.3335H12.6667V6.66683H4.66667C4.18095 6.66683 3.72557 6.53697 3.33333 6.31008V12.0002ZM13.3333 4.66683H4.66667C4.29848 4.66683 4 4.36835 4 4.00016C4 3.63198 4.29848 3.3335 4.66667 3.3335H13.3333V4.66683Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.BUSINESS: '''<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 14 14" fill="none">
|
||||
<path d="M3.66732 3.33341V1.33341C3.66732 0.965228 3.9658 0.666748 4.33398 0.666748H9.66732C10.0355 0.666748 10.334 0.965228 10.334 1.33341V3.33341H13.0007C13.3689 3.33341 13.6673 3.63189 13.6673 4.00008V13.3334C13.6673 13.7016 13.3689 14.0001 13.0007 14.0001H1.00065C0.632464 14.0001 0.333984 13.7016 0.333984 13.3334V4.00008C0.333984 3.63189 0.632464 3.33341 1.00065 3.33341H3.66732ZM12.334 8.66675H1.66732V12.6667H12.334V8.66675ZM12.334 4.66675H1.66732V7.33341H3.66732V6.00008H5.00065V7.33341H9.00065V6.00008H10.334V7.33341H12.334V4.66675ZM5.00065 2.00008V3.33341H9.00065V2.00008H5.00065Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.ENTERTAINMENT: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M11.3327 2.66675C13.5418 2.66675 15.3327 4.45761 15.3327 6.66675V9.33342C15.3327 11.5425 13.5418 13.3334 11.3327 13.3334H4.66602C2.45688 13.3334 0.666016 11.5425 0.666016 9.33342V6.66675C0.666016 4.45761 2.45688 2.66675 4.66602 2.66675H11.3327ZM11.3327 4.00008H4.66602C3.23788 4.00008 2.07196 5.12273 2.00262 6.53365L1.99935 6.66675V9.33342C1.99935 10.7615 3.122 11.9275 4.53292 11.9968L4.66602 12.0001H11.3327C12.7608 12.0001 13.9267 10.8774 13.9961 9.46648L13.9993 9.33342V6.66675C13.9993 5.23861 12.8767 4.07269 11.4657 4.00335L11.3327 4.00008ZM6.66602 6.00008V7.33342H7.99935V8.66675H6.66535L6.66602 10.0001H5.33268L5.33202 8.66675H3.99935V7.33342H5.33268V6.00008H6.66602ZM11.9993 8.66675V10.0001H10.666V8.66675H11.9993ZM10.666 6.00008V7.33342H9.33268V6.00008H10.666Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.UTILITIES: '''<svg xmlns="http://www.w3.org/2000/svg" width="13" height="15" viewBox="0 0 13 15" fill="none">
|
||||
<path d="M12.3346 0.333252C12.7028 0.333252 13.0013 0.631732 13.0013 0.999919V4.33325C13.0013 4.70144 12.7028 4.99992 12.3346 4.99992H9.0013V13.6666C9.0013 14.0348 8.70284 14.3333 8.33463 14.3333H5.66797C5.29978 14.3333 5.0013 14.0348 5.0013 13.6666V4.99992H1.33464C0.966449 4.99992 0.667969 4.70144 0.667969 4.33325V2.74527C0.667969 2.49276 0.810635 2.26192 1.0365 2.14899L4.66797 0.333252H12.3346ZM9.0013 1.66659H4.98273L2.0013 3.1573V3.66659H6.33464V12.9999H7.66797V3.66659H9.0013V1.66659ZM11.668 1.66659H10.3346V3.66659H11.668V1.66659Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.OTHER: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00052 0.666748L4.00065 7.33342H12.0007L8.00052 0.666748ZM8.00052 3.25828L9.64572 6.00008H6.35553L8.00052 3.25828ZM4.50065 13.3334C3.48813 13.3334 2.66732 12.5126 2.66732 11.5001C2.66732 10.4875 3.48813 9.66675 4.50065 9.66675C5.51317 9.66675 6.33398 10.4875 6.33398 11.5001C6.33398 12.5126 5.51317 13.3334 4.50065 13.3334ZM4.50065 14.6667C6.24955 14.6667 7.66732 13.249 7.66732 11.5001C7.66732 9.75115 6.24955 8.33342 4.50065 8.33342C2.75175 8.33342 1.33398 9.75115 1.33398 11.5001C1.33398 13.249 2.75175 14.6667 4.50065 14.6667ZM10.0007 10.3334V13.0001H12.6673V10.3334H10.0007ZM8.66732 14.3334V9.00008H14.0007V14.3334H8.66732Z" fill="#344054"/>
|
||||
</svg>'''
|
||||
}
|
||||
|
||||
default_tool_label_dict = {
|
||||
ToolLabelEnum.SEARCH: ToolLabel(name='search', label=I18nObject(en_US='Search', zh_Hans='搜索'), icon=ICONS[ToolLabelEnum.SEARCH]),
|
||||
ToolLabelEnum.IMAGE: ToolLabel(name='image', label=I18nObject(en_US='Image', zh_Hans='图片'), icon=ICONS[ToolLabelEnum.IMAGE]),
|
||||
ToolLabelEnum.VIDEOS: ToolLabel(name='videos', label=I18nObject(en_US='Videos', zh_Hans='视频'), icon=ICONS[ToolLabelEnum.VIDEOS]),
|
||||
ToolLabelEnum.WEATHER: ToolLabel(name='weather', label=I18nObject(en_US='Weather', zh_Hans='天气'), icon=ICONS[ToolLabelEnum.WEATHER]),
|
||||
ToolLabelEnum.FINANCE: ToolLabel(name='finance', label=I18nObject(en_US='Finance', zh_Hans='金融'), icon=ICONS[ToolLabelEnum.FINANCE]),
|
||||
ToolLabelEnum.DESIGN: ToolLabel(name='design', label=I18nObject(en_US='Design', zh_Hans='设计'), icon=ICONS[ToolLabelEnum.DESIGN]),
|
||||
ToolLabelEnum.TRAVEL: ToolLabel(name='travel', label=I18nObject(en_US='Travel', zh_Hans='旅行'), icon=ICONS[ToolLabelEnum.TRAVEL]),
|
||||
ToolLabelEnum.SOCIAL: ToolLabel(name='social', label=I18nObject(en_US='Social', zh_Hans='社交'), icon=ICONS[ToolLabelEnum.SOCIAL]),
|
||||
ToolLabelEnum.NEWS: ToolLabel(name='news', label=I18nObject(en_US='News', zh_Hans='新闻'), icon=ICONS[ToolLabelEnum.NEWS]),
|
||||
ToolLabelEnum.MEDICAL: ToolLabel(name='medical', label=I18nObject(en_US='Medical', zh_Hans='医疗'), icon=ICONS[ToolLabelEnum.MEDICAL]),
|
||||
ToolLabelEnum.PRODUCTIVITY: ToolLabel(name='productivity', label=I18nObject(en_US='Productivity', zh_Hans='生产力'), icon=ICONS[ToolLabelEnum.PRODUCTIVITY]),
|
||||
ToolLabelEnum.EDUCATION: ToolLabel(name='education', label=I18nObject(en_US='Education', zh_Hans='教育'), icon=ICONS[ToolLabelEnum.EDUCATION]),
|
||||
ToolLabelEnum.BUSINESS: ToolLabel(name='business', label=I18nObject(en_US='Business', zh_Hans='商业'), icon=ICONS[ToolLabelEnum.BUSINESS]),
|
||||
ToolLabelEnum.ENTERTAINMENT: ToolLabel(name='entertainment', label=I18nObject(en_US='Entertainment', zh_Hans='娱乐'), icon=ICONS[ToolLabelEnum.ENTERTAINMENT]),
|
||||
ToolLabelEnum.UTILITIES: ToolLabel(name='utilities', label=I18nObject(en_US='Utilities', zh_Hans='工具'), icon=ICONS[ToolLabelEnum.UTILITIES]),
|
||||
ToolLabelEnum.OTHER: ToolLabel(name='other', label=I18nObject(en_US='Other', zh_Hans='其他'), icon=ICONS[ToolLabelEnum.OTHER]),
|
||||
}
|
||||
|
||||
default_tool_labels = [v for k, v in default_tool_label_dict.items()]
|
||||
default_tool_label_name_list = [label.name for label in default_tool_labels]
|
||||
@@ -1,2 +0,0 @@
|
||||
class InvokeModelError(Exception):
|
||||
pass
|
||||
@@ -1,104 +0,0 @@
|
||||
ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
{{historic_messages}}
|
||||
Question: {{query}}
|
||||
{{agent_scratchpad}}
|
||||
Thought:"""
|
||||
|
||||
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
|
||||
Thought:"""
|
||||
|
||||
ENGLISH_REACT_CHAT_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
"""
|
||||
|
||||
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
|
||||
|
||||
REACT_PROMPT_TEMPLATES = {
|
||||
'english': {
|
||||
'chat': {
|
||||
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
|
||||
},
|
||||
'completion': {
|
||||
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolCredentialsOption,
|
||||
@@ -15,11 +14,11 @@ from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class ApiBasedToolProviderController(ToolProviderController):
|
||||
class ApiToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController':
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
|
||||
credentials_schema = {
|
||||
'auth_type': ToolProviderCredentials(
|
||||
name='auth_type',
|
||||
@@ -79,9 +78,11 @@ class ApiBasedToolProviderController(ToolProviderController):
|
||||
else:
|
||||
raise ValueError(f'invalid auth type {auth_type}')
|
||||
|
||||
return ApiBasedToolProviderController(**{
|
||||
user_name = db_provider.user.name if db_provider.user_id else ''
|
||||
|
||||
return ApiToolProviderController(**{
|
||||
'identity': {
|
||||
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
|
||||
'author': user_name,
|
||||
'name': db_provider.name,
|
||||
'label': {
|
||||
'en_US': db_provider.name,
|
||||
@@ -98,16 +99,10 @@ class ApiBasedToolProviderController(ToolProviderController):
|
||||
})
|
||||
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.API_BASED
|
||||
|
||||
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.API
|
||||
|
||||
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool:
|
||||
def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
|
||||
"""
|
||||
parse tool bundle to tool
|
||||
|
||||
@@ -136,7 +131,7 @@ class ApiBasedToolProviderController(ToolProviderController):
|
||||
'parameters' : tool_bundle.parameters if tool_bundle.parameters else [],
|
||||
})
|
||||
|
||||
def load_bundled_tools(self, tools: list[ApiBasedToolBundle]) -> list[ApiTool]:
|
||||
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
|
||||
"""
|
||||
load bundled tools
|
||||
|
||||
|
||||
@@ -11,10 +11,10 @@ from models.tools import PublishedAppTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AppBasedToolProviderEntity(ToolProviderController):
|
||||
class AppToolProviderEntity(ToolProviderController):
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.APP_BASED
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.APP
|
||||
|
||||
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os.path
|
||||
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.entities.api_entities import UserToolProvider
|
||||
from core.utils.position_helper import get_position_map, sort_by_position_map
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.aippt.tools.aippt import AIPPTGenerateTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -9,3 +10,9 @@ class AIPPTProvider(BuiltinToolProviderController):
|
||||
AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__')
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.PRODUCTIVITY,
|
||||
ToolLabelEnum.DESIGN,
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.arxiv.tools.arxiv_search import ArxivSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -7,7 +8,7 @@ class ArxivProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
ArxivSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -17,4 +18,9 @@ class ArxivProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH,
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.azuredalle.tools.dalle3 import DallE3Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -9,7 +10,7 @@ class AzureDALLEProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
DallE3Tool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -22,3 +23,8 @@ class AzureDALLEProvider(BuiltinToolProviderController):
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.IMAGE
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.bing.tools.bing_web_search import BingSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -9,7 +10,7 @@ class BingProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
BingSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).validate_credentials(
|
||||
@@ -21,3 +22,8 @@ class BingProvider(BuiltinToolProviderController):
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.brave.tools.brave_search import BraveSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -9,7 +10,7 @@ class BraveProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
BraveSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -19,4 +20,9 @@ class BraveProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH,
|
||||
]
|
||||
@@ -2,6 +2,7 @@ import matplotlib.pyplot as plt
|
||||
from fontTools.ttLib import TTFont
|
||||
from matplotlib.font_manager import findSystemFonts
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.chart.tools.line import LinearChartTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -44,7 +45,7 @@ class ChartProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
LinearChartTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -54,4 +55,9 @@ class ChartProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.DESIGN, ToolLabelEnum.PRODUCTIVITY, ToolLabelEnum.UTILITIES
|
||||
]
|
||||
@@ -1,8 +1,14 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class CodeToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
pass
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.dalle.tools.dalle2 import DallE2Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -9,7 +10,7 @@ class DALLEProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
DallE2Tool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -21,4 +22,9 @@ class DALLEProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.IMAGE, ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.devdocs.tools.searchDevDocs import SearchDevDocsTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -7,7 +8,7 @@ class DevDocsProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
SearchDevDocsTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -18,4 +19,9 @@ class DevDocsProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.provider.builtin.dingtalk.tools.dingtalk_group_bot import DingTalkGroupBotTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@@ -6,3 +7,8 @@ class DingTalkProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
DingTalkGroupBotTool()
|
||||
pass
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SOCIAL
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.duckduckgo.tools.duckduckgo_search import DuckDuckGoSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -7,7 +8,7 @@ class DuckDuckGoProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
DuckDuckGoSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -17,4 +18,9 @@ class DuckDuckGoProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.provider.builtin.feishu.tools.feishu_group_bot import FeishuGroupBotTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@@ -6,3 +7,8 @@ class FeishuProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
FeishuGroupBotTool()
|
||||
pass
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SOCIAL
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.firecrawl.tools.crawl import CrawlTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -8,7 +9,7 @@ class FirecrawlProvider(BuiltinToolProviderController):
|
||||
try:
|
||||
# Example validation using the Crawl tool
|
||||
CrawlTool().fork_tool_runtime(
|
||||
meta={"credentials": credentials}
|
||||
runtime={"credentials": credentials}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
@@ -20,4 +21,9 @@ class FirecrawlProvider(BuiltinToolProviderController):
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH, ToolLabelEnum.UTILITIES
|
||||
]
|
||||
@@ -2,6 +2,7 @@ import urllib.parse
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@@ -24,3 +25,9 @@ class GaodeProvider(BuiltinToolProviderController):
|
||||
raise ToolProviderCredentialValidationError("Gaode API Key is invalid. {}".format(e))
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.UTILITIES, ToolLabelEnum.PRODUCTIVITY,
|
||||
ToolLabelEnum.WEATHER, ToolLabelEnum.TRAVEL
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
import requests
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@@ -30,3 +31,8 @@ class GihubProvider(BuiltinToolProviderController):
|
||||
raise ToolProviderCredentialValidationError("Github API Key and Api Version is invalid. {}".format(e))
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.UTILITIES
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -9,7 +10,7 @@ class GoogleProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
GoogleSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -20,4 +21,9 @@ class GoogleProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@@ -9,4 +10,9 @@ class GoogleProvider(BuiltinToolProviderController):
|
||||
try:
|
||||
pass
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.judge0ce.tools.executeCode import ExecuteCodeTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -9,7 +10,7 @@ class Judge0CEProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
ExecuteCodeTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -20,4 +21,9 @@ class Judge0CEProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.OTHER, ToolLabelEnum.UTILITIES
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.maths.tools.eval_expression import EvaluateExpressionTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -16,3 +17,8 @@ class MathsProvider(BuiltinToolProviderController):
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.UTILITIES, ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
import requests
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@@ -34,3 +35,8 @@ class OpenweatherProvider(BuiltinToolProviderController):
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.WEATHER
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.pubmed.tools.pubmed_search import PubMedSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -7,7 +8,7 @@ class PubMedProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
PubMedSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -17,4 +18,9 @@ class PubMedProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.MEDICAL, ToolLabelEnum.SEARCH
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.qrcode.tools.qrcode_generator import QRCodeGeneratorTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -14,3 +15,8 @@ class QRCodeProvider(BuiltinToolProviderController):
|
||||
})
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.UTILITIES
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.searxng.tools.searxng_search import SearXNGSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -9,7 +10,7 @@ class SearXNGProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
SearXNGSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -22,4 +23,9 @@ class SearXNGProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.provider.builtin.slack.tools.slack_webhook import SlackWebhookTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@@ -6,3 +7,8 @@ class SlackProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
SlackWebhookTool()
|
||||
pass
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SOCIAL
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.spark.tools.spark_img_generation import spark_response
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -38,3 +39,8 @@ class SparkProvider(BuiltinToolProviderController):
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.IMAGE
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@@ -12,4 +13,9 @@ class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthoriz
|
||||
"""
|
||||
This method is responsible for validating the credentials.
|
||||
"""
|
||||
self.sd_validate_credentials(credentials)
|
||||
self.sd_validate_credentials(credentials)
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.IMAGE
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import StableDiffusionTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -9,9 +10,14 @@ class StableDiffusionProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
StableDiffusionTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).validate_models()
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.IMAGE
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.stackexchange.tools.searchStackExQuestions import SearchStackExQuestionsTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -7,7 +8,7 @@ class StackExchangeProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
SearchStackExQuestionsTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -22,4 +23,9 @@ class StackExchangeProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH, ToolLabelEnum.UTILITIES
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.tavily.tools.tavily_search import TavilySearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -9,7 +10,7 @@ class TavilyProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
TavilySearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -26,4 +27,9 @@ class TavilyProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.time.tools.current_time import CurrentTimeTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -13,4 +14,9 @@ class WikiPediaProvider(BuiltinToolProviderController):
|
||||
tool_parameters={},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.UTILITIES
|
||||
]
|
||||
@@ -2,6 +2,7 @@ from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@@ -31,4 +32,9 @@ class TrelloProvider(BuiltinToolProviderController):
|
||||
raise ToolProviderCredentialValidationError("Error validating Trello credentials")
|
||||
except requests.exceptions.RequestException as e:
|
||||
# Handle other exceptions, such as connection errors
|
||||
raise ToolProviderCredentialValidationError("Error validating Trello credentials")
|
||||
raise ToolProviderCredentialValidationError("Error validating Trello credentials")
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
||||
@@ -3,6 +3,7 @@ from typing import Any
|
||||
from twilio.base.exceptions import TwilioRestException
|
||||
from twilio.rest import Client
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@@ -26,4 +27,9 @@ class TwilioProvider(BuiltinToolProviderController):
|
||||
except KeyError as e:
|
||||
raise ToolProviderCredentialValidationError(f"Missing required credential: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SOCIAL
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -9,7 +10,7 @@ class VectorizerProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
VectorizerTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -20,4 +21,9 @@ class VectorizerProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.PRODUCTIVITY, ToolLabelEnum.IMAGE
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.webscraper.tools.webscraper import WebscraperTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -9,7 +10,7 @@ class WebscraperProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
WebscraperTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -20,4 +21,9 @@ class WebscraperProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.PRODUCTIVITY
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.provider.builtin.wecom.tools.wecom_group_bot import WecomGroupBotTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
@@ -6,3 +7,8 @@ class WecomProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
WecomGroupBotTool()
|
||||
pass
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SOCIAL
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.wikipedia.tools.wikipedia_search import WikiPediaSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -7,7 +8,7 @@ class WikiPediaProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
WikiPediaSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -17,4 +18,9 @@ class WikiPediaProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.SEARCH
|
||||
]
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.wolframalpha.tools.wolframalpha import WolframAlphaTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -9,7 +10,7 @@ class GoogleProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
WolframAlphaTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -19,4 +20,9 @@ class GoogleProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.PRODUCTIVITY, ToolLabelEnum.UTILITIES
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.yahoo.tools.ticker import YahooFinanceSearchTickerTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -7,7 +8,7 @@ class YahooFinanceProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
YahooFinanceSearchTickerTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -17,4 +18,9 @@ class YahooFinanceProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.BUSINESS, ToolLabelEnum.FINANCE
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.tools.entities.values import ToolLabelEnum
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.youtube.tools.videos import YoutubeVideosAnalyticsTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
@@ -7,7 +8,7 @@ class YahooFinanceProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
YoutubeVideosAnalyticsTool().fork_tool_runtime(
|
||||
meta={
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
@@ -19,4 +20,9 @@ class YahooFinanceProvider(BuiltinToolProviderController):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
return [
|
||||
ToolLabelEnum.VIDEOS
|
||||
]
|
||||
@@ -2,8 +2,9 @@ from abc import abstractmethod
|
||||
from os import listdir, path
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.api_entities import UserToolProviderCredentials
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProviderCredentials
|
||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||
from core.tools.errors import (
|
||||
ToolNotFoundError,
|
||||
ToolParameterValidationError,
|
||||
@@ -19,7 +20,7 @@ from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
if self.app_type == ToolProviderType.API_BASED or self.app_type == ToolProviderType.APP_BASED:
|
||||
if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
|
||||
super().__init__(**data)
|
||||
return
|
||||
|
||||
@@ -129,13 +130,29 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
len(self.credentials_schema) != 0
|
||||
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
@property
|
||||
def tool_labels(self) -> list[str]:
|
||||
"""
|
||||
returns the labels of the provider
|
||||
|
||||
:return: labels of the provider
|
||||
"""
|
||||
label_enums = self._get_tool_labels()
|
||||
return [default_tool_label_dict[label].name for label in label_enums]
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
"""
|
||||
returns the labels of the provider
|
||||
"""
|
||||
return []
|
||||
|
||||
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
"""
|
||||
|
||||
@@ -3,13 +3,13 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.api_entities import UserToolProviderCredentials
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.entities.user_entities import UserToolProviderCredentials
|
||||
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
@@ -67,7 +67,7 @@ class ToolProviderController(BaseModel, ABC):
|
||||
return tool.parameters
|
||||
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@@ -197,26 +197,4 @@ class ToolProviderController(BaseModel, ABC):
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
# validate credentials format
|
||||
self.validate_credentials_format(credentials)
|
||||
|
||||
# validate credentials
|
||||
self._validate_credentials(credentials)
|
||||
|
||||
@abstractmethod
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
pass
|
||||
|
||||
230
api/core/tools/provider/workflow_tool_provider.py
Normal file
230
api/core/tools/provider/workflow_tool_provider.py
Normal file
@@ -0,0 +1,230 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
ToolParameter,
|
||||
ToolParameterOption,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.workflow_tool import WorkflowTool
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController':
|
||||
app = db_provider.app
|
||||
|
||||
if not app:
|
||||
raise ValueError('app not found')
|
||||
|
||||
controller = WorkflowToolProviderController(**{
|
||||
'identity': {
|
||||
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
|
||||
'name': db_provider.label,
|
||||
'label': {
|
||||
'en_US': db_provider.label,
|
||||
'zh_Hans': db_provider.label
|
||||
},
|
||||
'description': {
|
||||
'en_US': db_provider.description,
|
||||
'zh_Hans': db_provider.description
|
||||
},
|
||||
'icon': db_provider.icon,
|
||||
},
|
||||
'credentials_schema': {},
|
||||
'provider_id': db_provider.id or '',
|
||||
})
|
||||
|
||||
# init tools
|
||||
|
||||
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
|
||||
|
||||
return controller
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
|
||||
"""
|
||||
get db provider tool
|
||||
:param db_provider: the db provider
|
||||
:param app: the app
|
||||
:return: the tool
|
||||
"""
|
||||
workflow: Workflow = db.session.query(Workflow).filter(
|
||||
Workflow.app_id == db_provider.app_id,
|
||||
Workflow.version == db_provider.version
|
||||
).first()
|
||||
if not workflow:
|
||||
raise ValueError('workflow not found')
|
||||
|
||||
# fetch start node
|
||||
graph: dict = workflow.graph_dict
|
||||
features_dict: dict = workflow.features_dict
|
||||
features = WorkflowAppConfigManager.convert_features(
|
||||
config_dict=features_dict,
|
||||
app_mode=AppMode.WORKFLOW
|
||||
)
|
||||
|
||||
parameters = db_provider.parameter_configurations
|
||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||
|
||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity:
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||
|
||||
user = db_provider.user
|
||||
|
||||
workflow_tool_parameters = []
|
||||
for parameter in parameters:
|
||||
variable = fetch_workflow_variable(parameter.name)
|
||||
if variable:
|
||||
parameter_type = None
|
||||
options = None
|
||||
if variable.type in [
|
||||
VariableEntity.Type.TEXT_INPUT,
|
||||
VariableEntity.Type.PARAGRAPH,
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.STRING
|
||||
elif variable.type in [
|
||||
VariableEntity.Type.SELECT
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.SELECT
|
||||
elif variable.type in [
|
||||
VariableEntity.Type.NUMBER
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.NUMBER
|
||||
else:
|
||||
raise ValueError(f'unsupported variable type {variable.type}')
|
||||
|
||||
if variable.type == VariableEntity.Type.SELECT and variable.options:
|
||||
options = [
|
||||
ToolParameterOption(
|
||||
value=option,
|
||||
label=I18nObject(
|
||||
en_US=option,
|
||||
zh_Hans=option
|
||||
)
|
||||
) for option in variable.options
|
||||
]
|
||||
|
||||
workflow_tool_parameters.append(
|
||||
ToolParameter(
|
||||
name=parameter.name,
|
||||
label=I18nObject(
|
||||
en_US=variable.label,
|
||||
zh_Hans=variable.label
|
||||
),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.description,
|
||||
zh_Hans=parameter.description
|
||||
),
|
||||
type=parameter_type,
|
||||
form=parameter.form,
|
||||
llm_description=parameter.description,
|
||||
required=variable.required,
|
||||
options=options,
|
||||
default=variable.default
|
||||
)
|
||||
)
|
||||
elif features.file_upload:
|
||||
workflow_tool_parameters.append(
|
||||
ToolParameter(
|
||||
name=parameter.name,
|
||||
label=I18nObject(
|
||||
en_US=parameter.name,
|
||||
zh_Hans=parameter.name
|
||||
),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.description,
|
||||
zh_Hans=parameter.description
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.FILE,
|
||||
llm_description=parameter.description,
|
||||
required=False,
|
||||
form=parameter.form,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError('variable not found')
|
||||
|
||||
return WorkflowTool(
|
||||
identity=ToolIdentity(
|
||||
author=user.name if user else '',
|
||||
name=db_provider.name,
|
||||
label=I18nObject(
|
||||
en_US=db_provider.label,
|
||||
zh_Hans=db_provider.label
|
||||
),
|
||||
provider=self.provider_id,
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(
|
||||
en_US=db_provider.description,
|
||||
zh_Hans=db_provider.description
|
||||
),
|
||||
llm=db_provider.description,
|
||||
),
|
||||
parameters=workflow_tool_parameters,
|
||||
is_team_authorization=True,
|
||||
workflow_app_id=app.id,
|
||||
workflow_entities={
|
||||
'app': app,
|
||||
'workflow': workflow,
|
||||
},
|
||||
version=db_provider.version,
|
||||
workflow_call_depth=0,
|
||||
label=db_provider.label
|
||||
)
|
||||
|
||||
def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:return: the tools
|
||||
"""
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
).first()
|
||||
|
||||
if not db_providers:
|
||||
return []
|
||||
|
||||
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
:return: the tool
|
||||
"""
|
||||
if self.tools is None:
|
||||
return None
|
||||
|
||||
for tool in self.tools:
|
||||
if tool.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
return None
|
||||
@@ -8,9 +8,8 @@ import httpx
|
||||
import requests
|
||||
|
||||
import core.helper.ssrf_proxy as ssrf_proxy
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
@@ -20,12 +19,12 @@ API_TOOL_DEFAULT_TIMEOUT = (
|
||||
)
|
||||
|
||||
class ApiTool(Tool):
|
||||
api_bundle: ApiBasedToolBundle
|
||||
api_bundle: ApiToolBundle
|
||||
|
||||
"""
|
||||
Api tool
|
||||
"""
|
||||
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
@@ -37,7 +36,7 @@ class ApiTool(Tool):
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.copy() if self.description else None,
|
||||
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
|
||||
runtime=Tool.Runtime(**meta)
|
||||
runtime=Tool.Runtime(**runtime)
|
||||
)
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:
|
||||
@@ -55,7 +54,7 @@ class ApiTool(Tool):
|
||||
return self.validate_and_parse_response(response)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return UserToolProvider.ProviderType.API
|
||||
return ToolProviderType.API
|
||||
|
||||
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
headers = {}
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.model.tool_model_manager import ToolModelManager
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||
from core.tools.utils.web_reader_tool import get_url
|
||||
|
||||
_SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
|
||||
@@ -34,7 +33,7 @@ class BuiltinTool(Tool):
|
||||
:return: the model result
|
||||
"""
|
||||
# invoke model
|
||||
return ToolModelManager.invoke(
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tool_type='builtin',
|
||||
@@ -43,7 +42,7 @@ class BuiltinTool(Tool):
|
||||
)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return UserToolProvider.ProviderType.BUILTIN
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def get_max_tokens(self) -> int:
|
||||
"""
|
||||
@@ -52,7 +51,7 @@ class BuiltinTool(Tool):
|
||||
:param model_config: the model config
|
||||
:return: the max tokens
|
||||
"""
|
||||
return ToolModelManager.get_max_llm_context_tokens(
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
)
|
||||
|
||||
@@ -63,7 +62,7 @@ class BuiltinTool(Tool):
|
||||
:param prompt_messages: the prompt messages
|
||||
:return: the tokens
|
||||
"""
|
||||
return ToolModelManager.calculate_tokens(
|
||||
return ModelInvocationUtils.calculate_tokens(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
@@ -4,9 +4,12 @@ from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileVar
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
ToolInvokeFrom,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
@@ -25,10 +28,7 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
@validator('parameters', pre=True, always=True)
|
||||
def set_parameters(cls, v, values):
|
||||
if not v:
|
||||
return []
|
||||
|
||||
return v
|
||||
return v or []
|
||||
|
||||
class Runtime(BaseModel):
|
||||
"""
|
||||
@@ -41,6 +41,8 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
tenant_id: str = None
|
||||
tool_id: str = None
|
||||
invoke_from: InvokeFrom = None
|
||||
tool_invoke_from: ToolInvokeFrom = None
|
||||
credentials: dict[str, Any] = None
|
||||
runtime_parameters: dict[str, Any] = None
|
||||
|
||||
@@ -53,7 +55,7 @@ class Tool(BaseModel, ABC):
|
||||
class VARIABLE_KEY(Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
@@ -64,7 +66,7 @@ class Tool(BaseModel, ABC):
|
||||
identity=self.identity.copy() if self.identity else None,
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.copy() if self.description else None,
|
||||
runtime=Tool.Runtime(**meta),
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -208,17 +210,17 @@ class Tool(BaseModel, ABC):
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += response.message
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += f"result link: {response.message}. please tell user to check it."
|
||||
result += f"result link: {response.message}. please tell user to check it. \n"
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
|
||||
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now. \n"
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
if len(response.message) > 114:
|
||||
result += str(response.message[:114]) + '...'
|
||||
else:
|
||||
result += str(response.message)
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
result += f"tool response: {response.message}. \n"
|
||||
|
||||
return result
|
||||
|
||||
@@ -343,6 +345,14 @@ class Tool(BaseModel, ABC):
|
||||
message=image,
|
||||
save_as=save_as)
|
||||
|
||||
def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
|
||||
message='',
|
||||
meta={
|
||||
'file_var': file_var
|
||||
},
|
||||
save_as='')
|
||||
|
||||
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a link message
|
||||
|
||||
200
api/core/tools/tool/workflow_tool.py
Normal file
200
api/core/tools/tool/workflow_tool.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class WorkflowTool(Tool):
|
||||
workflow_app_id: str
|
||||
version: str
|
||||
workflow_entities: dict[str, Any]
|
||||
workflow_call_depth: int
|
||||
|
||||
label: str
|
||||
|
||||
"""
|
||||
Workflow tool.
|
||||
"""
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
get the tool provider type
|
||||
|
||||
:return: the tool provider type
|
||||
"""
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke the tool
|
||||
"""
|
||||
app = self._get_app(app_id=self.workflow_app_id)
|
||||
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
|
||||
|
||||
# transform the tool parameters
|
||||
tool_parameters, files = self._transform_args(tool_parameters)
|
||||
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
generator = WorkflowAppGenerator()
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=self._get_user(user_id),
|
||||
args={
|
||||
'inputs': tool_parameters,
|
||||
'files': files
|
||||
},
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
stream=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
)
|
||||
|
||||
data = result.get('data', {})
|
||||
|
||||
if data.get('error'):
|
||||
raise Exception(data.get('error'))
|
||||
|
||||
result = []
|
||||
|
||||
outputs = data.get('outputs', {})
|
||||
outputs, files = self._extract_files(outputs)
|
||||
for file in files:
|
||||
result.append(self.create_file_var_message(file))
|
||||
|
||||
result.append(self.create_text_message(json.dumps(outputs)))
|
||||
|
||||
return result
|
||||
|
||||
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
|
||||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
|
||||
user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||
if not user:
|
||||
user = db.session.query(Account).filter(Account.id == user_id).first()
|
||||
|
||||
if not user:
|
||||
raise ValueError('user not found')
|
||||
|
||||
return user
|
||||
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool':
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
:param meta: the meta data of a tool call processing, tenant_id is required
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
identity=deepcopy(self.identity),
|
||||
parameters=deepcopy(self.parameters),
|
||||
description=deepcopy(self.description),
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
workflow_app_id=self.workflow_app_id,
|
||||
workflow_entities=self.workflow_entities,
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
version=self.version,
|
||||
label=self.label
|
||||
)
|
||||
|
||||
def _get_workflow(self, app_id: str, version: str) -> Workflow:
|
||||
"""
|
||||
get the workflow by app id and version
|
||||
"""
|
||||
if not version:
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.app_id == app_id,
|
||||
Workflow.version != 'draft'
|
||||
).order_by(Workflow.created_at.desc()).first()
|
||||
else:
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.app_id == app_id,
|
||||
Workflow.version == version
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise ValueError('workflow not found or not published')
|
||||
|
||||
return workflow
|
||||
|
||||
def _get_app(self, app_id: str) -> App:
|
||||
"""
|
||||
get the app by app id
|
||||
"""
|
||||
app = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError('app not found')
|
||||
|
||||
return app
|
||||
|
||||
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
|
||||
"""
|
||||
transform the tool parameters
|
||||
|
||||
:param tool_parameters: the tool parameters
|
||||
:return: tool_parameters, files
|
||||
"""
|
||||
parameter_rules = self.get_all_runtime_parameters()
|
||||
parameters_result = {}
|
||||
files = []
|
||||
for parameter in parameter_rules:
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
file = tool_parameters.get(parameter.name)
|
||||
if file:
|
||||
try:
|
||||
file_var_list = [FileVar(**f) for f in file]
|
||||
for file_var in file_var_list:
|
||||
file_dict = {
|
||||
'transfer_method': file_var.transfer_method.value,
|
||||
'type': file_var.type.value,
|
||||
}
|
||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict['tool_file_id'] = file_var.related_id
|
||||
elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict['upload_file_id'] = file_var.related_id
|
||||
elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
file_dict['url'] = file_var.preview_url
|
||||
|
||||
files.append(file_dict)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
else:
|
||||
parameters_result[parameter.name] = tool_parameters.get(parameter.name)
|
||||
|
||||
return parameters_result, files
|
||||
|
||||
def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
|
||||
"""
|
||||
extract files from the result
|
||||
|
||||
:param result: the result
|
||||
:return: the result, files
|
||||
"""
|
||||
files = []
|
||||
result = {}
|
||||
for key, value in outputs.items():
|
||||
if isinstance(value, list):
|
||||
has_file = False
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item.get('__variant') == 'FileVar':
|
||||
try:
|
||||
files.append(FileVar(**item))
|
||||
has_file = True
|
||||
except Exception as e:
|
||||
pass
|
||||
if has_file:
|
||||
continue
|
||||
|
||||
result[key] = value
|
||||
|
||||
return result, files
|
||||
@@ -1,7 +1,10 @@
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from mimetypes import guess_type
|
||||
from typing import Union
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
@@ -17,6 +20,7 @@ from core.tools.errors import (
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool.workflow_tool import WorkflowTool
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, MessageFile
|
||||
@@ -115,7 +119,8 @@ class ToolEngine:
|
||||
@staticmethod
|
||||
def workflow_invoke(tool: Tool, tool_parameters: dict,
|
||||
user_id: str, workflow_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler) \
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int) \
|
||||
-> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
@@ -127,6 +132,9 @@ class ToolEngine:
|
||||
tool_inputs=tool_parameters
|
||||
)
|
||||
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
|
||||
# hit the callback handler
|
||||
@@ -195,8 +203,24 @@ class ToolEngine:
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
mimetype = None
|
||||
if response.meta.get('mime_type'):
|
||||
mimetype = response.meta.get('mime_type')
|
||||
else:
|
||||
try:
|
||||
url = URL(response.message)
|
||||
extension = url.suffix
|
||||
guess_type_result, _ = guess_type(f'a{extension}')
|
||||
if guess_type_result:
|
||||
mimetype = guess_type_result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not mimetype:
|
||||
mimetype = 'image/jpeg'
|
||||
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
mimetype=response.meta.get('mime_type', 'image/jpeg'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
|
||||
96
api/core/tools/tool_label_manager.py
Normal file
96
api/core/tools/tool_label_manager.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from core.tools.entities.values import default_tool_label_name_list
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ToolLabelBinding
|
||||
|
||||
|
||||
class ToolLabelManager:
|
||||
@classmethod
|
||||
def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
|
||||
"""
|
||||
Filter tool labels
|
||||
"""
|
||||
tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
|
||||
return list(set(tool_labels))
|
||||
|
||||
@classmethod
|
||||
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
|
||||
"""
|
||||
Update tool labels
|
||||
"""
|
||||
labels = cls.filter_tool_labels(labels)
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id
|
||||
else:
|
||||
raise ValueError('Unsupported tool type')
|
||||
|
||||
# delete old labels
|
||||
db.session.query(ToolLabelBinding).filter(
|
||||
ToolLabelBinding.tool_id == provider_id
|
||||
).delete()
|
||||
|
||||
# insert new labels
|
||||
for label in labels:
|
||||
db.session.add(ToolLabelBinding(
|
||||
tool_id=provider_id,
|
||||
tool_type=controller.provider_type.value,
|
||||
label_name=label,
|
||||
))
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
|
||||
"""
|
||||
Get tool labels
|
||||
"""
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id
|
||||
elif isinstance(controller, BuiltinToolProviderController):
|
||||
return controller.tool_labels
|
||||
else:
|
||||
raise ValueError('Unsupported tool type')
|
||||
|
||||
labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding.label_name).filter(
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
||||
).all()
|
||||
|
||||
return [label.label_name for label in labels]
|
||||
|
||||
@classmethod
|
||||
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get tools labels
|
||||
|
||||
:param tool_providers: list of tool providers
|
||||
|
||||
:return: dict of tool labels
|
||||
:key: tool id
|
||||
:value: list of tool labels
|
||||
"""
|
||||
if not tool_providers:
|
||||
return {}
|
||||
|
||||
for controller in tool_providers:
|
||||
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
raise ValueError('Unsupported tool type')
|
||||
|
||||
provider_ids = [controller.provider_id for controller in tool_providers]
|
||||
|
||||
labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding).filter(
|
||||
ToolLabelBinding.tool_id.in_(provider_ids)
|
||||
).all()
|
||||
|
||||
tool_labels = {
|
||||
label.tool_id: [] for label in labels
|
||||
}
|
||||
|
||||
for label in labels:
|
||||
tool_labels[label.tool_id].append(label.label_name)
|
||||
|
||||
return tool_labels
|
||||
@@ -9,21 +9,24 @@ from typing import Any, Union
|
||||
from flask import current_app
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools import *
|
||||
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import (
|
||||
ToolConfigurationManager,
|
||||
ToolParameterConfigurationManager,
|
||||
@@ -31,8 +34,8 @@ from core.tools.utils.configuration import (
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider
|
||||
from services.tools_transform_service import ToolTransformService
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -99,7 +102,12 @@ class ToolManager:
|
||||
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
||||
|
||||
@classmethod
|
||||
def get_tool_runtime(cls, provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \
|
||||
def get_tool_runtime(cls, provider_type: str,
|
||||
provider_id: str,
|
||||
tool_name: str,
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
@@ -111,51 +119,76 @@ class ToolManager:
|
||||
:return: the tool
|
||||
"""
|
||||
if provider_type == 'builtin':
|
||||
builtin_tool = cls.get_builtin_tool(provider_name, tool_name)
|
||||
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
|
||||
|
||||
# check if the builtin tool need credentials
|
||||
provider_controller = cls.get_builtin_provider(provider_name)
|
||||
provider_controller = cls.get_builtin_provider(provider_id)
|
||||
if not provider_controller.need_credentials:
|
||||
return builtin_tool.fork_tool_runtime(meta={
|
||||
return builtin_tool.fork_tool_runtime(runtime={
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': {},
|
||||
'invoke_from': invoke_from,
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
|
||||
# get credentials
|
||||
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
BuiltinToolProvider.provider == provider_id,
|
||||
).first()
|
||||
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
|
||||
raise ToolProviderNotFoundError(f'builtin provider {provider_id} not found')
|
||||
|
||||
# decrypt the credentials
|
||||
credentials = builtin_provider.credentials
|
||||
controller = cls.get_builtin_provider(provider_name)
|
||||
controller = cls.get_builtin_provider(provider_id)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
|
||||
return builtin_tool.fork_tool_runtime(meta={
|
||||
return builtin_tool.fork_tool_runtime(runtime={
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': decrypted_credentials,
|
||||
'runtime_parameters': {}
|
||||
'runtime_parameters': {},
|
||||
'invoke_from': invoke_from,
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
|
||||
elif provider_type == 'api':
|
||||
if tenant_id is None:
|
||||
raise ValueError('tenant id is required for api provider')
|
||||
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_name)
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
|
||||
# decrypt the credentials
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
|
||||
return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
|
||||
return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': decrypted_credentials,
|
||||
'invoke_from': invoke_from,
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
elif provider_type == 'workflow':
|
||||
workflow_provider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == provider_id
|
||||
).first()
|
||||
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
|
||||
|
||||
controller = ToolTransformService.workflow_provider_to_controller(
|
||||
db_provider=workflow_provider
|
||||
)
|
||||
|
||||
return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': {},
|
||||
'invoke_from': invoke_from,
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
elif provider_type == 'app':
|
||||
raise NotImplementedError('app provider not implemented')
|
||||
@@ -207,18 +240,25 @@ class ToolManager:
|
||||
return parameter_value
|
||||
|
||||
@classmethod
|
||||
def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool:
|
||||
def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
"""
|
||||
tool_entity = cls.get_tool_runtime(
|
||||
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id,
|
||||
provider_type=agent_tool.provider_type,
|
||||
provider_id=agent_tool.provider_id,
|
||||
tool_name=agent_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
# check file types
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# save tool parameter to tool entity memory
|
||||
value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
|
||||
@@ -238,15 +278,17 @@ class ToolManager:
|
||||
return tool_entity
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity):
|
||||
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
"""
|
||||
tool_entity = cls.get_tool_runtime(
|
||||
provider_type=workflow_tool.provider_type,
|
||||
provider_name=workflow_tool.provider_id,
|
||||
provider_id=workflow_tool.provider_id,
|
||||
tool_name=workflow_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
@@ -371,51 +413,91 @@ class ToolManager:
|
||||
return cls._builtin_tools_labels[tool_name]
|
||||
|
||||
@classmethod
|
||||
def user_list_providers(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral) -> list[UserToolProvider]:
|
||||
result_providers: dict[str, UserToolProvider] = {}
|
||||
|
||||
# get builtin providers
|
||||
builtin_providers = cls.list_builtin_providers()
|
||||
|
||||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
|
||||
filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
filters = []
|
||||
if not typ:
|
||||
filters.extend(['builtin', 'api', 'workflow'])
|
||||
else:
|
||||
filters.append(typ)
|
||||
|
||||
find_db_builtin_provider = lambda provider: next(
|
||||
(x for x in db_builtin_providers if x.provider == provider),
|
||||
None
|
||||
)
|
||||
if 'builtin' in filters:
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.identity.name),
|
||||
decrypt_credentials=False
|
||||
# get builtin providers
|
||||
builtin_providers = cls.list_builtin_providers()
|
||||
|
||||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
|
||||
filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
find_db_builtin_provider = lambda provider: next(
|
||||
(x for x in db_builtin_providers if x.provider == provider),
|
||||
None
|
||||
)
|
||||
|
||||
result_providers[provider.identity.name] = user_provider
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.identity.name),
|
||||
decrypt_credentials=False
|
||||
)
|
||||
|
||||
result_providers[provider.identity.name] = user_provider
|
||||
|
||||
# get db api providers
|
||||
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
|
||||
filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
for db_api_provider in db_api_providers:
|
||||
provider_controller = ToolTransformService.api_provider_to_controller(
|
||||
db_provider=db_api_provider,
|
||||
)
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=db_api_provider,
|
||||
decrypt_credentials=False
|
||||
)
|
||||
result_providers[db_api_provider.name] = user_provider
|
||||
if 'api' in filters:
|
||||
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
|
||||
filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
api_provider_controllers = [{
|
||||
'provider': provider,
|
||||
'controller': ToolTransformService.api_provider_to_controller(provider)
|
||||
} for provider in db_api_providers]
|
||||
|
||||
# get labels
|
||||
labels = ToolLabelManager.get_tools_labels([x['controller'] for x in api_provider_controllers])
|
||||
|
||||
for api_provider_controller in api_provider_controllers:
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=api_provider_controller['controller'],
|
||||
db_provider=api_provider_controller['provider'],
|
||||
decrypt_credentials=False,
|
||||
labels=labels.get(api_provider_controller['controller'].provider_id, [])
|
||||
)
|
||||
result_providers[f'api_provider.{user_provider.name}'] = user_provider
|
||||
|
||||
if 'workflow' in filters:
|
||||
# get workflow providers
|
||||
workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \
|
||||
filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
workflow_provider_controllers = []
|
||||
for provider in workflow_providers:
|
||||
try:
|
||||
workflow_provider_controllers.append(
|
||||
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
|
||||
)
|
||||
except Exception as e:
|
||||
# app has been deleted
|
||||
pass
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
|
||||
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
labels=labels.get(provider_controller.provider_id, []),
|
||||
)
|
||||
result_providers[f'workflow_provider.{user_provider.name}'] = user_provider
|
||||
|
||||
return BuiltinToolProviderSort.sort(list(result_providers.values()))
|
||||
|
||||
@classmethod
|
||||
def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
|
||||
ApiBasedToolProviderController, dict[str, Any]]:
|
||||
ApiToolProviderController, dict[str, Any]]:
|
||||
"""
|
||||
get the api provider
|
||||
|
||||
@@ -431,7 +513,7 @@ class ToolManager:
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
|
||||
|
||||
controller = ApiBasedToolProviderController.from_db(
|
||||
controller = ApiToolProviderController.from_db(
|
||||
provider,
|
||||
ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else
|
||||
ApiProviderAuthType.NONE
|
||||
@@ -462,7 +544,7 @@ class ToolManager:
|
||||
credentials = {}
|
||||
|
||||
# package tool provider controller
|
||||
controller = ApiBasedToolProviderController.from_db(
|
||||
controller = ApiToolProviderController.from_db(
|
||||
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
|
||||
)
|
||||
# init tool configuration
|
||||
@@ -479,6 +561,9 @@ class ToolManager:
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
# add tool labels
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
|
||||
return jsonable_encoder({
|
||||
'schema_type': provider.schema_type,
|
||||
'schema': provider.schema,
|
||||
@@ -487,7 +572,8 @@ class ToolManager:
|
||||
'description': provider.description,
|
||||
'credentials': masked_credentials,
|
||||
'privacy_policy': provider.privacy_policy,
|
||||
'custom_disclaimer': provider.custom_disclaimer
|
||||
'custom_disclaimer': provider.custom_disclaimer,
|
||||
'labels': labels,
|
||||
})
|
||||
|
||||
@classmethod
|
||||
@@ -519,6 +605,15 @@ class ToolManager:
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
elif provider_type == 'workflow':
|
||||
provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == provider_id
|
||||
).first()
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
|
||||
|
||||
return json.loads(provider.icon)
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
|
||||
@@ -72,8 +72,8 @@ class ToolConfigurationManager(BaseModel):
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
@@ -95,8 +95,8 @@ class ToolConfigurationManager(BaseModel):
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import logging
|
||||
from mimetypes import guess_extension
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolFileMessageTransformer:
|
||||
@staticmethod
|
||||
def transform_tool_invoke_messages(messages: list[ToolInvokeMessage],
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str) -> list[ToolInvokeMessage]:
|
||||
@@ -62,7 +63,7 @@ class ToolFileMessageTransformer:
|
||||
mimetype=mimetype
|
||||
)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
|
||||
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
if 'image' in mimetype:
|
||||
@@ -79,7 +80,30 @@ class ToolFileMessageTransformer:
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
||||
file_var: FileVar = message.meta.get('file_var')
|
||||
if file_var:
|
||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
||||
if file_var.type == FileType.IMAGE:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
else:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
else:
|
||||
result.append(message)
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
|
||||
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
||||
@@ -20,12 +20,14 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.model.errors import InvokeModelError
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ToolModelInvoke
|
||||
|
||||
|
||||
class ToolModelManager:
|
||||
class InvokeModelError(Exception):
|
||||
pass
|
||||
|
||||
class ModelInvocationUtils:
|
||||
@staticmethod
|
||||
def get_max_llm_context_tokens(
|
||||
tenant_id: str,
|
||||
@@ -9,14 +9,14 @@ from requests import get
|
||||
from yaml import YAMLError, safe_load
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
|
||||
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
@@ -145,7 +145,7 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
interface['operation']['operationId'] = f'{path}_{interface["method"]}'
|
||||
|
||||
bundles.append(ApiBasedToolBundle(
|
||||
bundles.append(ApiToolBundle(
|
||||
server_url=server_url + interface['path'],
|
||||
method=interface['method'],
|
||||
summary=interface['operation']['description'] if 'description' in interface['operation'] else
|
||||
@@ -176,7 +176,7 @@ class ApiBasedToolSchemaParser:
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
|
||||
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
@@ -258,7 +258,7 @@ class ApiBasedToolSchemaParser:
|
||||
return openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
|
||||
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
@@ -290,7 +290,7 @@ class ApiBasedToolSchemaParser:
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiBasedToolBundle], str]:
|
||||
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
|
||||
48
api/core/tools/utils/workflow_configuration_sync.py
Normal file
48
api/core/tools/utils/workflow_configuration_sync.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[dict]):
|
||||
"""
|
||||
check parameter configurations
|
||||
"""
|
||||
for configuration in configurations:
|
||||
if not WorkflowToolParameterConfiguration(**configuration):
|
||||
raise ValueError('invalid parameter configuration')
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
nodes = graph.get('nodes', [])
|
||||
start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [
|
||||
VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', [])
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(cls,
|
||||
variables: list[VariableEntity],
|
||||
tool_configurations: list[WorkflowToolParameterConfiguration]) -> None:
|
||||
"""
|
||||
check is synced
|
||||
|
||||
raise ValueError if not synced
|
||||
"""
|
||||
variable_names = [variable.variable for variable in variables]
|
||||
|
||||
if len(tool_configurations) != len(variables):
|
||||
raise ValueError('parameter configuration mismatch, please republish the tool to update')
|
||||
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError('parameter configuration mismatch, please republish the tool to update')
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user