refactor: decouple Node and NodeData (#22581)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
@@ -122,13 +122,13 @@ class RetryConfig(BaseModel):
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
version: str = "1"
|
||||
error_strategy: Optional[ErrorStrategy] = None
|
||||
default_value: Optional[list[DefaultValue]] = None
|
||||
version: str = "1"
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
def default_value_dict(self):
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
|
||||
@@ -1,28 +1,22 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
|
||||
from .entities import BaseNodeData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
|
||||
|
||||
|
||||
class BaseNode(Generic[GenericNodeData]):
|
||||
_node_data_cls: type[GenericNodeData]
|
||||
class BaseNode:
|
||||
_node_type: ClassVar[NodeType]
|
||||
|
||||
def __init__(
|
||||
@@ -56,8 +50,8 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
|
||||
self.node_id = node_id
|
||||
|
||||
node_data = self._node_data_cls.model_validate(config.get("data", {}))
|
||||
self.node_data = node_data
|
||||
@abstractmethod
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
@@ -130,9 +124,9 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
# Pass raw dict data instead of creating NodeData instance
|
||||
data = cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
|
||||
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -142,32 +136,16 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: GenericNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def node_type(self) -> NodeType:
|
||||
"""
|
||||
Get node type
|
||||
:return:
|
||||
"""
|
||||
def type_(self) -> NodeType:
|
||||
return self._node_type
|
||||
|
||||
@classmethod
|
||||
@@ -181,19 +159,68 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@property
|
||||
def should_continue_on_error(self) -> bool:
|
||||
"""judge if should continue on error
|
||||
|
||||
Returns:
|
||||
bool: if should continue on error
|
||||
"""
|
||||
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
|
||||
def continue_on_error(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def should_retry(self) -> bool:
|
||||
"""judge if should retry
|
||||
def retry(self) -> bool:
|
||||
return False
|
||||
|
||||
Returns:
|
||||
bool: if should retry
|
||||
"""
|
||||
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE
|
||||
# Abstract methods that subclasses must implement to provide access
|
||||
# to BaseNodeData properties in a type-safe way
|
||||
|
||||
@abstractmethod
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
"""Get the error strategy for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_title(self) -> str:
|
||||
"""Get the node title."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_description(self) -> Optional[str]:
|
||||
"""Get the node description."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
"""Get the BaseNodeData object for this node."""
|
||||
...
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
def error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
"""Get the error strategy for this node."""
|
||||
return self._get_error_strategy()
|
||||
|
||||
@property
|
||||
def retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
return self._get_retry_config()
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
"""Get the node title."""
|
||||
return self._get_title()
|
||||
|
||||
@property
|
||||
def description(self) -> Optional[str]:
|
||||
"""Get the node description."""
|
||||
return self._get_description()
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
return self._get_default_value_dict()
|
||||
|
||||
Reference in New Issue
Block a user