Feat/workflow phase2 (#4687)

This commit is contained in:
Yeuoly
2024-05-27 22:01:11 +08:00
committed by GitHub
parent 45deaee762
commit e852a21634
139 changed files with 5997 additions and 779 deletions

View File

@@ -72,8 +72,8 @@ class ToolConfigurationManager(BaseModel):
return a deep copy of credentials with decrypted values
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cached_credentials = cache.get()
@@ -95,8 +95,8 @@ class ToolConfigurationManager(BaseModel):
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cache.delete()

View File

@@ -1,14 +1,15 @@
import logging
from mimetypes import guess_extension
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
logger = logging.getLogger(__name__)
class ToolFileMessageTransformer:
@staticmethod
def transform_tool_invoke_messages(messages: list[ToolInvokeMessage],
@classmethod
def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],
user_id: str,
tenant_id: str,
conversation_id: str) -> list[ToolInvokeMessage]:
@@ -62,7 +63,7 @@ class ToolFileMessageTransformer:
mimetype=mimetype
)
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
# check if file is image
if 'image' in mimetype:
@@ -79,7 +80,30 @@ class ToolFileMessageTransformer:
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
file_var: FileVar = message.meta.get('file_var')
if file_var:
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
if file_var.type == FileType.IMAGE:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
else:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
else:
result.append(message)
return result
return result
@classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
return f'/files/tools/{tool_file_id}{extension or ".bin"}'

View File

@@ -0,0 +1,179 @@
"""
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
Therefore, a model manager is needed to list/invoke/validate models.
"""
import json
from typing import cast
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.tools import ToolModelInvoke
class InvokeModelError(Exception):
pass
class ModelInvocationUtils:
@staticmethod
def get_max_llm_context_tokens(
tenant_id: str,
) -> int:
"""
get max llm context tokens of the model
"""
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM,
)
if not model_instance:
raise InvokeModelError('Model not found')
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if not schema:
raise InvokeModelError('No model schema found')
max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
if max_tokens is None:
return 2048
return max_tokens
@staticmethod
def calculate_tokens(
tenant_id: str,
prompt_messages: list[PromptMessage]
) -> int:
"""
calculate tokens from prompt messages and model parameters
"""
# get model instance
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM
)
if not model_instance:
raise InvokeModelError('Model not found')
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
# get tokens
tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages)
return tokens
@staticmethod
def invoke(
user_id: str, tenant_id: str,
tool_type: str, tool_name: str,
prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
invoke model with parameters in user's own context
:param user_id: user id
:param tenant_id: tenant id, the tenant id of the creator of the tool
:param tool_provider: tool provider
:param tool_id: tool id
:param tool_name: tool name
:param provider: model provider
:param model: model name
:param model_parameters: model parameters
:param prompt_messages: prompt messages
:return: AssistantPromptMessage
"""
# get model manager
model_manager = ModelManager()
# get model instance
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM,
)
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
# get model credentials
model_credentials = model_instance.credentials
# get prompt tokens
prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages)
model_parameters = {
'temperature': 0.8,
'top_p': 0.8,
}
# create tool model invoke
tool_model_invoke = ToolModelInvoke(
user_id=user_id,
tenant_id=tenant_id,
provider=model_instance.provider,
tool_type=tool_type,
tool_name=tool_name,
model_parameters=json.dumps(model_parameters),
prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
model_response='',
prompt_tokens=prompt_tokens,
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency='USD',
)
db.session.add(tool_model_invoke)
db.session.commit()
try:
response: LLMResult = llm_model.invoke(
model=model_instance.model,
credentials=model_credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[], stop=[], stream=False, user=user_id, callbacks=[]
)
except InvokeRateLimitError as e:
raise InvokeModelError(f'Invoke rate limit error: {e}')
except InvokeBadRequestError as e:
raise InvokeModelError(f'Invoke bad request error: {e}')
except InvokeConnectionError as e:
raise InvokeModelError(f'Invoke connection error: {e}')
except InvokeAuthorizationError as e:
raise InvokeModelError('Invoke authorization error')
except InvokeServerUnavailableError as e:
raise InvokeModelError(f'Invoke server unavailable error: {e}')
except Exception as e:
raise InvokeModelError(f'Invoke error: {e}')
# update tool model invoke
tool_model_invoke.model_response = response.message.content
if response.usage:
tool_model_invoke.answer_tokens = response.usage.completion_tokens
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
tool_model_invoke.provider_response_latency = response.usage.latency
tool_model_invoke.total_price = response.usage.total_price
tool_model_invoke.currency = response.usage.currency
db.session.commit()
return response

View File

@@ -9,14 +9,14 @@ from requests import get
from yaml import YAMLError, safe_load
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
@@ -145,7 +145,7 @@ class ApiBasedToolSchemaParser:
interface['operation']['operationId'] = f'{path}_{interface["method"]}'
bundles.append(ApiBasedToolBundle(
bundles.append(ApiToolBundle(
server_url=server_url + interface['path'],
method=interface['method'],
summary=interface['operation']['description'] if 'description' in interface['operation'] else
@@ -176,7 +176,7 @@ class ApiBasedToolSchemaParser:
return ToolParameter.ToolParameterType.STRING
@staticmethod
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
@@ -258,7 +258,7 @@ class ApiBasedToolSchemaParser:
return openapi
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
@@ -290,7 +290,7 @@ class ApiBasedToolSchemaParser:
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
@staticmethod
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiBasedToolBundle], str]:
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle

View File

@@ -0,0 +1,48 @@
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[dict]):
"""
check parameter configurations
"""
for configuration in configurations:
if not WorkflowToolParameterConfiguration(**configuration):
raise ValueError('invalid parameter configuration')
@classmethod
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
"""
get workflow graph variables
"""
nodes = graph.get('nodes', [])
start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None)
if not start_node:
return []
return [
VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', [])
]
@classmethod
def check_is_synced(cls,
variables: list[VariableEntity],
tool_configurations: list[WorkflowToolParameterConfiguration]) -> None:
"""
check is synced
raise ValueError if not synced
"""
variable_names = [variable.variable for variable in variables]
if len(tool_configurations) != len(variables):
raise ValueError('parameter configuration mismatch, please republish the tool to update')
for parameter in tool_configurations:
if parameter.name not in variable_names:
raise ValueError('parameter configuration mismatch, please republish the tool to update')
return True