feat: support LLM jinja2 template prompt (#3968)
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
17
api/core/helper/code_executor/jinja2_formatter.py
Normal file
17
api/core/helper/code_executor/jinja2_formatter.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
|
||||
|
||||
class Jinja2Formatter:
|
||||
@classmethod
|
||||
def format(cls, template: str, inputs: str) -> str:
|
||||
"""
|
||||
Format template
|
||||
:param template: template
|
||||
:param inputs: inputs
|
||||
:return:
|
||||
"""
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language='jinja2', code=template, inputs=inputs
|
||||
)
|
||||
|
||||
return result['result']
|
||||
@@ -2,6 +2,7 @@ from typing import Optional, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file.file_obj import FileVar
|
||||
from core.helper.code_executor.jinja2_formatter import Jinja2Formatter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
@@ -80,29 +81,35 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
prompt_messages = []
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
if prompt_template.edition_type == 'basic' or not prompt_template.edition_type:
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
if memory and memory_config:
|
||||
role_prefix = memory_config.role_prefix
|
||||
prompt_inputs = self._set_histories_variable(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
raw_prompt=raw_prompt,
|
||||
role_prefix=role_prefix,
|
||||
prompt_template=prompt_template,
|
||||
prompt_inputs=prompt_inputs,
|
||||
model_config=model_config
|
||||
if memory and memory_config:
|
||||
role_prefix = memory_config.role_prefix
|
||||
prompt_inputs = self._set_histories_variable(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
raw_prompt=raw_prompt,
|
||||
role_prefix=role_prefix,
|
||||
prompt_template=prompt_template,
|
||||
prompt_inputs=prompt_inputs,
|
||||
model_config=model_config
|
||||
)
|
||||
|
||||
if query:
|
||||
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
else:
|
||||
prompt = raw_prompt
|
||||
prompt_inputs = inputs
|
||||
|
||||
if query:
|
||||
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||
|
||||
if files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
||||
@@ -135,14 +142,22 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
for prompt_item in raw_prompt_list:
|
||||
raw_prompt = prompt_item.text
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
if prompt_item.edition_type == 'basic' or not prompt_item.edition_type:
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
elif prompt_item.edition_type == 'jinja2':
|
||||
prompt = raw_prompt
|
||||
prompt_inputs = inputs
|
||||
|
||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||
else:
|
||||
raise ValueError(f'Invalid edition type: {prompt_item.edition_type}')
|
||||
|
||||
if prompt_item.role == PromptMessageRole.USER:
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -11,6 +11,7 @@ class ChatModelMessage(BaseModel):
|
||||
"""
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
edition_type: Optional[Literal['basic', 'jinja2']]
|
||||
|
||||
|
||||
class CompletionModelPromptTemplate(BaseModel):
|
||||
@@ -18,6 +19,7 @@ class CompletionModelPromptTemplate(BaseModel):
|
||||
Completion Model Prompt Template.
|
||||
"""
|
||||
text: str
|
||||
edition_type: Optional[Literal['basic', 'jinja2']]
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
@@ -37,13 +38,31 @@ class VisionConfig(BaseModel):
|
||||
enabled: bool
|
||||
configs: Optional[Configs] = None
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""
|
||||
Prompt Config.
|
||||
"""
|
||||
jinja2_variables: Optional[list[VariableSelector]] = None
|
||||
|
||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||
"""
|
||||
LLM Node Chat Model Message.
|
||||
"""
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
"""
|
||||
LLM Node Chat Model Prompt Template.
|
||||
"""
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
"""
|
||||
LLM Node Data.
|
||||
"""
|
||||
model: ModelConfig
|
||||
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
|
||||
prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
|
||||
prompt_config: Optional[PromptConfig] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
@@ -17,11 +19,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
@@ -39,16 +45,24 @@ class LLMNode(BaseNode):
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
node_data = cast(LLMNodeData, deepcopy(self.node_data))
|
||||
|
||||
node_inputs = None
|
||||
process_data = None
|
||||
|
||||
try:
|
||||
# init messages template
|
||||
node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)
|
||||
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data, variable_pool)
|
||||
|
||||
# fetch jinja2 inputs
|
||||
jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)
|
||||
|
||||
# merge inputs
|
||||
inputs.update(jinja_inputs)
|
||||
|
||||
node_inputs = {}
|
||||
|
||||
# fetch files
|
||||
@@ -183,6 +197,86 @@ class LLMNode(BaseNode):
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
return full_text, usage
|
||||
|
||||
def _transform_chat_messages(self,
|
||||
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
"""
|
||||
Transform chat messages
|
||||
|
||||
:param messages: chat messages
|
||||
:return:
|
||||
"""
|
||||
|
||||
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
|
||||
if messages.edition_type == 'jinja2':
|
||||
messages.text = messages.jinja2_text
|
||||
|
||||
return messages
|
||||
|
||||
for message in messages:
|
||||
if message.edition_type == 'jinja2':
|
||||
message.text = message.jinja2_text
|
||||
|
||||
return messages
|
||||
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
Fetch jinja inputs
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
variables = {}
|
||||
|
||||
if not node_data.prompt_config:
|
||||
return variables
|
||||
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
def parse_dict(d: dict) -> str:
|
||||
"""
|
||||
Parse dict into string
|
||||
"""
|
||||
# check if it's a context structure
|
||||
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
|
||||
return d['content']
|
||||
|
||||
# else, parse the dict
|
||||
try:
|
||||
return json.dumps(d, ensure_ascii=False)
|
||||
except Exception:
|
||||
return str(d)
|
||||
|
||||
if isinstance(value, str):
|
||||
value = value
|
||||
elif isinstance(value, list):
|
||||
result = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
result += parse_dict(item)
|
||||
elif isinstance(item, str):
|
||||
result += item
|
||||
elif isinstance(item, int | float):
|
||||
result += str(item)
|
||||
else:
|
||||
result += str(item)
|
||||
result += '\n'
|
||||
value = result.strip()
|
||||
elif isinstance(value, dict):
|
||||
value = parse_dict(value)
|
||||
elif isinstance(value, int | float):
|
||||
value = str(value)
|
||||
else:
|
||||
value = str(value)
|
||||
|
||||
variables[variable] = value
|
||||
|
||||
return variables
|
||||
|
||||
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
@@ -531,25 +625,25 @@ class LLMNode(BaseNode):
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
if prompt.edition_type != 'jinja2':
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
else:
|
||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
if prompt_template.edition_type != 'jinja2':
|
||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
@@ -571,6 +665,22 @@ class LLMNode(BaseNode):
|
||||
if node_data.memory:
|
||||
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
|
||||
|
||||
if node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
if prompt.edition_type == 'jinja2':
|
||||
enable_jinja = True
|
||||
break
|
||||
else:
|
||||
if prompt_template.edition_type == 'jinja2':
|
||||
enable_jinja = True
|
||||
|
||||
if enable_jinja:
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
@@ -588,7 +698,8 @@ class LLMNode(BaseNode):
|
||||
"prompts": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "You are a helpful AI assistant."
|
||||
"text": "You are a helpful AI assistant.",
|
||||
"edition_type": "basic"
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -600,7 +711,8 @@ class LLMNode(BaseNode):
|
||||
"prompt": {
|
||||
"text": "Here is the chat histories between human and assistant, inside "
|
||||
"<histories></histories> XML tags.\n\n<histories>\n{{"
|
||||
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:"
|
||||
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
|
||||
"edition_type": "basic"
|
||||
},
|
||||
"stop": ["Human:"]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user