feat(api): automatically NODE_TYPE_CLASSES_MAPPING generation from node class definitions (#28525)

This commit is contained in:
wangxiaolei
2025-12-01 14:14:19 +08:00
committed by GitHub
parent 2f8cb2a1af
commit d162f7e5ef
11 changed files with 245 additions and 189 deletions

View File

@@ -29,7 +29,7 @@ class _TestNode(Node[_TestNodeData]):
@classmethod
def version(cls) -> str:
return "test"
return "1"
def __init__(
self,

View File

@@ -92,7 +92,7 @@ class MockLLMNode(MockNodeMixin, LLMNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock LLM node."""
@@ -189,7 +189,7 @@ class MockAgentNode(MockNodeMixin, AgentNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock agent node."""
@@ -241,7 +241,7 @@ class MockToolNode(MockNodeMixin, ToolNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock tool node."""
@@ -294,7 +294,7 @@ class MockKnowledgeRetrievalNode(MockNodeMixin, KnowledgeRetrievalNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock knowledge retrieval node."""
@@ -351,7 +351,7 @@ class MockHttpRequestNode(MockNodeMixin, HttpRequestNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock HTTP request node."""
@@ -404,7 +404,7 @@ class MockQuestionClassifierNode(MockNodeMixin, QuestionClassifierNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock question classifier node."""
@@ -452,7 +452,7 @@ class MockParameterExtractorNode(MockNodeMixin, ParameterExtractorNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock parameter extractor node."""
@@ -502,7 +502,7 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock document extractor node."""
@@ -557,7 +557,7 @@ class MockIterationNode(MockNodeMixin, IterationNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _create_graph_engine(self, index: int, item: Any):
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
@@ -632,7 +632,7 @@ class MockLoopNode(MockNodeMixin, LoopNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _create_graph_engine(self, start_at, root_node_id: str):
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
@@ -694,7 +694,7 @@ class MockTemplateTransformNode(MockNodeMixin, TemplateTransformNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> NodeRunResult:
"""Execute mock template transform node."""
@@ -780,7 +780,7 @@ class MockCodeNode(MockNodeMixin, CodeNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> NodeRunResult:
"""Execute mock code node."""

View File

@@ -33,6 +33,10 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
type_version_set: set[tuple[NodeType, str]] = set()
for cls in classes:
# Only validate production node classes; skip test-defined subclasses and external helpers
module_name = getattr(cls, "__module__", "")
if not module_name.startswith("core."):
continue
# Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__
assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)"
node_type = cls.node_type

View File

@@ -0,0 +1,84 @@
import types
from collections.abc import Mapping
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.node import Node
# Import concrete nodes we will assert on (numeric version path)
from core.workflow.nodes.variable_assigner.v1.node import (
VariableAssignerNode as VariableAssignerV1,
)
from core.workflow.nodes.variable_assigner.v2.node import (
VariableAssignerNode as VariableAssignerV2,
)
def test_variable_assigner_latest_prefers_highest_numeric_version():
# Act
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
# Assert basic presence
assert NodeType.VARIABLE_ASSIGNER in mapping
va_versions = mapping[NodeType.VARIABLE_ASSIGNER]
# Both concrete versions must be present
assert va_versions.get("1") is VariableAssignerV1
assert va_versions.get("2") is VariableAssignerV2
# And latest should point to numerically-highest version ("2")
assert va_versions.get("latest") is VariableAssignerV2
def test_latest_prefers_highest_numeric_version():
# Arrange: define two ephemeral subclasses with numeric versions under a NodeType
# that has no concrete implementations in production to avoid interference.
class _Version1(Node[BaseNodeData]): # type: ignore[misc]
node_type = NodeType.LEGACY_VARIABLE_AGGREGATOR
def init_node_data(self, data):
pass
def _run(self):
raise NotImplementedError
@classmethod
def version(cls) -> str:
return "1"
def _get_error_strategy(self):
return None
def _get_retry_config(self):
return types.SimpleNamespace() # not used
def _get_title(self) -> str:
return "version1"
def _get_description(self):
return None
def _get_default_value_dict(self):
return {}
def get_base_node_data(self):
return types.SimpleNamespace(title="version1")
class _Version2(_Version1): # type: ignore[misc]
@classmethod
def version(cls) -> str:
return "2"
def _get_title(self) -> str:
return "version2"
# Act: build a fresh mapping (it should now see our ephemeral subclasses)
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
# Assert: both numeric versions exist for this NodeType; 'latest' points to the higher numeric version
assert NodeType.LEGACY_VARIABLE_AGGREGATOR in mapping
legacy_versions = mapping[NodeType.LEGACY_VARIABLE_AGGREGATOR]
assert legacy_versions.get("1") is _Version1
assert legacy_versions.get("2") is _Version2
assert legacy_versions.get("latest") is _Version2

View File

@@ -19,7 +19,7 @@ class _SampleNode(Node[_SampleNodeData]):
@classmethod
def version(cls) -> str:
return "sample-test"
return "1"
def _run(self):
raise NotImplementedError