Feat/workflow phase2 (#4687)

This commit is contained in:
Yeuoly
2024-05-27 22:01:11 +08:00
committed by GitHub
parent 45deaee762
commit e852a21634
139 changed files with 5997 additions and 779 deletions

View File

@@ -2,8 +2,9 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
@@ -37,6 +38,9 @@ class BaseNode(ABC):
workflow_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
workflow_call_depth: int
node_id: str
node_data: BaseNodeData
@@ -49,13 +53,17 @@ class BaseNode(ABC):
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
config: dict,
callbacks: list[BaseWorkflowCallback] = None) -> None:
callbacks: list[BaseWorkflowCallback] = None,
workflow_call_depth: int = 0) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
self.workflow_id = workflow_id
self.user_id = user_id
self.user_from = user_from
self.invoke_from = invoke_from
self.workflow_call_depth = workflow_call_depth
self.node_id = config.get("id")
if not self.node_id:
@@ -140,3 +148,38 @@ class BaseNode(ABC):
:return:
"""
return self._node_type
class BaseIterationNode(BaseNode):
@abstractmethod
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
"""
Run node
:param variable_pool: variable pool
:return:
"""
raise NotImplementedError
def run(self, variable_pool: VariablePool) -> BaseIterationState:
"""
Run node entry
:param variable_pool: variable pool
:return:
"""
return self._run(variable_pool=variable_pool)
def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
return self._get_next_iteration(variable_pool, state)
@abstractmethod
def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
raise NotImplementedError

View File

@@ -0,0 +1,39 @@
from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
class IterationNodeData(BaseIterationNodeData):
"""
Iteration Node Data.
"""
parent_loop_id: Optional[str] # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
class IterationState(BaseIterationState):
"""
Iteration State.
"""
outputs: list[Any] = None
current_output: Optional[Any] = None
class MetaData(BaseIterationState.MetaData):
"""
Data.
"""
iterator_length: int
def get_last_output(self) -> Optional[Any]:
"""
Get last output.
"""
if self.outputs:
return self.outputs[-1]
return None
def get_current_output(self) -> Optional[Any]:
"""
Get current output.
"""
return self.current_output

View File

@@ -0,0 +1,119 @@
from typing import cast
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
from models.workflow import WorkflowNodeExecutionStatus
class IterationNode(BaseIterationNode):
"""
Iteration Node.
"""
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
"""
Run the node.
"""
iterator = variable_pool.get_variable_value(cast(IterationNodeData, self.node_data).iterator_selector)
if not isinstance(iterator, list):
raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={
'iterator_selector': iterator
}, outputs=[], metadata=IterationState.MetaData(
iterator_length=len(iterator) if iterator is not None else 0
))
self._set_current_iteration_variable(variable_pool, state)
return state
def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
# resolve current output
self._resolve_current_output(variable_pool, state)
# move to next iteration
self._next_iteration(variable_pool, state)
node_data = cast(IterationNodeData, self.node_data)
if self._reached_iteration_limit(variable_pool, state):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'output': jsonable_encoder(state.outputs)
}
)
return node_data.start_node_id
def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
"""
Set current iteration variable.
:variable_pool: variable pool
"""
node_data = cast(IterationNodeData, self.node_data)
variable_pool.append_variable(self.node_id, ['index'], state.index)
# get the iterator value
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return
if state.index < len(iterator):
variable_pool.append_variable(self.node_id, ['item'], iterator[state.index])
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
"""
Move to next iteration.
:param variable_pool: variable pool
"""
state.index += 1
self._set_current_iteration_variable(variable_pool, state)
def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
"""
Check if iteration limit is reached.
:return: True if iteration limit is reached, False otherwise
"""
node_data = cast(IterationNodeData, self.node_data)
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return True
return state.index >= len(iterator)
def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
"""
Resolve current output.
:param variable_pool: variable pool
"""
output_selector = cast(IterationNodeData, self.node_data).output_selector
output = variable_pool.get_variable_value(output_selector)
# clear the output for this iteration
variable_pool.append_variable(self.node_id, output_selector[1:], None)
state.current_output = output
if output is not None:
state.outputs.append(output)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {
'input_selector': node_data.iterator_selector,
}

View File

View File

@@ -0,0 +1,13 @@
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
class LoopNodeData(BaseIterationNodeData):
"""
Loop Node Data.
"""
class LoopState(BaseIterationState):
"""
Loop State.
"""

View File

@@ -0,0 +1,20 @@
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
class LoopNode(BaseIterationNode):
"""
Loop Node.
"""
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
def _run(self, variable_pool: VariablePool) -> LoopState:
return super()._run(variable_pool)
def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
"""

View File

@@ -0,0 +1,85 @@
from typing import Any, Literal, Optional
from pydantic import BaseModel, validator
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
class ParameterConfig(BaseModel):
"""
Parameter Config.
"""
name: str
type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]']
options: Optional[list[str]]
description: str
required: bool
@validator('name', pre=True, always=True)
def validate_name(cls, value):
if not value:
raise ValueError('Parameter name is required')
if value in ['__reason', '__is_success']:
raise ValueError('Invalid parameter name, __reason and __is_success are reserved')
return value
class ParameterExtractorNodeData(BaseNodeData):
"""
Parameter Extractor Node Data.
"""
model: ModelConfig
query: list[str]
parameters: list[ParameterConfig]
instruction: Optional[str]
memory: Optional[MemoryConfig]
reasoning_mode: Literal['function_call', 'prompt']
@validator('reasoning_mode', pre=True, always=True)
def set_reasoning_mode(cls, v):
return v or 'function_call'
def get_parameter_json_schema(self) -> dict:
"""
Get parameter json schema.
:return: parameter json schema
"""
parameters = {
'type': 'object',
'properties': {},
'required': []
}
for parameter in self.parameters:
parameter_schema = {
'description': parameter.description
}
if parameter.type in ['string', 'select']:
parameter_schema['type'] = 'string'
elif parameter.type.startswith('array'):
parameter_schema['type'] = 'array'
nested_type = parameter.type[6:-1]
parameter_schema['items'] = {'type': nested_type}
else:
parameter_schema['type'] = parameter.type
if parameter.type == 'select':
parameter_schema['enum'] = parameter.options
parameters['properties'][parameter.name] = parameter_schema
if parameter.required:
parameters['required'].append(parameter.name)
return parameters

View File

@@ -0,0 +1,711 @@
import json
import uuid
from typing import Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
PromptMessageTool,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from core.workflow.nodes.parameter_extractor.prompts import (
CHAT_EXAMPLE,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
COMPLETION_GENERATE_JSON_PROMPT,
FUNCTION_CALLING_EXTRACTOR_EXAMPLE,
FUNCTION_CALLING_EXTRACTOR_NAME,
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT,
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus
class ParameterExtractorNode(LLMNode):
"""
Parameter Extractor Node.
"""
_node_data_cls = ParameterExtractorNodeData
_node_type = NodeType.PARAMETER_EXTRACTOR
_model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return {
"model": {
"prompt_templates": {
"completion_model": {
"conversation_histories_role": {
"user_prefix": "Human",
"assistant_prefix": "Assistant"
},
"stop": ["Human:"]
}
}
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self.node_data)
query = variable_pool.get_variable_value(node_data.query)
if not query:
raise ValueError("Query not found")
inputs={
'query': query,
'parameters': jsonable_encoder(node_data.parameters),
'instruction': jsonable_encoder(node_data.instruction),
}
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise ValueError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
if not model_schema:
raise ValueError("Model schema not found")
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]) \
and node_data.reasoning_mode == 'function_call':
# use function call
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
node_data, query, variable_pool, model_config, memory
)
else:
# use prompt engineering
prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config, memory)
prompt_message_tools = []
process_data = {
'model_mode': model_config.mode,
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode,
prompt_messages=prompt_messages
),
'usage': None,
'function': {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
'tool_call': None,
}
try:
text, usage, tool_call = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
tools=prompt_message_tools,
stop=model_config.stop,
)
process_data['usage'] = jsonable_encoder(usage)
process_data['tool_call'] = jsonable_encoder(tool_call)
process_data['llm_text'] = text
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data={},
outputs={
'__is_success': 0,
'__reason': str(e)
},
error=str(e),
metadata={}
)
error = None
if tool_call:
result = self._extract_json_from_tool_call(tool_call)
else:
result = self._extract_complete_json_response(text)
if not result:
result = self._generate_default_result(node_data)
error = "Failed to extract result from function call or text response, using empty result."
try:
result = self._validate_result(node_data, result)
except Exception as e:
error = str(e)
# transform result into standard format
result = self._transform_result(node_data, result)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={
'__is_success': 1 if not error else 0,
'__reason': error,
**result
},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
)
def _invoke_llm(self, node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
stop: list[str]) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
"""
Invoke large language model
:param node_data_model: node data model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,
tools=tools,
stop=stop,
stream=False,
user=self.user_id,
)
# handle invoke result
if not isinstance(invoke_result, LLMResult):
raise ValueError(f"Invalid invoke result: {invoke_result}")
text = invoke_result.message.content
usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage, tool_call
def _generate_function_call_prompt(self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
"""
Generate function call prompt.
"""
query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps(node_data.get_parameter_json_schema()))
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory, rest_token)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
files=[],
context='',
memory_config=node_data.memory,
memory=None,
model_config=model_config
)
# find last user message
last_user_message_idx = -1
for i, prompt_message in enumerate(prompt_messages):
if prompt_message.role == PromptMessageRole.USER:
last_user_message_idx = i
# add function call messages before last user message
example_messages = []
for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE:
id = uuid.uuid4().hex
example_messages.extend([
UserPromptMessage(content=example['user']['query']),
AssistantPromptMessage(
content=example['assistant']['text'],
tool_calls=[
AssistantPromptMessage.ToolCall(
id=id,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=example['assistant']['function_call']['name'],
arguments=json.dumps(example['assistant']['function_call']['parameters']
)
))
]
),
ToolPromptMessage(
content='Great! You have called the function with the correct parameters.',
tool_call_id=id
),
AssistantPromptMessage(
content='I have extracted the parameters, let\'s move on.',
)
])
prompt_messages = prompt_messages[:last_user_message_idx] + \
example_messages + prompt_messages[last_user_message_idx:]
# generate tool
tool = PromptMessageTool(
name=FUNCTION_CALLING_EXTRACTOR_NAME,
description='Extract parameters from the natural language text',
parameters=node_data.get_parameter_json_schema(),
)
return prompt_messages, [tool]
def _generate_prompt_engineering_prompt(self,
data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
"""
Generate prompt engineering prompt.
"""
model_mode = ModelMode.value_of(data.model.mode)
if model_mode == ModelMode.COMPLETION:
return self._generate_prompt_engineering_completion_prompt(
data, query, variable_pool, model_config, memory
)
elif model_mode == ModelMode.CHAT:
return self._generate_prompt_engineering_chat_prompt(
data, query, variable_pool, model_config, memory
)
else:
raise ValueError(f"Invalid model mode: {model_mode}")
def _generate_prompt_engineering_completion_prompt(self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
"""
Generate completion prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory, rest_token)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={
'structure': json.dumps(node_data.get_parameter_json_schema())
},
query='',
files=[],
context='',
memory_config=node_data.memory,
memory=memory,
model_config=model_config
)
return prompt_messages
def _generate_prompt_engineering_chat_prompt(self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
"""
Generate chat prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
prompt_template = self._get_prompt_engineering_prompt_template(
node_data,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
structure=json.dumps(node_data.get_parameter_json_schema()),
text=query
),
variable_pool, memory, rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
files=[],
context='',
memory_config=node_data.memory,
memory=memory,
model_config=model_config
)
# find last user message
last_user_message_idx = -1
for i, prompt_message in enumerate(prompt_messages):
if prompt_message.role == PromptMessageRole.USER:
last_user_message_idx = i
# add example messages before last user message
example_messages = []
for example in CHAT_EXAMPLE:
example_messages.extend([
UserPromptMessage(content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
structure=json.dumps(example['user']['json']),
text=example['user']['query'],
)),
AssistantPromptMessage(
content=json.dumps(example['assistant']['json']),
)
])
prompt_messages = prompt_messages[:last_user_message_idx] + \
example_messages + prompt_messages[last_user_message_idx:]
return prompt_messages
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
"""
Validate result.
"""
if len(data.parameters) != len(result):
raise ValueError("Invalid number of parameters")
for parameter in data.parameters:
if parameter.required and parameter.name not in result:
raise ValueError(f"Parameter {parameter.name} is required")
if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options:
raise ValueError(f"Invalid `select` value for parameter {parameter.name}")
if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float):
raise ValueError(f"Invalid `number` value for parameter {parameter.name}")
if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool):
raise ValueError(f"Invalid `bool` value for parameter {parameter.name}")
if parameter.type == 'string' and not isinstance(result.get(parameter.name), str):
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
if parameter.type.startswith('array'):
if not isinstance(result.get(parameter.name), list):
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
nested_type = parameter.type[6:-1]
for item in result.get(parameter.name):
if nested_type == 'number' and not isinstance(item, int | float):
raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
if nested_type == 'string' and not isinstance(item, str):
raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
if nested_type == 'object' and not isinstance(item, dict):
raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
return result
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
"""
Transform result into standard format.
"""
transformed_result = {}
for parameter in data.parameters:
if parameter.name in result:
# transform value
if parameter.type == 'number':
if isinstance(result[parameter.name], int | float):
transformed_result[parameter.name] = result[parameter.name]
elif isinstance(result[parameter.name], str):
try:
if '.' in result[parameter.name]:
result[parameter.name] = float(result[parameter.name])
else:
result[parameter.name] = int(result[parameter.name])
except ValueError:
pass
else:
pass
# TODO: bool is not supported in the current version
# elif parameter.type == 'bool':
# if isinstance(result[parameter.name], bool):
# transformed_result[parameter.name] = bool(result[parameter.name])
# elif isinstance(result[parameter.name], str):
# if result[parameter.name].lower() in ['true', 'false']:
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
# elif isinstance(result[parameter.name], int):
# transformed_result[parameter.name] = bool(result[parameter.name])
elif parameter.type in ['string', 'select']:
if isinstance(result[parameter.name], str):
transformed_result[parameter.name] = result[parameter.name]
elif parameter.type.startswith('array'):
if isinstance(result[parameter.name], list):
nested_type = parameter.type[6:-1]
transformed_result[parameter.name] = []
for item in result[parameter.name]:
if nested_type == 'number':
if isinstance(item, int | float):
transformed_result[parameter.name].append(item)
elif isinstance(item, str):
try:
if '.' in item:
transformed_result[parameter.name].append(float(item))
else:
transformed_result[parameter.name].append(int(item))
except ValueError:
pass
elif nested_type == 'string':
if isinstance(item, str):
transformed_result[parameter.name].append(item)
elif nested_type == 'object':
if isinstance(item, dict):
transformed_result[parameter.name].append(item)
if parameter.name not in transformed_result:
if parameter.type == 'number':
transformed_result[parameter.name] = 0
elif parameter.type == 'bool':
transformed_result[parameter.name] = False
elif parameter.type in ['string', 'select']:
transformed_result[parameter.name] = ''
elif parameter.type.startswith('array'):
transformed_result[parameter.name] = []
return transformed_result
def _extract_complete_json_response(self, result: str) -> Optional[dict]:
"""
Extract complete json response.
"""
def extract_json(text):
"""
From a given JSON started from '{' or '[' extract the complete JSON object.
"""
stack = []
for i, c in enumerate(text):
if c == '{' or c == '[':
stack.append(c)
elif c == '}' or c == ']':
# check if stack is empty
if not stack:
return text[:i]
# check if the last element in stack is matching
if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['):
stack.pop()
if not stack:
return text[:i+1]
else:
return text[:i]
return None
# extract json from the text
for idx in range(len(result)):
if result[idx] == '{' or result[idx] == '[':
json_str = extract_json(result[idx:])
if json_str:
try:
return json.loads(json_str)
except Exception:
pass
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
"""
Extract json from tool call.
"""
if not tool_call or not tool_call.function.arguments:
return None
return json.loads(tool_call.function.arguments)
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
"""
Generate default result.
"""
result = {}
for parameter in data.parameters:
if parameter.type == 'number':
result[parameter.name] = 0
elif parameter.type == 'bool':
result[parameter.name] = False
elif parameter.type in ['string', 'select']:
result[parameter.name] = ''
return result
def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
"""
Render instruction.
"""
variable_template_parser = VariableTemplateParser(instruction)
inputs = {}
for selector in variable_template_parser.extract_variable_selectors():
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
return variable_template_parser.format(inputs)
def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000) \
-> list[ChatModelMessage]:
model_mode = ModelMode.value_of(node_data.model.mode)
input_text = query
memory_str = ''
instruction = self._render_instruction(node_data.instruction or '', variable_pool)
if memory:
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction)
)
user_prompt_message = ChatModelMessage(
role=PromptMessageRole.USER,
text=input_text
)
return [system_prompt_messages, user_prompt_message]
else:
raise ValueError(f"Model mode {model_mode} not support.")
def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000) \
-> list[ChatModelMessage]:
model_mode = ModelMode.value_of(node_data.model.mode)
input_text = query
memory_str = ''
instruction = self._render_instruction(node_data.instruction or '', variable_pool)
if memory:
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction)
)
user_prompt_message = ChatModelMessage(
role=PromptMessageRole.USER,
text=input_text
)
return [system_prompt_messages, user_prompt_message]
elif model_mode == ModelMode.COMPLETION:
return CompletionModelPromptTemplate(
text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str,
text=input_text,
instruction=instruction)
.replace('{γγγ', '')
.replace('}γγγ', '')
)
else:
raise ValueError(f"Model mode {model_mode} not support.")
def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str]) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise ValueError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
if not model_schema:
raise ValueError("Model schema not found")
if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]):
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
else:
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
curr_message_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages
) + 1000 # add 1000 to ensure tool call messages
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config.
"""
if not self._model_instance or not self._model_config:
self._model_instance, self._model_config = super()._fetch_model_config(node_data_model)
return self._model_instance, self._model_config
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
node_data = node_data
variable_mapping = {
'query': node_data.query
}
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
for selector in variable_template_parser.extract_variable_selectors():
variable_mapping[selector.variable] = selector.value_selector
return variable_mapping

View File

@@ -0,0 +1,206 @@
FUNCTION_CALLING_EXTRACTOR_NAME = 'extract_parameters'
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
### Task
Always call the `{FUNCTION_CALLING_EXTRACTOR_NAME}` function with the correct parameters. Ensure that the information extraction is contextual and aligns with the provided criteria.
### Memory
Here is the chat history between the human and assistant, provided within <histories> tags:
<histories>
\x7bhistories\x7d
</histories>
### Instructions:
Some additional information is provided below. Always adhere to these instructions as closely as possible:
<instruction>
\x7binstruction\x7d
</instruction>
Steps:
1. Review the chat history provided within the <histories> tags.
2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text.
3. Generate a well-formatted output using the defined functions and arguments.
4. Use the `extract_parameter` function to create structured outputs with appropriate parameters.
5. Do not include any XML tags in your output.
### Example
To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples.
### Final Output
Produce well-formatted function calls in json without XML tags, as shown in the example.
"""
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside <context></context> XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside <structure></structure> XML tags.
<context>
\x7bcontent\x7d
</context>
<structure>
\x7bstructure\x7d
</structure>
"""
FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [{
'user': {
'query': 'What is the weather today in SF?',
'function': {
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
'parameters': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The location to get the weather information',
'required': True
},
},
'required': ['location']
}
}
},
'assistant': {
'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.',
'function_call' : {
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
'parameters': {
'location': 'San Francisco'
}
}
}
}, {
'user': {
'query': 'I want to eat some apple pie.',
'function': {
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
'parameters': {
'type': 'object',
'properties': {
'food': {
'type': 'string',
'description': 'The food to eat',
'required': True
}
},
'required': ['food']
}
}
},
'assistant': {
'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.',
'function_call' : {
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
'parameters': {
'food': 'apple pie'
}
}
}
}]
COMPLETION_GENERATE_JSON_PROMPT = """### Instructions:
Some extra information are provided below, I should always follow the instructions as possible as I can.
<instructions>
{instruction}
</instructions>
### Extract parameter Workflow
I need to extract the following information from the input text. The <information to be extracted> tag specifies the 'type', 'description' and 'required' of the information to be extracted.
<information to be extracted>
{{ structure }}
</information to be extracted>
Step 1: Carefully read the input and understand the structure of the expected output.
Step 2: Extract relevant parameters from the provided text based on the name and description of object.
Step 3: Structure the extracted parameters to JSON object as specified in <structure>.
Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted.
### Memory
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
### Structure
Here is the structure of the expected output, I should always follow the output structure.
{{γγγ
'properties1': 'relevant text extracted from input',
'properties2': 'relevant text extracted from input',
}}γγγ
### Input Text
Inside <text></text> XML tags, there is a text that I should extract parameters and convert to a JSON object.
<text>
{text}
</text>
### Answer
I should always output a valid JSON object. Output nothing other than the JSON object.
```JSON
"""
CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object.
The structure of the JSON object you can found in the instructions.
### Memory
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
### Instructions:
Some extra information are provided below, you should always follow the instructions as possible as you can.
<instructions>
{{instructions}}
</instructions>
"""
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE = """### Structure
Here is the structure of the JSON object, you should always follow the structure.
<structure>
{structure}
</structure>
### Text to be converted to JSON
Inside <text></text> XML tags, there is a text that you should convert to a JSON object.
<text>
{text}
</text>
"""
CHAT_EXAMPLE = [{
'user': {
'query': 'What is the weather today in SF?',
'json': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The location to get the weather information',
'required': True
}
},
'required': ['location']
}
},
'assistant': {
'text': 'I need to output a valid JSON object.',
'json': {
'location': 'San Francisco'
}
}
}, {
'user': {
'query': 'I want to eat some apple pie.',
'json': {
'type': 'object',
'properties': {
'food': {
'type': 'string',
'description': 'The food to eat',
'required': True
}
},
'required': ['food']
}
},
'assistant': {
'text': 'I need to output a valid JSON object.',
'json': {
'result': 'apple pie'
}
}
}]

View File

@@ -7,7 +7,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
class ToolEntity(BaseModel):
provider_id: str
provider_type: Literal['builtin', 'api']
provider_type: Literal['builtin', 'api', 'workflow']
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy

View File

@@ -1,13 +1,14 @@
from os import path
from typing import cast
from typing import Optional, cast
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
@@ -35,20 +36,23 @@ class ToolNode(BaseNode):
'provider_id': node_data.provider_id
}
# get parameters
parameters = self._generate_parameters(variable_pool, node_data)
# get tool runtime
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, self.app_id, self.node_id, node_data)
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters,
inputs={},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to get tool runtime: {str(e)}'
)
# get parameters
parameters = self._generate_parameters(variable_pool, node_data, tool_runtime)
try:
messages = ToolEngine.workflow_invoke(
@@ -56,7 +60,8 @@ class ToolNode(BaseNode):
tool_parameters=parameters,
user_id=self.user_id,
workflow_id=self.workflow_id,
workflow_tool_callback=DifyWorkflowCallbackHandler()
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth + 1
)
except Exception as e:
return NodeRunResult(
@@ -83,19 +88,32 @@ class ToolNode(BaseNode):
inputs=parameters
)
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict:
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData, tool_runtime: Tool) -> dict:
"""
Generate parameters
"""
tool_parameters = tool_runtime.get_all_runtime_parameters()
def fetch_parameter(name: str) -> Optional[ToolParameter]:
return next((parameter for parameter in tool_parameters if parameter.name == name), None)
result = {}
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
elif input.type == 'variable':
result[parameter_name] = variable_pool.get_variable_value(input.value)
elif input.type == 'constant':
result[parameter_name] = input.value
parameter = fetch_parameter(parameter_name)
if not parameter:
continue
if parameter.type == ToolParameter.ToolParameterType.FILE:
result[parameter_name] = [
v.to_dict() for v in self._fetch_files(variable_pool)
]
else:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
elif input.type == 'variable':
result[parameter_name] = variable_pool.get_variable_value(input.value)
elif input.type == 'constant':
result[parameter_name] = input.value
return result
@@ -109,6 +127,13 @@ class ToolNode(BaseNode):
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
return template_parser.format(inputs)
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
if not files:
return []
return files
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
"""

View File

@@ -0,0 +1,33 @@
from typing import Literal, Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
class AdvancedSetting(BaseModel):
"""
Advanced setting.
"""
group_enabled: bool
class Group(BaseModel):
"""
Group.
"""
output_type: Literal['string', 'number', 'array', 'object']
variables: list[list[str]]
group_name: str
groups: list[Group]
class VariableAssignerNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
type: str = 'variable-assigner'
output_type: str
variables: list[list[str]]
advanced_setting: Optional[AdvancedSetting]

View File

@@ -0,0 +1,52 @@
from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
from models.workflow import WorkflowNodeExecutionStatus
class VariableAggregatorNode(BaseNode):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data = cast(VariableAssignerNodeData, self.node_data)
# Get variables
outputs = {}
inputs = {}
if not node_data.advanced_setting or node_data.advanced_setting.group_enabled:
for variable in node_data.variables:
value = variable_pool.get_variable_value(variable)
if value is not None:
outputs = {
"output": value
}
inputs = {
'.'.join(variable[1:]): value
}
break
else:
for group in node_data.advanced_setting.groups:
for variable in group.variables:
value = variable_pool.get_variable_value(variable)
if value is not None:
outputs[f'{group.group_name}_output'] = value
inputs['.'.join(variable[1:])] = value
break
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
inputs=inputs
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
return {}

View File

@@ -1,12 +0,0 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
class VariableAssignerNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
type: str = 'variable-assigner'
output_type: str
variables: list[list[str]]

View File

@@ -1,41 +0,0 @@
from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.variable_assigner.entities import VariableAssignerNodeData
from models.workflow import WorkflowNodeExecutionStatus
class VariableAssignerNode(BaseNode):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: VariableAssignerNodeData = cast(self._node_data_cls, self.node_data)
# Get variables
outputs = {}
inputs = {}
for variable in node_data.variables:
value = variable_pool.get_variable_value(variable)
if value is not None:
outputs = {
"output": value
}
inputs = {
'.'.join(variable[1:]): value
}
break
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
inputs=inputs
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
return {}