Feat/workflow phase2 (#4687)
This commit is contained in:
@@ -8,9 +8,8 @@ import httpx
|
||||
import requests
|
||||
|
||||
import core.helper.ssrf_proxy as ssrf_proxy
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
@@ -20,12 +19,12 @@ API_TOOL_DEFAULT_TIMEOUT = (
|
||||
)
|
||||
|
||||
class ApiTool(Tool):
|
||||
api_bundle: ApiBasedToolBundle
|
||||
api_bundle: ApiToolBundle
|
||||
|
||||
"""
|
||||
Api tool
|
||||
"""
|
||||
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
@@ -37,7 +36,7 @@ class ApiTool(Tool):
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.copy() if self.description else None,
|
||||
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
|
||||
runtime=Tool.Runtime(**meta)
|
||||
runtime=Tool.Runtime(**runtime)
|
||||
)
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:
|
||||
@@ -55,7 +54,7 @@ class ApiTool(Tool):
|
||||
return self.validate_and_parse_response(response)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return UserToolProvider.ProviderType.API
|
||||
return ToolProviderType.API
|
||||
|
||||
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
headers = {}
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.model.tool_model_manager import ToolModelManager
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||
from core.tools.utils.web_reader_tool import get_url
|
||||
|
||||
_SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
|
||||
@@ -34,7 +33,7 @@ class BuiltinTool(Tool):
|
||||
:return: the model result
|
||||
"""
|
||||
# invoke model
|
||||
return ToolModelManager.invoke(
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tool_type='builtin',
|
||||
@@ -43,7 +42,7 @@ class BuiltinTool(Tool):
|
||||
)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return UserToolProvider.ProviderType.BUILTIN
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def get_max_tokens(self) -> int:
|
||||
"""
|
||||
@@ -52,7 +51,7 @@ class BuiltinTool(Tool):
|
||||
:param model_config: the model config
|
||||
:return: the max tokens
|
||||
"""
|
||||
return ToolModelManager.get_max_llm_context_tokens(
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
)
|
||||
|
||||
@@ -63,7 +62,7 @@ class BuiltinTool(Tool):
|
||||
:param prompt_messages: the prompt messages
|
||||
:return: the tokens
|
||||
"""
|
||||
return ToolModelManager.calculate_tokens(
|
||||
return ModelInvocationUtils.calculate_tokens(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
@@ -4,9 +4,12 @@ from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileVar
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
ToolInvokeFrom,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
@@ -25,10 +28,7 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
@validator('parameters', pre=True, always=True)
|
||||
def set_parameters(cls, v, values):
|
||||
if not v:
|
||||
return []
|
||||
|
||||
return v
|
||||
return v or []
|
||||
|
||||
class Runtime(BaseModel):
|
||||
"""
|
||||
@@ -41,6 +41,8 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
tenant_id: str = None
|
||||
tool_id: str = None
|
||||
invoke_from: InvokeFrom = None
|
||||
tool_invoke_from: ToolInvokeFrom = None
|
||||
credentials: dict[str, Any] = None
|
||||
runtime_parameters: dict[str, Any] = None
|
||||
|
||||
@@ -53,7 +55,7 @@ class Tool(BaseModel, ABC):
|
||||
class VARIABLE_KEY(Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
@@ -64,7 +66,7 @@ class Tool(BaseModel, ABC):
|
||||
identity=self.identity.copy() if self.identity else None,
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.copy() if self.description else None,
|
||||
runtime=Tool.Runtime(**meta),
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -208,17 +210,17 @@ class Tool(BaseModel, ABC):
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += response.message
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += f"result link: {response.message}. please tell user to check it."
|
||||
result += f"result link: {response.message}. please tell user to check it. \n"
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
|
||||
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now. \n"
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
if len(response.message) > 114:
|
||||
result += str(response.message[:114]) + '...'
|
||||
else:
|
||||
result += str(response.message)
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
result += f"tool response: {response.message}. \n"
|
||||
|
||||
return result
|
||||
|
||||
@@ -343,6 +345,14 @@ class Tool(BaseModel, ABC):
|
||||
message=image,
|
||||
save_as=save_as)
|
||||
|
||||
def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
|
||||
message='',
|
||||
meta={
|
||||
'file_var': file_var
|
||||
},
|
||||
save_as='')
|
||||
|
||||
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a link message
|
||||
|
||||
200
api/core/tools/tool/workflow_tool.py
Normal file
200
api/core/tools/tool/workflow_tool.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class WorkflowTool(Tool):
|
||||
workflow_app_id: str
|
||||
version: str
|
||||
workflow_entities: dict[str, Any]
|
||||
workflow_call_depth: int
|
||||
|
||||
label: str
|
||||
|
||||
"""
|
||||
Workflow tool.
|
||||
"""
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
get the tool provider type
|
||||
|
||||
:return: the tool provider type
|
||||
"""
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke the tool
|
||||
"""
|
||||
app = self._get_app(app_id=self.workflow_app_id)
|
||||
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
|
||||
|
||||
# transform the tool parameters
|
||||
tool_parameters, files = self._transform_args(tool_parameters)
|
||||
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
generator = WorkflowAppGenerator()
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=self._get_user(user_id),
|
||||
args={
|
||||
'inputs': tool_parameters,
|
||||
'files': files
|
||||
},
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
stream=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
)
|
||||
|
||||
data = result.get('data', {})
|
||||
|
||||
if data.get('error'):
|
||||
raise Exception(data.get('error'))
|
||||
|
||||
result = []
|
||||
|
||||
outputs = data.get('outputs', {})
|
||||
outputs, files = self._extract_files(outputs)
|
||||
for file in files:
|
||||
result.append(self.create_file_var_message(file))
|
||||
|
||||
result.append(self.create_text_message(json.dumps(outputs)))
|
||||
|
||||
return result
|
||||
|
||||
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
|
||||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
|
||||
user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||
if not user:
|
||||
user = db.session.query(Account).filter(Account.id == user_id).first()
|
||||
|
||||
if not user:
|
||||
raise ValueError('user not found')
|
||||
|
||||
return user
|
||||
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool':
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
:param meta: the meta data of a tool call processing, tenant_id is required
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
identity=deepcopy(self.identity),
|
||||
parameters=deepcopy(self.parameters),
|
||||
description=deepcopy(self.description),
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
workflow_app_id=self.workflow_app_id,
|
||||
workflow_entities=self.workflow_entities,
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
version=self.version,
|
||||
label=self.label
|
||||
)
|
||||
|
||||
def _get_workflow(self, app_id: str, version: str) -> Workflow:
|
||||
"""
|
||||
get the workflow by app id and version
|
||||
"""
|
||||
if not version:
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.app_id == app_id,
|
||||
Workflow.version != 'draft'
|
||||
).order_by(Workflow.created_at.desc()).first()
|
||||
else:
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.app_id == app_id,
|
||||
Workflow.version == version
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise ValueError('workflow not found or not published')
|
||||
|
||||
return workflow
|
||||
|
||||
def _get_app(self, app_id: str) -> App:
|
||||
"""
|
||||
get the app by app id
|
||||
"""
|
||||
app = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError('app not found')
|
||||
|
||||
return app
|
||||
|
||||
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
|
||||
"""
|
||||
transform the tool parameters
|
||||
|
||||
:param tool_parameters: the tool parameters
|
||||
:return: tool_parameters, files
|
||||
"""
|
||||
parameter_rules = self.get_all_runtime_parameters()
|
||||
parameters_result = {}
|
||||
files = []
|
||||
for parameter in parameter_rules:
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
file = tool_parameters.get(parameter.name)
|
||||
if file:
|
||||
try:
|
||||
file_var_list = [FileVar(**f) for f in file]
|
||||
for file_var in file_var_list:
|
||||
file_dict = {
|
||||
'transfer_method': file_var.transfer_method.value,
|
||||
'type': file_var.type.value,
|
||||
}
|
||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict['tool_file_id'] = file_var.related_id
|
||||
elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict['upload_file_id'] = file_var.related_id
|
||||
elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
file_dict['url'] = file_var.preview_url
|
||||
|
||||
files.append(file_dict)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
else:
|
||||
parameters_result[parameter.name] = tool_parameters.get(parameter.name)
|
||||
|
||||
return parameters_result, files
|
||||
|
||||
def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
|
||||
"""
|
||||
extract files from the result
|
||||
|
||||
:param result: the result
|
||||
:return: the result, files
|
||||
"""
|
||||
files = []
|
||||
result = {}
|
||||
for key, value in outputs.items():
|
||||
if isinstance(value, list):
|
||||
has_file = False
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item.get('__variant') == 'FileVar':
|
||||
try:
|
||||
files.append(FileVar(**item))
|
||||
has_file = True
|
||||
except Exception as e:
|
||||
pass
|
||||
if has_file:
|
||||
continue
|
||||
|
||||
result[key] = value
|
||||
|
||||
return result, files
|
||||
Reference in New Issue
Block a user