feat: Persist Variables for Enhanced Debugging Workflow (#20699)

This pull request introduces a feature aimed at improving the debugging experience during workflow editing. With the addition of variable persistence, the system will automatically retain the output variables from previously executed nodes. These persisted variables can then be reused when debugging subsequent nodes, eliminating the need for repetitive manual input.

By streamlining this aspect of the workflow, the feature minimizes user errors and significantly reduces debugging effort, offering a smoother and more efficient experience.

Key highlights of this change:

- Automatic persistence of output variables for executed nodes.
- Reuse of persisted variables to simplify input steps for nodes requiring them (e.g., `code`, `template`, `variable_assigner`).
- Enhanced debugging experience with reduced friction.

Closes #19735.
This commit is contained in:
QuantumGhost
2025-06-24 09:05:29 +08:00
committed by GitHub
parent 3113350e51
commit 10b738a296
106 changed files with 6025 additions and 718 deletions

View File

@@ -0,0 +1,39 @@
import abc
from typing import Protocol
from core.variables import Variable
class ConversationVariableUpdater(Protocol):
"""
ConversationVariableUpdater defines an abstraction for updating conversation variable values.
It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating
conversation variables.
Implementations may choose to batch updates. If batching is used, the `flush` method
should be implemented to persist buffered changes, and `update`
should handle buffering accordingly.
Note: Since implementations may buffer updates, instances of ConversationVariableUpdater
are not thread-safe. Each VariableAssignerNode should create its own instance during execution.
"""
@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable") -> None:
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
:param variable: The `Variable` instance containing the updated value.
"""
pass
@abc.abstractmethod
def flush(self):
"""
Flushes all pending updates to the underlying storage system.
If the implementation does not buffer updates, this method can be a no-op.
"""
pass

View File

@@ -7,12 +7,12 @@ from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import SystemVariableKey
from factories import variable_factory
from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from ..enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, File]
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
@@ -30,9 +30,11 @@ class VariablePool(BaseModel):
# TODO: This user inputs is not used for pool.
user_inputs: Mapping[str, Any] = Field(
description="User inputs",
default_factory=dict,
)
system_variables: Mapping[SystemVariableKey, Any] = Field(
description="System variables",
default_factory=dict,
)
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
@@ -43,28 +45,7 @@ class VariablePool(BaseModel):
default_factory=list,
)
def __init__(
self,
*,
system_variables: Mapping[SystemVariableKey, Any] | None = None,
user_inputs: Mapping[str, Any] | None = None,
environment_variables: Sequence[Variable] | None = None,
conversation_variables: Sequence[Variable] | None = None,
**kwargs,
):
environment_variables = environment_variables or []
conversation_variables = conversation_variables or []
user_inputs = user_inputs or {}
system_variables = system_variables or {}
super().__init__(
system_variables=system_variables,
user_inputs=user_inputs,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
**kwargs,
)
def model_post_init(self, context: Any, /) -> None:
for key, value in self.system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Add environment variables to the variable pool
@@ -91,12 +72,12 @@ class VariablePool(BaseModel):
Returns:
None
"""
if len(selector) < 2:
if len(selector) < MIN_SELECTORS_LENGTH:
raise ValueError("Invalid selector")
if isinstance(value, Variable):
variable = value
if isinstance(value, Segment):
elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
else:
segment = variable_factory.build_segment(value)
@@ -118,7 +99,7 @@ class VariablePool(BaseModel):
Raises:
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
if len(selector) < MIN_SELECTORS_LENGTH:
return None
hash_key = hash(tuple(selector[1:]))

View File

@@ -66,6 +66,8 @@ class BaseNodeEvent(GraphEngineEvent):
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
# The version of the node, or "1" if not specified.
node_version: str = "1"
class NodeRunStartedEvent(BaseNodeEvent):

View File

@@ -314,6 +314,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
raise e
@@ -627,6 +628,7 @@ class GraphEngine:
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
agent_strategy=agent_strategy,
node_version=node_instance.version(),
)
max_retries = node_instance.node_data.retry_config.max_retries
@@ -677,6 +679,7 @@ class GraphEngine:
error=run_result.error or "Unknown error",
retry_index=retries,
start_at=retry_start_at,
node_version=node_instance.version(),
)
time.sleep(retry_interval)
break
@@ -712,6 +715,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
should_continue_retry = False
else:
@@ -726,6 +730,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
should_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
@@ -786,6 +791,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
should_continue_retry = False
@@ -803,6 +809,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
elif isinstance(event, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
@@ -817,6 +824,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
except GenerateTaskStoppedError:
# trigger node run failed event
@@ -833,6 +841,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
return
except Exception as e:

View File

@@ -18,7 +18,11 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerNode(BaseNode[AnswerNodeData]):
_node_data_cls = AnswerNodeData
_node_type: NodeType = NodeType.ANSWER
_node_type = NodeType.ANSWER
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
@@ -45,7 +49,10 @@ class AnswerNode(BaseNode[AnswerNodeData]):
part = cast(TextGenerateRouteChunk, part)
answer += part.text
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files})
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"answer": answer, "files": ArrayFileSegment(value=files)},
)
@classmethod
def _extract_variable_selector_to_variable_mapping(

View File

@@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor):
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"],
node_version=event.node_version,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
@@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
)
self.route_position[answer_node_id] += 1

View File

@@ -1,7 +1,7 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
class BaseNode(Generic[GenericNodeData]):
_node_data_cls: type[GenericNodeData]
_node_type: NodeType
_node_type: ClassVar[NodeType]
def __init__(
self,
@@ -90,8 +90,38 @@ class BaseNode(Generic[GenericNodeData]):
graph_config: Mapping[str, Any],
config: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
"""Extracts references variable selectors from node configuration.
The `config` parameter represents the configuration for a specific node type and corresponds
to the `data` field in the node definition object.
The returned mapping has the following structure:
{'1747829548239.#1747829667553.result#': ['1747829667553', 'result']}
For loop and iteration nodes, the mapping may look like this:
{
"1748332301644.input_selector": ["1748332363630", "result"],
"1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"],
}
where `1748332301644` is the ID of the loop / iteration node,
and `1748332325079` is the ID of the node inside the loop or iteration node.
Here, the key consists of two parts: the current node ID (provided as the `node_id`
parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector,
enclosed in `#` symbols. These two parts are separated by a dot (`.`).
The value is a list of string representing the variable selector, where the first element is the node ID
of the referenced variable, and the second element is the variable name within that node.
The meaning of the above response is:
The node with ID `1747829548239` references the variable `result` from the node with
ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a
reference to the `result` output variable of node `1747829667553`.
:param graph_config: graph config
:param config: node config
:return:
@@ -101,9 +131,10 @@ class BaseNode(Generic[GenericNodeData]):
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(
data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
)
return data
@classmethod
def _extract_variable_selector_to_variable_mapping(
@@ -139,6 +170,16 @@ class BaseNode(Generic[GenericNodeData]):
"""
return self._node_type
@classmethod
@abstractmethod
def version(cls) -> str:
"""`node_version` returns the version of current node type."""
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
#
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
# in `api/core/workflow/nodes/__init__.py`.
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@property
def should_continue_on_error(self) -> bool:
"""judge if should continue on error

View File

@@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]):
return code_provider.get_default_config()
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
# Get code language
code_language = self.node_data.code_language
@@ -126,6 +130,9 @@ class CodeNode(BaseNode[CodeNodeData]):
prefix: str = "",
depth: int = 1,
):
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
# Note that `_transform_result` may produce lists containing `None` values,
# which don't conform to the type requirements of `Array*Segment` classes.
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")

View File

@@ -24,7 +24,7 @@ from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import FileSegment
from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@@ -45,6 +45,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
_node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
@@ -67,7 +71,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": extracted_text_list},
outputs={"text": ArrayStringSegment(value=extracted_text_list)},
)
elif isinstance(value, File):
extracted_text = _extract_text_from_file(value)

View File

@@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]):
_node_data_cls = EndNodeData
_node_type = NodeType.END
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run node

View File

@@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
)
self.route_position[end_node_id] += 1

View File

@@ -6,6 +6,7 @@ from typing import Any, Optional
from configs import dify_config
from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@@ -60,6 +61,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
process_data = {}
try:
@@ -92,7 +97,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"status_code": response.status_code,
"body": response.text if not files else "",
"body": response.text if not files.value else "",
"headers": response.headers,
"files": files,
},
@@ -166,7 +171,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
return mapping
def extract_files(self, url: str, response: Response) -> list[File]:
def extract_files(self, url: str, response: Response) -> ArrayFileSegment:
"""
Extract files from response by checking both Content-Type header and URL
"""
@@ -178,7 +183,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
content_disposition_type = None
if not is_file:
return files
return ArrayFileSegment(value=[])
if parsed_content_disposition:
content_disposition_filename = parsed_content_disposition.get_filename()
@@ -211,4 +216,4 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
)
files.append(file)
return files
return ArrayFileSegment(value=files)

View File

@@ -1,4 +1,5 @@
from typing import Literal
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing_extensions import deprecated
@@ -16,6 +17,10 @@ class IfElseNode(BaseNode[IfElseNodeData]):
_node_data_cls = IfElseNodeData
_node_type = NodeType.IF_ELSE
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run node
@@ -87,6 +92,22 @@ class IfElseNode(BaseNode[IfElseNodeData]):
return data
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IfElseNodeData,
) -> Mapping[str, Sequence[str]]:
var_mapping: dict[str, list[str]] = {}
for case in node_data.cases or []:
for condition in case.conditions:
key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
var_mapping[key] = condition.variable_selector
return var_mapping
@deprecated("This function is deprecated. You should use the new cases structure.")
def _should_not_use_old_function(

View File

@@ -11,6 +11,7 @@ from flask import Flask, current_app
from configs import dify_config
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import (
NodeRunResult,
)
@@ -37,6 +38,7 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from factories.variable_factory import build_segment
from libs.flask_utils import preserve_flask_contexts
from .exc import (
@@ -72,6 +74,10 @@ class IterationNode(BaseNode[IterationNodeData]):
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
Run the node.
@@ -85,10 +91,17 @@ class IterationNode(BaseNode[IterationNodeData]):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
# Try our best to preserve the type informat.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": []},
# TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
)
)
return
@@ -231,6 +244,7 @@ class IterationNode(BaseNode[IterationNodeData]):
# Flatten the list of lists
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
outputs = [item for sublist in outputs for item in sublist]
output_segment = build_segment(outputs)
yield IterationRunSucceededEvent(
iteration_id=self.id,
@@ -247,7 +261,7 @@ class IterationNode(BaseNode[IterationNodeData]):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs},
outputs={"output": output_segment},
metadata={
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,

View File

@@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]):
_node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the node.

View File

@@ -24,6 +24,7 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType
@@ -115,9 +116,12 @@ class KnowledgeRetrievalNode(LLMNode):
# retrieve knowledge
try:
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
outputs = {"result": results}
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=None,
outputs=outputs, # type: ignore
)
except KnowledgeRetrievalNodeError as e:

View File

@@ -3,6 +3,7 @@ from typing import Any, Literal, Union
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@@ -16,6 +17,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
_node_data_cls = ListOperatorNodeData
_node_type = NodeType.LIST_OPERATOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
inputs: dict[str, list] = {}
process_data: dict[str, list] = {}
@@ -30,7 +35,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
if not variable.value:
inputs = {"variable": []}
process_data = {"variable": []}
outputs = {"result": [], "first_record": None, "last_record": None}
if isinstance(variable, ArraySegment):
result = variable.model_copy(update={"value": []})
else:
result = ArrayAnySegment(value=[])
outputs = {"result": result, "first_record": None, "last_record": None}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
@@ -71,7 +80,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
variable = self._apply_slice(variable)
outputs = {
"result": variable.value,
"result": variable,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}

View File

@@ -119,9 +119,6 @@ class FileSaverImpl(LLMFileSaver):
size=len(data),
related_id=tool_file.id,
url=url,
# TODO(QuantumGhost): how should I set the following key?
# What's the difference between `remote_url` and `url`?
# What's the purpose of `storage_key` and `dify_model_identity`?
storage_key=tool_file.file_key,
)

View File

@@ -138,6 +138,10 @@ class LLMNode(BaseNode[LLMNodeData]):
)
self._llm_file_saver = llm_file_saver
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def process_structured_output(text: str) -> Optional[dict[str, Any]]:
"""Process structured output if enabled"""
@@ -255,7 +259,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if structured_output:
outputs["structured_output"] = structured_output
if self._file_outputs is not None:
outputs["files"] = self._file_outputs
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
yield RunCompletedEvent(
run_result=NodeRunResult(

View File

@@ -13,6 +13,10 @@ class LoopEndNode(BaseNode[LoopEndNodeData]):
_node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the node.

View File

@@ -54,6 +54,10 @@ class LoopNode(BaseNode[LoopNodeData]):
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""Run the node."""
# Get inputs
@@ -482,6 +486,13 @@ class LoopNode(BaseNode[LoopNodeData]):
variable_mapping.update(sub_node_variable_mapping)
for loop_variable in node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping
selector = loop_variable.value
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
# remove variable out from loop
variable_mapping = {
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids

View File

@@ -13,6 +13,10 @@ class LoopStartNode(BaseNode[LoopStartNodeData]):
_node_data_cls = LoopStartNodeData
_node_type = NodeType.LOOP_START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the node.

View File

@@ -25,6 +25,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var
LATEST_VERSION = "latest"
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
# Specifically, if you have introduced new node types, you should add them here.
#
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
# hook. Try to avoid duplication of node information.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,

View File

@@ -7,6 +7,10 @@ from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig
class _ParameterConfigError(Exception):
pass
class ParameterConfig(BaseModel):
"""
Parameter Config.
@@ -27,6 +31,19 @@ class ParameterConfig(BaseModel):
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
return str(value)
def is_array_type(self) -> bool:
return self.type in ("array[string]", "array[number]", "array[object]")
def element_type(self) -> Literal["string", "number", "object"]:
if self.type == "array[number]":
return "number"
elif self.type == "array[string]":
return "string"
elif self.type == "array[object]":
return "object"
else:
raise _ParameterConfigError(f"{self.type} is not array type.")
class ParameterExtractorNodeData(BaseNodeData):
"""

View File

@@ -25,6 +25,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@@ -32,6 +33,7 @@ from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.utils import variable_template_parser
from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData
from .exc import (
@@ -109,6 +111,10 @@ class ParameterExtractorNode(BaseNode):
}
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
"""
Run the node.
@@ -584,28 +590,30 @@ class ParameterExtractorNode(BaseNode):
elif parameter.type in {"string", "select"}:
if isinstance(result[parameter.name], str):
transformed_result[parameter.name] = result[parameter.name]
elif parameter.type.startswith("array"):
elif parameter.is_array_type():
if isinstance(result[parameter.name], list):
nested_type = parameter.type[6:-1]
transformed_result[parameter.name] = []
nested_type = parameter.element_type()
assert nested_type is not None
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
transformed_result[parameter.name] = segment_value
for item in result[parameter.name]:
if nested_type == "number":
if isinstance(item, int | float):
transformed_result[parameter.name].append(item)
segment_value.value.append(item)
elif isinstance(item, str):
try:
if "." in item:
transformed_result[parameter.name].append(float(item))
segment_value.value.append(float(item))
else:
transformed_result[parameter.name].append(int(item))
segment_value.value.append(int(item))
except ValueError:
pass
elif nested_type == "string":
if isinstance(item, str):
transformed_result[parameter.name].append(item)
segment_value.value.append(item)
elif nested_type == "object":
if isinstance(item, dict):
transformed_result[parameter.name].append(item)
segment_value.value.append(item)
if parameter.name not in transformed_result:
if parameter.type == "number":
@@ -615,7 +623,9 @@ class ParameterExtractorNode(BaseNode):
elif parameter.type in {"string", "select"}:
transformed_result[parameter.name] = ""
elif parameter.type.startswith("array"):
transformed_result[parameter.name] = []
transformed_result[parameter.name] = build_segment_with_type(
segment_type=SegmentType(parameter.type), value=[]
)
return transformed_result

View File

@@ -10,6 +10,10 @@ class StartNode(BaseNode[StartNodeData]):
_node_data_cls = StartNodeData
_node_type = NodeType.START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables
@@ -18,5 +22,6 @@ class StartNode(BaseNode[StartNodeData]):
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)

View File

@@ -28,6 +28,10 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
# Get variables
variables = {}

View File

@@ -12,7 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
@@ -44,6 +44,10 @@ class ToolNode(BaseNode[ToolNodeData]):
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
"""
Run the tool node
@@ -300,6 +304,7 @@ class ToolNode(BaseNode[ToolNodeData]):
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, File)
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
@@ -363,7 +368,7 @@ class ToolNode(BaseNode[ToolNodeData]):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": files, "json": json, **variables},
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,

View File

@@ -1,3 +1,6 @@
from collections.abc import Mapping
from core.variables.segments import Segment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@@ -9,16 +12,20 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
# Get variables
outputs = {}
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
inputs = {}
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs = {"output": variable.to_object()}
outputs = {"output": variable}
inputs = {".".join(selector[1:]): variable.to_object()}
break
@@ -28,7 +35,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs[group.group_name] = {"output": variable.to_object()}
outputs[group.group_name] = {"output": variable}
inputs[".".join(selector[1:])] = variable.to_object()
break

View File

@@ -1,19 +1,55 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, TypeVar
from core.variables import Variable
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from extensions.ext_database import db
from models import ConversationVariable
from pydantic import BaseModel
from core.variables import Segment
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.types import SegmentType
# Use double underscore (`__`) prefix for internal variables
# to minimize risk of collision with user-defined variable names.
_UPDATED_VARIABLES_KEY = "__updated_variables"
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
class UpdatedVariable(BaseModel):
name: str
selector: Sequence[str]
value_type: SegmentType
new_value: Any
_T = TypeVar("_T", bound=MutableMapping[str, Any])
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
if len(selector) < MIN_SELECTORS_LENGTH:
raise Exception("selector too short")
node_id, var_name = selector[:2]
return UpdatedVariable(
name=var_name,
selector=list(selector[:2]),
value_type=seg.value_type,
new_value=seg.value,
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T:
m[_UPDATED_VARIABLES_KEY] = updates
return m
def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None:
updated_values = m.get(_UPDATED_VARIABLES_KEY, None)
if updated_values is None:
return None
result = []
for items in updated_values:
if isinstance(items, UpdatedVariable):
result.append(items)
elif isinstance(items, dict):
items = UpdatedVariable.model_validate(items)
result.append(items)
else:
raise TypeError(f"Invalid updated variable: {items}, type={type(items)}")
return result

View File

@@ -0,0 +1,38 @@
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session
from core.variables.variables import Variable
from models.engine import db
from models.workflow import ConversationVariable
from .exc import VariableOperatorNodeError
class ConversationVariableUpdaterImpl:
_engine: Engine | None
def __init__(self, engine: Engine | None = None) -> None:
self._engine = engine
def _get_engine(self) -> Engine:
if self._engine:
return self._engine
return db.engine
def update(self, conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(self._get_engine()) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
def flush(self):
pass
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
return ConversationVariableUpdaterImpl()

View File

@@ -1,4 +1,9 @@
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@@ -7,16 +12,71 @@ from core.workflow.nodes.variable_assigner.common import helpers as common_helpe
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory
from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
if TYPE_CHECKING:
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls = VariableAssignerData
_node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
self._conv_var_updater_factory = conv_var_updater_factory
@classmethod
def version(cls) -> str:
return "1"
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: VariableAssignerData,
) -> Mapping[str, Sequence[str]]:
mapping = {}
assigned_variable_node_id = node_data.assigned_variable_selector[0]
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
selector_key = ".".join(node_data.assigned_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = node_data.assigned_variable_selector
selector_key = ".".join(node_data.input_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = node_data.input_variable_selector
return mapping
def _run(self) -> NodeRunResult:
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found")
@@ -44,20 +104,28 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise VariableOperatorNodeError("conversation_id not found")
common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
conv_var_updater = self._conv_var_updater_factory()
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
conv_var_updater.flush()
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
"value": income_value.to_object(),
},
# NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`,
# we still set `output_variables` as a list to ensure the schema of output is
# compatible with `v2.VariableAssignerNode`.
process_data=common_helpers.set_updated_variables({}, updated_variables),
outputs={},
)

View File

@@ -12,6 +12,12 @@ class VariableOperationItem(BaseModel):
variable_selector: Sequence[str]
input_type: InputType
operation: Operation
# NOTE(QuantumGhost): The `value` field serves multiple purposes depending on context:
#
# 1. For CONSTANT input_type: Contains the literal value to be used in the operation.
# 2. For VARIABLE input_type: Initially contains the selector of the source variable.
# 3. During the variable updating procedure: The `value` field is reassigned to hold
# the resolved actual value that will be applied to the target variable.
value: Any | None = None

View File

@@ -29,3 +29,8 @@ class InvalidInputValueError(VariableOperatorNodeError):
class ConversationIDNotFoundError(VariableOperatorNodeError):
def __init__(self):
super().__init__("conversation_id not found")
class InvalidDataError(VariableOperatorNodeError):
def __init__(self, message: str) -> None:
super().__init__(message)

View File

@@ -1,34 +1,84 @@
import json
from collections.abc import Sequence
from typing import Any, cast
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from typing import Any, TypeAlias, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from . import helpers
from .constants import EMPTY_VALUE_MAPPING
from .entities import VariableAssignerNodeData
from .entities import VariableAssignerNodeData, VariableOperationItem
from .enums import InputType, Operation
from .exc import (
ConversationIDNotFoundError,
InputTypeNotSupportedError,
InvalidDataError,
InvalidInputValueError,
OperationNotSupportedError,
VariableNotFoundError,
)
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
selector_node_id = item.variable_selector[0]
if selector_node_id != CONVERSATION_VARIABLE_NODE_ID:
return
selector_str = ".".join(item.variable_selector)
key = f"{node_id}.#{selector_str}#"
mapping[key] = item.variable_selector
def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
# Keep this in sync with the logic in _run methods...
if item.input_type != InputType.VARIABLE:
return
selector = item.value
if not isinstance(selector, list):
raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}")
if len(selector) < MIN_SELECTORS_LENGTH:
raise InvalidDataError(f"selector too short, {node_id=}, {item=}")
selector_str = ".".join(selector)
key = f"{node_id}.#{selector_str}#"
mapping[key] = selector
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory()
@classmethod
def version(cls) -> str:
return "2"
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: VariableAssignerNodeData,
) -> Mapping[str, Sequence[str]]:
var_mapping: dict[str, Sequence[str]] = {}
for item in node_data.items:
_target_mapping_from_item(var_mapping, node_id, item)
_source_mapping_from_item(var_mapping, node_id, item)
return var_mapping
def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump()
process_data: dict[str, Any] = {}
@@ -114,6 +164,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
# remove the duplicated items first.
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
conv_var_updater = self._conv_var_updater_factory()
# Update variables
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
@@ -128,15 +179,23 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
raise ConversationIDNotFoundError
else:
conversation_id = conversation_id.value
common_helpers.update_conversation_variable(
conv_var_updater.update(
conversation_id=cast(str, conversation_id),
variable=variable,
)
conv_var_updater.flush()
updated_variables = [
common_helpers.variable_to_processed_data(selector, seg)
for selector in updated_variable_selectors
if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None
]
process_data = common_helpers.set_updated_variables(process_data, updated_variables)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={},
)
def _handle_item(

View File

@@ -0,0 +1,79 @@
import abc
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.variables import Variable
from core.workflow.entities.variable_pool import VariablePool
class VariableLoader(Protocol):
"""Interface for loading variables based on selectors.
A `VariableLoader` is responsible for retrieving additional variables required during the execution
of a single node, which are not provided as user inputs.
NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same
application and share the same `app_id`. However, this interface does not enforce that constraint,
and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of
concern and allow for flexible implementations.
Implementations of `VariableLoader` should almost always have an `app_id` parameter in
their constructor.
TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into
`WorkflowService.single_step_run`, we may get rid of this interface.
"""
@abc.abstractmethod
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
"""Load variables based on the provided selectors. If the selectors are empty,
this method should return an empty list.
The order of the returned variables is not guaranteed. If the caller wants to ensure
a specific order, they should sort the returned list themselves.
:param: selectors: a list of string list, each inner list should have at least two elements:
- the first element is the node ID,
- the second element is the variable name.
:return: a list of Variable objects that match the provided selectors.
"""
pass
class _DummyVariableLoader(VariableLoader):
"""A dummy implementation of VariableLoader that does not load any variables.
Serves as a placeholder when no variable loading is needed.
"""
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
return []
DUMMY_VARIABLE_LOADER = _DummyVariableLoader()
def load_into_variable_pool(
variable_loader: VariableLoader,
variable_pool: VariablePool,
variable_mapping: Mapping[str, Sequence[str]],
user_inputs: Mapping[str, Any],
):
# Loading missing variable from draft var here, and set it into
# variable_pool.
variables_to_load: list[list[str]] = []
for key, selector in variable_mapping.items():
# NOTE(QuantumGhost): this logic needs to be in sync with
# `WorkflowEntry.mapping_user_inputs_to_variable_pool`.
node_variable_list = key.split(".")
if len(node_variable_list) < 1:
raise ValueError(f"Invalid variable key: {key}. It should have at least one element.")
if key in user_inputs:
continue
node_variable_key = ".".join(node_variable_list[1:])
if node_variable_key in user_inputs:
continue
if variable_pool.get(selector) is None:
variables_to_load.append(list(selector))
loaded = variable_loader.load_variables(variables_to_load)
for var in loaded:
variable_pool.add(var.selector, var)

View File

@@ -92,7 +92,7 @@ class WorkflowCycleManager:
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
outputs = WorkflowEntry.handle_special_values(outputs)
# outputs = WorkflowEntry.handle_special_values(outputs)
workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
workflow_execution.outputs = outputs or {}
@@ -125,7 +125,7 @@ class WorkflowCycleManager:
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
# outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
execution.outputs = outputs or {}
@@ -242,9 +242,9 @@ class WorkflowCycleManager:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
# Process data
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
inputs = event.inputs
process_data = event.process_data
outputs = event.outputs
# Convert metadata keys to strings
execution_metadata_dict = {}
@@ -289,7 +289,7 @@ class WorkflowCycleManager:
# Process data
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
outputs = event.outputs
# Convert metadata keys to strings
execution_metadata_dict = {}
@@ -326,7 +326,7 @@ class WorkflowCycleManager:
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
outputs = event.outputs
# Convert metadata keys to strings
origin_metadata = {

View File

@@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from factories import file_factory
from models.enums import UserFrom
from models.workflow import (
@@ -119,7 +120,9 @@ class WorkflowEntry:
workflow: Workflow,
node_id: str,
user_id: str,
user_inputs: dict,
user_inputs: Mapping[str, Any],
variable_pool: VariablePool,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
"""
Single step run workflow node
@@ -129,29 +132,14 @@ class WorkflowEntry:
:param user_inputs: user inputs
:return:
"""
# fetch node info from workflow graph
workflow_graph = workflow.graph_dict
if not workflow_graph:
raise ValueError("workflow graph not found")
nodes = workflow_graph.get("nodes")
if not nodes:
raise ValueError("nodes not found in workflow graph")
# fetch node config from node id
try:
node_config = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise ValueError("node id not found in workflow graph")
node_config = workflow.get_node_config_by_id(node_id)
node_config_data = node_config.get("data", {})
# Get node class
node_type = NodeType(node_config.get("data", {}).get("type"))
node_version = node_config.get("data", {}).get("version", "1")
node_type = NodeType(node_config_data.get("type"))
node_version = node_config_data.get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(environment_variables=workflow.environment_variables)
# init graph
graph = Graph.init(graph_config=workflow.graph_dict)
@@ -182,16 +170,33 @@ class WorkflowEntry:
except NotImplementedError:
variable_mapping = {}
# Loading missing variable from draft var here, and set it into
# variable_pool.
load_into_variable_pool(
variable_loader=variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
try:
# run node
generator = node_instance.run()
except Exception as e:
logger.exception(
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
workflow.id,
node_instance.id,
node_instance.node_type,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
return node_instance, generator
@@ -294,10 +299,20 @@ class WorkflowEntry:
return node_instance, generator
except Exception as e:
logger.exception(
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
node_instance.id,
node_instance.node_type,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
@staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
# NOTE(QuantumGhost): Avoid using this function in new code.
# Keep values structured as long as possible and only convert to dict
# immediately before serialization (e.g., JSON serialization) to maintain
# data integrity and type information.
result = WorkflowEntry._handle_special_values(value)
return result if isinstance(result, Mapping) or result is None else dict(result)
@@ -324,10 +339,17 @@ class WorkflowEntry:
cls,
*,
variable_mapping: Mapping[str, Sequence[str]],
user_inputs: dict,
user_inputs: Mapping[str, Any],
variable_pool: VariablePool,
tenant_id: str,
) -> None:
# NOTE(QuantumGhost): This logic should remain synchronized with
# the implementation of `load_into_variable_pool`, specifically the logic about
# variable existence checking.
# WARNING(QuantumGhost): The semantics of this method are not clearly defined,
# and multiple parts of the codebase depend on its current behavior.
# Modify with caution.
for node_variable, variable_selector in variable_mapping.items():
# fetch node id and variable key from node_variable
node_variable_list = node_variable.split(".")

View File

@@ -0,0 +1,49 @@
import json
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel
from core.file.models import File
from core.variables import Segment
class WorkflowRuntimeTypeEncoder(json.JSONEncoder):
def default(self, o: Any):
if isinstance(o, Segment):
return o.value
elif isinstance(o, File):
return o.to_dict()
elif isinstance(o, BaseModel):
return o.model_dump(mode="json")
else:
return super().default(o)
class WorkflowRuntimeTypeConverter:
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
result = self._to_json_encodable_recursive(value)
return result if isinstance(result, Mapping) or result is None else dict(result)
def _to_json_encodable_recursive(self, value: Any) -> Any:
if value is None:
return value
if isinstance(value, (bool, int, str, float)):
return value
if isinstance(value, Segment):
return self._to_json_encodable_recursive(value.value)
if isinstance(value, File):
return value.to_dict()
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
if isinstance(value, dict):
res = {}
for k, v in value.items():
res[k] = self._to_json_encodable_recursive(v)
return res
if isinstance(value, list):
res_list = []
for item in value:
res_list.append(self._to_json_encodable_recursive(item))
return res_list
return value