Refactor: centralize node data hydration (#27771)

This commit is contained in:
-LAN-
2025-11-27 15:41:56 +08:00
committed by GitHub
parent 1b733abe82
commit 13bf6547ee
58 changed files with 381 additions and 899 deletions

View File

@@ -2,7 +2,7 @@ import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
from typing import Any, ClassVar
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -49,12 +49,121 @@ from models.enums import UserFrom
from .entities import BaseNodeData, RetryConfig
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
logger = logging.getLogger(__name__)
class Node:
class Node(Generic[NodeDataT]):
node_type: ClassVar["NodeType"]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
def __init_subclass__(cls, **kwargs: Any) -> None:
"""
Automatically extract and validate the node data type from the generic parameter.
When a subclass is defined as `class MyNode(Node[MyNodeData])`, this method:
1. Inspects `__orig_bases__` to find the `Node[T]` parameterization
2. Extracts `T` (e.g., `MyNodeData`) from the generic argument
3. Validates that `T` is a proper `BaseNodeData` subclass
4. Stores it in `_node_data_type` for automatic hydration in `__init__`
This eliminates the need for subclasses to manually implement boilerplate
accessor methods like `_get_title()`, `_get_error_strategy()`, etc.
How it works:
::
class CodeNode(Node[CodeNodeData]):
│ │
│ └─────────────────────────────────┐
│ │
▼ ▼
┌─────────────────────────────┐ ┌─────────────────────────────────┐
│ __orig_bases__ = ( │ │ CodeNodeData(BaseNodeData) │
│ Node[CodeNodeData], │ │ title: str │
│ ) │ │ desc: str | None │
└──────────────┬──────────────┘ │ ... │
│ └─────────────────────────────────┘
▼ ▲
┌─────────────────────────────┐ │
│ get_origin(base) -> Node │ │
│ get_args(base) -> ( │ │
│ CodeNodeData, │ ──────────────────────┘
│ ) │
└──────────────┬──────────────┘
┌─────────────────────────────┐
│ Validate: │
│ - Is it a type? │
│ - Is it a BaseNodeData │
│ subclass? │
└──────────────┬──────────────┘
┌─────────────────────────────┐
│ cls._node_data_type = │
│ CodeNodeData │
└─────────────────────────────┘
Later, in __init__:
::
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
CodeNodeData instance
(stored in self._node_data)
Example:
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
node_type = NodeType.CODE
# No need to implement _get_title, _get_error_strategy, etc.
"""
super().__init_subclass__(**kwargs)
if cls is Node:
return
node_data_type = cls._extract_node_data_type_from_generic()
if node_data_type is None:
raise TypeError(f"{cls.__name__} must inherit from Node[T] with a BaseNodeData subtype")
cls._node_data_type = node_data_type
@classmethod
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
"""
Extract the node data type from the generic parameter `Node[T]`.
Inspects `__orig_bases__` to find the `Node[T]` parameterization and extracts `T`.
Returns:
The extracted BaseNodeData subtype, or None if not found.
Raises:
TypeError: If the generic argument is invalid (not exactly one argument,
or not a BaseNodeData subtype).
"""
# __orig_bases__ contains the original generic bases before type erasure.
# For `class CodeNode(Node[CodeNodeData])`, this would be `(Node[CodeNodeData],)`.
for base in getattr(cls, "__orig_bases__", ()): # type: ignore[attr-defined]
origin = get_origin(base) # Returns `Node` for `Node[CodeNodeData]`
if origin is Node:
args = get_args(base) # Returns `(CodeNodeData,)` for `Node[CodeNodeData]`
if len(args) != 1:
raise TypeError(f"{cls.__name__} must specify exactly one node data generic argument")
candidate = args[0]
if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData):
raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype")
return candidate
return None
def __init__(
self,
@@ -63,6 +172,7 @@ class Node:
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
self._graph_init_params = graph_init_params
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
@@ -83,8 +193,24 @@ class Node:
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
@abstractmethod
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
raw_node_data = config.get("data") or {}
if not isinstance(raw_node_data, Mapping):
raise ValueError("Node config data must be a mapping.")
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
self.post_init()
def post_init(self) -> None:
"""Optional hook for subclasses requiring extra initialization."""
return
@property
def graph_init_params(self) -> "GraphInitParams":
return self._graph_init_params
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@abstractmethod
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
@@ -273,38 +399,29 @@ class Node:
def retry(self) -> bool:
return False
# Abstract methods that subclasses must implement to provide access
# to BaseNodeData properties in a type-safe way
@abstractmethod
def _get_error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
...
return self._node_data.error_strategy
@abstractmethod
def _get_retry_config(self) -> RetryConfig:
"""Get the retry configuration for this node."""
...
return self._node_data.retry_config
@abstractmethod
def _get_title(self) -> str:
"""Get the node title."""
...
return self._node_data.title
@abstractmethod
def _get_description(self) -> str | None:
"""Get the node description."""
...
return self._node_data.desc
@abstractmethod
def _get_default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
...
return self._node_data.default_value_dict
@abstractmethod
def get_base_node_data(self) -> BaseNodeData:
"""Get the BaseNodeData object for this node."""
...
return self._node_data
# Public interface properties that delegate to abstract methods
@property
@@ -332,6 +449,11 @@ class Node:
"""Get the default values dictionary for this node."""
return self._get_default_value_dict()
@property
def node_data(self) -> NodeDataT:
"""Typed access to this node's configuration data."""
return self._node_data
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
match result.status:
case WorkflowNodeExecutionStatus.FAILED: