Feat/workflow phase2 (#4687)

This commit is contained in:
Yeuoly
2024-05-27 22:01:11 +08:00
committed by GitHub
parent 45deaee762
commit e852a21634
139 changed files with 5997 additions and 779 deletions

View File

@@ -8,9 +8,8 @@ import httpx
import requests
import core.helper.ssrf_proxy as ssrf_proxy
from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.entities.user_entities import UserToolProvider
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.tool import Tool
@@ -20,12 +19,12 @@ API_TOOL_DEFAULT_TIMEOUT = (
)
class ApiTool(Tool):
api_bundle: ApiBasedToolBundle
api_bundle: ApiToolBundle
"""
Api tool
"""
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
"""
fork a new tool with meta data
@@ -37,7 +36,7 @@ class ApiTool(Tool):
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
runtime=Tool.Runtime(**meta)
runtime=Tool.Runtime(**runtime)
)
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:
@@ -55,7 +54,7 @@ class ApiTool(Tool):
return self.validate_and_parse_response(response)
def tool_provider_type(self) -> ToolProviderType:
return UserToolProvider.ProviderType.API
return ToolProviderType.API
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
headers = {}

View File

@@ -2,9 +2,8 @@
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.user_entities import UserToolProvider
from core.tools.model.tool_model_manager import ToolModelManager
from core.tools.tool.tool import Tool
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.tools.utils.web_reader_tool import get_url
_SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
@@ -34,7 +33,7 @@ class BuiltinTool(Tool):
:return: the model result
"""
# invoke model
return ToolModelManager.invoke(
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id,
tool_type='builtin',
@@ -43,7 +42,7 @@ class BuiltinTool(Tool):
)
def tool_provider_type(self) -> ToolProviderType:
return UserToolProvider.ProviderType.BUILTIN
return ToolProviderType.BUILT_IN
def get_max_tokens(self) -> int:
"""
@@ -52,7 +51,7 @@ class BuiltinTool(Tool):
:param model_config: the model config
:return: the max tokens
"""
return ToolModelManager.get_max_llm_context_tokens(
return ModelInvocationUtils.get_max_llm_context_tokens(
tenant_id=self.runtime.tenant_id,
)
@@ -63,7 +62,7 @@ class BuiltinTool(Tool):
:param prompt_messages: the prompt messages
:return: the tokens
"""
return ToolModelManager.calculate_tokens(
return ModelInvocationUtils.calculate_tokens(
tenant_id=self.runtime.tenant_id,
prompt_messages=prompt_messages
)

View File

@@ -4,9 +4,12 @@ from typing import Any, Optional, Union
from pydantic import BaseModel, validator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileVar
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
ToolInvokeFrom,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
@@ -25,10 +28,7 @@ class Tool(BaseModel, ABC):
@validator('parameters', pre=True, always=True)
def set_parameters(cls, v, values):
if not v:
return []
return v
return v or []
class Runtime(BaseModel):
"""
@@ -41,6 +41,8 @@ class Tool(BaseModel, ABC):
tenant_id: str = None
tool_id: str = None
invoke_from: InvokeFrom = None
tool_invoke_from: ToolInvokeFrom = None
credentials: dict[str, Any] = None
runtime_parameters: dict[str, Any] = None
@@ -53,7 +55,7 @@ class Tool(BaseModel, ABC):
class VARIABLE_KEY(Enum):
IMAGE = 'image'
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
"""
fork a new tool with meta data
@@ -64,7 +66,7 @@ class Tool(BaseModel, ABC):
identity=self.identity.copy() if self.identity else None,
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
runtime=Tool.Runtime(**meta),
runtime=Tool.Runtime(**runtime),
)
@abstractmethod
@@ -208,17 +210,17 @@ class Tool(BaseModel, ABC):
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += response.message
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += f"result link: {response.message}. please tell user to check it."
result += f"result link: {response.message}. please tell user to check it. \n"
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now. \n"
elif response.type == ToolInvokeMessage.MessageType.BLOB:
if len(response.message) > 114:
result += str(response.message[:114]) + '...'
else:
result += str(response.message)
else:
result += f"tool response: {response.message}."
result += f"tool response: {response.message}. \n"
return result
@@ -343,6 +345,14 @@ class Tool(BaseModel, ABC):
message=image,
save_as=save_as)
def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
message='',
meta={
'file_var': file_var
},
save_as='')
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a link message

View File

@@ -0,0 +1,200 @@
import json
import logging
from copy import deepcopy
from typing import Any, Union
from core.file.file_obj import FileTransferMethod, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.tool.tool import Tool
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class WorkflowTool(Tool):
workflow_app_id: str
version: str
workflow_entities: dict[str, Any]
workflow_call_depth: int
label: str
"""
Workflow tool.
"""
def tool_provider_type(self) -> ToolProviderType:
"""
get the tool provider type
:return: the tool provider type
"""
return ToolProviderType.WORKFLOW
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke the tool
"""
app = self._get_app(app_id=self.workflow_app_id)
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
# transform the tool parameters
tool_parameters, files = self._transform_args(tool_parameters)
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
generator = WorkflowAppGenerator()
result = generator.generate(
app_model=app,
workflow=workflow,
user=self._get_user(user_id),
args={
'inputs': tool_parameters,
'files': files
},
invoke_from=self.runtime.invoke_from,
stream=False,
call_depth=self.workflow_call_depth + 1,
)
data = result.get('data', {})
if data.get('error'):
raise Exception(data.get('error'))
result = []
outputs = data.get('outputs', {})
outputs, files = self._extract_files(outputs)
for file in files:
result.append(self.create_file_var_message(file))
result.append(self.create_text_message(json.dumps(outputs)))
return result
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
"""
get the user by user id
"""
user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
if not user:
user = db.session.query(Account).filter(Account.id == user_id).first()
if not user:
raise ValueError('user not found')
return user
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool':
"""
fork a new tool with meta data
:param meta: the meta data of a tool call processing, tenant_id is required
:return: the new tool
"""
return self.__class__(
identity=deepcopy(self.identity),
parameters=deepcopy(self.parameters),
description=deepcopy(self.description),
runtime=Tool.Runtime(**runtime),
workflow_app_id=self.workflow_app_id,
workflow_entities=self.workflow_entities,
workflow_call_depth=self.workflow_call_depth,
version=self.version,
label=self.label
)
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""
get the workflow by app id and version
"""
if not version:
workflow = db.session.query(Workflow).filter(
Workflow.app_id == app_id,
Workflow.version != 'draft'
).order_by(Workflow.created_at.desc()).first()
else:
workflow = db.session.query(Workflow).filter(
Workflow.app_id == app_id,
Workflow.version == version
).first()
if not workflow:
raise ValueError('workflow not found or not published')
return workflow
def _get_app(self, app_id: str) -> App:
"""
get the app by app id
"""
app = db.session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError('app not found')
return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
"""
transform the tool parameters
:param tool_parameters: the tool parameters
:return: tool_parameters, files
"""
parameter_rules = self.get_all_runtime_parameters()
parameters_result = {}
files = []
for parameter in parameter_rules:
if parameter.type == ToolParameter.ToolParameterType.FILE:
file = tool_parameters.get(parameter.name)
if file:
try:
file_var_list = [FileVar(**f) for f in file]
for file_var in file_var_list:
file_dict = {
'transfer_method': file_var.transfer_method.value,
'type': file_var.type.value,
}
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
file_dict['tool_file_id'] = file_var.related_id
elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
file_dict['upload_file_id'] = file_var.related_id
elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
file_dict['url'] = file_var.preview_url
files.append(file_dict)
except Exception as e:
logger.exception(e)
else:
parameters_result[parameter.name] = tool_parameters.get(parameter.name)
return parameters_result, files
def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
"""
extract files from the result
:param result: the result
:return: the result, files
"""
files = []
result = {}
for key, value in outputs.items():
if isinstance(value, list):
has_file = False
for item in value:
if isinstance(item, dict) and item.get('__variant') == 'FileVar':
try:
files.append(FileVar(**item))
has_file = True
except Exception as e:
pass
if has_file:
continue
result[key] = value
return result, files