fix(graph_engine): error strategy fall. (#26078)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-09-23 01:51:43 +08:00
committed by GitHub
parent f4522fd695
commit 2e2c87c5a1
8 changed files with 255 additions and 84 deletions

View File

@@ -41,7 +41,8 @@ class GraphExecutionState(BaseModel):
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
error: GraphExecutionErrorState | None = Field(default=None)
node_executions: list[NodeExecutionState] = Field(default_factory=list)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
@@ -103,7 +104,8 @@ class GraphExecution:
completed: bool = False
aborted: bool = False
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
def start(self) -> None:
"""Mark the graph execution as started."""
@@ -172,6 +174,7 @@ class GraphExecution:
completed=self.completed,
aborted=self.aborted,
error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states,
)
@@ -195,6 +198,7 @@ class GraphExecution:
self.completed = state.completed
self.aborted = state.aborted
self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = {
item.node_id: NodeExecution(
node_id=item.node_id,
@@ -205,3 +209,7 @@ class GraphExecution:
)
for item in state.node_executions
}
def record_node_failure(self) -> None:
"""Increment the count of node failures encountered during execution."""
self.exceptions_count += 1

View File

@@ -3,11 +3,12 @@ Event handler implementations for different event types.
"""
import logging
from collections.abc import Mapping
from functools import singledispatchmethod
from typing import TYPE_CHECKING, final
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
from core.workflow.enums import ErrorStrategy, NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphNodeEventBase,
@@ -122,13 +123,15 @@ class EventHandler:
"""
# Track execution in domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
node_execution.mark_started(event.id)
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event
self._event_collector.collect(event)
# Collect the event only for the first attempt; retries remain silent
if is_initial_attempt:
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunStreamChunkEvent) -> None:
@@ -161,7 +164,7 @@ class EventHandler:
node_execution.mark_taken()
# Store outputs in variable pool
self._store_node_outputs(event)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
# Forward to response coordinator and emit streaming events
streaming_events = self._response_coordinator.intercept_event(event)
@@ -191,7 +194,7 @@ class EventHandler:
# Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event)
self._update_response_outputs(event.node_run_result.outputs)
# Collect the event
self._event_collector.collect(event)
@@ -207,6 +210,7 @@ class EventHandler:
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
self._graph_execution.record_node_failure()
result = self._error_handler.handle_node_failure(event)
@@ -227,10 +231,40 @@ class EventHandler:
Args:
event: The node exception event
"""
# Node continues via fail-branch, so it's technically "succeeded"
# Node continues via fail-branch/default-value, treat as completion
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
# Persist outputs produced by the exception strategy (e.g. default values)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
node = self._graph.nodes[event.node_id]
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
for edge_event in edge_streaming_events:
self._event_collector.collect(edge_event)
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update response outputs if applicable
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event.node_run_result.outputs)
self._state_manager.finish_execution(event.node_id)
# Collect the exception event for observers
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunRetryEvent) -> None:
"""
@@ -242,21 +276,31 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.increment_retry()
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
# Finish the previous attempt before re-queuing the node
self._state_manager.finish_execution(event.node_id)
# Emit retry event for observers
self._event_collector.collect(event)
# Re-queue node for execution
self._state_manager.enqueue_node(event.node_id)
self._state_manager.start_execution(event.node_id)
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
"""
Store node outputs in the variable pool.
Args:
event: The node succeeded event containing outputs
"""
for variable_name, variable_value in event.node_run_result.outputs.items():
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
for variable_name, variable_value in outputs.items():
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
"""Update response outputs for response nodes."""
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
# in runtime state, rather than allowing nodes to directly access runtime state.
for key, value in event.node_run_result.outputs.items():
for key, value in outputs.items():
if key == "answer":
existing = self._graph_runtime_state.get_output("answer", "")
if existing:

View File

@@ -23,6 +23,7 @@ from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
@@ -260,12 +261,23 @@ class GraphEngine:
if self._graph_execution.error:
raise self._graph_execution.error
else:
yield GraphRunSucceededEvent(
outputs=self._graph_runtime_state.outputs,
)
outputs = self._graph_runtime_state.outputs
exceptions_count = self._graph_execution.exceptions_count
if exceptions_count > 0:
yield GraphRunPartialSucceededEvent(
exceptions_count=exceptions_count,
outputs=outputs,
)
else:
yield GraphRunSucceededEvent(
outputs=outputs,
)
except Exception as e:
yield GraphRunFailedEvent(error=str(e))
yield GraphRunFailedEvent(
error=str(e),
exceptions_count=self._graph_execution.exceptions_count,
)
raise
finally:

View File

@@ -15,6 +15,7 @@ from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
@@ -127,6 +128,13 @@ class DebugLoggingLayer(GraphEngineLayer):
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self.logger.warning("⚠️ Graph run partially succeeded")
if event.exceptions_count > 0:
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunFailedEvent):
self.logger.error("❌ Graph run failed: %s", event.error)
if event.exceptions_count > 0: