feat(api): automatically NODE_TYPE_CLASSES_MAPPING generation from node class definitions (#28525)
This commit is contained in:
@@ -29,7 +29,7 @@ class _TestNode(Node[_TestNodeData]):
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "test"
|
||||
return "1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -92,7 +92,7 @@ class MockLLMNode(MockNodeMixin, LLMNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock LLM node."""
|
||||
@@ -189,7 +189,7 @@ class MockAgentNode(MockNodeMixin, AgentNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock agent node."""
|
||||
@@ -241,7 +241,7 @@ class MockToolNode(MockNodeMixin, ToolNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock tool node."""
|
||||
@@ -294,7 +294,7 @@ class MockKnowledgeRetrievalNode(MockNodeMixin, KnowledgeRetrievalNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock knowledge retrieval node."""
|
||||
@@ -351,7 +351,7 @@ class MockHttpRequestNode(MockNodeMixin, HttpRequestNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock HTTP request node."""
|
||||
@@ -404,7 +404,7 @@ class MockQuestionClassifierNode(MockNodeMixin, QuestionClassifierNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock question classifier node."""
|
||||
@@ -452,7 +452,7 @@ class MockParameterExtractorNode(MockNodeMixin, ParameterExtractorNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock parameter extractor node."""
|
||||
@@ -502,7 +502,7 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock document extractor node."""
|
||||
@@ -557,7 +557,7 @@ class MockIterationNode(MockNodeMixin, IterationNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _create_graph_engine(self, index: int, item: Any):
|
||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||
@@ -632,7 +632,7 @@ class MockLoopNode(MockNodeMixin, LoopNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _create_graph_engine(self, start_at, root_node_id: str):
|
||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||
@@ -694,7 +694,7 @@ class MockTemplateTransformNode(MockNodeMixin, TemplateTransformNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""Execute mock template transform node."""
|
||||
@@ -780,7 +780,7 @@ class MockCodeNode(MockNodeMixin, CodeNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""Execute mock code node."""
|
||||
|
||||
@@ -33,6 +33,10 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
|
||||
type_version_set: set[tuple[NodeType, str]] = set()
|
||||
|
||||
for cls in classes:
|
||||
# Only validate production node classes; skip test-defined subclasses and external helpers
|
||||
module_name = getattr(cls, "__module__", "")
|
||||
if not module_name.startswith("core."):
|
||||
continue
|
||||
# Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__
|
||||
assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)"
|
||||
node_type = cls.node_type
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
import types
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
# Import concrete nodes we will assert on (numeric version path)
|
||||
from core.workflow.nodes.variable_assigner.v1.node import (
|
||||
VariableAssignerNode as VariableAssignerV1,
|
||||
)
|
||||
from core.workflow.nodes.variable_assigner.v2.node import (
|
||||
VariableAssignerNode as VariableAssignerV2,
|
||||
)
|
||||
|
||||
|
||||
def test_variable_assigner_latest_prefers_highest_numeric_version():
|
||||
# Act
|
||||
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
|
||||
# Assert basic presence
|
||||
assert NodeType.VARIABLE_ASSIGNER in mapping
|
||||
va_versions = mapping[NodeType.VARIABLE_ASSIGNER]
|
||||
|
||||
# Both concrete versions must be present
|
||||
assert va_versions.get("1") is VariableAssignerV1
|
||||
assert va_versions.get("2") is VariableAssignerV2
|
||||
|
||||
# And latest should point to numerically-highest version ("2")
|
||||
assert va_versions.get("latest") is VariableAssignerV2
|
||||
|
||||
|
||||
def test_latest_prefers_highest_numeric_version():
|
||||
# Arrange: define two ephemeral subclasses with numeric versions under a NodeType
|
||||
# that has no concrete implementations in production to avoid interference.
|
||||
class _Version1(Node[BaseNodeData]): # type: ignore[misc]
|
||||
node_type = NodeType.LEGACY_VARIABLE_AGGREGATOR
|
||||
|
||||
def init_node_data(self, data):
|
||||
pass
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _get_error_strategy(self):
|
||||
return None
|
||||
|
||||
def _get_retry_config(self):
|
||||
return types.SimpleNamespace() # not used
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return "version1"
|
||||
|
||||
def _get_description(self):
|
||||
return None
|
||||
|
||||
def _get_default_value_dict(self):
|
||||
return {}
|
||||
|
||||
def get_base_node_data(self):
|
||||
return types.SimpleNamespace(title="version1")
|
||||
|
||||
class _Version2(_Version1): # type: ignore[misc]
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "2"
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return "version2"
|
||||
|
||||
# Act: build a fresh mapping (it should now see our ephemeral subclasses)
|
||||
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
|
||||
# Assert: both numeric versions exist for this NodeType; 'latest' points to the higher numeric version
|
||||
assert NodeType.LEGACY_VARIABLE_AGGREGATOR in mapping
|
||||
legacy_versions = mapping[NodeType.LEGACY_VARIABLE_AGGREGATOR]
|
||||
|
||||
assert legacy_versions.get("1") is _Version1
|
||||
assert legacy_versions.get("2") is _Version2
|
||||
assert legacy_versions.get("latest") is _Version2
|
||||
@@ -19,7 +19,7 @@ class _SampleNode(Node[_SampleNodeData]):
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "sample-test"
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user