feat: advanced prompt backend (#1301)
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
56
api/services/advanced_prompt_template_service.py
Normal file
56
api/services/advanced_prompt_template_service.py
Normal file
@@ -0,0 +1,56 @@
|
||||
|
||||
import copy
|
||||
|
||||
from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
|
||||
|
||||
class AdvancedPromptTemplateService:
|
||||
|
||||
def get_prompt(self, args: dict) -> dict:
|
||||
app_mode = args['app_mode']
|
||||
model_mode = args['model_mode']
|
||||
model_name = args['model_name']
|
||||
has_context = args['has_context']
|
||||
|
||||
if 'baichuan' in model_name:
|
||||
return self.get_baichuan_prompt(app_mode, model_mode, has_context)
|
||||
else:
|
||||
return self.get_common_prompt(app_mode, model_mode, has_context)
|
||||
|
||||
def get_common_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict:
|
||||
if app_mode == 'chat':
|
||||
if model_mode == 'completion':
|
||||
return self.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT)
|
||||
elif model_mode == 'chat':
|
||||
return self.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT)
|
||||
elif app_mode == 'completion':
|
||||
if model_mode == 'completion':
|
||||
return self.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT)
|
||||
elif model_mode == 'chat':
|
||||
return self.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT)
|
||||
|
||||
def get_completion_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict:
|
||||
if has_context == 'true':
|
||||
prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text']
|
||||
|
||||
return prompt_template
|
||||
|
||||
|
||||
def get_chat_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict:
|
||||
if has_context == 'true':
|
||||
prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text']
|
||||
|
||||
return prompt_template
|
||||
|
||||
|
||||
def get_baichuan_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict:
|
||||
if app_mode == 'chat':
|
||||
if model_mode == 'completion':
|
||||
return self.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
|
||||
elif model_mode == 'chat':
|
||||
return self.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
|
||||
elif app_mode == 'completion':
|
||||
if model_mode == 'completion':
|
||||
return self.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
|
||||
elif model_mode == 'chat':
|
||||
return self.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
|
||||
from core.agent.agent_executor import PlanningStrategy
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelMode
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
@@ -34,40 +34,28 @@ class AppModelConfigService:
|
||||
# max_tokens
|
||||
if 'max_tokens' not in cp:
|
||||
cp["max_tokens"] = 512
|
||||
#
|
||||
# if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
|
||||
# llm_constant.max_context_token_length[model_name]:
|
||||
# raise ValueError(
|
||||
# "max_tokens must be an integer greater than 0 "
|
||||
# "and not exceeding the maximum value of the corresponding model")
|
||||
#
|
||||
|
||||
# temperature
|
||||
if 'temperature' not in cp:
|
||||
cp["temperature"] = 1
|
||||
#
|
||||
# if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
|
||||
# raise ValueError("temperature must be a float between 0 and 2")
|
||||
#
|
||||
|
||||
# top_p
|
||||
if 'top_p' not in cp:
|
||||
cp["top_p"] = 1
|
||||
|
||||
# if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
|
||||
# raise ValueError("top_p must be a float between 0 and 2")
|
||||
#
|
||||
# presence_penalty
|
||||
if 'presence_penalty' not in cp:
|
||||
cp["presence_penalty"] = 0
|
||||
|
||||
# if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
|
||||
# raise ValueError("presence_penalty must be a float between -2 and 2")
|
||||
#
|
||||
# presence_penalty
|
||||
if 'frequency_penalty' not in cp:
|
||||
cp["frequency_penalty"] = 0
|
||||
|
||||
# if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
|
||||
# raise ValueError("frequency_penalty must be a float between -2 and 2")
|
||||
# stop
|
||||
if 'stop' not in cp:
|
||||
cp["stop"] = []
|
||||
elif not isinstance(cp["stop"], list):
|
||||
raise ValueError("stop in model.completion_params must be of list type")
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_cp = {
|
||||
@@ -75,7 +63,8 @@ class AppModelConfigService:
|
||||
"temperature": cp["temperature"],
|
||||
"top_p": cp["top_p"],
|
||||
"presence_penalty": cp["presence_penalty"],
|
||||
"frequency_penalty": cp["frequency_penalty"]
|
||||
"frequency_penalty": cp["frequency_penalty"],
|
||||
"stop": cp["stop"]
|
||||
}
|
||||
|
||||
return filtered_cp
|
||||
@@ -211,6 +200,10 @@ class AppModelConfigService:
|
||||
model_ids = [m['id'] for m in model_list]
|
||||
if config["model"]["name"] not in model_ids:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
# model.mode
|
||||
if 'mode' not in config['model'] or not config['model']["mode"]:
|
||||
config['model']["mode"] = ""
|
||||
|
||||
# model.completion_params
|
||||
if 'completion_params' not in config["model"]:
|
||||
@@ -339,6 +332,9 @@ class AppModelConfigService:
|
||||
# dataset_query_variable
|
||||
AppModelConfigService.is_dataset_query_variable_valid(config, mode)
|
||||
|
||||
# advanced prompt validation
|
||||
AppModelConfigService.is_advanced_prompt_valid(config, mode)
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {
|
||||
"opening_statement": config["opening_statement"],
|
||||
@@ -351,12 +347,17 @@ class AppModelConfigService:
|
||||
"model": {
|
||||
"provider": config["model"]["provider"],
|
||||
"name": config["model"]["name"],
|
||||
"mode": config['model']["mode"],
|
||||
"completion_params": config["model"]["completion_params"]
|
||||
},
|
||||
"user_input_form": config["user_input_form"],
|
||||
"dataset_query_variable": config.get('dataset_query_variable'),
|
||||
"pre_prompt": config["pre_prompt"],
|
||||
"agent_mode": config["agent_mode"]
|
||||
"agent_mode": config["agent_mode"],
|
||||
"prompt_type": config["prompt_type"],
|
||||
"chat_prompt_config": config["chat_prompt_config"],
|
||||
"completion_prompt_config": config["completion_prompt_config"],
|
||||
"dataset_configs": config["dataset_configs"]
|
||||
}
|
||||
|
||||
return filtered_config
|
||||
@@ -375,4 +376,51 @@ class AppModelConfigService:
|
||||
|
||||
if dataset_exists and not dataset_query_variable:
|
||||
raise ValueError("Dataset query variable is required when dataset is exist")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:
|
||||
# prompt_type
|
||||
if 'prompt_type' not in config or not config["prompt_type"]:
|
||||
config["prompt_type"] = "simple"
|
||||
|
||||
if config['prompt_type'] not in ['simple', 'advanced']:
|
||||
raise ValueError("prompt_type must be in ['simple', 'advanced']")
|
||||
|
||||
# chat_prompt_config
|
||||
if 'chat_prompt_config' not in config or not config["chat_prompt_config"]:
|
||||
config["chat_prompt_config"] = {}
|
||||
|
||||
if not isinstance(config["chat_prompt_config"], dict):
|
||||
raise ValueError("chat_prompt_config must be of object type")
|
||||
|
||||
# completion_prompt_config
|
||||
if 'completion_prompt_config' not in config or not config["completion_prompt_config"]:
|
||||
config["completion_prompt_config"] = {}
|
||||
|
||||
if not isinstance(config["completion_prompt_config"], dict):
|
||||
raise ValueError("completion_prompt_config must be of object type")
|
||||
|
||||
# dataset_configs
|
||||
if 'dataset_configs' not in config or not config["dataset_configs"]:
|
||||
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if config['prompt_type'] == 'advanced':
|
||||
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
|
||||
raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced")
|
||||
|
||||
if config['model']["mode"] not in ['chat', 'completion']:
|
||||
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
|
||||
|
||||
if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value:
|
||||
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
|
||||
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
|
||||
|
||||
if not user_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
|
||||
|
||||
if not assistant_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
|
||||
|
||||
@@ -244,7 +244,8 @@ class CompletionService:
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
|
||||
message_id: str, streaming: bool = True) -> Union[dict | Generator]:
|
||||
message_id: str, streaming: bool = True,
|
||||
retriever_from: str = 'dev') -> Union[dict | Generator]:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
|
||||
@@ -266,14 +267,11 @@ class CompletionService:
|
||||
raise MoreLikeThisDisabledError()
|
||||
|
||||
app_model_config = message.app_model_config
|
||||
|
||||
if message.override_model_configs:
|
||||
override_model_configs = json.loads(message.override_model_configs)
|
||||
pre_prompt = override_model_configs.get("pre_prompt", '')
|
||||
elif app_model_config:
|
||||
pre_prompt = app_model_config.pre_prompt
|
||||
else:
|
||||
raise AppModelConfigBrokenError()
|
||||
model_dict = app_model_config.model_dict
|
||||
completion_params = model_dict.get('completion_params')
|
||||
completion_params['temperature'] = 0.9
|
||||
model_dict['completion_params'] = completion_params
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
|
||||
generate_task_id = str(uuid.uuid4())
|
||||
|
||||
@@ -282,58 +280,28 @@ class CompletionService:
|
||||
|
||||
user = cls.get_real_user_instead_of_proxy_obj(user)
|
||||
|
||||
generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
|
||||
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'generate_task_id': generate_task_id,
|
||||
'detached_app_model': app_model,
|
||||
'app_model_config': app_model_config,
|
||||
'detached_message': message,
|
||||
'pre_prompt': pre_prompt,
|
||||
'query': message.query,
|
||||
'inputs': message.inputs,
|
||||
'detached_user': user,
|
||||
'streaming': streaming
|
||||
'detached_conversation': None,
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': True,
|
||||
'retriever_from': retriever_from
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
|
||||
# wait for 10 minutes to close the thread
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
|
||||
generate_task_id)
|
||||
|
||||
return cls.compact_response(pubsub, streaming)
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
|
||||
app_model_config: AppModelConfig, detached_message: Message, pre_prompt: str,
|
||||
detached_user: Union[Account, EndUser], streaming: bool):
|
||||
with flask_app.app_context():
|
||||
# fixed the state of the model object when it detached from the original session
|
||||
user = db.session.merge(detached_user)
|
||||
app_model = db.session.merge(detached_app_model)
|
||||
message = db.session.merge(detached_message)
|
||||
|
||||
try:
|
||||
# run
|
||||
Completion.generate_more_like_this(
|
||||
task_id=generate_task_id,
|
||||
app=app_model,
|
||||
user=user,
|
||||
message=message,
|
||||
pre_prompt=pre_prompt,
|
||||
app_model_config=app_model_config,
|
||||
streaming=streaming
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
pass
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
|
||||
ModelCurrentlyNotSupportError) as e:
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
except LLMAuthorizationError:
|
||||
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
|
||||
except Exception as e:
|
||||
logging.exception("Unknown Error in completion")
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
finally:
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
|
||||
if user_inputs is None:
|
||||
|
||||
@@ -482,6 +482,9 @@ class ProviderService:
|
||||
'features': []
|
||||
}
|
||||
|
||||
if 'mode' in model:
|
||||
valid_model_dict['model_mode'] = model['mode']
|
||||
|
||||
if 'features' in model:
|
||||
valid_model_dict['features'] = model['features']
|
||||
|
||||
|
||||
Reference in New Issue
Block a user