Feat/assistant app (#2086)

Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: Pascal M <11357019+perzeuss@users.noreply.github.com>
This commit is contained in:
Yeuoly
2024-01-23 19:58:23 +08:00
committed by GitHub
parent 7bbe12b2bd
commit 86286e1ac8
175 changed files with 11619 additions and 1235 deletions

View File

@@ -1,30 +1,27 @@
import logging
from typing import List, Optional, cast
from typing import cast, Optional, List
from langchain import WikipediaAPIWrapper
from langchain.callbacks.base import BaseCallbackHandler
from langchain.tools import BaseTool, WikipediaQueryRun, Tool
from pydantic import BaseModel, Field
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor
from core.application_queue_manager import ApplicationQueueManager
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.entities.application_entities import (AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity, InvokeFrom,
ModelConfigEntity)
from core.entities.application_entities import ModelConfigEntity, InvokeFrom, \
AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tool.current_datetime_tool import DatetimeTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIInput, OptimizedSerpAPIWrapper
from core.tool.web_reader_tool import WebReaderTool
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from langchain import WikipediaAPIWrapper
from langchain.callbacks.base import BaseCallbackHandler
from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from models.dataset import Dataset
from models.model import Message
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
@@ -132,55 +129,6 @@ class AgentRunnerFeature:
logger.exception("agent_executor run failed")
return None
def to_tools(self, tool_configs: list[AgentToolEntity],
invoke_from: InvokeFrom,
callbacks: list[BaseCallbackHandler]) \
-> Optional[List[BaseTool]]:
"""
Convert tool configs to tools
:param tool_configs: tool configs
:param invoke_from: invoke from
:param callbacks: callbacks
"""
tools = []
for tool_config in tool_configs:
tool = None
if tool_config.tool_id == "dataset":
tool = self.to_dataset_retriever_tool(
tool_config=tool_config.config,
invoke_from=invoke_from
)
elif tool_config.tool_id == "web_reader":
tool = self.to_web_reader_tool(
tool_config=tool_config.config,
invoke_from=invoke_from
)
elif tool_config.tool_id == "google_search":
tool = self.to_google_search_tool(
tool_config=tool_config.config,
invoke_from=invoke_from
)
elif tool_config.tool_id == "wikipedia":
tool = self.to_wikipedia_tool(
tool_config=tool_config.config,
invoke_from=invoke_from
)
elif tool_config.tool_id == "current_datetime":
tool = self.to_current_datetime_tool(
tool_config=tool_config.config,
invoke_from=invoke_from
)
if tool:
if tool.callbacks is not None:
tool.callbacks.extend(callbacks)
else:
tool.callbacks = callbacks
tools.append(tool)
return tools
def to_dataset_retriever_tool(self, tool_config: dict,
invoke_from: InvokeFrom) \
-> Optional[BaseTool]:
@@ -247,78 +195,4 @@ class AgentRunnerFeature:
retriever_from=invoke_from.to_source()
)
return tool
def to_web_reader_tool(self, tool_config: dict,
invoke_from: InvokeFrom) -> Optional[BaseTool]:
"""
A tool for reading web pages
:param tool_config: tool config
:param invoke_from: invoke from
:return:
"""
model_parameters = {
"temperature": 0,
"max_tokens": 500
}
tool = WebReaderTool(
model_config=self.model_config,
model_parameters=model_parameters,
max_chunk_length=4000,
continue_reading=True
)
return tool
def to_google_search_tool(self, tool_config: dict,
invoke_from: InvokeFrom) -> Optional[BaseTool]:
"""
A tool for performing a Google search and extracting snippets and webpages
:param tool_config: tool config
:param invoke_from: invoke from
:return:
"""
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
func_kwargs = tool_provider.credentials_to_func_kwargs()
if not func_kwargs:
return None
tool = Tool(
name="google_search",
description="A tool for performing a Google search and extracting snippets and webpages "
"when you need to search for something you don't know or when your information "
"is not up to date. "
"Input should be a search query.",
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
args_schema=OptimizedSerpAPIInput
)
return tool
def to_current_datetime_tool(self, tool_config: dict,
invoke_from: InvokeFrom) -> Optional[BaseTool]:
"""
A tool for getting the current date and time
:param tool_config: tool config
:param invoke_from: invoke from
:return:
"""
return DatetimeTool()
def to_wikipedia_tool(self, tool_config: dict,
invoke_from: InvokeFrom) -> Optional[BaseTool]:
"""
A tool for searching Wikipedia
:param tool_config: tool config
:param invoke_from: invoke from
:return:
"""
class WikipediaInput(BaseModel):
query: str = Field(..., description="search query.")
return WikipediaQueryRun(
name="wikipedia",
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
args_schema=WikipediaInput
)
return tool

View File

@@ -0,0 +1,558 @@
import logging
import json
from typing import Optional, List, Tuple, Union
from datetime import datetime
from mimetypes import guess_extension
from core.app_runner.app_runner import AppRunner
from extensions.ext_database import db
from models.model import MessageAgentThought, Message, MessageFile
from models.tools import ToolConversationVariables
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, \
ToolRuntimeVariablePool, ToolParamter
from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager
from core.tools.tool_file_manager import ToolFileManager
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.app_runner.app_runner import AppRunner
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ModelConfigEntity, AgentEntity, AgentToolEntity
from core.application_queue_manager import ApplicationQueueManager
from core.memory.token_buffer_memory import TokenBufferMemory
from core.entities.application_entities import ModelConfigEntity, \
AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.file.message_file_parser import FileTransferMethod
logger = logging.getLogger(__name__)
class BaseAssistantApplicationRunner(AppRunner):
def __init__(self, tenant_id: str,
application_generate_entity: ApplicationGenerateEntity,
app_orchestration_config: AppOrchestrationConfigEntity,
model_config: ModelConfigEntity,
config: AgentEntity,
queue_manager: ApplicationQueueManager,
message: Message,
user_id: str,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[List[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
) -> None:
"""
Agent runner
:param tenant_id: tenant id
:param app_orchestration_config: app orchestration config
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param agent_llm_callback: agent llm callback
:param callback: callback
:param memory: memory
"""
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
self.app_orchestration_config = app_orchestration_config
self.model_config = model_config
self.config = config
self.queue_manager = queue_manager
self.message = message
self.user_id = user_id
self.memory = memory
self.history_prompt_messages = prompt_messages
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
# init callback
self.agent_callback = DifyAgentCallbackHandler()
# init dataset tools
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=queue_manager,
app_id=self.application_generate_entity.app_id,
message_id=message.id,
user_id=user_id,
invoke_from=self.application_generate_entity.invoke_from,
)
self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
tenant_id=tenant_id,
dataset_ids=app_orchestration_config.dataset.dataset_ids if app_orchestration_config.dataset else [],
retrieve_config=app_orchestration_config.dataset.retrieve_config if app_orchestration_config.dataset else None,
return_resource=app_orchestration_config.show_retrieve_source,
invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback
)
# get how many agent thoughts have been created
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
MessageAgentThought.message_id == self.message.id,
).count()
def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
"""
Repacket app orchestration config
"""
if app_orchestration_config.prompt_template.simple_prompt_template is None:
app_orchestration_config.prompt_template.simple_prompt_template = ''
return app_orchestration_config
def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str:
"""
Handle tool response
"""
result = ''
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += response.message
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += f"result link: {response.message}. please dirct user to check it."
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
result += f"image has been created and sent to user already, you should tell user to check it now."
else:
result += f"tool response: {response.message}."
return result
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]:
"""
convert tool to prompt message tool
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
tanent_id=self.application_generate_entity.tenant_id,
agent_callback=self.agent_callback
)
tool_entity.load_variables(self.variables_pool)
message_tool = PromptMessageTool(
name=tool.tool_name,
description=tool_entity.description.llm,
parameters={
"type": "object",
"properties": {},
"required": [],
}
)
runtime_parameters = {}
parameters = tool_entity.parameters or []
user_parameters = tool_entity.get_runtime_parameters() or []
# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
found = False
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
found = True
break
if found:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
else:
# add new parameter
parameters.append(parameter)
for parameter in parameters:
parameter_type = 'string'
enum = []
if parameter.type == ToolParamter.ToolParameterType.STRING:
parameter_type = 'string'
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN:
parameter_type = 'boolean'
elif parameter.type == ToolParamter.ToolParameterType.NUMBER:
parameter_type = 'number'
elif parameter.type == ToolParamter.ToolParameterType.SELECT:
for option in parameter.options:
enum.append(option.value)
parameter_type = 'string'
else:
raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParamter.ToolParameterForm.FORM:
# get tool parameter from form
tool_parameter_config = tool.tool_parameters.get(parameter.name)
if not tool_parameter_config:
# get default value
tool_parameter_config = parameter.default
if not tool_parameter_config and parameter.required:
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
if parameter.type == ToolParamter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter.options))
if tool_parameter_config not in options:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
# convert tool parameter config to correct type
try:
if parameter.type == ToolParamter.ToolParameterType.NUMBER:
# check if tool parameter is integer
if isinstance(tool_parameter_config, int):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, float):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, str):
if '.' in tool_parameter_config:
tool_parameter_config = float(tool_parameter_config)
else:
tool_parameter_config = int(tool_parameter_config)
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN:
tool_parameter_config = bool(tool_parameter_config)
elif parameter.type not in [ToolParamter.ToolParameterType.SELECT, ToolParamter.ToolParameterType.STRING]:
tool_parameter_config = str(tool_parameter_config)
elif parameter.type == ToolParamter.ToolParameterType:
tool_parameter_config = str(tool_parameter_config)
except Exception as e:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
# save tool parameter to tool entity memory
runtime_parameters[parameter.name] = tool_parameter_config
elif parameter.form == ToolParamter.ToolParameterForm.LLM:
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required:
message_tool.parameters['required'].append(parameter.name)
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return message_tool, tool_entity
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
"""
convert dataset retriever tool to prompt message tool
"""
prompt_tool = PromptMessageTool(
name=tool.identity.name,
description=tool.description.llm,
parameters={
"type": "object",
"properties": {},
"required": [],
}
)
for parameter in tool.get_runtime_parameters():
parameter_type = 'string'
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
return prompt_tool
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
"""
update prompt message tool
"""
# try to get tool runtime parameters
tool_runtime_parameters = tool.get_runtime_parameters() or []
for parameter in tool_runtime_parameters:
parameter_type = 'string'
enum = []
if parameter.type == ToolParamter.ToolParameterType.STRING:
parameter_type = 'string'
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN:
parameter_type = 'boolean'
elif parameter.type == ToolParamter.ToolParameterType.NUMBER:
parameter_type = 'number'
elif parameter.type == ToolParamter.ToolParameterType.SELECT:
for option in parameter.options:
enum.append(option.value)
parameter_type = 'string'
else:
raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParamter.ToolParameterForm.LLM:
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
return prompt_tool
def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]:
"""
Extract tool response binary
"""
result = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
result.append(ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream'),
url=response.message,
save_as=response.save_as,
))
elif response.type == ToolInvokeMessage.MessageType.BLOB:
result.append(ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream'),
url=response.message,
save_as=response.save_as,
))
elif response.type == ToolInvokeMessage.MessageType.LINK:
# check if there is a mime type in meta
if response.meta and 'mime_type' in response.meta:
result.append(ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
url=response.message,
save_as=response.save_as,
))
return result
def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]:
"""
Create message file
:param messages: messages
:return: message files, should save as variable
"""
result = []
for message in messages:
file_type = 'bin'
if 'image' in message.mimetype:
file_type = 'image'
elif 'video' in message.mimetype:
file_type = 'video'
elif 'audio' in message.mimetype:
file_type = 'audio'
elif 'text' in message.mimetype:
file_type = 'text'
elif 'pdf' in message.mimetype:
file_type = 'pdf'
elif 'zip' in message.mimetype:
file_type = 'archive'
# ...
invoke_from = self.application_generate_entity.invoke_from
message_file = MessageFile(
message_id=self.message.id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE.value,
belongs_to='assistant',
url=message.url,
upload_file_id=None,
created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
created_by=self.user_id,
)
db.session.add(message_file)
result.append((
message_file,
message.save_as
))
db.session.commit()
return result
def create_agent_thought(self, message_id: str, message: str,
tool_name: str, tool_input: str, messages_ids: List[str]
) -> MessageAgentThought:
"""
Create agent thought
"""
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
thought='',
tool=tool_name,
tool_input=tool_input,
message=message,
message_token=0,
message_unit_price=0,
message_price_unit=0,
message_files=json.dumps(messages_ids) if messages_ids else '',
answer='',
observation='',
answer_token=0,
answer_unit_price=0,
answer_price_unit=0,
tokens=0,
total_price=0,
position=self.agent_thought_count + 1,
currency='USD',
latency=0,
created_by_role='account',
created_by=self.user_id,
)
db.session.add(thought)
db.session.commit()
self.agent_thought_count += 1
return thought
def save_agent_thought(self,
agent_thought: MessageAgentThought,
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: str,
answer: str,
messages_ids: List[str],
llm_usage: LLMUsage = None) -> MessageAgentThought:
"""
Save agent thought
"""
if thought is not None:
agent_thought.thought = thought
if tool_name is not None:
agent_thought.tool = tool_name
if tool_input is not None:
if isinstance(tool_input, dict):
try:
tool_input = json.dumps(tool_input, ensure_ascii=False)
except Exception as e:
tool_input = json.dumps(tool_input)
agent_thought.tool_input = tool_input
if observation is not None:
agent_thought.observation = observation
if answer is not None:
agent_thought.answer = answer
if messages_ids is not None and len(messages_ids) > 0:
agent_thought.message_files = json.dumps(messages_ids)
if llm_usage:
agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit
agent_thought.message_unit_price = llm_usage.prompt_unit_price
agent_thought.answer_token = llm_usage.completion_tokens
agent_thought.answer_price_unit = llm_usage.completion_price_unit
agent_thought.answer_unit_price = llm_usage.completion_unit_price
agent_thought.tokens = llm_usage.total_tokens
agent_thought.total_price = llm_usage.total_price
db.session.commit()
def get_history_prompt_messages(self) -> List[PromptMessage]:
"""
Get history prompt messages
"""
if self.history_prompt_messages is None:
self.history_prompt_messages = db.session.query(PromptMessage).filter(
PromptMessage.message_id == self.message.id,
).order_by(PromptMessage.position.asc()).all()
return self.history_prompt_messages
def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]:
"""
Transform tool message into agent thought
"""
result = []
for message in messages:
if message.type == ToolInvokeMessage.MessageType.TEXT:
result.append(message)
elif message.type == ToolInvokeMessage.MessageType.LINK:
result.append(message)
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
# try to download image
try:
file = ToolFileManager.create_file_by_url(user_id=self.user_id, tenant_id=self.tenant_id,
conversation_id=self.message.conversation_id,
file_url=message.message)
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
except Exception as e:
logger.exception(e)
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
meta=message.meta.copy() if message.meta is not None else {},
save_as=message.save_as,
))
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage
mimetype = message.meta.get('mime_type', 'octet/stream')
# if message is str, encode it to bytes
if isinstance(message.message, str):
message.message = message.message.encode('utf-8')
file = ToolFileManager.create_file_by_raw(user_id=self.user_id, tenant_id=self.tenant_id,
conversation_id=self.message.conversation_id,
file_binary=message.message,
mimetype=mimetype)
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
# check if file is image
if 'image' in mimetype:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
else:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
else:
result.append(message)
return result
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
"""
convert tool variables to db variables
"""
db_variables.updated_at = datetime.utcnow()
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()

View File

@@ -0,0 +1,578 @@
import json
import logging
import re
from typing import Literal, Union, Generator, Dict, List
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
from core.application_queue_manager import PublishFrom
from core.model_runtime.utils.encoders import jsonable_encoder
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, \
UserPromptMessage, SystemPromptMessage, AssistantPromptMessage
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
from core.model_manager import ModelInstance
from core.tools.errors import ToolInvokeError, ToolNotFoundError, \
ToolNotSupportedError, ToolProviderNotFoundError, ToolParamterValidationError, \
ToolProviderCredentialValidationError
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
from models.model import Conversation, Message
logger = logging.getLogger(__name__)
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
def run(self, model_instance: ModelInstance,
conversation: Conversation,
message: Message,
query: str,
) -> Union[Generator, LLMResult]:
"""
Run Cot agent application
"""
app_orchestration_config = self.app_orchestration_config
self._repacket_app_orchestration_config(app_orchestration_config)
agent_scratchpad: List[AgentScratchpadUnit] = []
# check model mode
if self.app_orchestration_config.model_config.mode == "completion":
# TODO: stop words
if 'Observation' not in app_orchestration_config.model_config.stop:
app_orchestration_config.model_config.stop.append('Observation')
iteration_step = 1
max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1
prompt_messages = self.history_prompt_messages
# convert tools into ModelRuntime Tool format
prompt_messages_tools: List[PromptMessageTool] = []
tool_instances = {}
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception:
# api tool may be deleted
continue
# save tool entity
tool_instances[tool.tool_name] = tool_entity
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# convert dataset tools into ModelRuntime Tool format
for dataset_tool in self.dataset_tools:
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# save tool entity
tool_instances[dataset_tool.identity.name] = dataset_tool
function_call_state = True
llm_usage = {
'usage': None
}
final_answer = ''
def increse_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
else:
llm_usage = final_llm_usage_dict['usage']
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
while function_call_state and iteration_step <= max_iteration_steps:
# continue to run until there is not any tool call
function_call_state = False
if iteration_step == max_iteration_steps:
# the last iteration, remove all tools
prompt_messages_tools = []
message_file_ids = []
agent_thought = self.create_agent_thought(
message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# update prompt messages
prompt_messages = self._originze_cot_prompt_messages(
mode=app_orchestration_config.model_config.mode,
prompt_messages=prompt_messages,
tools=prompt_messages_tools,
agent_scratchpad=agent_scratchpad,
agent_prompt_message=app_orchestration_config.agent.prompt,
instruction=app_orchestration_config.prompt_template.simple_prompt_template,
input=query
)
# recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
llm_result: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
tools=[],
stop=app_orchestration_config.model_config.stop,
stream=False,
user=self.user_id,
callbacks=[],
)
# check llm result
if not llm_result:
raise ValueError("failed to invoke llm")
# get scratchpad
scratchpad = self._extract_response_scratchpad(llm_result.message.content)
agent_scratchpad.append(scratchpad)
# get llm usage
if llm_result.usage:
increse_usage(llm_usage, llm_result.usage)
self.save_agent_thought(agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input=scratchpad.action.action_input if scratchpad.action else '',
thought=scratchpad.thought,
observation='',
answer=llm_result.message.content,
messages_ids=[],
llm_usage=llm_result.usage)
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# publish agent thought if it's not empty and there is a action
if scratchpad.thought and scratchpad.action:
# check if final answer
if not scratchpad.action.action_name.lower() == "final answer":
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=scratchpad.thought
),
usage=llm_result.usage,
),
system_fingerprint=''
)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = scratchpad.agent_response or ''
else:
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
final_answer = scratchpad.action.action_input if \
isinstance(scratchpad.action.action_input, str) else \
json.dumps(scratchpad.action.action_input)
except json.JSONDecodeError:
final_answer = f'{scratchpad.action.action_input}'
else:
function_call_state = True
# action is tool call, invoke tool
tool_call_name = scratchpad.action.action_name
tool_call_args = scratchpad.action.action_input
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
logger.error(f"failed to find tool instance: {tool_call_name}")
answer = f"there is not a tool named {tool_call_name}"
self.save_agent_thought(agent_thought=agent_thought,
tool_name='',
tool_input='',
thought=None,
observation=answer,
answer=answer,
messages_ids=[])
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
else:
# invoke tool
error_response = None
try:
tool_response = tool_instance.invoke(
user_id=self.user_id,
tool_paramters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
)
# transform tool response to llm friendly response
tool_response = self.transform_tool_invoke_messages(tool_response)
# extract binary data from tool invoke message
binary_files = self.extract_tool_response_binary(tool_response)
# create message file
message_files = self.create_message_files(binary_files)
# publish files
for message_file, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name,
value=message_file.id,
name=save_as)
self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
message_file_ids = [message_file.id for message_file, _ in message_files]
except ToolProviderCredentialValidationError as e:
error_response = f"Plese check your tool provider credentials"
except (
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
) as e:
error_response = f"there is not a tool named {tool_call_name}"
except (
ToolParamterValidationError
) as e:
error_response = f"tool paramters validation error: {e}, please check your tool paramters"
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
except Exception as e:
error_response = f"unknown error: {e}"
if error_response:
observation = error_response
logger.error(error_response)
else:
observation = self._convert_tool_response_to_str(tool_response)
# save scratchpad
scratchpad.observation = observation
scratchpad.agent_response = llm_result.message.content
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=tool_call_name,
tool_input=tool_call_args,
thought=None,
observation=observation,
answer=llm_result.message.content,
messages_ids=message_file_ids,
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# update prompt tool message
for prompt_tool in prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=final_answer
),
usage=llm_usage['usage']
),
system_fingerprint=''
)
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name='',
tool_input='',
thought=final_answer,
observation='',
answer=final_answer,
messages_ids=[]
)
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish_message_end(LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=final_answer
),
usage=llm_usage['usage'],
system_fingerprint=''
), PublishFrom.APPLICATION_MANAGER)
def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
"""
extract response from llm response
"""
def extra_quotes() -> AgentScratchpadUnit:
agent_response = content
# try to extract all quotes
pattern = re.compile(r'```(.*?)```', re.DOTALL)
quotes = pattern.findall(content)
# try to extract action from end to start
for i in range(len(quotes) - 1, 0, -1):
"""
1. use json load to parse action
2. use plain text `Action: xxx` to parse action
"""
try:
action = json.loads(quotes[i].replace('```', ''))
action_name = action.get("action")
action_input = action.get("action_input")
agent_thought = agent_response.replace(quotes[i], '')
if action_name and action_input:
return AgentScratchpadUnit(
agent_response=content,
thought=agent_thought,
action_str=quotes[i],
action=AgentScratchpadUnit.Action(
action_name=action_name,
action_input=action_input,
)
)
except:
# try to parse action from plain text
action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
# delete action from agent response
agent_thought = agent_response.replace(quotes[i], '')
# remove extra quotes
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
# remove Action: xxx from agent thought
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
if action_name and action_input:
return AgentScratchpadUnit(
agent_response=content,
thought=agent_thought,
action_str=quotes[i],
action=AgentScratchpadUnit.Action(
action_name=action_name[0],
action_input=action_input[0],
)
)
def extra_json():
agent_response = content
# try to extract all json
structures, pair_match_stack = [], []
started_at, end_at = 0, 0
for i in range(len(content)):
if content[i] == '{':
pair_match_stack.append(i)
if len(pair_match_stack) == 1:
started_at = i
elif content[i] == '}':
begin = pair_match_stack.pop()
if not pair_match_stack:
end_at = i + 1
structures.append((content[begin:i+1], (started_at, end_at)))
# handle the last character
if pair_match_stack:
end_at = len(content)
structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
for i in range(len(structures), 0, -1):
try:
json_content, (started_at, end_at) = structures[i - 1]
action = json.loads(json_content)
action_name = action.get("action")
action_input = action.get("action_input")
# delete json content from agent response
agent_thought = agent_response[:started_at] + agent_response[end_at:]
# remove extra quotes like ```(json)*\n\n```
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
# remove Action: xxx from agent thought
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
if action_name and action_input:
return AgentScratchpadUnit(
agent_response=content,
thought=agent_thought,
action_str=json_content,
action=AgentScratchpadUnit.Action(
action_name=action_name,
action_input=action_input,
)
)
except:
pass
agent_scratchpad = extra_quotes()
if agent_scratchpad:
return agent_scratchpad
agent_scratchpad = extra_json()
if agent_scratchpad:
return agent_scratchpad
return AgentScratchpadUnit(
agent_response=content,
thought=content,
action_str='',
action=None
)
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
agent_prompt_message: AgentPromptEntity,
):
"""
check chain of thought prompt messages, a standard prompt message is like:
Respond to the human as helpfully and accurately as possible.
{{instruction}}
You have access to the following tools:
{{tools}}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid action values: "Final Answer" or {{tool_names}}
Provide only ONE action per $JSON_BLOB, as shown:
```
{
"action": $TOOL_NAME,
"action_input": $ACTION_INPUT
}
```
"""
# parse agent prompt message
first_prompt = agent_prompt_message.first_prompt
next_iteration = agent_prompt_message.next_iteration
if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
raise ValueError(f"first_prompt or next_iteration is required in CoT agent mode")
# check instruction, tools, and tool_names slots
if not first_prompt.find("{{instruction}}") >= 0:
raise ValueError("{{instruction}} is required in first_prompt")
if not first_prompt.find("{{tools}}") >= 0:
raise ValueError("{{tools}} is required in first_prompt")
if not first_prompt.find("{{tool_names}}") >= 0:
raise ValueError("{{tool_names}} is required in first_prompt")
if mode == "completion":
if not first_prompt.find("{{query}}") >= 0:
raise ValueError("{{query}} is required in first_prompt")
if not first_prompt.find("{{agent_scratchpad}}") >= 0:
raise ValueError("{{agent_scratchpad}} is required in first_prompt")
if mode == "completion":
if not next_iteration.find("{{observation}}") >= 0:
raise ValueError("{{observation}} is required in next_iteration")
def _convert_strachpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
"""
convert agent scratchpad list to str
"""
next_iteration = self.app_orchestration_config.agent.prompt.next_iteration
result = ''
for scratchpad in agent_scratchpad:
result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation) + "\n"
return result
def _originze_cot_prompt_messages(self, mode: Literal["completion", "chat"],
prompt_messages: List[PromptMessage],
tools: List[PromptMessageTool],
agent_scratchpad: List[AgentScratchpadUnit],
agent_prompt_message: AgentPromptEntity,
instruction: str,
input: str,
) -> List[PromptMessage]:
"""
originze chain of thought prompt messages, a standard prompt message is like:
Respond to the human as helpfully and accurately as possible.
{{instruction}}
You have access to the following tools:
{{tools}}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid action values: "Final Answer" or {{tool_names}}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $ACTION_INPUT
}}}}
```
"""
self._check_cot_prompt_messages(mode, agent_prompt_message)
# parse agent prompt message
first_prompt = agent_prompt_message.first_prompt
# parse tools
tools_str = self._jsonify_tool_prompt_messages(tools)
# parse tools name
tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"'
# get system message
system_message = first_prompt.replace("{{instruction}}", instruction) \
.replace("{{tools}}", tools_str) \
.replace("{{tool_names}}", tool_names)
# originze prompt messages
if mode == "chat":
# override system message
overrided = False
prompt_messages = prompt_messages.copy()
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):
prompt_message.content = system_message
overrided = True
break
if not overrided:
prompt_messages.insert(0, SystemPromptMessage(
content=system_message,
))
# add assistant message
if len(agent_scratchpad) > 0:
prompt_messages.append(AssistantPromptMessage(
content=agent_scratchpad[-1].thought + "\n" + agent_scratchpad[-1].observation
))
# add user message
if len(agent_scratchpad) > 0:
prompt_messages.append(UserPromptMessage(
content=input,
))
return prompt_messages
elif mode == "completion":
# parse agent scratchpad
agent_scratchpad_str = self._convert_strachpad_list_to_str(agent_scratchpad)
# parse prompt messages
return [UserPromptMessage(
content=first_prompt.replace("{{instruction}}", instruction)
.replace("{{tools}}", tools_str)
.replace("{{tool_names}}", tool_names)
.replace("{{query}}", input)
.replace("{{agent_scratchpad}}", agent_scratchpad_str),
)]
else:
raise ValueError(f"mode {mode} is not supported")
def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str:
"""
jsonify tool prompt messages
"""
tools = jsonable_encoder(tools)
try:
return json.dumps(tools, ensure_ascii=False)
except json.JSONDecodeError:
return json.dumps(tools)

View File

@@ -0,0 +1,335 @@
import json
import logging
from typing import Union, Generator, Dict, Any, Tuple, List
from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\
SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage
from core.model_manager import ModelInstance
from core.application_queue_manager import PublishFrom
from core.tools.errors import ToolInvokeError, ToolNotFoundError, \
ToolNotSupportedError, ToolProviderNotFoundError, ToolParamterValidationError, \
ToolProviderCredentialValidationError
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
from models.model import Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
def run(self, model_instance: ModelInstance,
conversation: Conversation,
message: Message,
query: str,
) -> Generator[LLMResultChunk, None, None]:
"""
Run FunctionCall agent application
"""
app_orchestration_config = self.app_orchestration_config
prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or ''
prompt_messages = self.history_prompt_messages
prompt_messages = self.organize_prompt_messages(
prompt_template=prompt_template,
query=query,
prompt_messages=prompt_messages
)
# convert tools into ModelRuntime Tool format
prompt_messages_tools: List[PromptMessageTool] = []
tool_instances = {}
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception:
# api tool may be deleted
continue
# save tool entity
tool_instances[tool.tool_name] = tool_entity
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# convert dataset tools into ModelRuntime Tool format
for dataset_tool in self.dataset_tools:
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# save tool entity
tool_instances[dataset_tool.identity.name] = dataset_tool
iteration_step = 1
max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1
# continue to run until there is not any tool call
function_call_state = True
agent_thoughts: List[MessageAgentThought] = []
llm_usage = {
'usage': None
}
final_answer = ''
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
else:
llm_usage = final_llm_usage_dict['usage']
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
while function_call_state and iteration_step <= max_iteration_steps:
function_call_state = False
if iteration_step == max_iteration_steps:
# the last iteration, remove all tools
prompt_messages_tools = []
message_file_ids = []
agent_thought = self.create_agent_thought(
message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
tools=prompt_messages_tools,
stop=app_orchestration_config.model_config.stop,
stream=True,
user=self.user_id,
callbacks=[],
)
tool_calls: List[Tuple[str, str, Dict[str, Any]]] = []
# save full response
response = ''
# save tool call names and inputs
tool_call_names = ''
tool_call_inputs = ''
current_llm_usage = None
for chunk in chunks:
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content:
response += content.data
else:
response += chunk.delta.message.content
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
yield chunk
# save thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=tool_call_names,
tool_input=tool_call_inputs,
thought=response,
observation=None,
answer=response,
messages_ids=[],
llm_usage=current_llm_usage
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
final_answer += response + '\n'
# call tools
tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
logger.error(f"failed to find tool instance: {tool_call_name}")
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}"
}
tool_responses.append(tool_response)
else:
# invoke tool
error_response = None
try:
tool_invoke_message = tool_instance.invoke(
user_id=self.user_id,
tool_paramters=tool_call_args,
)
# transform tool invoke message to get LLM friendly message
tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message)
# extract binary data from tool invoke message
binary_files = self.extract_tool_response_binary(tool_invoke_message)
# create message file
message_files = self.create_message_files(binary_files)
# publish files
for message_file, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
# publish message file
self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
# add message file ids
message_file_ids.append(message_file.id)
except ToolProviderCredentialValidationError as e:
error_response = f"Plese check your tool provider credentials"
except (
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
) as e:
error_response = f"there is not a tool named {tool_call_name}"
except (
ToolParamterValidationError
) as e:
error_response = f"tool paramters validation error: {e}, please check your tool paramters"
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
except Exception as e:
error_response = f"unknown error: {e}"
if error_response:
observation = error_response
logger.error(error_response)
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": error_response
}
tool_responses.append(tool_response)
else:
observation = self._convert_tool_response_to_str(tool_invoke_message)
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": observation
}
tool_responses.append(tool_response)
prompt_messages = self.organize_prompt_messages(
prompt_template=prompt_template,
query=None,
tool_call_id=tool_call_id,
tool_call_name=tool_call_name,
tool_response=tool_response['tool_response'],
prompt_messages=prompt_messages,
)
if len(tool_responses) > 0:
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=None,
tool_input=None,
thought=None,
observation=tool_response['tool_response'],
answer=None,
messages_ids=message_file_ids
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# update prompt messages
if response.strip():
prompt_messages.append(AssistantPromptMessage(
content=response,
))
# update prompt tool
for prompt_tool in prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish_message_end(LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=final_answer,
),
usage=llm_usage['usage'],
system_fingerprint=''
), PublishFrom.APPLICATION_MANAGER)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
"""
Check if there is any tool call in llm result chunk
"""
if llm_result_chunk.delta.message.tool_calls:
return True
return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
"""
Extract tool calls from llm result chunk
Returns:
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
"""
tool_calls = []
for prompt_message in llm_result_chunk.delta.message.tool_calls:
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
json.loads(prompt_message.function.arguments),
))
return tool_calls
def organize_prompt_messages(self, prompt_template: str,
query: str = None,
tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
prompt_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
"""
Organize prompt messages
"""
if not prompt_messages:
prompt_messages = [
SystemPromptMessage(content=prompt_template),
UserPromptMessage(content=query),
]
else:
if tool_response:
prompt_messages = prompt_messages.copy()
prompt_messages.append(
ToolPromptMessage(
content=tool_response,
tool_call_id=tool_call_id,
name=tool_call_name,
)
)
return prompt_messages

View File

@@ -6,8 +6,8 @@ from core.entities.application_entities import DatasetEntity, DatasetRetrieveCon
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from langchain.tools import BaseTool
from models.dataset import Dataset