feat: [backend] vision support (#1510)
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
This commit is contained in:
104
api/core/third_party/langchain/llms/chat_open_ai.py
vendored
104
api/core/third_party/langchain/llms/chat_open_ai.py
vendored
@@ -1,10 +1,13 @@
|
||||
import os
|
||||
|
||||
from typing import Dict, Any, Optional, Union, Tuple
|
||||
from typing import Dict, Any, Optional, Union, Tuple, List, cast
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile
|
||||
|
||||
|
||||
class EnhanceChatOpenAI(ChatOpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
@@ -48,3 +51,102 @@ class EnhanceChatOpenAI(ChatOpenAI):
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
model, encoding = self._get_encoding_model()
|
||||
if model.startswith("gpt-3.5-turbo-0301"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
|
||||
value = text
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, LCHumanMessageWithFiles):
|
||||
content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": message.content
|
||||
}
|
||||
]
|
||||
|
||||
for file in message.files:
|
||||
if file.type == PromptMessageFileType.IMAGE:
|
||||
file = cast(ImagePromptMessageFile, file)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": file.data,
|
||||
"detail": file.detail.value
|
||||
}
|
||||
})
|
||||
|
||||
message_dict = {"role": "user", "content": content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
Reference in New Issue
Block a user