Promote GraphRuntimeState snapshot loading to class factory (#27222)

This commit is contained in:
-LAN-
2025-10-23 22:29:02 +08:00
committed by GitHub
parent 2f3a61b51b
commit 53b21eea61
2 changed files with 195 additions and 71 deletions

View File

@@ -5,6 +5,7 @@ import json
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from collections.abc import Mapping as TypingMapping from collections.abc import Mapping as TypingMapping
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Protocol from typing import Any, Protocol
from pydantic.json import pydantic_encoder from pydantic.json import pydantic_encoder
@@ -106,6 +107,23 @@ class GraphProtocol(Protocol):
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ... def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
@dataclass(slots=True)
class _GraphRuntimeStateSnapshot:
"""Immutable view of a serialized runtime state snapshot."""
start_at: float
total_tokens: int
node_run_steps: int
llm_usage: LLMUsage
outputs: dict[str, Any]
variable_pool: VariablePool
has_variable_pool: bool
ready_queue_dump: str | None
graph_execution_dump: str | None
response_coordinator_dump: str | None
paused_nodes: tuple[str, ...]
class GraphRuntimeState: class GraphRuntimeState:
"""Mutable runtime state shared across graph execution components.""" """Mutable runtime state shared across graph execution components."""
@@ -293,69 +311,28 @@ class GraphRuntimeState:
return json.dumps(snapshot, default=pydantic_encoder) return json.dumps(snapshot, default=pydantic_encoder)
def loads(self, data: str | Mapping[str, Any]) -> None: @classmethod
def from_snapshot(cls, data: str | Mapping[str, Any]) -> GraphRuntimeState:
"""Restore runtime state from a serialized snapshot.""" """Restore runtime state from a serialized snapshot."""
payload: dict[str, Any] snapshot = cls._parse_snapshot_payload(data)
if isinstance(data, str):
payload = json.loads(data)
else:
payload = dict(data)
version = payload.get("version") state = cls(
if version != "1.0": variable_pool=snapshot.variable_pool,
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}") start_at=snapshot.start_at,
total_tokens=snapshot.total_tokens,
llm_usage=snapshot.llm_usage,
outputs=snapshot.outputs,
node_run_steps=snapshot.node_run_steps,
)
state._apply_snapshot(snapshot)
return state
self._start_at = float(payload.get("start_at", 0.0)) def loads(self, data: str | Mapping[str, Any]) -> None:
total_tokens = int(payload.get("total_tokens", 0)) """Restore runtime state from a serialized snapshot (legacy API)."""
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
node_run_steps = int(payload.get("node_run_steps", 0)) snapshot = self._parse_snapshot_payload(data)
if node_run_steps < 0: self._apply_snapshot(snapshot)
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
llm_usage_payload = payload.get("llm_usage", {})
self._llm_usage = LLMUsage.model_validate(llm_usage_payload)
self._outputs = deepcopy(payload.get("outputs", {}))
variable_pool_payload = payload.get("variable_pool")
if variable_pool_payload is not None:
self._variable_pool = VariablePool.model_validate(variable_pool_payload)
ready_queue_payload = payload.get("ready_queue")
if ready_queue_payload is not None:
self._ready_queue = self._build_ready_queue()
self._ready_queue.loads(ready_queue_payload)
else:
self._ready_queue = None
graph_execution_payload = payload.get("graph_execution")
self._graph_execution = None
self._pending_graph_execution_workflow_id = None
if graph_execution_payload is not None:
try:
execution_payload = json.loads(graph_execution_payload)
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
except (json.JSONDecodeError, TypeError, AttributeError):
self._pending_graph_execution_workflow_id = None
self.graph_execution.loads(graph_execution_payload)
response_payload = payload.get("response_coordinator")
if response_payload is not None:
if self._graph is not None:
self.response_coordinator.loads(response_payload)
else:
self._pending_response_coordinator_dump = response_payload
else:
self._pending_response_coordinator_dump = None
self._response_coordinator = None
paused_nodes_payload = payload.get("paused_nodes", [])
self._paused_nodes = set(map(str, paused_nodes_payload))
def register_paused_node(self, node_id: str) -> None: def register_paused_node(self, node_id: str) -> None:
"""Record a node that should resume when execution is continued.""" """Record a node that should resume when execution is continued."""
@@ -391,3 +368,106 @@ class GraphRuntimeState:
module = importlib.import_module("core.workflow.graph_engine.response_coordinator") module = importlib.import_module("core.workflow.graph_engine.response_coordinator")
coordinator_cls = module.ResponseStreamCoordinator coordinator_cls = module.ResponseStreamCoordinator
return coordinator_cls(variable_pool=self.variable_pool, graph=graph) return coordinator_cls(variable_pool=self.variable_pool, graph=graph)
# ------------------------------------------------------------------
# Snapshot helpers
# ------------------------------------------------------------------
@classmethod
def _parse_snapshot_payload(cls, data: str | Mapping[str, Any]) -> _GraphRuntimeStateSnapshot:
payload: dict[str, Any]
if isinstance(data, str):
payload = json.loads(data)
else:
payload = dict(data)
version = payload.get("version")
if version != "1.0":
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}")
start_at = float(payload.get("start_at", 0.0))
total_tokens = int(payload.get("total_tokens", 0))
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
node_run_steps = int(payload.get("node_run_steps", 0))
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
llm_usage_payload = payload.get("llm_usage", {})
llm_usage = LLMUsage.model_validate(llm_usage_payload)
outputs_payload = deepcopy(payload.get("outputs", {}))
variable_pool_payload = payload.get("variable_pool")
has_variable_pool = variable_pool_payload is not None
variable_pool = VariablePool.model_validate(variable_pool_payload) if has_variable_pool else VariablePool()
ready_queue_payload = payload.get("ready_queue")
graph_execution_payload = payload.get("graph_execution")
response_payload = payload.get("response_coordinator")
paused_nodes_payload = payload.get("paused_nodes", [])
return _GraphRuntimeStateSnapshot(
start_at=start_at,
total_tokens=total_tokens,
node_run_steps=node_run_steps,
llm_usage=llm_usage,
outputs=outputs_payload,
variable_pool=variable_pool,
has_variable_pool=has_variable_pool,
ready_queue_dump=ready_queue_payload,
graph_execution_dump=graph_execution_payload,
response_coordinator_dump=response_payload,
paused_nodes=tuple(map(str, paused_nodes_payload)),
)
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
self._start_at = snapshot.start_at
self._total_tokens = snapshot.total_tokens
self._node_run_steps = snapshot.node_run_steps
self._llm_usage = snapshot.llm_usage.model_copy()
self._outputs = deepcopy(snapshot.outputs)
if snapshot.has_variable_pool or self._variable_pool is None:
self._variable_pool = snapshot.variable_pool
self._restore_ready_queue(snapshot.ready_queue_dump)
self._restore_graph_execution(snapshot.graph_execution_dump)
self._restore_response_coordinator(snapshot.response_coordinator_dump)
self._paused_nodes = set(snapshot.paused_nodes)
def _restore_ready_queue(self, payload: str | None) -> None:
if payload is not None:
self._ready_queue = self._build_ready_queue()
self._ready_queue.loads(payload)
else:
self._ready_queue = None
def _restore_graph_execution(self, payload: str | None) -> None:
self._graph_execution = None
self._pending_graph_execution_workflow_id = None
if payload is None:
return
try:
execution_payload = json.loads(payload)
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
except (json.JSONDecodeError, TypeError, AttributeError):
self._pending_graph_execution_workflow_id = None
self.graph_execution.loads(payload)
def _restore_response_coordinator(self, payload: str | None) -> None:
if payload is None:
self._pending_response_coordinator_dump = None
self._response_coordinator = None
return
if self._graph is not None:
self.response_coordinator.loads(payload)
self._pending_response_coordinator_dump = None
return
self._pending_response_coordinator_dump = payload
self._response_coordinator = None

View File

@@ -8,6 +8,18 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
class StubCoordinator:
def __init__(self) -> None:
self.state = "initial"
def dumps(self) -> str:
return json.dumps({"state": self.state})
def loads(self, data: str) -> None:
payload = json.loads(data)
self.state = payload["state"]
class TestGraphRuntimeState: class TestGraphRuntimeState:
def test_property_getters_and_setters(self): def test_property_getters_and_setters(self):
# FIXME(-LAN-): Mock VariablePool if needed # FIXME(-LAN-): Mock VariablePool if needed
@@ -191,17 +203,6 @@ class TestGraphRuntimeState:
graph_execution.exceptions_count = 4 graph_execution.exceptions_count = 4
graph_execution.started = True graph_execution.started = True
class StubCoordinator:
def __init__(self) -> None:
self.state = "initial"
def dumps(self) -> str:
return json.dumps({"state": self.state})
def loads(self, data: str) -> None:
payload = json.loads(data)
self.state = payload["state"]
mock_graph = MagicMock() mock_graph = MagicMock()
stub = StubCoordinator() stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub): with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub):
@@ -211,8 +212,7 @@ class TestGraphRuntimeState:
snapshot = state.dumps() snapshot = state.dumps()
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) restored = GraphRuntimeState.from_snapshot(snapshot)
restored.loads(snapshot)
assert restored.total_tokens == 10 assert restored.total_tokens == 10
assert restored.node_run_steps == 3 assert restored.node_run_steps == 3
@@ -235,3 +235,47 @@ class TestGraphRuntimeState:
restored.attach_graph(mock_graph) restored.attach_graph(mock_graph)
assert new_stub.state == "configured" assert new_stub.state == "configured"
def test_loads_rehydrates_existing_instance(self):
variable_pool = VariablePool()
variable_pool.add(("node", "key"), "value")
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
state.total_tokens = 7
state.node_run_steps = 2
state.set_output("foo", "bar")
state.ready_queue.put("node-1")
execution = state.graph_execution
execution.workflow_id = "wf-456"
execution.started = True
mock_graph = MagicMock()
original_stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub):
state.attach_graph(mock_graph)
original_stub.state = "configured"
snapshot = state.dumps()
new_stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
restored.attach_graph(mock_graph)
restored.loads(snapshot)
assert restored.total_tokens == 7
assert restored.node_run_steps == 2
assert restored.get_output("foo") == "bar"
assert restored.ready_queue.qsize() == 1
assert restored.ready_queue.get(timeout=0.01) == "node-1"
restored_segment = restored.variable_pool.get(("node", "key"))
assert restored_segment is not None
assert restored_segment.value == "value"
restored_execution = restored.graph_execution
assert restored_execution.workflow_id == "wf-456"
assert restored_execution.started is True
assert new_stub.state == "configured"