Feat/tool secret parameter (#2760)
This commit is contained in:
@@ -119,7 +119,7 @@ parameters: # Parameter list
|
||||
- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc.
|
||||
- `parameters` Parameter list
|
||||
- `name` Parameter name, unique, no duplication with other parameters
|
||||
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select` four types, corresponding to string, number, boolean, drop-down box
|
||||
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` four types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using `secret-input` type
|
||||
- `required` Required or not
|
||||
- In `llm` mode, if the parameter is required, the Agent is required to infer this parameter
|
||||
- In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts
|
||||
|
||||
@@ -119,7 +119,7 @@ parameters: # 参数列表
|
||||
- `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等
|
||||
- `parameters` 参数列表
|
||||
- `name` 参数名称,唯一,不允许和其他参数重名
|
||||
- `type` 参数类型,目前支持`string`、`number`、`boolean`、`select` 四种类型,分别对应字符串、数字、布尔值、下拉框
|
||||
- `type` 参数类型,目前支持`string`、`number`、`boolean`、`select`、`secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型
|
||||
- `required` 是否必填
|
||||
- 在`llm`模式下,如果参数为必填,则会要求Agent必须要推理出这个参数
|
||||
- 在`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数
|
||||
|
||||
@@ -100,6 +100,7 @@ class ToolParameter(BaseModel):
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
SELECT = "select"
|
||||
SECRET_INPUT = "secret-input"
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
|
||||
@@ -23,6 +23,8 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
_api_base_url = URL('https://co.aippt.cn/api')
|
||||
_api_token_cache = {}
|
||||
_api_token_cache_lock = Lock()
|
||||
_style_cache = {}
|
||||
_style_cache_lock = Lock()
|
||||
|
||||
_task = {}
|
||||
_task_type_map = {
|
||||
@@ -390,20 +392,31 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
).digest()
|
||||
).decode('utf-8')
|
||||
|
||||
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
@classmethod
|
||||
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
"""
|
||||
Get styles
|
||||
|
||||
:param credentials: the credentials
|
||||
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
|
||||
"""
|
||||
|
||||
# check cache
|
||||
with cls._style_cache_lock:
|
||||
# clear expired styles
|
||||
now = time()
|
||||
for key in list(cls._style_cache.keys()):
|
||||
if cls._style_cache[key]['expire'] < now:
|
||||
del cls._style_cache[key]
|
||||
|
||||
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
|
||||
if key in cls._style_cache:
|
||||
return cls._style_cache[key]['colors'], cls._style_cache[key]['styles']
|
||||
|
||||
headers = {
|
||||
'x-channel': '',
|
||||
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id)
|
||||
'x-api-key': credentials['aippt_access_key'],
|
||||
'x-token': cls._get_api_token(credentials=credentials, user_id=user_id)
|
||||
}
|
||||
response = get(
|
||||
str(self._api_base_url / 'template_component' / 'suit' / 'select'),
|
||||
str(cls._api_base_url / 'template_component' / 'suit' / 'select'),
|
||||
headers=headers
|
||||
)
|
||||
|
||||
@@ -425,7 +438,26 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
'name': item.get('title'),
|
||||
} for item in response.get('data', {}).get('suit_style') or []]
|
||||
|
||||
with cls._style_cache_lock:
|
||||
cls._style_cache[key] = {
|
||||
'colors': colors,
|
||||
'styles': styles,
|
||||
'expire': now + 60 * 60
|
||||
}
|
||||
|
||||
return colors, styles
|
||||
|
||||
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
"""
|
||||
Get styles
|
||||
|
||||
:param credentials: the credentials
|
||||
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
|
||||
"""
|
||||
if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'):
|
||||
return [], []
|
||||
|
||||
return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
|
||||
|
||||
def _get_suit(self, style_id: int, colour_id: int) -> int:
|
||||
"""
|
||||
|
||||
@@ -14,7 +14,7 @@ description:
|
||||
llm: A tool for sending messages to a chat group on Wecom(企业微信) .
|
||||
parameters:
|
||||
- name: hook_key
|
||||
type: string
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Wecom Group bot webhook key
|
||||
|
||||
@@ -266,6 +266,40 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
return self.parameters
|
||||
|
||||
def get_all_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
get all runtime parameters
|
||||
|
||||
:return: all runtime parameters
|
||||
"""
|
||||
parameters = self.parameters or []
|
||||
parameters = parameters.copy()
|
||||
user_parameters = self.get_runtime_parameters() or []
|
||||
user_parameters = user_parameters.copy()
|
||||
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
found = False
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
else:
|
||||
# add new parameter
|
||||
parameters.append(parameter)
|
||||
|
||||
return parameters
|
||||
|
||||
def is_tool_available(self) -> bool:
|
||||
"""
|
||||
check if the tool is available
|
||||
|
||||
@@ -6,11 +6,17 @@ from os import listdir, path
|
||||
from typing import Any, Union
|
||||
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.entities.application_entities import AgentToolEntity
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.constant import DEFAULT_PROVIDERS
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
)
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
|
||||
@@ -21,7 +27,12 @@ from core.tools.provider.model_tool_provider import ModelToolProviderController
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.utils.configuration import ModelToolConfigurationManager, ToolConfiguration
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.configuration import (
|
||||
ModelToolConfigurationManager,
|
||||
ToolConfigurationManager,
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.encoder import serialize_base_model_dict
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider
|
||||
@@ -172,7 +183,7 @@ class ToolManager:
|
||||
# decrypt the credentials
|
||||
credentials = builtin_provider.credentials
|
||||
controller = ToolManager.get_builtin_provider(provider_name)
|
||||
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
|
||||
@@ -189,7 +200,7 @@ class ToolManager:
|
||||
api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
|
||||
|
||||
# decrypt the credentials
|
||||
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
|
||||
return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
|
||||
@@ -214,6 +225,71 @@ class ToolManager:
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
||||
|
||||
@staticmethod
|
||||
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
"""
|
||||
tool_entity = ToolManager.get_tool_runtime(
|
||||
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# get tool parameter from form
|
||||
tool_parameter_config = agent_tool.tool_parameters.get(parameter.name)
|
||||
if not tool_parameter_config:
|
||||
# get default value
|
||||
tool_parameter_config = parameter.default
|
||||
if not tool_parameter_config and parameter.required:
|
||||
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
|
||||
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = list(map(lambda x: x.value, parameter.options))
|
||||
if tool_parameter_config not in options:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
|
||||
|
||||
# convert tool parameter config to correct type
|
||||
try:
|
||||
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
# check if tool parameter is integer
|
||||
if isinstance(tool_parameter_config, int):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, float):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, str):
|
||||
if '.' in tool_parameter_config:
|
||||
tool_parameter_config = float(tool_parameter_config)
|
||||
else:
|
||||
tool_parameter_config = int(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
tool_parameter_config = bool(tool_parameter_config)
|
||||
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
|
||||
|
||||
# save tool parameter to tool entity memory
|
||||
runtime_parameters[parameter.name] = tool_parameter_config
|
||||
|
||||
# decrypt runtime parameters
|
||||
encryption_manager = ToolParameterConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
tool_runtime=tool_entity,
|
||||
provider_name=agent_tool.provider_id,
|
||||
provider_type=agent_tool.provider_type,
|
||||
)
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_provider_icon(provider: str) -> tuple[str, str]:
|
||||
"""
|
||||
@@ -396,7 +472,7 @@ class ToolManager:
|
||||
controller = ToolManager.get_builtin_provider(provider_name)
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
|
||||
@@ -463,7 +539,7 @@ class ToolManager:
|
||||
)
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
@@ -523,7 +599,7 @@ class ToolManager:
|
||||
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
|
||||
)
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
|
||||
@@ -5,16 +5,19 @@ from pydantic import BaseModel
|
||||
from yaml import FullLoader, load
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||
from core.tools.entities.tool_entities import (
|
||||
ModelToolConfiguration,
|
||||
ModelToolProviderConfiguration,
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
|
||||
class ToolConfiguration(BaseModel):
|
||||
class ToolConfigurationManager(BaseModel):
|
||||
tenant_id: str
|
||||
provider_controller: ToolProviderController
|
||||
|
||||
@@ -101,6 +104,128 @@ class ToolConfiguration(BaseModel):
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
class ToolParameterConfigurationManager(BaseModel):
|
||||
"""
|
||||
Tool parameter configuration manager
|
||||
"""
|
||||
tenant_id: str
|
||||
tool_runtime: Tool
|
||||
provider_name: str
|
||||
provider_type: str
|
||||
|
||||
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
deep copy parameters
|
||||
"""
|
||||
return {key: value for key, value in parameters.items()}
|
||||
|
||||
def _merge_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
merge parameters
|
||||
"""
|
||||
# get tool parameters
|
||||
tool_parameters = self.tool_runtime.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
|
||||
# override parameters
|
||||
current_parameters = tool_parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return current_parameters
|
||||
|
||||
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool parameters
|
||||
|
||||
return a deep copy of parameters with masked values
|
||||
"""
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
for parameter in current_parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
|
||||
if parameter.name in parameters:
|
||||
if len(parameters[parameter.name]) > 6:
|
||||
parameters[parameter.name] = \
|
||||
parameters[parameter.name][:2] + \
|
||||
'*' * (len(parameters[parameter.name]) - 4) +\
|
||||
parameters[parameter.name][-2:]
|
||||
else:
|
||||
parameters[parameter.name] = '*' * len(parameters[parameter.name])
|
||||
|
||||
return parameters
|
||||
|
||||
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
encrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with encrypted values
|
||||
"""
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
for parameter in current_parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
|
||||
if parameter.name in parameters:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
parameters[parameter.name] = encrypted
|
||||
|
||||
return parameters
|
||||
|
||||
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with decrypted values
|
||||
"""
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f'{self.provider_type}.{self.provider_name}',
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER
|
||||
)
|
||||
cached_parameters = cache.get()
|
||||
if cached_parameters:
|
||||
return cached_parameters
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
has_secret_input = False
|
||||
|
||||
for parameter in current_parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
|
||||
if parameter.name in parameters:
|
||||
try:
|
||||
has_secret_input = True
|
||||
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
except:
|
||||
pass
|
||||
|
||||
if has_secret_input:
|
||||
cache.set(parameters)
|
||||
|
||||
return parameters
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f'{self.provider_type}.{self.provider_name}',
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
class ModelToolConfigurationManager:
|
||||
"""
|
||||
Model as tool configuration
|
||||
|
||||
Reference in New Issue
Block a user