Add workflow graph validation checks (#27106)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
from ..runtime.graph_runtime_state import GraphRuntimeState
|
||||||
|
from ..runtime.variable_pool import VariablePool
|
||||||
from .agent import AgentNodeStrategyInit
|
from .agent import AgentNodeStrategyInit
|
||||||
from .graph_init_params import GraphInitParams
|
from .graph_init_params import GraphInitParams
|
||||||
from .workflow_execution import WorkflowExecution
|
from .workflow_execution import WorkflowExecution
|
||||||
@@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentNodeStrategyInit",
|
"AgentNodeStrategyInit",
|
||||||
"GraphInitParams",
|
"GraphInitParams",
|
||||||
|
"GraphRuntimeState",
|
||||||
|
"VariablePool",
|
||||||
"WorkflowExecution",
|
"WorkflowExecution",
|
||||||
"WorkflowNodeExecution",
|
"WorkflowNodeExecution",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -3,11 +3,12 @@ from collections import defaultdict
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Protocol, cast, final
|
from typing import Protocol, cast, final
|
||||||
|
|
||||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from libs.typing import is_str, is_str_dict
|
from libs.typing import is_str, is_str_dict
|
||||||
|
|
||||||
from .edge import Edge
|
from .edge import Edge
|
||||||
|
from .validation import get_graph_validator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -201,6 +202,17 @@ class Graph:
|
|||||||
|
|
||||||
return GraphBuilder(graph_cls=cls)
|
return GraphBuilder(graph_cls=cls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
|
||||||
|
"""
|
||||||
|
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
|
||||||
|
|
||||||
|
:param nodes: mapping of node ID to node instance
|
||||||
|
"""
|
||||||
|
for node in nodes.values():
|
||||||
|
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||||
|
node.execution_type = NodeExecutionType.BRANCH
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _mark_inactive_root_branches(
|
def _mark_inactive_root_branches(
|
||||||
cls,
|
cls,
|
||||||
@@ -307,6 +319,9 @@ class Graph:
|
|||||||
# Create node instances
|
# Create node instances
|
||||||
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
||||||
|
|
||||||
|
# Promote fail-branch nodes to branch execution type at graph level
|
||||||
|
cls._promote_fail_branch_nodes(nodes)
|
||||||
|
|
||||||
# Get root node instance
|
# Get root node instance
|
||||||
root_node = nodes[root_node_id]
|
root_node = nodes[root_node_id]
|
||||||
|
|
||||||
@@ -314,7 +329,7 @@ class Graph:
|
|||||||
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
||||||
|
|
||||||
# Create and return the graph
|
# Create and return the graph
|
||||||
return cls(
|
graph = cls(
|
||||||
nodes=nodes,
|
nodes=nodes,
|
||||||
edges=edges,
|
edges=edges,
|
||||||
in_edges=in_edges,
|
in_edges=in_edges,
|
||||||
@@ -322,6 +337,11 @@ class Graph:
|
|||||||
root_node=root_node,
|
root_node=root_node,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate the graph structure using built-in validators
|
||||||
|
get_graph_validator().validate(graph)
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def node_ids(self) -> list[str]:
|
def node_ids(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
125
api/core/workflow/graph/validation.py
Normal file
125
api/core/workflow/graph/validation.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
|
from core.workflow.enums import NodeExecutionType, NodeType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .graph import Graph
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class GraphValidationIssue:
|
||||||
|
"""Immutable value object describing a single validation issue."""
|
||||||
|
|
||||||
|
code: str
|
||||||
|
message: str
|
||||||
|
node_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GraphValidationError(ValueError):
|
||||||
|
"""Raised when graph validation fails."""
|
||||||
|
|
||||||
|
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
|
||||||
|
if not issues:
|
||||||
|
raise ValueError("GraphValidationError requires at least one issue.")
|
||||||
|
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
|
||||||
|
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphValidationRule(Protocol):
|
||||||
|
"""Protocol that individual validation rules must satisfy."""
|
||||||
|
|
||||||
|
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||||
|
"""Validate the provided graph and return any discovered issues."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class _EdgeEndpointValidator:
|
||||||
|
"""Ensures all edges reference existing nodes."""
|
||||||
|
|
||||||
|
missing_node_code: str = "MISSING_NODE"
|
||||||
|
|
||||||
|
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||||
|
issues: list[GraphValidationIssue] = []
|
||||||
|
for edge in graph.edges.values():
|
||||||
|
if edge.tail not in graph.nodes:
|
||||||
|
issues.append(
|
||||||
|
GraphValidationIssue(
|
||||||
|
code=self.missing_node_code,
|
||||||
|
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
|
||||||
|
node_id=edge.tail,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if edge.head not in graph.nodes:
|
||||||
|
issues.append(
|
||||||
|
GraphValidationIssue(
|
||||||
|
code=self.missing_node_code,
|
||||||
|
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
|
||||||
|
node_id=edge.head,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return issues
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class _RootNodeValidator:
|
||||||
|
"""Validates root node invariants."""
|
||||||
|
|
||||||
|
invalid_root_code: str = "INVALID_ROOT"
|
||||||
|
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
|
||||||
|
|
||||||
|
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||||
|
root_node = graph.root_node
|
||||||
|
issues: list[GraphValidationIssue] = []
|
||||||
|
if root_node.id not in graph.nodes:
|
||||||
|
issues.append(
|
||||||
|
GraphValidationIssue(
|
||||||
|
code=self.invalid_root_code,
|
||||||
|
message=f"Root node '{root_node.id}' is missing from the node registry.",
|
||||||
|
node_id=root_node.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return issues
|
||||||
|
|
||||||
|
node_type = getattr(root_node, "node_type", None)
|
||||||
|
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
|
||||||
|
issues.append(
|
||||||
|
GraphValidationIssue(
|
||||||
|
code=self.invalid_root_code,
|
||||||
|
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
|
||||||
|
node_id=root_node.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return issues
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class GraphValidator:
|
||||||
|
"""Coordinates execution of graph validation rules."""
|
||||||
|
|
||||||
|
rules: tuple[GraphValidationRule, ...]
|
||||||
|
|
||||||
|
def validate(self, graph: Graph) -> None:
|
||||||
|
"""Validate the graph against all configured rules."""
|
||||||
|
issues: list[GraphValidationIssue] = []
|
||||||
|
for rule in self.rules:
|
||||||
|
issues.extend(rule.validate(graph))
|
||||||
|
|
||||||
|
if issues:
|
||||||
|
raise GraphValidationError(issues)
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
|
||||||
|
_EdgeEndpointValidator(),
|
||||||
|
_RootNodeValidator(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_graph_validator() -> GraphValidator:
|
||||||
|
"""Construct the validator composed of default rules."""
|
||||||
|
return GraphValidator(_DEFAULT_RULES)
|
||||||
@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final
|
|||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
from core.workflow.enums import NodeType
|
||||||
from core.workflow.graph import NodeFactory
|
from core.workflow.graph import NodeFactory
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from libs.typing import is_str, is_str_dict
|
from libs.typing import is_str, is_str_dict
|
||||||
@@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory):
|
|||||||
raise ValueError(f"Node {node_id} missing data information")
|
raise ValueError(f"Node {node_id} missing data information")
|
||||||
node_instance.init_node_data(node_data)
|
node_instance.init_node_data(node_data)
|
||||||
|
|
||||||
# If node has fail branch, change execution type to branch
|
|
||||||
if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
|
||||||
node_instance.execution_type = NodeExecutionType.BRANCH
|
|
||||||
|
|
||||||
return node_instance
|
return node_instance
|
||||||
|
|||||||
@@ -0,0 +1,181 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||||
|
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||||
|
from core.workflow.graph import Graph
|
||||||
|
from core.workflow.graph.validation import GraphValidationError
|
||||||
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||||
|
from core.workflow.nodes.base.node import Node
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
|
from models.enums import UserFrom
|
||||||
|
|
||||||
|
|
||||||
|
class _TestNode(Node):
|
||||||
|
node_type = NodeType.ANSWER
|
||||||
|
execution_type = NodeExecutionType.EXECUTABLE
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "test"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
id: str,
|
||||||
|
config: Mapping[str, object],
|
||||||
|
graph_init_params: GraphInitParams,
|
||||||
|
graph_runtime_state: GraphRuntimeState,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
id=id,
|
||||||
|
config=config,
|
||||||
|
graph_init_params=graph_init_params,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
)
|
||||||
|
data = config.get("data", {})
|
||||||
|
if isinstance(data, Mapping):
|
||||||
|
execution_type = data.get("execution_type")
|
||||||
|
if isinstance(execution_type, str):
|
||||||
|
self.execution_type = NodeExecutionType(execution_type)
|
||||||
|
self._base_node_data = BaseNodeData(title=str(data.get("title", self.id)))
|
||||||
|
self.data: dict[str, object] = {}
|
||||||
|
|
||||||
|
def init_node_data(self, data: Mapping[str, object]) -> None:
|
||||||
|
title = str(data.get("title", self.id))
|
||||||
|
desc = data.get("description")
|
||||||
|
error_strategy_value = data.get("error_strategy")
|
||||||
|
error_strategy: ErrorStrategy | None = None
|
||||||
|
if isinstance(error_strategy_value, ErrorStrategy):
|
||||||
|
error_strategy = error_strategy_value
|
||||||
|
elif isinstance(error_strategy_value, str):
|
||||||
|
error_strategy = ErrorStrategy(error_strategy_value)
|
||||||
|
self._base_node_data = BaseNodeData(
|
||||||
|
title=title,
|
||||||
|
desc=str(desc) if desc is not None else None,
|
||||||
|
error_strategy=error_strategy,
|
||||||
|
)
|
||||||
|
self.data = dict(data)
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||||
|
return self._base_node_data.error_strategy
|
||||||
|
|
||||||
|
def _get_retry_config(self) -> RetryConfig:
|
||||||
|
return self._base_node_data.retry_config
|
||||||
|
|
||||||
|
def _get_title(self) -> str:
|
||||||
|
return self._base_node_data.title
|
||||||
|
|
||||||
|
def _get_description(self) -> str | None:
|
||||||
|
return self._base_node_data.desc
|
||||||
|
|
||||||
|
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||||
|
return self._base_node_data.default_value_dict
|
||||||
|
|
||||||
|
def get_base_node_data(self) -> BaseNodeData:
|
||||||
|
return self._base_node_data
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _SimpleNodeFactory:
|
||||||
|
graph_init_params: GraphInitParams
|
||||||
|
graph_runtime_state: GraphRuntimeState
|
||||||
|
|
||||||
|
def create_node(self, node_config: Mapping[str, object]) -> _TestNode:
|
||||||
|
node_id = str(node_config["id"])
|
||||||
|
node = _TestNode(
|
||||||
|
id=node_id,
|
||||||
|
config=node_config,
|
||||||
|
graph_init_params=self.graph_init_params,
|
||||||
|
graph_runtime_state=self.graph_runtime_state,
|
||||||
|
)
|
||||||
|
node.init_node_data(node_config.get("data", {}))
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]:
|
||||||
|
graph_config: dict[str, object] = {"edges": [], "nodes": []}
|
||||||
|
init_params = GraphInitParams(
|
||||||
|
tenant_id="tenant",
|
||||||
|
app_id="app",
|
||||||
|
workflow_id="workflow",
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id="user",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={})
|
||||||
|
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||||
|
factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state)
|
||||||
|
return factory, graph_config
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_initialization_runs_default_validators(
|
||||||
|
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
|
||||||
|
):
|
||||||
|
node_factory, graph_config = graph_init_dependencies
|
||||||
|
graph_config["nodes"] = [
|
||||||
|
{"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}},
|
||||||
|
{"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}},
|
||||||
|
]
|
||||||
|
graph_config["edges"] = [
|
||||||
|
{"source": "start", "target": "answer", "sourceHandle": "success"},
|
||||||
|
]
|
||||||
|
|
||||||
|
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||||
|
|
||||||
|
assert graph.root_node.id == "start"
|
||||||
|
assert "answer" in graph.nodes
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_validation_fails_for_unknown_edge_targets(
|
||||||
|
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
|
||||||
|
) -> None:
|
||||||
|
node_factory, graph_config = graph_init_dependencies
|
||||||
|
graph_config["nodes"] = [
|
||||||
|
{"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}},
|
||||||
|
]
|
||||||
|
graph_config["edges"] = [
|
||||||
|
{"source": "start", "target": "missing", "sourceHandle": "success"},
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(GraphValidationError) as exc:
|
||||||
|
Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||||
|
|
||||||
|
assert any(issue.code == "MISSING_NODE" for issue in exc.value.issues)
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_promotes_fail_branch_nodes_to_branch_execution_type(
|
||||||
|
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
|
||||||
|
) -> None:
|
||||||
|
node_factory, graph_config = graph_init_dependencies
|
||||||
|
graph_config["nodes"] = [
|
||||||
|
{"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}},
|
||||||
|
{
|
||||||
|
"id": "branch",
|
||||||
|
"data": {
|
||||||
|
"type": NodeType.IF_ELSE,
|
||||||
|
"title": "Branch",
|
||||||
|
"error_strategy": ErrorStrategy.FAIL_BRANCH,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
graph_config["edges"] = [
|
||||||
|
{"source": "start", "target": "branch", "sourceHandle": "success"},
|
||||||
|
]
|
||||||
|
|
||||||
|
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||||
|
|
||||||
|
assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH
|
||||||
Reference in New Issue
Block a user