FEAT: NEW WORKFLOW ENGINE (#3160)
Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
0
api/core/agent/__init__.py
Normal file
0
api/core/agent/__init__.py
Normal file
473
api/core/agent/base_agent_runner.py
Normal file
473
api/core/agent/base_agent_runner.py
Normal file
@@ -0,0 +1,473 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
ModelConfigWithCredentialsEntity,
|
||||
)
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolRuntimeVariablePool,
|
||||
)
|
||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseAgentRunner(AppRunner):
|
||||
def __init__(self, tenant_id: str,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
app_config: AgentChatAppConfig,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None
|
||||
) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
:param app_config: app generate entity
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
:param queue_manager: queue manager
|
||||
:param message: message
|
||||
:param user_id: user id
|
||||
:param agent_llm_callback: agent llm callback
|
||||
:param callback: callback
|
||||
:param memory: memory
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.app_config = app_config
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
self.queue_manager = queue_manager
|
||||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = self.organize_agent_history(
|
||||
prompt_messages=prompt_messages or []
|
||||
)
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
|
||||
# init callback
|
||||
self.agent_callback = DifyAgentCallbackHandler()
|
||||
# init dataset tools
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=queue_manager,
|
||||
app_id=self.app_config.app_id,
|
||||
message_id=message.id,
|
||||
user_id=user_id,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
|
||||
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
||||
return_resource=app_config.additional_features.show_retrieve_source,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback
|
||||
)
|
||||
# get how many agent thoughts have been created
|
||||
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
).count()
|
||||
db.session.close()
|
||||
|
||||
# check if model supports stream tool call
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
|
||||
self.stream_tool_call = True
|
||||
else:
|
||||
self.stream_tool_call = False
|
||||
|
||||
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
|
||||
-> AgentChatAppGenerateEntity:
|
||||
"""
|
||||
Repack app generate entity
|
||||
"""
|
||||
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
|
||||
app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
|
||||
|
||||
return app_generate_entity
|
||||
|
||||
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
result = ''
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += response.message
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += f"result link: {response.message}. please tell user to check it."
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
tool_entity = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
agent_tool=tool,
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
message_tool = PromptMessageTool(
|
||||
name=tool.tool_name,
|
||||
description=tool_entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.STRING:
|
||||
parameter_type = 'string'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
parameter_type = 'boolean'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
parameter_type = 'number'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
for option in parameter.options:
|
||||
enum.append(option.value)
|
||||
parameter_type = 'string'
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return message_tool, tool_entity
|
||||
|
||||
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
|
||||
"""
|
||||
convert dataset retriever tool to prompt message tool
|
||||
"""
|
||||
prompt_tool = PromptMessageTool(
|
||||
name=tool.identity.name,
|
||||
description=tool.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
for parameter in tool.get_runtime_parameters():
|
||||
parameter_type = 'string'
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
|
||||
"""
|
||||
update prompt message tool
|
||||
"""
|
||||
# try to get tool runtime parameters
|
||||
tool_runtime_parameters = tool.get_runtime_parameters() or []
|
||||
|
||||
for parameter in tool_runtime_parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.STRING:
|
||||
parameter_type = 'string'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
parameter_type = 'boolean'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
parameter_type = 'number'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
for option in parameter.options:
|
||||
enum.append(option.value)
|
||||
parameter_type = 'string'
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
tool_name: str, tool_input: str, messages_ids: list[str]
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Create agent thought
|
||||
"""
|
||||
thought = MessageAgentThought(
|
||||
message_id=message_id,
|
||||
message_chain_id=None,
|
||||
thought='',
|
||||
tool=tool_name,
|
||||
tool_labels_str='{}',
|
||||
tool_meta_str='{}',
|
||||
tool_input=tool_input,
|
||||
message=message,
|
||||
message_token=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
message_files=json.dumps(messages_ids) if messages_ids else '',
|
||||
answer='',
|
||||
observation='',
|
||||
answer_token=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
tokens=0,
|
||||
total_price=0,
|
||||
position=self.agent_thought_count + 1,
|
||||
currency='USD',
|
||||
latency=0,
|
||||
created_by_role='account',
|
||||
created_by=self.user_id,
|
||||
)
|
||||
|
||||
db.session.add(thought)
|
||||
db.session.commit()
|
||||
db.session.refresh(thought)
|
||||
db.session.close()
|
||||
|
||||
self.agent_thought_count += 1
|
||||
|
||||
return thought
|
||||
|
||||
def save_agent_thought(self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, str],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
answer: str,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.id == agent_thought.id
|
||||
).first()
|
||||
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
|
||||
if tool_name is not None:
|
||||
agent_thought.tool = tool_name
|
||||
|
||||
if tool_input is not None:
|
||||
if isinstance(tool_input, dict):
|
||||
try:
|
||||
tool_input = json.dumps(tool_input, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
tool_input = json.dumps(tool_input)
|
||||
|
||||
agent_thought.tool_input = tool_input
|
||||
|
||||
if observation is not None:
|
||||
if isinstance(observation, dict):
|
||||
try:
|
||||
observation = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
observation = json.dumps(observation)
|
||||
|
||||
agent_thought.observation = observation
|
||||
|
||||
if answer is not None:
|
||||
agent_thought.answer = answer
|
||||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
agent_thought.message_files = json.dumps(messages_ids)
|
||||
|
||||
if llm_usage:
|
||||
agent_thought.message_token = llm_usage.prompt_tokens
|
||||
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
agent_thought.answer_token = llm_usage.completion_tokens
|
||||
agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
agent_thought.tokens = llm_usage.total_tokens
|
||||
agent_thought.total_price = llm_usage.total_price
|
||||
|
||||
# check if tool labels is not empty
|
||||
labels = agent_thought.tool_labels or {}
|
||||
tools = agent_thought.tool.split(';') if agent_thought.tool else []
|
||||
for tool in tools:
|
||||
if not tool:
|
||||
continue
|
||||
if tool not in labels:
|
||||
tool_label = ToolManager.get_tool_label(tool)
|
||||
if tool_label:
|
||||
labels[tool] = tool_label.to_dict()
|
||||
else:
|
||||
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
|
||||
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
if tool_invoke_meta is not None:
|
||||
if isinstance(tool_invoke_meta, dict):
|
||||
try:
|
||||
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
tool_invoke_meta = json.dumps(tool_invoke_meta)
|
||||
|
||||
agent_thought.tool_meta_str = tool_invoke_meta
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
).first()
|
||||
|
||||
db_variables.updated_at = datetime.utcnow()
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize agent history
|
||||
"""
|
||||
result = []
|
||||
# check if there is a system message in the beginning of the conversation
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
result.append(prompt_messages[0])
|
||||
|
||||
messages: list[Message] = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
).order_by(Message.created_at.asc()).all()
|
||||
|
||||
for message in messages:
|
||||
result.append(UserPromptMessage(content=message.query))
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
tools = agent_thought.tool
|
||||
if tools:
|
||||
tools = tools.split(';')
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
try:
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
except Exception as e:
|
||||
tool_inputs = { tool: {} for tool in tools }
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception as e:
|
||||
tool_responses = { tool: agent_thought.observation for tool in tools }
|
||||
|
||||
for tool in tools:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
tool_calls.append(AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool,
|
||||
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||
)
|
||||
))
|
||||
tool_call_response.append(ToolPromptMessage(
|
||||
content=tool_responses.get(tool, agent_thought.observation),
|
||||
name=tool,
|
||||
tool_call_id=tool_call_id,
|
||||
))
|
||||
|
||||
result.extend([
|
||||
AssistantPromptMessage(
|
||||
content=agent_thought.thought,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
*tool_call_response
|
||||
])
|
||||
if not tools:
|
||||
result.append(AssistantPromptMessage(content=agent_thought.thought))
|
||||
else:
|
||||
if message.answer:
|
||||
result.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
db.session.close()
|
||||
|
||||
return result
|
||||
690
api/core/agent/cot_agent_runner.py
Normal file
690
api/core/agent/cot_agent_runner.py
Normal file
@@ -0,0 +1,690 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Literal, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Conversation, Message
|
||||
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ['wenxin']
|
||||
|
||||
def run(self, conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
app_generate_entity = self.application_generate_entity
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
|
||||
|
||||
# check model mode
|
||||
if 'Observation' not in app_generate_entity.model_config.stop:
|
||||
if app_generate_entity.model_config.provider not in self._ignore_observation_providers:
|
||||
app_generate_entity.model_config.stop.append('Observation')
|
||||
|
||||
app_config = self.app_config
|
||||
|
||||
# override inputs
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
prompt_messages = self.history_prompt_messages
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in app_config.agent.tools if app_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
# api tool may be deleted
|
||||
continue
|
||||
# save tool entity
|
||||
tool_instances[tool.tool_name] = tool_entity
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
|
||||
# convert dataset tools into ModelRuntime Tool format
|
||||
for dataset_tool in self.dataset_tools:
|
||||
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
|
||||
model_instance = self.model_instance
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
|
||||
message_file_ids = []
|
||||
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
if iteration_step > 1:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt messages
|
||||
prompt_messages = self._organize_cot_prompt_messages(
|
||||
mode=app_generate_entity.model_config.mode,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=prompt_messages_tools,
|
||||
agent_scratchpad=agent_scratchpad,
|
||||
agent_prompt_message=app_config.agent.prompt,
|
||||
instruction=instruction,
|
||||
input=query
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_config.parameters,
|
||||
tools=[],
|
||||
stop=app_generate_entity.model_config.stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# check llm result
|
||||
if not chunks:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict = {}
|
||||
react_chunks = self._handle_stream_react(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response='',
|
||||
thought='',
|
||||
action_str='',
|
||||
observation='',
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, dict):
|
||||
scratchpad.agent_response += json.dumps(chunk)
|
||||
try:
|
||||
if scratchpad.action:
|
||||
raise Exception("")
|
||||
scratchpad.action_str = json.dumps(chunk)
|
||||
scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=chunk['action'],
|
||||
action_input=chunk['action_input']
|
||||
)
|
||||
except:
|
||||
scratchpad.thought += json.dumps(chunk)
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=json.dumps(chunk, ensure_ascii=False) # if ensure_ascii=True, the text in webui maybe garbled text
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
)
|
||||
else:
|
||||
scratchpad.agent_response += chunk
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=chunk
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
)
|
||||
|
||||
scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if 'usage' in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict['usage'])
|
||||
else:
|
||||
usage_dict['usage'] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input={
|
||||
scratchpad.action.action_name: scratchpad.action.action_input
|
||||
} if scratchpad.action else '',
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict['usage'])
|
||||
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = scratchpad.agent_response or ''
|
||||
else:
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
final_answer = scratchpad.action.action_input if \
|
||||
isinstance(scratchpad.action.action_input, str) else \
|
||||
json.dumps(scratchpad.action.action_input)
|
||||
except json.JSONDecodeError:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
else:
|
||||
function_call_state = True
|
||||
|
||||
# action is tool call, invoke tool
|
||||
tool_call_name = scratchpad.action.action_name
|
||||
tool_call_args = scratchpad.action.action_input
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
if not tool_instance:
|
||||
answer = f"there is not a tool named {tool_call_name}"
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
tool_invoke_meta=ToolInvokeMeta.error_instance(
|
||||
f"there is not a tool named {tool_call_name}"
|
||||
).to_dict(),
|
||||
thought=None,
|
||||
observation={
|
||||
tool_call_name: answer
|
||||
},
|
||||
answer=answer,
|
||||
messages_ids=[]
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
if isinstance(tool_call_args, str):
|
||||
try:
|
||||
tool_call_args = json.loads(tool_call_args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback
|
||||
)
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name,
|
||||
value=message_file.id,
|
||||
name=save_as)
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
message_file_ids = [message_file.id for message_file, _ in message_files]
|
||||
|
||||
observation = tool_invoke_response
|
||||
|
||||
# save scratchpad
|
||||
scratchpad.observation = observation
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_name,
|
||||
tool_input={
|
||||
tool_call_name: tool_call_args
|
||||
},
|
||||
tool_invoke_meta={
|
||||
tool_call_name: tool_invoke_meta.to_dict()
|
||||
},
|
||||
thought=None,
|
||||
observation={
|
||||
tool_call_name: observation
|
||||
},
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage']
|
||||
),
|
||||
system_fingerprint=''
|
||||
)
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_input={},
|
||||
tool_invoke_meta={},
|
||||
thought=final_answer,
|
||||
observation={},
|
||||
answer=final_answer,
|
||||
messages_ids=[]
|
||||
)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
|
||||
-> Generator[Union[str, dict], None, None]:
|
||||
def parse_json(json_str):
|
||||
try:
|
||||
return json.loads(json_str.strip())
|
||||
except:
|
||||
return json_str
|
||||
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
||||
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return
|
||||
for block in code_blocks:
|
||||
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
|
||||
yield parse_json(json_text)
|
||||
|
||||
code_block_cache = ''
|
||||
code_block_delimiter_count = 0
|
||||
in_code_block = False
|
||||
json_cache = ''
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
got_json = False
|
||||
|
||||
for response in llm_response:
|
||||
response = response.delta.message.content
|
||||
if not isinstance(response, str):
|
||||
continue
|
||||
|
||||
# stream
|
||||
index = 0
|
||||
while index < len(response):
|
||||
steps = 1
|
||||
delta = response[index:index+steps]
|
||||
if delta == '`':
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
else:
|
||||
if not in_code_block:
|
||||
if code_block_delimiter_count > 0:
|
||||
yield code_block_cache
|
||||
code_block_cache = ''
|
||||
else:
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ''
|
||||
|
||||
in_code_block = not in_code_block
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block:
|
||||
# handle single json
|
||||
if delta == '{':
|
||||
json_quote_count += 1
|
||||
in_json = True
|
||||
json_cache += delta
|
||||
elif delta == '}':
|
||||
json_cache += delta
|
||||
if json_quote_count > 0:
|
||||
json_quote_count -= 1
|
||||
if json_quote_count == 0:
|
||||
in_json = False
|
||||
got_json = True
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if in_json:
|
||||
json_cache += delta
|
||||
|
||||
if got_json:
|
||||
got_json = False
|
||||
yield parse_json(json_cache)
|
||||
json_cache = ''
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
yield delta.replace('`', '')
|
||||
|
||||
index += steps
|
||||
|
||||
if code_block_cache:
|
||||
yield code_block_cache
|
||||
|
||||
if json_cache:
|
||||
yield parse_json(json_cache)
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
|
||||
"""
|
||||
fill in inputs from external data tools
|
||||
"""
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
return instruction
|
||||
|
||||
def _init_agent_scratchpad(self,
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
messages: list[PromptMessage]
|
||||
) -> list[AgentScratchpadUnit]:
|
||||
"""
|
||||
init agent scratchpad
|
||||
"""
|
||||
current_scratchpad: AgentScratchpadUnit = None
|
||||
for message in messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or 'I am thinking about how to help you',
|
||||
action_str='',
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
if message.tool_calls:
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments)
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
agent_scratchpad.append(current_scratchpad)
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
current_scratchpad.observation = message.content
|
||||
|
||||
return agent_scratchpad
|
||||
|
||||
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
):
|
||||
"""
|
||||
check chain of thought prompt messages, a standard prompt message is like:
|
||||
Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid action values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
# parse agent prompt message
|
||||
first_prompt = agent_prompt_message.first_prompt
|
||||
next_iteration = agent_prompt_message.next_iteration
|
||||
|
||||
if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
|
||||
raise ValueError("first_prompt or next_iteration is required in CoT agent mode")
|
||||
|
||||
# check instruction, tools, and tool_names slots
|
||||
if not first_prompt.find("{{instruction}}") >= 0:
|
||||
raise ValueError("{{instruction}} is required in first_prompt")
|
||||
if not first_prompt.find("{{tools}}") >= 0:
|
||||
raise ValueError("{{tools}} is required in first_prompt")
|
||||
if not first_prompt.find("{{tool_names}}") >= 0:
|
||||
raise ValueError("{{tool_names}} is required in first_prompt")
|
||||
|
||||
if mode == "completion":
|
||||
if not first_prompt.find("{{query}}") >= 0:
|
||||
raise ValueError("{{query}} is required in first_prompt")
|
||||
if not first_prompt.find("{{agent_scratchpad}}") >= 0:
|
||||
raise ValueError("{{agent_scratchpad}} is required in first_prompt")
|
||||
|
||||
if mode == "completion":
|
||||
if not next_iteration.find("{{observation}}") >= 0:
|
||||
raise ValueError("{{observation}} is required in next_iteration")
|
||||
|
||||
def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
convert agent scratchpad list to str
|
||||
"""
|
||||
next_iteration = self.app_config.agent.prompt.next_iteration
|
||||
|
||||
result = ''
|
||||
for scratchpad in agent_scratchpad:
|
||||
result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \
|
||||
next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available')
|
||||
|
||||
return result
|
||||
|
||||
def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
instruction: str,
|
||||
input: str,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize chain of thought prompt messages, a standard prompt message is like:
|
||||
Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid action values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}}}}
|
||||
```
|
||||
"""
|
||||
|
||||
self._check_cot_prompt_messages(mode, agent_prompt_message)
|
||||
|
||||
# parse agent prompt message
|
||||
first_prompt = agent_prompt_message.first_prompt
|
||||
|
||||
# parse tools
|
||||
tools_str = self._jsonify_tool_prompt_messages(tools)
|
||||
|
||||
# parse tools name
|
||||
tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"'
|
||||
|
||||
# get system message
|
||||
system_message = first_prompt.replace("{{instruction}}", instruction) \
|
||||
.replace("{{tools}}", tools_str) \
|
||||
.replace("{{tool_names}}", tool_names)
|
||||
|
||||
# organize prompt messages
|
||||
if mode == "chat":
|
||||
# override system message
|
||||
overridden = False
|
||||
prompt_messages = prompt_messages.copy()
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
prompt_message.content = system_message
|
||||
overridden = True
|
||||
break
|
||||
|
||||
# convert tool prompt messages to user prompt messages
|
||||
for idx, prompt_message in enumerate(prompt_messages):
|
||||
if isinstance(prompt_message, ToolPromptMessage):
|
||||
prompt_messages[idx] = UserPromptMessage(
|
||||
content=prompt_message.content
|
||||
)
|
||||
|
||||
if not overridden:
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=system_message,
|
||||
))
|
||||
|
||||
# add assistant message
|
||||
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''),
|
||||
))
|
||||
|
||||
# add user message
|
||||
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content=(agent_scratchpad[-1].observation or 'It seems that no response is available'),
|
||||
))
|
||||
|
||||
self._is_first_iteration = False
|
||||
|
||||
return prompt_messages
|
||||
elif mode == "completion":
|
||||
# parse agent scratchpad
|
||||
agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
|
||||
self._is_first_iteration = False
|
||||
# parse prompt messages
|
||||
return [UserPromptMessage(
|
||||
content=first_prompt.replace("{{instruction}}", instruction)
|
||||
.replace("{{tools}}", tools_str)
|
||||
.replace("{{tool_names}}", tool_names)
|
||||
.replace("{{query}}", input)
|
||||
.replace("{{agent_scratchpad}}", agent_scratchpad_str),
|
||||
)]
|
||||
else:
|
||||
raise ValueError(f"mode {mode} is not supported")
|
||||
|
||||
def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str:
|
||||
"""
|
||||
jsonify tool prompt messages
|
||||
"""
|
||||
tools = jsonable_encoder(tools)
|
||||
try:
|
||||
return json.dumps(tools, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
return json.dumps(tools)
|
||||
61
api/core/agent/entities.py
Normal file
61
api/core/agent/entities.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentToolEntity(BaseModel):
|
||||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
provider_type: Literal["builtin", "api"]
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
"""
|
||||
Agent Prompt Entity.
|
||||
"""
|
||||
first_prompt: str
|
||||
next_iteration: str
|
||||
|
||||
|
||||
class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
Agent First Prompt Entity.
|
||||
"""
|
||||
|
||||
class Action(BaseModel):
|
||||
"""
|
||||
Action Entity.
|
||||
"""
|
||||
action_name: str
|
||||
action_input: Union[dict, str]
|
||||
|
||||
agent_response: Optional[str] = None
|
||||
thought: Optional[str] = None
|
||||
action_str: Optional[str] = None
|
||||
observation: Optional[str] = None
|
||||
action: Optional[Action] = None
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Entity.
|
||||
"""
|
||||
|
||||
class Strategy(Enum):
|
||||
"""
|
||||
Agent Strategy.
|
||||
"""
|
||||
CHAIN_OF_THOUGHT = 'chain-of-thought'
|
||||
FUNCTION_CALLING = 'function-calling'
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
strategy: Strategy
|
||||
prompt: Optional[AgentPromptEntity] = None
|
||||
tools: list[AgentToolEntity] = None
|
||||
max_iteration: int = 5
|
||||
414
api/core/agent/fc_agent_runner.py
Normal file
414
api/core/agent/fc_agent_runner.py
Normal file
@@ -0,0 +1,414 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
def run(self, conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
"""
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
|
||||
prompt_template = app_config.prompt_template.simple_prompt_template or ''
|
||||
prompt_messages = self.history_prompt_messages
|
||||
prompt_messages = self.organize_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
query=query,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in app_config.agent.tools if app_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
# api tool may be deleted
|
||||
continue
|
||||
# save tool entity
|
||||
tool_instances[tool.tool_name] = tool_entity
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
|
||||
# convert dataset tools into ModelRuntime Tool format
|
||||
for dataset_tool in self.dataset_tools:
|
||||
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
agent_thoughts: list[MessageAgentThought] = []
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
|
||||
model_instance = self.model_instance
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
|
||||
message_file_ids = []
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_config.parameters,
|
||||
tools=prompt_messages_tools,
|
||||
stop=app_generate_entity.model_config.stop,
|
||||
stream=self.stream_tool_call,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ''
|
||||
|
||||
# save tool call names and inputs
|
||||
tool_call_names = ''
|
||||
tool_call_inputs = ''
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
if self.stream_tool_call:
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
is_first_chunk = False
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
for content in chunk.delta.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += chunk.delta.message.content
|
||||
|
||||
if chunk.delta.usage:
|
||||
increase_usage(llm_usage, chunk.delta.usage)
|
||||
current_llm_usage = chunk.delta.usage
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
result: LLMResult = chunks
|
||||
# check if there is any tool call
|
||||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
|
||||
if result.usage:
|
||||
increase_usage(llm_usage, result.usage)
|
||||
current_llm_usage = result.usage
|
||||
|
||||
if result.message and result.message.content:
|
||||
if isinstance(result.message.content, list):
|
||||
for content in result.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += result.message.content
|
||||
|
||||
if not result.message.content:
|
||||
result.message.content = ''
|
||||
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
system_fingerprint=result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content='',
|
||||
name='',
|
||||
tool_calls=[AssistantPromptMessage.ToolCall(
|
||||
id=tool_call[0],
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call[1],
|
||||
arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
)
|
||||
) for tool_call in tool_calls]
|
||||
))
|
||||
|
||||
# save thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_names,
|
||||
tool_input=tool_call_inputs,
|
||||
thought=response,
|
||||
tool_invoke_meta=None,
|
||||
observation=None,
|
||||
answer=response,
|
||||
messages_ids=[],
|
||||
llm_usage=current_llm_usage
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
final_answer += response + '\n'
|
||||
|
||||
# update prompt messages
|
||||
if response.strip():
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=response,
|
||||
))
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
if not tool_instance:
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": f"there is not a tool named {tool_call_name}",
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
|
||||
}
|
||||
else:
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
)
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": tool_invoke_response,
|
||||
"meta": tool_invoke_meta.to_dict()
|
||||
}
|
||||
|
||||
tool_responses.append(tool_response)
|
||||
prompt_messages = self.organize_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
query=None,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_call_name=tool_call_name,
|
||||
tool_response=tool_response['tool_response'],
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if len(tool_responses) > 0:
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
tool_invoke_meta={
|
||||
tool_response['tool_call_name']: tool_response['meta']
|
||||
for tool_response in tool_responses
|
||||
},
|
||||
observation={
|
||||
tool_response['tool_call_name']: tool_response['tool_response']
|
||||
for tool_response in tool_responses
|
||||
},
|
||||
answer=None,
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
||||
"""
|
||||
Check if there is any tool call in llm result chunk
|
||||
"""
|
||||
if llm_result_chunk.delta.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
|
||||
"""
|
||||
Check if there is any blocking tool call in llm result
|
||||
"""
|
||||
if llm_result.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
||||
"""
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
json.loads(prompt_message.function.arguments),
|
||||
))
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
||||
"""
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result.message.tool_calls:
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
json.loads(prompt_message.function.arguments),
|
||||
))
|
||||
|
||||
return tool_calls
|
||||
|
||||
def organize_prompt_messages(self, prompt_template: str,
|
||||
query: str = None,
|
||||
tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
|
||||
prompt_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
"""
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
else:
|
||||
if tool_response:
|
||||
prompt_messages = prompt_messages.copy()
|
||||
prompt_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
Reference in New Issue
Block a user