Refactor workflow nodes to use generic node_data (#28782)
This commit is contained in:
@@ -102,8 +102,6 @@ logger = logging.getLogger(__name__)
|
||||
class LLMNode(Node[LLMNodeData]):
|
||||
node_type = NodeType.LLM
|
||||
|
||||
_node_data: LLMNodeData
|
||||
|
||||
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
|
||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
@@ -154,13 +152,13 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
try:
|
||||
# init messages template
|
||||
self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template)
|
||||
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
|
||||
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data=self._node_data)
|
||||
inputs = self._fetch_inputs(node_data=self.node_data)
|
||||
|
||||
# fetch jinja2 inputs
|
||||
jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data)
|
||||
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
|
||||
|
||||
# merge inputs
|
||||
inputs.update(jinja_inputs)
|
||||
@@ -169,9 +167,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
files = (
|
||||
llm_utils.fetch_files(
|
||||
variable_pool=variable_pool,
|
||||
selector=self._node_data.vision.configs.variable_selector,
|
||||
selector=self.node_data.vision.configs.variable_selector,
|
||||
)
|
||||
if self._node_data.vision.enabled
|
||||
if self.node_data.vision.enabled
|
||||
else []
|
||||
)
|
||||
|
||||
@@ -179,7 +177,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_inputs["#files#"] = [file.to_dict() for file in files]
|
||||
|
||||
# fetch context value
|
||||
generator = self._fetch_context(node_data=self._node_data)
|
||||
generator = self._fetch_context(node_data=self.node_data)
|
||||
context = None
|
||||
for event in generator:
|
||||
context = event.context
|
||||
@@ -189,7 +187,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = LLMNode._fetch_model_config(
|
||||
node_data_model=self._node_data.model,
|
||||
node_data_model=self.node_data.model,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
@@ -197,13 +195,13 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=self._node_data.memory,
|
||||
node_data_memory=self.node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
query: str | None = None
|
||||
if self._node_data.memory:
|
||||
query = self._node_data.memory.query_prompt_template
|
||||
if self.node_data.memory:
|
||||
query = self.node_data.memory.query_prompt_template
|
||||
if not query and (
|
||||
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
):
|
||||
@@ -215,29 +213,29 @@ class LLMNode(Node[LLMNodeData]):
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=self._node_data.prompt_template,
|
||||
memory_config=self._node_data.memory,
|
||||
vision_enabled=self._node_data.vision.enabled,
|
||||
vision_detail=self._node_data.vision.configs.detail,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self._node_data.prompt_config.jinja2_variables,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=self._node_data.model,
|
||||
node_data_model=self.node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||
structured_output=self._node_data.structured_output,
|
||||
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||
structured_output=self.node_data.structured_output,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
reasoning_format=self._node_data.reasoning_format,
|
||||
reasoning_format=self.node_data.reasoning_format,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
@@ -253,12 +251,12 @@ class LLMNode(Node[LLMNodeData]):
|
||||
reasoning_content = event.reasoning_content or ""
|
||||
|
||||
# For downstream nodes, determine clean text based on reasoning_format
|
||||
if self._node_data.reasoning_format == "tagged":
|
||||
if self.node_data.reasoning_format == "tagged":
|
||||
# Keep <think> tags for backward compatibility
|
||||
clean_text = result_text
|
||||
else:
|
||||
# Extract clean text from <think> tags
|
||||
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
|
||||
clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format)
|
||||
|
||||
# Process structured output if available from the event.
|
||||
structured_output = (
|
||||
@@ -1204,7 +1202,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
return self.node_data.retry_config.retry_enabled
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
|
||||
Reference in New Issue
Block a user