feat: advanced prompt backend (#1301)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Garfield Dai
2023-10-12 23:13:10 +08:00
committed by GitHub
parent 2d1cb076c6
commit 42a5b3ec17
61 changed files with 767 additions and 581 deletions

View File

@@ -0,0 +1,56 @@
import copy
from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
class AdvancedPromptTemplateService:
def get_prompt(self, args: dict) -> dict:
app_mode = args['app_mode']
model_mode = args['model_mode']
model_name = args['model_name']
has_context = args['has_context']
if 'baichuan' in model_name:
return self.get_baichuan_prompt(app_mode, model_mode, has_context)
else:
return self.get_common_prompt(app_mode, model_mode, has_context)
def get_common_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict:
if app_mode == 'chat':
if model_mode == 'completion':
return self.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT)
elif model_mode == 'chat':
return self.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT)
elif app_mode == 'completion':
if model_mode == 'completion':
return self.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT)
elif model_mode == 'chat':
return self.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT)
def get_completion_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict:
if has_context == 'true':
prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text']
return prompt_template
def get_chat_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict:
if has_context == 'true':
prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text']
return prompt_template
def get_baichuan_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict:
if app_mode == 'chat':
if model_mode == 'completion':
return self.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
elif model_mode == 'chat':
return self.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
elif app_mode == 'completion':
if model_mode == 'completion':
return self.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
elif model_mode == 'chat':
return self.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)

View File

@@ -3,7 +3,7 @@ import uuid
from core.agent.agent_executor import PlanningStrategy
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.models.entity.model_params import ModelType, ModelMode
from models.account import Account
from services.dataset_service import DatasetService
@@ -34,40 +34,28 @@ class AppModelConfigService:
# max_tokens
if 'max_tokens' not in cp:
cp["max_tokens"] = 512
#
# if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
# llm_constant.max_context_token_length[model_name]:
# raise ValueError(
# "max_tokens must be an integer greater than 0 "
# "and not exceeding the maximum value of the corresponding model")
#
# temperature
if 'temperature' not in cp:
cp["temperature"] = 1
#
# if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
# raise ValueError("temperature must be a float between 0 and 2")
#
# top_p
if 'top_p' not in cp:
cp["top_p"] = 1
# if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
# raise ValueError("top_p must be a float between 0 and 2")
#
# presence_penalty
if 'presence_penalty' not in cp:
cp["presence_penalty"] = 0
# if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
# raise ValueError("presence_penalty must be a float between -2 and 2")
#
# presence_penalty
if 'frequency_penalty' not in cp:
cp["frequency_penalty"] = 0
# if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
# raise ValueError("frequency_penalty must be a float between -2 and 2")
# stop
if 'stop' not in cp:
cp["stop"] = []
elif not isinstance(cp["stop"], list):
raise ValueError("stop in model.completion_params must be of list type")
# Filter out extra parameters
filtered_cp = {
@@ -75,7 +63,8 @@ class AppModelConfigService:
"temperature": cp["temperature"],
"top_p": cp["top_p"],
"presence_penalty": cp["presence_penalty"],
"frequency_penalty": cp["frequency_penalty"]
"frequency_penalty": cp["frequency_penalty"],
"stop": cp["stop"]
}
return filtered_cp
@@ -211,6 +200,10 @@ class AppModelConfigService:
model_ids = [m['id'] for m in model_list]
if config["model"]["name"] not in model_ids:
raise ValueError("model.name must be in the specified model list")
# model.mode
if 'mode' not in config['model'] or not config['model']["mode"]:
config['model']["mode"] = ""
# model.completion_params
if 'completion_params' not in config["model"]:
@@ -339,6 +332,9 @@ class AppModelConfigService:
# dataset_query_variable
AppModelConfigService.is_dataset_query_variable_valid(config, mode)
# advanced prompt validation
AppModelConfigService.is_advanced_prompt_valid(config, mode)
# Filter out extra parameters
filtered_config = {
"opening_statement": config["opening_statement"],
@@ -351,12 +347,17 @@ class AppModelConfigService:
"model": {
"provider": config["model"]["provider"],
"name": config["model"]["name"],
"mode": config['model']["mode"],
"completion_params": config["model"]["completion_params"]
},
"user_input_form": config["user_input_form"],
"dataset_query_variable": config.get('dataset_query_variable'),
"pre_prompt": config["pre_prompt"],
"agent_mode": config["agent_mode"]
"agent_mode": config["agent_mode"],
"prompt_type": config["prompt_type"],
"chat_prompt_config": config["chat_prompt_config"],
"completion_prompt_config": config["completion_prompt_config"],
"dataset_configs": config["dataset_configs"]
}
return filtered_config
@@ -375,4 +376,51 @@ class AppModelConfigService:
if dataset_exists and not dataset_query_variable:
raise ValueError("Dataset query variable is required when dataset is exist")
@staticmethod
def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:
# prompt_type
if 'prompt_type' not in config or not config["prompt_type"]:
config["prompt_type"] = "simple"
if config['prompt_type'] not in ['simple', 'advanced']:
raise ValueError("prompt_type must be in ['simple', 'advanced']")
# chat_prompt_config
if 'chat_prompt_config' not in config or not config["chat_prompt_config"]:
config["chat_prompt_config"] = {}
if not isinstance(config["chat_prompt_config"], dict):
raise ValueError("chat_prompt_config must be of object type")
# completion_prompt_config
if 'completion_prompt_config' not in config or not config["completion_prompt_config"]:
config["completion_prompt_config"] = {}
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
# dataset_configs
if 'dataset_configs' not in config or not config["dataset_configs"]:
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
if config['prompt_type'] == 'advanced':
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced")
if config['model']["mode"] not in ['chat', 'completion']:
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
if not user_prefix:
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'

View File

@@ -244,7 +244,8 @@ class CompletionService:
@classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
message_id: str, streaming: bool = True) -> Union[dict | Generator]:
message_id: str, streaming: bool = True,
retriever_from: str = 'dev') -> Union[dict | Generator]:
if not user:
raise ValueError('user cannot be None')
@@ -266,14 +267,11 @@ class CompletionService:
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
if message.override_model_configs:
override_model_configs = json.loads(message.override_model_configs)
pre_prompt = override_model_configs.get("pre_prompt", '')
elif app_model_config:
pre_prompt = app_model_config.pre_prompt
else:
raise AppModelConfigBrokenError()
model_dict = app_model_config.model_dict
completion_params = model_dict.get('completion_params')
completion_params['temperature'] = 0.9
model_dict['completion_params'] = completion_params
app_model_config.model = json.dumps(model_dict)
generate_task_id = str(uuid.uuid4())
@@ -282,58 +280,28 @@ class CompletionService:
user = cls.get_real_user_instead_of_proxy_obj(user)
generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'generate_task_id': generate_task_id,
'detached_app_model': app_model,
'app_model_config': app_model_config,
'detached_message': message,
'pre_prompt': pre_prompt,
'query': message.query,
'inputs': message.inputs,
'detached_user': user,
'streaming': streaming
'detached_conversation': None,
'streaming': streaming,
'is_model_config_override': True,
'retriever_from': retriever_from
})
generate_worker_thread.start()
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
# wait for 10 minutes to close the thread
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
generate_task_id)
return cls.compact_response(pubsub, streaming)
@classmethod
def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
app_model_config: AppModelConfig, detached_message: Message, pre_prompt: str,
detached_user: Union[Account, EndUser], streaming: bool):
with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
app_model = db.session.merge(detached_app_model)
message = db.session.merge(detached_message)
try:
# run
Completion.generate_more_like_this(
task_id=generate_task_id,
app=app_model,
user=user,
message=message,
pre_prompt=pre_prompt,
app_model_config=app_model_config,
streaming=streaming
)
except ConversationTaskStoppedException:
pass
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
ModelCurrentlyNotSupportError) as e:
PubHandler.pub_error(user, generate_task_id, e)
except LLMAuthorizationError:
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
except Exception as e:
logging.exception("Unknown Error in completion")
PubHandler.pub_error(user, generate_task_id, e)
finally:
db.session.commit()
@classmethod
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
if user_inputs is None:

View File

@@ -482,6 +482,9 @@ class ProviderService:
'features': []
}
if 'mode' in model:
valid_model_dict['model_mode'] = model['mode']
if 'features' in model:
valid_model_dict['features'] = model['features']