feat: remove llm client use (#1316)

This commit is contained in:
takatost
2023-10-12 03:02:53 +08:00
committed by GitHub
parent c007dbdc13
commit cbf095465c
14 changed files with 434 additions and 353 deletions

View File

@@ -4,16 +4,17 @@ from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain import BasePromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.memory.summary import SummarizerMixin
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \
get_buffer_string
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.model_providers.models.llm.base import BaseLLM
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
@@ -52,8 +53,7 @@ Action:
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel = None
model_instance: BaseLLM
summary_model_instance: BaseLLM = None
class Config:
"""Configuration for this pydantic object."""
@@ -95,14 +95,14 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if prompts:
messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception
try:
@@ -118,7 +118,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_llm:
if len(intermediate_steps) >= 2 and self.summary_model_instance:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]
@@ -130,11 +130,10 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
summary_handler = SummarizerMixin(llm=self.summary_llm)
if self.moving_summary_buffer and 'chat_history' in kwargs:
kwargs["chat_history"].pop()
self.moving_summary_buffer = summary_handler.predict_new_summary(
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
@@ -144,6 +143,18 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
@classmethod
def create_prompt(
cls,
@@ -176,7 +187,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
@@ -188,16 +199,27 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
output_parser=output_parser,
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
llm_chain = LLMChain(
model_instance=model_instance,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)