Feat/workflow phase2 (#4687)
This commit is contained in:
@@ -72,8 +72,8 @@ class ToolConfigurationManager(BaseModel):
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
@@ -95,8 +95,8 @@ class ToolConfigurationManager(BaseModel):
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import logging
|
||||
from mimetypes import guess_extension
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolFileMessageTransformer:
|
||||
@staticmethod
|
||||
def transform_tool_invoke_messages(messages: list[ToolInvokeMessage],
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str) -> list[ToolInvokeMessage]:
|
||||
@@ -62,7 +63,7 @@ class ToolFileMessageTransformer:
|
||||
mimetype=mimetype
|
||||
)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
|
||||
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
if 'image' in mimetype:
|
||||
@@ -79,7 +80,30 @@ class ToolFileMessageTransformer:
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
||||
file_var: FileVar = message.meta.get('file_var')
|
||||
if file_var:
|
||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
||||
if file_var.type == FileType.IMAGE:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
else:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
else:
|
||||
result.append(message)
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
|
||||
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
||||
179
api/core/tools/utils/model_invocation_utils.py
Normal file
179
api/core/tools/utils/model_invocation_utils.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
|
||||
|
||||
Therefore, a model manager is needed to list/invoke/validate models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ToolModelInvoke
|
||||
|
||||
|
||||
class InvokeModelError(Exception):
|
||||
pass
|
||||
|
||||
class ModelInvocationUtils:
|
||||
@staticmethod
|
||||
def get_max_llm_context_tokens(
|
||||
tenant_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
get max llm context tokens of the model
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
raise InvokeModelError('Model not found')
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
|
||||
if not schema:
|
||||
raise InvokeModelError('No model schema found')
|
||||
|
||||
max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
|
||||
if max_tokens is None:
|
||||
return 2048
|
||||
|
||||
return max_tokens
|
||||
|
||||
@staticmethod
|
||||
def calculate_tokens(
|
||||
tenant_id: str,
|
||||
prompt_messages: list[PromptMessage]
|
||||
) -> int:
|
||||
"""
|
||||
calculate tokens from prompt messages and model parameters
|
||||
"""
|
||||
|
||||
# get model instance
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
raise InvokeModelError('Model not found')
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
|
||||
# get tokens
|
||||
tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages)
|
||||
|
||||
return tokens
|
||||
|
||||
@staticmethod
|
||||
def invoke(
|
||||
user_id: str, tenant_id: str,
|
||||
tool_type: str, tool_name: str,
|
||||
prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
invoke model with parameters in user's own context
|
||||
|
||||
:param user_id: user id
|
||||
:param tenant_id: tenant id, the tenant id of the creator of the tool
|
||||
:param tool_provider: tool provider
|
||||
:param tool_id: tool id
|
||||
:param tool_name: tool name
|
||||
:param provider: model provider
|
||||
:param model: model name
|
||||
:param model_parameters: model parameters
|
||||
:param prompt_messages: prompt messages
|
||||
:return: AssistantPromptMessage
|
||||
"""
|
||||
|
||||
# get model manager
|
||||
model_manager = ModelManager()
|
||||
# get model instance
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
|
||||
# get model credentials
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# get prompt tokens
|
||||
prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages)
|
||||
|
||||
model_parameters = {
|
||||
'temperature': 0.8,
|
||||
'top_p': 0.8,
|
||||
}
|
||||
|
||||
# create tool model invoke
|
||||
tool_model_invoke = ToolModelInvoke(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
tool_type=tool_type,
|
||||
tool_name=tool_name,
|
||||
model_parameters=json.dumps(model_parameters),
|
||||
prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
|
||||
model_response='',
|
||||
prompt_tokens=prompt_tokens,
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency='USD',
|
||||
)
|
||||
|
||||
db.session.add(tool_model_invoke)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response: LLMResult = llm_model.invoke(
|
||||
model=model_instance.model,
|
||||
credentials=model_credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[], stop=[], stream=False, user=user_id, callbacks=[]
|
||||
)
|
||||
except InvokeRateLimitError as e:
|
||||
raise InvokeModelError(f'Invoke rate limit error: {e}')
|
||||
except InvokeBadRequestError as e:
|
||||
raise InvokeModelError(f'Invoke bad request error: {e}')
|
||||
except InvokeConnectionError as e:
|
||||
raise InvokeModelError(f'Invoke connection error: {e}')
|
||||
except InvokeAuthorizationError as e:
|
||||
raise InvokeModelError('Invoke authorization error')
|
||||
except InvokeServerUnavailableError as e:
|
||||
raise InvokeModelError(f'Invoke server unavailable error: {e}')
|
||||
except Exception as e:
|
||||
raise InvokeModelError(f'Invoke error: {e}')
|
||||
|
||||
# update tool model invoke
|
||||
tool_model_invoke.model_response = response.message.content
|
||||
if response.usage:
|
||||
tool_model_invoke.answer_tokens = response.usage.completion_tokens
|
||||
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
|
||||
tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
|
||||
tool_model_invoke.provider_response_latency = response.usage.latency
|
||||
tool_model_invoke.total_price = response.usage.total_price
|
||||
tool_model_invoke.currency = response.usage.currency
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return response
|
||||
@@ -9,14 +9,14 @@ from requests import get
|
||||
from yaml import YAMLError, safe_load
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
|
||||
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
@@ -145,7 +145,7 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
interface['operation']['operationId'] = f'{path}_{interface["method"]}'
|
||||
|
||||
bundles.append(ApiBasedToolBundle(
|
||||
bundles.append(ApiToolBundle(
|
||||
server_url=server_url + interface['path'],
|
||||
method=interface['method'],
|
||||
summary=interface['operation']['description'] if 'description' in interface['operation'] else
|
||||
@@ -176,7 +176,7 @@ class ApiBasedToolSchemaParser:
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
|
||||
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
@@ -258,7 +258,7 @@ class ApiBasedToolSchemaParser:
|
||||
return openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
|
||||
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
@@ -290,7 +290,7 @@ class ApiBasedToolSchemaParser:
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiBasedToolBundle], str]:
|
||||
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
|
||||
48
api/core/tools/utils/workflow_configuration_sync.py
Normal file
48
api/core/tools/utils/workflow_configuration_sync.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[dict]):
|
||||
"""
|
||||
check parameter configurations
|
||||
"""
|
||||
for configuration in configurations:
|
||||
if not WorkflowToolParameterConfiguration(**configuration):
|
||||
raise ValueError('invalid parameter configuration')
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
nodes = graph.get('nodes', [])
|
||||
start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [
|
||||
VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', [])
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(cls,
|
||||
variables: list[VariableEntity],
|
||||
tool_configurations: list[WorkflowToolParameterConfiguration]) -> None:
|
||||
"""
|
||||
check is synced
|
||||
|
||||
raise ValueError if not synced
|
||||
"""
|
||||
variable_names = [variable.variable for variable in variables]
|
||||
|
||||
if len(tool_configurations) != len(variables):
|
||||
raise ValueError('parameter configuration mismatch, please republish the tool to update')
|
||||
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError('parameter configuration mismatch, please republish the tool to update')
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user