Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -14,39 +14,24 @@ from models.model import Conversation
|
||||
|
||||
def test__get_completion_model_prompt_messages():
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-3.5-turbo-instruct'
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = "gpt-3.5-turbo-instruct"
|
||||
|
||||
prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}."
|
||||
prompt_template_config = CompletionModelPromptTemplate(
|
||||
text=prompt_template
|
||||
)
|
||||
prompt_template_config = CompletionModelPromptTemplate(text=prompt_template)
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
role_prefix=MemoryConfig.RolePrefix(
|
||||
user="Human",
|
||||
assistant="Assistant"
|
||||
),
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False
|
||||
)
|
||||
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
|
||||
window=MemoryConfig.WindowConfig(enabled=False),
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
inputs = {"name": "John"}
|
||||
files = []
|
||||
context = "I am superman."
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=Conversation(),
|
||||
model_instance=model_config_mock
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
|
||||
|
||||
history_prompt_messages = [
|
||||
UserPromptMessage(content="Hi"),
|
||||
AssistantPromptMessage(content="Hello")
|
||||
]
|
||||
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
|
||||
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
@@ -59,16 +44,22 @@ def test__get_completion_model_prompt_messages():
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config_mock
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 1
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format({
|
||||
"#context#": context,
|
||||
"#histories#": "\n".join([f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: "
|
||||
f"{prompt.content}" for prompt in history_prompt_messages]),
|
||||
**inputs,
|
||||
})
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format(
|
||||
{
|
||||
"#context#": context,
|
||||
"#histories#": "\n".join(
|
||||
[
|
||||
f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " f"{prompt.content}"
|
||||
for prompt in history_prompt_messages
|
||||
]
|
||||
),
|
||||
**inputs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages(get_chat_model_args):
|
||||
@@ -77,15 +68,9 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
|
||||
files = []
|
||||
query = "Hi2."
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=Conversation(),
|
||||
model_instance=model_config_mock
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
|
||||
|
||||
history_prompt_messages = [
|
||||
UserPromptMessage(content="Hi1."),
|
||||
AssistantPromptMessage(content="Hello1!")
|
||||
]
|
||||
history_prompt_messages = [UserPromptMessage(content="Hi1."), AssistantPromptMessage(content="Hello1!")]
|
||||
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
@@ -98,14 +83,14 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config_mock
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 6
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
|
||||
{**inputs, "#context#": context}
|
||||
)
|
||||
assert prompt_messages[5].content == query
|
||||
|
||||
|
||||
@@ -124,14 +109,14 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
|
||||
context=context,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 3
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
|
||||
{**inputs, "#context#": context}
|
||||
)
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
|
||||
@@ -148,7 +133,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
image_config={
|
||||
"detail": "high",
|
||||
}
|
||||
)
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
@@ -162,14 +147,14 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
context=context,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 4
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
|
||||
{**inputs, "#context#": context}
|
||||
)
|
||||
assert isinstance(prompt_messages[3].content, list)
|
||||
assert len(prompt_messages[3].content) == 2
|
||||
assert prompt_messages[3].content[1].data == files[0].url
|
||||
@@ -178,33 +163,20 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
@pytest.fixture
|
||||
def get_chat_model_args():
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-4'
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = "gpt-4"
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False
|
||||
)
|
||||
)
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
|
||||
prompt_messages = [
|
||||
ChatModelMessage(
|
||||
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
|
||||
role=PromptMessageRole.SYSTEM
|
||||
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", role=PromptMessageRole.SYSTEM
|
||||
),
|
||||
ChatModelMessage(
|
||||
text="Hi.",
|
||||
role=PromptMessageRole.USER
|
||||
),
|
||||
ChatModelMessage(
|
||||
text="Hello!",
|
||||
role=PromptMessageRole.ASSISTANT
|
||||
)
|
||||
ChatModelMessage(text="Hi.", role=PromptMessageRole.USER),
|
||||
ChatModelMessage(text="Hello!", role=PromptMessageRole.ASSISTANT),
|
||||
]
|
||||
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
inputs = {"name": "John"}
|
||||
|
||||
context = "I am superman."
|
||||
|
||||
|
||||
@@ -18,27 +18,28 @@ from models.model import Conversation
|
||||
|
||||
def test_get_prompt():
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content='System Template'),
|
||||
UserPromptMessage(content='User Query'),
|
||||
SystemPromptMessage(content="System Template"),
|
||||
UserPromptMessage(content="User Query"),
|
||||
]
|
||||
history_messages = [
|
||||
SystemPromptMessage(content='System Prompt 1'),
|
||||
UserPromptMessage(content='User Prompt 1'),
|
||||
AssistantPromptMessage(content='Assistant Thought 1'),
|
||||
ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'),
|
||||
ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'),
|
||||
SystemPromptMessage(content='System Prompt 2'),
|
||||
UserPromptMessage(content='User Prompt 2'),
|
||||
AssistantPromptMessage(content='Assistant Thought 2'),
|
||||
ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'),
|
||||
ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'),
|
||||
UserPromptMessage(content='User Prompt 3'),
|
||||
AssistantPromptMessage(content='Assistant Thought 3'),
|
||||
SystemPromptMessage(content="System Prompt 1"),
|
||||
UserPromptMessage(content="User Prompt 1"),
|
||||
AssistantPromptMessage(content="Assistant Thought 1"),
|
||||
ToolPromptMessage(content="Tool 1-1", name="Tool 1-1", tool_call_id="1"),
|
||||
ToolPromptMessage(content="Tool 1-2", name="Tool 1-2", tool_call_id="2"),
|
||||
SystemPromptMessage(content="System Prompt 2"),
|
||||
UserPromptMessage(content="User Prompt 2"),
|
||||
AssistantPromptMessage(content="Assistant Thought 2"),
|
||||
ToolPromptMessage(content="Tool 2-1", name="Tool 2-1", tool_call_id="3"),
|
||||
ToolPromptMessage(content="Tool 2-2", name="Tool 2-2", tool_call_id="4"),
|
||||
UserPromptMessage(content="User Prompt 3"),
|
||||
AssistantPromptMessage(content="Assistant Thought 3"),
|
||||
]
|
||||
|
||||
# use message number instead of token for testing
|
||||
def side_effect_get_num_tokens(*args):
|
||||
return len(args[2])
|
||||
|
||||
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
|
||||
large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens)
|
||||
|
||||
@@ -46,20 +47,17 @@ def test_get_prompt():
|
||||
provider_model_bundle_mock.model_type_instance = large_language_model_mock
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config_mock.model = 'openai'
|
||||
model_config_mock.model = "openai"
|
||||
model_config_mock.credentials = {}
|
||||
model_config_mock.provider_model_bundle = provider_model_bundle_mock
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=Conversation(),
|
||||
model_instance=model_config_mock
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
|
||||
|
||||
transform = AgentHistoryPromptTransform(
|
||||
model_config=model_config_mock,
|
||||
prompt_messages=prompt_messages,
|
||||
history_messages=history_messages,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
max_token_limit = 5
|
||||
|
||||
@@ -12,19 +12,15 @@ from core.prompt.prompt_transform import PromptTransform
|
||||
def test__calculate_rest_token():
|
||||
model_schema_mock = MagicMock(spec=AIModelEntity)
|
||||
parameter_rule_mock = MagicMock(spec=ParameterRule)
|
||||
parameter_rule_mock.name = 'max_tokens'
|
||||
model_schema_mock.parameter_rules = [
|
||||
parameter_rule_mock
|
||||
]
|
||||
model_schema_mock.model_properties = {
|
||||
ModelPropertyKey.CONTEXT_SIZE: 62
|
||||
}
|
||||
parameter_rule_mock.name = "max_tokens"
|
||||
model_schema_mock.parameter_rules = [parameter_rule_mock]
|
||||
model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62}
|
||||
|
||||
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
|
||||
large_language_model_mock.get_num_tokens.return_value = 6
|
||||
|
||||
provider_mock = MagicMock(spec=ProviderEntity)
|
||||
provider_mock.provider = 'openai'
|
||||
provider_mock.provider = "openai"
|
||||
|
||||
provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
|
||||
provider_configuration_mock.provider = provider_mock
|
||||
@@ -35,11 +31,9 @@ def test__calculate_rest_token():
|
||||
provider_model_bundle_mock.configuration = provider_configuration_mock
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.model = 'gpt-4'
|
||||
model_config_mock.model = "gpt-4"
|
||||
model_config_mock.credentials = {}
|
||||
model_config_mock.parameters = {
|
||||
'max_tokens': 50
|
||||
}
|
||||
model_config_mock.parameters = {"max_tokens": 50}
|
||||
model_config_mock.model_schema = model_schema_mock
|
||||
model_config_mock.provider_model_bundle = provider_model_bundle_mock
|
||||
|
||||
@@ -49,8 +43,10 @@ def test__calculate_rest_token():
|
||||
rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock)
|
||||
|
||||
# Validate based on the mock configuration and expected logic
|
||||
expected_rest_tokens = (model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
- model_config_mock.parameters['max_tokens']
|
||||
- large_language_model_mock.get_num_tokens.return_value)
|
||||
expected_rest_tokens = (
|
||||
model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
- model_config_mock.parameters["max_tokens"]
|
||||
- large_language_model_mock.get_num_tokens.return_value
|
||||
)
|
||||
assert rest_tokens == expected_rest_tokens
|
||||
assert rest_tokens == 6
|
||||
|
||||
@@ -19,12 +19,15 @@ def test_get_common_chat_app_prompt_template_with_pcqm():
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=True,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ pre_prompt + '\n'
|
||||
+ prompt_rules['histories_prompt']
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#']
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"]
|
||||
+ pre_prompt
|
||||
+ "\n"
|
||||
+ prompt_rules["histories_prompt"]
|
||||
+ prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"]
|
||||
|
||||
|
||||
def test_get_baichuan_chat_app_prompt_template_with_pcqm():
|
||||
@@ -39,12 +42,15 @@ def test_get_baichuan_chat_app_prompt_template_with_pcqm():
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=True,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ pre_prompt + '\n'
|
||||
+ prompt_rules['histories_prompt']
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#']
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"]
|
||||
+ pre_prompt
|
||||
+ "\n"
|
||||
+ prompt_rules["histories_prompt"]
|
||||
+ prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"]
|
||||
|
||||
|
||||
def test_get_common_completion_app_prompt_template_with_pcq():
|
||||
@@ -59,11 +65,11 @@ def test_get_common_completion_app_prompt_template_with_pcq():
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ pre_prompt + '\n'
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
|
||||
|
||||
|
||||
def test_get_baichuan_completion_app_prompt_template_with_pcq():
|
||||
@@ -78,12 +84,12 @@ def test_get_baichuan_completion_app_prompt_template_with_pcq():
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
print(prompt_template['prompt_template'].template)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ pre_prompt + '\n'
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
|
||||
print(prompt_template["prompt_template"].template)
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_q():
|
||||
@@ -98,9 +104,9 @@ def test_get_common_chat_app_prompt_template_with_q():
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == prompt_rules['query_prompt']
|
||||
assert prompt_template['special_variable_keys'] == ['#query#']
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == prompt_rules["query_prompt"]
|
||||
assert prompt_template["special_variable_keys"] == ["#query#"]
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_cq():
|
||||
@@ -115,10 +121,11 @@ def test_get_common_chat_app_prompt_template_with_cq():
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"] + prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_p():
|
||||
@@ -133,30 +140,25 @@ def test_get_common_chat_app_prompt_template_with_p():
|
||||
query_in_prompt=False,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
assert prompt_template['prompt_template'].template == pre_prompt + '\n'
|
||||
assert prompt_template['custom_variable_keys'] == ['name']
|
||||
assert prompt_template['special_variable_keys'] == []
|
||||
assert prompt_template["prompt_template"].template == pre_prompt + "\n"
|
||||
assert prompt_template["custom_variable_keys"] == ["name"]
|
||||
assert prompt_template["special_variable_keys"] == []
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages():
|
||||
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-4'
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = "gpt-4"
|
||||
|
||||
memory_mock = MagicMock(spec=TokenBufferMemory)
|
||||
history_prompt_messages = [
|
||||
UserPromptMessage(content="Hi"),
|
||||
AssistantPromptMessage(content="Hello")
|
||||
]
|
||||
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
|
||||
memory_mock.get_history_prompt_messages.return_value = history_prompt_messages
|
||||
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
|
||||
pre_prompt = "You are a helpful assistant {{name}}."
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
inputs = {"name": "John"}
|
||||
context = "yes or no."
|
||||
query = "How are you?"
|
||||
prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages(
|
||||
@@ -167,7 +169,7 @@ def test__get_chat_model_prompt_messages():
|
||||
files=[],
|
||||
context=context,
|
||||
memory=memory_mock,
|
||||
model_config=model_config_mock
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
@@ -180,8 +182,8 @@ def test__get_chat_model_prompt_messages():
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
|
||||
full_inputs = {**inputs, '#context#': context}
|
||||
real_system_prompt = prompt_template['prompt_template'].format(full_inputs)
|
||||
full_inputs = {**inputs, "#context#": context}
|
||||
real_system_prompt = prompt_template["prompt_template"].format(full_inputs)
|
||||
|
||||
assert len(prompt_messages) == 4
|
||||
assert prompt_messages[0].content == real_system_prompt
|
||||
@@ -192,26 +194,18 @@ def test__get_chat_model_prompt_messages():
|
||||
|
||||
def test__get_completion_model_prompt_messages():
|
||||
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-3.5-turbo-instruct'
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = "gpt-3.5-turbo-instruct"
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=Conversation(),
|
||||
model_instance=model_config_mock
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
|
||||
|
||||
history_prompt_messages = [
|
||||
UserPromptMessage(content="Hi"),
|
||||
AssistantPromptMessage(content="Hello")
|
||||
]
|
||||
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
|
||||
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
|
||||
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
pre_prompt = "You are a helpful assistant {{name}}."
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
inputs = {"name": "John"}
|
||||
context = "yes or no."
|
||||
query = "How are you?"
|
||||
prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages(
|
||||
@@ -222,7 +216,7 @@ def test__get_completion_model_prompt_messages():
|
||||
files=[],
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config_mock
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
@@ -235,14 +229,19 @@ def test__get_completion_model_prompt_messages():
|
||||
with_memory_prompt=True,
|
||||
)
|
||||
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text(
|
||||
max_token_limit=2000,
|
||||
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
||||
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant")
|
||||
)}
|
||||
real_prompt = prompt_template['prompt_template'].format(full_inputs)
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
full_inputs = {
|
||||
**inputs,
|
||||
"#context#": context,
|
||||
"#query#": query,
|
||||
"#histories#": memory.get_history_prompt_text(
|
||||
max_token_limit=2000,
|
||||
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
||||
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
|
||||
),
|
||||
}
|
||||
real_prompt = prompt_template["prompt_template"].format(full_inputs)
|
||||
|
||||
assert len(prompt_messages) == 1
|
||||
assert stops == prompt_rules.get('stops')
|
||||
assert stops == prompt_rules.get("stops")
|
||||
assert prompt_messages[0].content == real_prompt
|
||||
|
||||
Reference in New Issue
Block a user