Refactor: centralize node data hydration (#27771)
This commit is contained in:
@@ -2,7 +2,7 @@ import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@@ -49,12 +49,121 @@ from models.enums import UserFrom
|
||||
|
||||
from .entities import BaseNodeData, RetryConfig
|
||||
|
||||
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Node:
|
||||
class Node(Generic[NodeDataT]):
|
||||
node_type: ClassVar["NodeType"]
|
||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
"""
|
||||
Automatically extract and validate the node data type from the generic parameter.
|
||||
|
||||
When a subclass is defined as `class MyNode(Node[MyNodeData])`, this method:
|
||||
1. Inspects `__orig_bases__` to find the `Node[T]` parameterization
|
||||
2. Extracts `T` (e.g., `MyNodeData`) from the generic argument
|
||||
3. Validates that `T` is a proper `BaseNodeData` subclass
|
||||
4. Stores it in `_node_data_type` for automatic hydration in `__init__`
|
||||
|
||||
This eliminates the need for subclasses to manually implement boilerplate
|
||||
accessor methods like `_get_title()`, `_get_error_strategy()`, etc.
|
||||
|
||||
How it works:
|
||||
::
|
||||
|
||||
class CodeNode(Node[CodeNodeData]):
|
||||
│ │
|
||||
│ └─────────────────────────────────┐
|
||||
│ │
|
||||
▼ ▼
|
||||
┌─────────────────────────────┐ ┌─────────────────────────────────┐
|
||||
│ __orig_bases__ = ( │ │ CodeNodeData(BaseNodeData) │
|
||||
│ Node[CodeNodeData], │ │ title: str │
|
||||
│ ) │ │ desc: str | None │
|
||||
└──────────────┬──────────────┘ │ ... │
|
||||
│ └─────────────────────────────────┘
|
||||
▼ ▲
|
||||
┌─────────────────────────────┐ │
|
||||
│ get_origin(base) -> Node │ │
|
||||
│ get_args(base) -> ( │ │
|
||||
│ CodeNodeData, │ ──────────────────────┘
|
||||
│ ) │
|
||||
└──────────────┬──────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────┐
|
||||
│ Validate: │
|
||||
│ - Is it a type? │
|
||||
│ - Is it a BaseNodeData │
|
||||
│ subclass? │
|
||||
└──────────────┬──────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────┐
|
||||
│ cls._node_data_type = │
|
||||
│ CodeNodeData │
|
||||
└─────────────────────────────┘
|
||||
|
||||
Later, in __init__:
|
||||
::
|
||||
|
||||
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
|
||||
│
|
||||
▼
|
||||
CodeNodeData instance
|
||||
(stored in self._node_data)
|
||||
|
||||
Example:
|
||||
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
|
||||
node_type = NodeType.CODE
|
||||
# No need to implement _get_title, _get_error_strategy, etc.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
if cls is Node:
|
||||
return
|
||||
|
||||
node_data_type = cls._extract_node_data_type_from_generic()
|
||||
|
||||
if node_data_type is None:
|
||||
raise TypeError(f"{cls.__name__} must inherit from Node[T] with a BaseNodeData subtype")
|
||||
|
||||
cls._node_data_type = node_data_type
|
||||
|
||||
@classmethod
|
||||
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
|
||||
"""
|
||||
Extract the node data type from the generic parameter `Node[T]`.
|
||||
|
||||
Inspects `__orig_bases__` to find the `Node[T]` parameterization and extracts `T`.
|
||||
|
||||
Returns:
|
||||
The extracted BaseNodeData subtype, or None if not found.
|
||||
|
||||
Raises:
|
||||
TypeError: If the generic argument is invalid (not exactly one argument,
|
||||
or not a BaseNodeData subtype).
|
||||
"""
|
||||
# __orig_bases__ contains the original generic bases before type erasure.
|
||||
# For `class CodeNode(Node[CodeNodeData])`, this would be `(Node[CodeNodeData],)`.
|
||||
for base in getattr(cls, "__orig_bases__", ()): # type: ignore[attr-defined]
|
||||
origin = get_origin(base) # Returns `Node` for `Node[CodeNodeData]`
|
||||
if origin is Node:
|
||||
args = get_args(base) # Returns `(CodeNodeData,)` for `Node[CodeNodeData]`
|
||||
if len(args) != 1:
|
||||
raise TypeError(f"{cls.__name__} must specify exactly one node data generic argument")
|
||||
|
||||
candidate = args[0]
|
||||
if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData):
|
||||
raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype")
|
||||
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -63,6 +172,7 @@ class Node:
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
self._graph_init_params = graph_init_params
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
@@ -83,8 +193,24 @@ class Node:
|
||||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
@abstractmethod
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
||||
raw_node_data = config.get("data") or {}
|
||||
if not isinstance(raw_node_data, Mapping):
|
||||
raise ValueError("Node config data must be a mapping.")
|
||||
|
||||
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self) -> None:
|
||||
"""Optional hook for subclasses requiring extra initialization."""
|
||||
return
|
||||
|
||||
@property
|
||||
def graph_init_params(self) -> "GraphInitParams":
|
||||
return self._graph_init_params
|
||||
|
||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
@@ -273,38 +399,29 @@ class Node:
|
||||
def retry(self) -> bool:
|
||||
return False
|
||||
|
||||
# Abstract methods that subclasses must implement to provide access
|
||||
# to BaseNodeData properties in a type-safe way
|
||||
|
||||
@abstractmethod
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
...
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@abstractmethod
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
...
|
||||
return self._node_data.retry_config
|
||||
|
||||
@abstractmethod
|
||||
def _get_title(self) -> str:
|
||||
"""Get the node title."""
|
||||
...
|
||||
return self._node_data.title
|
||||
|
||||
@abstractmethod
|
||||
def _get_description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
...
|
||||
return self._node_data.desc
|
||||
|
||||
@abstractmethod
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
...
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
@abstractmethod
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
"""Get the BaseNodeData object for this node."""
|
||||
...
|
||||
return self._node_data
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
@@ -332,6 +449,11 @@ class Node:
|
||||
"""Get the default values dictionary for this node."""
|
||||
return self._get_default_value_dict()
|
||||
|
||||
@property
|
||||
def node_data(self) -> NodeDataT:
|
||||
"""Typed access to this node's configuration data."""
|
||||
return self._node_data
|
||||
|
||||
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
|
||||
match result.status:
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
|
||||
Reference in New Issue
Block a user