2023-07-27 13:08:57 +08:00
import re
2024-01-12 12:34:01 +08:00
from typing import Any , List , Optional , Sequence , Tuple , Union , cast
2023-07-27 13:08:57 +08:00
2024-01-12 12:34:01 +08:00
from core . agent . agent . agent_llm_callback import AgentLLMCallback
from core . agent . agent . calc_token_mixin import CalcTokenMixin , ExceededLLMTokensLimitError
from core . chain . llm_chain import LLMChain
from core . entities . application_entities import ModelConfigEntity
from core . entities . message_entities import lc_messages_to_prompt_messages
2023-10-17 19:54:59 +08:00
from langchain import BasePromptTemplate , PromptTemplate
2024-01-12 12:34:01 +08:00
from langchain . agents import Agent , AgentOutputParser , StructuredChatAgent
2023-07-27 13:08:57 +08:00
from langchain . agents . structured_chat . base import HUMAN_MESSAGE_TEMPLATE
2024-01-12 12:34:01 +08:00
from langchain . agents . structured_chat . prompt import PREFIX , SUFFIX
2023-07-27 13:08:57 +08:00
from langchain . callbacks . base import BaseCallbackManager
from langchain . callbacks . manager import Callbacks
2023-10-12 03:02:53 +08:00
from langchain . memory . prompt import SUMMARY_PROMPT
2024-01-12 12:34:01 +08:00
from langchain . prompts import ChatPromptTemplate , HumanMessagePromptTemplate , SystemMessagePromptTemplate
from langchain . schema import ( AgentAction , AgentFinish , AIMessage , BaseMessage , HumanMessage , OutputParserException ,
get_buffer_string )
2023-07-27 13:08:57 +08:00
from langchain . tools import BaseTool
FORMAT_INSTRUCTIONS = """ Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of " Thought " , " Action " , " Action Input " , " Final Answer " must be expressed in English .
Valid " action " values : " Final Answer " or { tool_names }
Provide only ONE action per $ JSON_BLOB , as shown :
` ` `
{ { { {
" action " : $ TOOL_NAME ,
" action_input " : $ INPUT
} } } }
` ` `
Follow this format :
Question : input question to answer
Thought : consider previous and subsequent steps
Action :
` ` `
$ JSON_BLOB
` ` `
Observation : action result
. . . ( repeat Thought / Action / Observation N times )
Thought : I know what to respond
Action :
` ` `
{ { { {
" action " : " Final Answer " ,
" action_input " : " Final response to human "
} } } }
` ` ` """
class AutoSummarizingStructuredChatAgent ( StructuredChatAgent , CalcTokenMixin ) :
moving_summary_buffer : str = " "
moving_summary_index : int = 0
2024-01-02 23:42:00 +08:00
summary_model_config : ModelConfigEntity = None
2023-08-12 00:57:00 +08:00
class Config :
""" Configuration for this pydantic object. """
arbitrary_types_allowed = True
2023-07-27 13:08:57 +08:00
def should_use_agent ( self , query : str ) :
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly ,
so it ' s better to just use an Agent for reasoning, which is cheaper.
: param query :
: return :
"""
return True
def plan (
self ,
intermediate_steps : List [ Tuple [ AgentAction , str ] ] ,
callbacks : Callbacks = None ,
* * kwargs : Any ,
) - > Union [ AgentAction , AgentFinish ] :
""" Given input, decided what to do.
Args :
intermediate_steps : Steps the LLM has taken to date ,
2024-01-02 23:42:00 +08:00
along with observatons
2023-07-27 13:08:57 +08:00
callbacks : Callbacks to run .
* * kwargs : User inputs .
Returns :
Action specifying what tool to use .
"""
full_inputs = self . get_full_inputs ( intermediate_steps , * * kwargs )
prompts , _ = self . llm_chain . prep_prompts ( input_list = [ self . llm_chain . prep_inputs ( full_inputs ) ] )
2023-08-16 15:55:42 +08:00
2023-07-27 13:08:57 +08:00
messages = [ ]
if prompts :
messages = prompts [ 0 ] . to_messages ( )
2024-01-02 23:42:00 +08:00
prompt_messages = lc_messages_to_prompt_messages ( messages )
rest_tokens = self . get_message_rest_tokens ( self . llm_chain . model_config , prompt_messages )
2023-07-27 13:08:57 +08:00
if rest_tokens < 0 :
full_inputs = self . summarize_messages ( intermediate_steps , * * kwargs )
2023-08-16 15:55:42 +08:00
try :
full_output = self . llm_chain . predict ( callbacks = callbacks , * * full_inputs )
except Exception as e :
2024-01-02 23:42:00 +08:00
raise e
2023-07-29 17:00:38 +08:00
try :
2023-08-26 17:35:17 +08:00
agent_decision = self . output_parser . parse ( full_output )
if isinstance ( agent_decision , AgentAction ) and agent_decision . tool == ' dataset ' :
tool_inputs = agent_decision . tool_input
if isinstance ( tool_inputs , dict ) and ' query ' in tool_inputs :
tool_inputs [ ' query ' ] = kwargs [ ' input ' ]
agent_decision . tool_input = tool_inputs
return agent_decision
2023-07-29 17:00:38 +08:00
except OutputParserException :
return AgentFinish ( { " output " : " I ' m sorry, the answer of model is invalid, "
" I don ' t know how to respond to that. " } , " " )
2023-07-27 13:08:57 +08:00
def summarize_messages ( self , intermediate_steps : List [ Tuple [ AgentAction , str ] ] , * * kwargs ) :
2024-01-02 23:42:00 +08:00
if len ( intermediate_steps ) > = 2 and self . summary_model_config :
2023-07-27 13:08:57 +08:00
should_summary_intermediate_steps = intermediate_steps [ self . moving_summary_index : - 1 ]
should_summary_messages = [ AIMessage ( content = observation )
for _ , observation in should_summary_intermediate_steps ]
if self . moving_summary_index == 0 :
should_summary_messages . insert ( 0 , HumanMessage ( content = kwargs . get ( " input " ) ) )
self . moving_summary_index = len ( intermediate_steps )
else :
error_msg = " Exceeded LLM tokens limit, stopped. "
raise ExceededLLMTokensLimitError ( error_msg )
if self . moving_summary_buffer and ' chat_history ' in kwargs :
kwargs [ " chat_history " ] . pop ( )
2023-10-12 03:02:53 +08:00
self . moving_summary_buffer = self . predict_new_summary (
2023-07-27 13:08:57 +08:00
messages = should_summary_messages ,
existing_summary = self . moving_summary_buffer
)
if ' chat_history ' in kwargs :
kwargs [ " chat_history " ] . append ( AIMessage ( content = self . moving_summary_buffer ) )
return self . get_full_inputs ( [ intermediate_steps [ - 1 ] ] , * * kwargs )
2023-10-12 03:02:53 +08:00
def predict_new_summary (
self , messages : List [ BaseMessage ] , existing_summary : str
) - > str :
new_lines = get_buffer_string (
messages ,
human_prefix = " Human " ,
ai_prefix = " AI " ,
)
2024-01-02 23:42:00 +08:00
chain = LLMChain ( model_config = self . summary_model_config , prompt = SUMMARY_PROMPT )
2023-10-12 03:02:53 +08:00
return chain . predict ( summary = existing_summary , new_lines = new_lines )
2023-07-27 13:08:57 +08:00
@classmethod
def create_prompt (
cls ,
tools : Sequence [ BaseTool ] ,
prefix : str = PREFIX ,
suffix : str = SUFFIX ,
human_message_template : str = HUMAN_MESSAGE_TEMPLATE ,
format_instructions : str = FORMAT_INSTRUCTIONS ,
input_variables : Optional [ List [ str ] ] = None ,
memory_prompts : Optional [ List [ BasePromptTemplate ] ] = None ,
) - > BasePromptTemplate :
tool_strings = [ ]
for tool in tools :
args_schema = re . sub ( " } " , " }}}} " , re . sub ( " { " , " {{ {{ " , str ( tool . args ) ) )
tool_strings . append ( f " { tool . name } : { tool . description } , args: { args_schema } " )
formatted_tools = " \n " . join ( tool_strings )
tool_names = " , " . join ( [ ( ' " ' + tool . name + ' " ' ) for tool in tools ] )
format_instructions = format_instructions . format ( tool_names = tool_names )
template = " \n \n " . join ( [ prefix , formatted_tools , format_instructions , suffix ] )
if input_variables is None :
input_variables = [ " input " , " agent_scratchpad " ]
_memory_prompts = memory_prompts or [ ]
messages = [
SystemMessagePromptTemplate . from_template ( template ) ,
* _memory_prompts ,
HumanMessagePromptTemplate . from_template ( human_message_template ) ,
]
return ChatPromptTemplate ( input_variables = input_variables , messages = messages )
2023-10-17 19:54:59 +08:00
@classmethod
def create_completion_prompt (
cls ,
tools : Sequence [ BaseTool ] ,
prefix : str = PREFIX ,
format_instructions : str = FORMAT_INSTRUCTIONS ,
input_variables : Optional [ List [ str ] ] = None ,
) - > PromptTemplate :
""" Create prompt in the style of the zero shot agent.
Args :
tools : List of tools the agent will have access to , used to format the
prompt .
prefix : String to put before the list of tools .
input_variables : List of input variables the final prompt will expect .
Returns :
A PromptTemplate with the template assembled from the pieces here .
"""
suffix = """ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question : { input }
Thought : { agent_scratchpad }
"""
tool_strings = " \n " . join ( [ f " { tool . name } : { tool . description } " for tool in tools ] )
tool_names = " , " . join ( [ tool . name for tool in tools ] )
format_instructions = format_instructions . format ( tool_names = tool_names )
template = " \n \n " . join ( [ prefix , tool_strings , format_instructions , suffix ] )
if input_variables is None :
input_variables = [ " input " , " agent_scratchpad " ]
return PromptTemplate ( template = template , input_variables = input_variables )
def _construct_scratchpad (
self , intermediate_steps : List [ Tuple [ AgentAction , str ] ]
) - > str :
agent_scratchpad = " "
for action , observation in intermediate_steps :
agent_scratchpad + = action . log
agent_scratchpad + = f " \n { self . observation_prefix } { observation } \n { self . llm_prefix } "
if not isinstance ( agent_scratchpad , str ) :
raise ValueError ( " agent_scratchpad should be of type string. " )
if agent_scratchpad :
llm_chain = cast ( LLMChain , self . llm_chain )
2024-01-02 23:42:00 +08:00
if llm_chain . model_config . mode == " chat " :
2023-10-17 19:54:59 +08:00
return (
f " This was your previous work "
f " (but I haven ' t seen any of it! I only see what "
f " you return as final answer): \n { agent_scratchpad } "
)
else :
return agent_scratchpad
else :
return agent_scratchpad
2023-07-27 13:08:57 +08:00
@classmethod
def from_llm_and_tools (
cls ,
2024-01-02 23:42:00 +08:00
model_config : ModelConfigEntity ,
2023-07-27 13:08:57 +08:00
tools : Sequence [ BaseTool ] ,
callback_manager : Optional [ BaseCallbackManager ] = None ,
output_parser : Optional [ AgentOutputParser ] = None ,
prefix : str = PREFIX ,
suffix : str = SUFFIX ,
human_message_template : str = HUMAN_MESSAGE_TEMPLATE ,
format_instructions : str = FORMAT_INSTRUCTIONS ,
input_variables : Optional [ List [ str ] ] = None ,
memory_prompts : Optional [ List [ BasePromptTemplate ] ] = None ,
2024-01-02 23:42:00 +08:00
agent_llm_callback : Optional [ AgentLLMCallback ] = None ,
2023-07-27 13:08:57 +08:00
* * kwargs : Any ,
) - > Agent :
2023-10-12 03:02:53 +08:00
""" Construct an agent from an LLM and tools. """
cls . _validate_tools ( tools )
2024-01-02 23:42:00 +08:00
if model_config . mode == " chat " :
2023-10-17 19:54:59 +08:00
prompt = cls . create_prompt (
tools ,
prefix = prefix ,
suffix = suffix ,
human_message_template = human_message_template ,
format_instructions = format_instructions ,
input_variables = input_variables ,
memory_prompts = memory_prompts ,
)
else :
prompt = cls . create_completion_prompt (
tools ,
prefix = prefix ,
format_instructions = format_instructions ,
input_variables = input_variables ,
)
2023-10-12 03:02:53 +08:00
llm_chain = LLMChain (
2024-01-02 23:42:00 +08:00
model_config = model_config ,
2023-10-12 03:02:53 +08:00
prompt = prompt ,
callback_manager = callback_manager ,
2024-01-02 23:42:00 +08:00
agent_llm_callback = agent_llm_callback ,
parameters = {
' temperature ' : 0.2 ,
' top_p ' : 0.3 ,
' max_tokens ' : 1500
}
2023-10-12 03:02:53 +08:00
)
tool_names = [ tool . name for tool in tools ]
_output_parser = output_parser
return cls (
llm_chain = llm_chain ,
allowed_tools = tool_names ,
output_parser = _output_parser ,
2023-07-27 13:08:57 +08:00
* * kwargs ,
)