refactor: decouple Node and NodeData (#22581)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
-LAN-
2025-07-18 10:08:51 +08:00
committed by GitHub
parent 54c56f2d05
commit 460a825ef1
65 changed files with 2305 additions and 1146 deletions

View File

@@ -122,13 +122,13 @@ class RetryConfig(BaseModel):
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
version: str = "1"
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
version: str = "1"
retry_config: RetryConfig = RetryConfig()
@property
def default_value_dict(self):
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}

View File

@@ -1,28 +1,22 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from .entities import BaseNodeData
if TYPE_CHECKING:
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
logger = logging.getLogger(__name__)
GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
class BaseNode(Generic[GenericNodeData]):
_node_data_cls: type[GenericNodeData]
class BaseNode:
_node_type: ClassVar[NodeType]
def __init__(
@@ -56,8 +50,8 @@ class BaseNode(Generic[GenericNodeData]):
self.node_id = node_id
node_data = self._node_data_cls.model_validate(config.get("data", {}))
self.node_data = node_data
@abstractmethod
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
@@ -130,9 +124,9 @@ class BaseNode(Generic[GenericNodeData]):
if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_data = cls._node_data_cls(**config.get("data", {}))
# Pass raw dict data instead of creating NodeData instance
data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
)
return data
@@ -142,32 +136,16 @@ class BaseNode(Generic[GenericNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: GenericNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {}
@property
def node_type(self) -> NodeType:
"""
Get node type
:return:
"""
def type_(self) -> NodeType:
return self._node_type
@classmethod
@@ -181,19 +159,68 @@ class BaseNode(Generic[GenericNodeData]):
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@property
def should_continue_on_error(self) -> bool:
"""judge if should continue on error
Returns:
bool: if should continue on error
"""
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
def continue_on_error(self) -> bool:
return False
@property
def should_retry(self) -> bool:
"""judge if should retry
def retry(self) -> bool:
return False
Returns:
bool: if should retry
"""
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE
# Abstract methods that subclasses must implement to provide access
# to BaseNodeData properties in a type-safe way
@abstractmethod
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
"""Get the error strategy for this node."""
...
@abstractmethod
def _get_retry_config(self) -> RetryConfig:
"""Get the retry configuration for this node."""
...
@abstractmethod
def _get_title(self) -> str:
"""Get the node title."""
...
@abstractmethod
def _get_description(self) -> Optional[str]:
"""Get the node description."""
...
@abstractmethod
def _get_default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
...
@abstractmethod
def get_base_node_data(self) -> BaseNodeData:
"""Get the BaseNodeData object for this node."""
...
# Public interface properties that delegate to abstract methods
@property
def error_strategy(self) -> Optional[ErrorStrategy]:
"""Get the error strategy for this node."""
return self._get_error_strategy()
@property
def retry_config(self) -> RetryConfig:
"""Get the retry configuration for this node."""
return self._get_retry_config()
@property
def title(self) -> str:
"""Get the node title."""
return self._get_title()
@property
def description(self) -> Optional[str]:
"""Get the node description."""
return self._get_description()
@property
def default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
return self._get_default_value_dict()