feat(graph_engine): Support pausing workflow graph executions (#26585)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-10-19 21:33:41 +08:00
committed by GitHub
parent 9a5f214623
commit 578247ffbc
112 changed files with 3766 additions and 2415 deletions

View File

@@ -1,16 +1,11 @@
from .edge import Edge
from .graph import Graph, NodeFactory
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
from .graph import Graph, GraphBuilder, NodeFactory
from .graph_template import GraphTemplate
from .read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
__all__ = [
"Edge",
"Graph",
"GraphBuilder",
"GraphTemplate",
"NodeFactory",
"ReadOnlyGraphRuntimeState",
"ReadOnlyGraphRuntimeStateWrapper",
"ReadOnlyVariablePool",
"ReadOnlyVariablePoolWrapper",
]

View File

@@ -195,6 +195,12 @@ class Graph:
return nodes
@classmethod
def new(cls) -> "GraphBuilder":
"""Create a fluent builder for assembling a graph programmatically."""
return GraphBuilder(graph_cls=cls)
@classmethod
def _mark_inactive_root_branches(
cls,
@@ -344,3 +350,96 @@ class Graph:
"""
edge_ids = self.in_edges.get(node_id, [])
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
@final
class GraphBuilder:
"""Fluent helper for constructing simple graphs, primarily for tests."""
def __init__(self, *, graph_cls: type[Graph]):
self._graph_cls = graph_cls
self._nodes: list[Node] = []
self._nodes_by_id: dict[str, Node] = {}
self._edges: list[Edge] = []
self._edge_counter = 0
def add_root(self, node: Node) -> "GraphBuilder":
"""Register the root node. Must be called exactly once."""
if self._nodes:
raise ValueError("Root node has already been added")
self._register_node(node)
self._nodes.append(node)
return self
def add_node(
self,
node: Node,
*,
from_node_id: str | None = None,
source_handle: str = "source",
) -> "GraphBuilder":
"""Append a node and connect it from the specified predecessor."""
if not self._nodes:
raise ValueError("Root node must be added before adding other nodes")
predecessor_id = from_node_id or self._nodes[-1].id
if predecessor_id not in self._nodes_by_id:
raise ValueError(f"Predecessor node '{predecessor_id}' not found")
predecessor = self._nodes_by_id[predecessor_id]
self._register_node(node)
self._nodes.append(node)
edge_id = f"edge_{self._edge_counter}"
self._edge_counter += 1
edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle)
self._edges.append(edge)
return self
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
"""Connect two existing nodes without adding a new node."""
if tail not in self._nodes_by_id:
raise ValueError(f"Tail node '{tail}' not found")
if head not in self._nodes_by_id:
raise ValueError(f"Head node '{head}' not found")
edge_id = f"edge_{self._edge_counter}"
self._edge_counter += 1
edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle)
self._edges.append(edge)
return self
def build(self) -> Graph:
"""Materialize the graph instance from the accumulated nodes and edges."""
if not self._nodes:
raise ValueError("Cannot build an empty graph")
nodes = {node.id: node for node in self._nodes}
edges = {edge.id: edge for edge in self._edges}
in_edges: dict[str, list[str]] = defaultdict(list)
out_edges: dict[str, list[str]] = defaultdict(list)
for edge in self._edges:
out_edges[edge.tail].append(edge.id)
in_edges[edge.head].append(edge.id)
return self._graph_cls(
nodes=nodes,
edges=edges,
in_edges=dict(in_edges),
out_edges=dict(out_edges),
root_node=self._nodes[0],
)
def _register_node(self, node: Node) -> None:
if not node.id:
raise ValueError("Node must have a non-empty id")
if node.id in self._nodes_by_id:
raise ValueError(f"Duplicate node id detected: {node.id}")
self._nodes_by_id[node.id] = node

View File

@@ -1,61 +0,0 @@
from collections.abc import Mapping
from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
class ReadOnlyVariablePool(Protocol):
"""Read-only interface for VariablePool."""
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Get a variable value (read-only)."""
...
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Get all variables for a node (read-only)."""
...
class ReadOnlyGraphRuntimeState(Protocol):
"""
Read-only view of GraphRuntimeState for layers.
This protocol defines a read-only interface that prevents layers from
modifying the graph runtime state while still allowing observation.
All methods return defensive copies to ensure immutability.
"""
@property
def variable_pool(self) -> ReadOnlyVariablePool:
"""Get read-only access to the variable pool."""
...
@property
def start_at(self) -> float:
"""Get the start time (read-only)."""
...
@property
def total_tokens(self) -> int:
"""Get the total tokens count (read-only)."""
...
@property
def llm_usage(self) -> LLMUsage:
"""Get a copy of LLM usage info (read-only)."""
...
@property
def outputs(self) -> dict[str, Any]:
"""Get a defensive copy of outputs (read-only)."""
...
@property
def node_run_steps(self) -> int:
"""Get the node run steps count (read-only)."""
...
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
...

View File

@@ -1,77 +0,0 @@
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
class ReadOnlyVariablePoolWrapper:
"""Wrapper that provides read-only access to VariablePool."""
def __init__(self, variable_pool: VariablePool):
self._variable_pool = variable_pool
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Get a variable value (returns a defensive copy)."""
value = self._variable_pool.get([node_id, variable_key])
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Get all variables for a node (returns defensive copies)."""
variables: dict[str, object] = {}
if node_id in self._variable_pool.variable_dictionary:
for key, var in self._variable_pool.variable_dictionary[node_id].items():
# Variables have a value property that contains the actual data
variables[key] = deepcopy(var.value)
return variables
class ReadOnlyGraphRuntimeStateWrapper:
"""
Wrapper that provides read-only access to GraphRuntimeState.
This wrapper ensures that layers can observe the state without
modifying it. All returned values are defensive copies.
"""
def __init__(self, state: GraphRuntimeState):
self._state = state
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
@property
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
"""Get read-only access to the variable pool."""
return self._variable_pool_wrapper
@property
def start_at(self) -> float:
"""Get the start time (read-only)."""
return self._state.start_at
@property
def total_tokens(self) -> int:
"""Get the total tokens count (read-only)."""
return self._state.total_tokens
@property
def llm_usage(self) -> LLMUsage:
"""Get a copy of LLM usage info (read-only)."""
# Return a copy to prevent modification
return self._state.llm_usage.model_copy()
@property
def outputs(self) -> dict[str, Any]:
"""Get a defensive copy of outputs (read-only)."""
return deepcopy(self._state.outputs)
@property
def node_run_steps(self) -> int:
"""Get the node run steps count (read-only)."""
return self._state.node_run_steps
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
return self._state.get_output(key, default)