feat(api): automatically NODE_TYPE_CLASSES_MAPPING generation from node class definitions (#28525)
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
import importlib
|
||||
import logging
|
||||
import operator
|
||||
import pkgutil
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from types import MappingProxyType
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -134,6 +138,34 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
cls._node_data_type = node_data_type
|
||||
|
||||
# Skip base class itself
|
||||
if cls is Node:
|
||||
return
|
||||
# Only register production node implementations defined under core.workflow.nodes.*
|
||||
# This prevents test helper subclasses from polluting the global registry and
|
||||
# accidentally overriding real node types (e.g., a test Answer node).
|
||||
module_name = getattr(cls, "__module__", "")
|
||||
# Only register concrete subclasses that define node_type and version()
|
||||
node_type = cls.node_type
|
||||
version = cls.version()
|
||||
bucket = Node._registry.setdefault(node_type, {})
|
||||
if module_name.startswith("core.workflow.nodes."):
|
||||
# Production node definitions take precedence and may override
|
||||
bucket[version] = cls # type: ignore[index]
|
||||
else:
|
||||
# External/test subclasses may register but must not override production
|
||||
bucket.setdefault(version, cls) # type: ignore[index]
|
||||
# Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic
|
||||
version_keys = [v for v in bucket if v != "latest"]
|
||||
numeric_pairs: list[tuple[str, int]] = []
|
||||
for v in version_keys:
|
||||
numeric_pairs.append((v, int(v)))
|
||||
if numeric_pairs:
|
||||
latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0]
|
||||
else:
|
||||
latest_key = max(version_keys) if version_keys else version
|
||||
bucket["latest"] = bucket[latest_key]
|
||||
|
||||
@classmethod
|
||||
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
|
||||
"""
|
||||
@@ -165,6 +197,9 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
return None
|
||||
|
||||
# Global registry populated via __init_subclass__
|
||||
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
@@ -395,6 +430,29 @@ class Node(Generic[NodeDataT]):
|
||||
# in `api/core/workflow/nodes/__init__.py`.
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@classmethod
|
||||
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
|
||||
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||
|
||||
Import all modules under core.workflow.nodes so subclasses register themselves on import.
|
||||
Then we return a readonly view of the registry to avoid accidental mutation.
|
||||
"""
|
||||
# Import all node modules to ensure they are loaded (thus registered)
|
||||
import core.workflow.nodes as _nodes_pkg
|
||||
|
||||
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
|
||||
# Avoid importing modules that depend on the registry to prevent circular imports
|
||||
# e.g. node_factory imports node_mapping which builds the mapping here.
|
||||
if _modname in {
|
||||
"core.workflow.nodes.node_factory",
|
||||
"core.workflow.nodes.node_mapping",
|
||||
}:
|
||||
continue
|
||||
importlib.import_module(_modname)
|
||||
|
||||
# Return a readonly view so callers can't mutate the registry by accident
|
||||
return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()}
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user