refactor advanced prompt core. (#1350)

This commit is contained in:
Garfield Dai
2023-10-18 20:02:52 +08:00
committed by GitHub
parent 52ebffa857
commit fe14130b3c
9 changed files with 405 additions and 272 deletions

View File

@@ -1,6 +1,8 @@
import copy
from core.model_providers.models.entity.model_params import ModelMode
from core.prompt.prompt_transform import AppMode
from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
@@ -13,7 +15,7 @@ class AdvancedPromptTemplateService:
model_name = args['model_name']
has_context = args['has_context']
if 'baichuan' in model_name:
if 'baichuan' in model_name.lower():
return cls.get_baichuan_prompt(app_mode, model_mode, has_context)
else:
return cls.get_common_prompt(app_mode, model_mode, has_context)
@@ -22,15 +24,15 @@ class AdvancedPromptTemplateService:
def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
context_prompt = copy.deepcopy(CONTEXT)
if app_mode == 'chat':
if model_mode == 'completion':
if app_mode == AppMode.CHAT.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == 'completion':
if model_mode == 'completion':
elif app_mode == AppMode.COMPLETION.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
@classmethod
@@ -51,13 +53,13 @@ class AdvancedPromptTemplateService:
def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
if app_mode == 'chat':
if model_mode == 'completion':
if app_mode == AppMode.CHAT.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif app_mode == 'completion':
if model_mode == 'completion':
elif app_mode == AppMode.COMPLETION.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)

View File

@@ -1,6 +1,7 @@
import re
import uuid
from core.prompt.prompt_transform import AppMode
from core.agent.agent_executor import PlanningStrategy
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType, ModelMode
@@ -418,7 +419,7 @@ class AppModelConfigService:
if config['model']["mode"] not in ['chat', 'completion']:
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value:
if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
@@ -427,3 +428,10 @@ class AppModelConfigService:
if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
if config['model']["mode"] == ModelMode.CHAT.value:
prompt_list = config['chat_prompt_config']['prompt']
if len(prompt_list) > 10:
raise ValueError("prompt messages must be less than 10")