feat(graph_engine): Support pausing workflow graph executions (#26585)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -1,16 +1,11 @@
|
||||
from .edge import Edge
|
||||
from .graph import Graph, NodeFactory
|
||||
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
|
||||
from .graph import Graph, GraphBuilder, NodeFactory
|
||||
from .graph_template import GraphTemplate
|
||||
from .read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
|
||||
|
||||
__all__ = [
|
||||
"Edge",
|
||||
"Graph",
|
||||
"GraphBuilder",
|
||||
"GraphTemplate",
|
||||
"NodeFactory",
|
||||
"ReadOnlyGraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeStateWrapper",
|
||||
"ReadOnlyVariablePool",
|
||||
"ReadOnlyVariablePoolWrapper",
|
||||
]
|
||||
|
||||
@@ -195,6 +195,12 @@ class Graph:
|
||||
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def new(cls) -> "GraphBuilder":
|
||||
"""Create a fluent builder for assembling a graph programmatically."""
|
||||
|
||||
return GraphBuilder(graph_cls=cls)
|
||||
|
||||
@classmethod
|
||||
def _mark_inactive_root_branches(
|
||||
cls,
|
||||
@@ -344,3 +350,96 @@ class Graph:
|
||||
"""
|
||||
edge_ids = self.in_edges.get(node_id, [])
|
||||
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||
|
||||
|
||||
@final
|
||||
class GraphBuilder:
|
||||
"""Fluent helper for constructing simple graphs, primarily for tests."""
|
||||
|
||||
def __init__(self, *, graph_cls: type[Graph]):
|
||||
self._graph_cls = graph_cls
|
||||
self._nodes: list[Node] = []
|
||||
self._nodes_by_id: dict[str, Node] = {}
|
||||
self._edges: list[Edge] = []
|
||||
self._edge_counter = 0
|
||||
|
||||
def add_root(self, node: Node) -> "GraphBuilder":
|
||||
"""Register the root node. Must be called exactly once."""
|
||||
|
||||
if self._nodes:
|
||||
raise ValueError("Root node has already been added")
|
||||
self._register_node(node)
|
||||
self._nodes.append(node)
|
||||
return self
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
node: Node,
|
||||
*,
|
||||
from_node_id: str | None = None,
|
||||
source_handle: str = "source",
|
||||
) -> "GraphBuilder":
|
||||
"""Append a node and connect it from the specified predecessor."""
|
||||
|
||||
if not self._nodes:
|
||||
raise ValueError("Root node must be added before adding other nodes")
|
||||
|
||||
predecessor_id = from_node_id or self._nodes[-1].id
|
||||
if predecessor_id not in self._nodes_by_id:
|
||||
raise ValueError(f"Predecessor node '{predecessor_id}' not found")
|
||||
|
||||
predecessor = self._nodes_by_id[predecessor_id]
|
||||
self._register_node(node)
|
||||
self._nodes.append(node)
|
||||
|
||||
edge_id = f"edge_{self._edge_counter}"
|
||||
self._edge_counter += 1
|
||||
edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle)
|
||||
self._edges.append(edge)
|
||||
|
||||
return self
|
||||
|
||||
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
|
||||
"""Connect two existing nodes without adding a new node."""
|
||||
|
||||
if tail not in self._nodes_by_id:
|
||||
raise ValueError(f"Tail node '{tail}' not found")
|
||||
if head not in self._nodes_by_id:
|
||||
raise ValueError(f"Head node '{head}' not found")
|
||||
|
||||
edge_id = f"edge_{self._edge_counter}"
|
||||
self._edge_counter += 1
|
||||
edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle)
|
||||
self._edges.append(edge)
|
||||
|
||||
return self
|
||||
|
||||
def build(self) -> Graph:
|
||||
"""Materialize the graph instance from the accumulated nodes and edges."""
|
||||
|
||||
if not self._nodes:
|
||||
raise ValueError("Cannot build an empty graph")
|
||||
|
||||
nodes = {node.id: node for node in self._nodes}
|
||||
edges = {edge.id: edge for edge in self._edges}
|
||||
in_edges: dict[str, list[str]] = defaultdict(list)
|
||||
out_edges: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
for edge in self._edges:
|
||||
out_edges[edge.tail].append(edge.id)
|
||||
in_edges[edge.head].append(edge.id)
|
||||
|
||||
return self._graph_cls(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
in_edges=dict(in_edges),
|
||||
out_edges=dict(out_edges),
|
||||
root_node=self._nodes[0],
|
||||
)
|
||||
|
||||
def _register_node(self, node: Node) -> None:
|
||||
if not node.id:
|
||||
raise ValueError("Node must have a non-empty id")
|
||||
if node.id in self._nodes_by_id:
|
||||
raise ValueError(f"Duplicate node id detected: {node.id}")
|
||||
self._nodes_by_id[node.id] = node
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
|
||||
|
||||
class ReadOnlyVariablePool(Protocol):
|
||||
"""Read-only interface for VariablePool."""
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
"""Get a variable value (read-only)."""
|
||||
...
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
"""Get all variables for a node (read-only)."""
|
||||
...
|
||||
|
||||
|
||||
class ReadOnlyGraphRuntimeState(Protocol):
|
||||
"""
|
||||
Read-only view of GraphRuntimeState for layers.
|
||||
|
||||
This protocol defines a read-only interface that prevents layers from
|
||||
modifying the graph runtime state while still allowing observation.
|
||||
All methods return defensive copies to ensure immutability.
|
||||
"""
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||
"""Get read-only access to the variable pool."""
|
||||
...
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get a copy of LLM usage info (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
"""Get a defensive copy of outputs (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count (read-only)."""
|
||||
...
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value (returns a copy)."""
|
||||
...
|
||||
@@ -1,77 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
class ReadOnlyVariablePoolWrapper:
|
||||
"""Wrapper that provides read-only access to VariablePool."""
|
||||
|
||||
def __init__(self, variable_pool: VariablePool):
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
"""Get a variable value (returns a defensive copy)."""
|
||||
value = self._variable_pool.get([node_id, variable_key])
|
||||
return deepcopy(value) if value is not None else None
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
"""Get all variables for a node (returns defensive copies)."""
|
||||
variables: dict[str, object] = {}
|
||||
if node_id in self._variable_pool.variable_dictionary:
|
||||
for key, var in self._variable_pool.variable_dictionary[node_id].items():
|
||||
# Variables have a value property that contains the actual data
|
||||
variables[key] = deepcopy(var.value)
|
||||
return variables
|
||||
|
||||
|
||||
class ReadOnlyGraphRuntimeStateWrapper:
|
||||
"""
|
||||
Wrapper that provides read-only access to GraphRuntimeState.
|
||||
|
||||
This wrapper ensures that layers can observe the state without
|
||||
modifying it. All returned values are defensive copies.
|
||||
"""
|
||||
|
||||
def __init__(self, state: GraphRuntimeState):
|
||||
self._state = state
|
||||
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
|
||||
"""Get read-only access to the variable pool."""
|
||||
return self._variable_pool_wrapper
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time (read-only)."""
|
||||
return self._state.start_at
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count (read-only)."""
|
||||
return self._state.total_tokens
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get a copy of LLM usage info (read-only)."""
|
||||
# Return a copy to prevent modification
|
||||
return self._state.llm_usage.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
"""Get a defensive copy of outputs (read-only)."""
|
||||
return deepcopy(self._state.outputs)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count (read-only)."""
|
||||
return self._state.node_run_steps
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value (returns a copy)."""
|
||||
return self._state.get_output(key, default)
|
||||
Reference in New Issue
Block a user