fix(graph_engine): error strategy fall. (#26078)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -41,7 +41,8 @@ class GraphExecutionState(BaseModel):
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
|
||||
|
||||
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
||||
@@ -103,7 +104,8 @@ class GraphExecution:
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Mark the graph execution as started."""
|
||||
@@ -172,6 +174,7 @@ class GraphExecution:
|
||||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
)
|
||||
|
||||
@@ -195,6 +198,7 @@ class GraphExecution:
|
||||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
item.node_id: NodeExecution(
|
||||
node_id=item.node_id,
|
||||
@@ -205,3 +209,7 @@ class GraphExecution:
|
||||
)
|
||||
for item in state.node_executions
|
||||
}
|
||||
|
||||
def record_node_failure(self) -> None:
|
||||
"""Increment the count of node failures encountered during execution."""
|
||||
self.exceptions_count += 1
|
||||
|
||||
@@ -3,11 +3,12 @@ Event handler implementations for different event types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
@@ -122,13 +123,15 @@ class EventHandler:
|
||||
"""
|
||||
# Track execution in domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
node_execution.mark_started(event.id)
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
# Collect the event only for the first attempt; retries remain silent
|
||||
if is_initial_attempt:
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
@@ -161,7 +164,7 @@ class EventHandler:
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
# Forward to response coordinator and emit streaming events
|
||||
streaming_events = self._response_coordinator.intercept_event(event)
|
||||
@@ -191,7 +194,7 @@ class EventHandler:
|
||||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event)
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
@@ -207,6 +210,7 @@ class EventHandler:
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
self._graph_execution.record_node_failure()
|
||||
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
@@ -227,10 +231,40 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node exception event
|
||||
"""
|
||||
# Node continues via fail-branch, so it's technically "succeeded"
|
||||
# Node continues via fail-branch/default-value, treat as completion
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Persist outputs produced by the exception strategy (e.g. default values)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
|
||||
|
||||
for edge_event in edge_streaming_events:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update response outputs if applicable
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Collect the exception event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunRetryEvent) -> None:
|
||||
"""
|
||||
@@ -242,21 +276,31 @@ class EventHandler:
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.increment_retry()
|
||||
|
||||
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
# Finish the previous attempt before re-queuing the node
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Emit retry event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
# Re-queue node for execution
|
||||
self._state_manager.enqueue_node(event.node_id)
|
||||
self._state_manager.start_execution(event.node_id)
|
||||
|
||||
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
|
||||
"""
|
||||
Store node outputs in the variable pool.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event containing outputs
|
||||
"""
|
||||
for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
|
||||
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
|
||||
"""Update response outputs for response nodes."""
|
||||
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
|
||||
# in runtime state, rather than allowing nodes to directly access runtime state.
|
||||
for key, value in event.node_run_result.outputs.items():
|
||||
for key, value in outputs.items():
|
||||
if key == "answer":
|
||||
existing = self._graph_runtime_state.get_output("answer", "")
|
||||
if existing:
|
||||
|
||||
@@ -23,6 +23,7 @@ from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
@@ -260,12 +261,23 @@ class GraphEngine:
|
||||
if self._graph_execution.error:
|
||||
raise self._graph_execution.error
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
outputs = self._graph_runtime_state.outputs
|
||||
exceptions_count = self._graph_execution.exceptions_count
|
||||
if exceptions_count > 0:
|
||||
yield GraphRunPartialSucceededEvent(
|
||||
exceptions_count=exceptions_count,
|
||||
outputs=outputs,
|
||||
)
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
yield GraphRunFailedEvent(
|
||||
error=str(e),
|
||||
exceptions_count=self._graph_execution.exceptions_count,
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
||||
@@ -15,6 +15,7 @@ from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
@@ -127,6 +128,13 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self.logger.warning("⚠️ Graph run partially succeeded")
|
||||
if event.exceptions_count > 0:
|
||||
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.logger.error("❌ Graph run failed: %s", event.error)
|
||||
if event.exceptions_count > 0:
|
||||
|
||||
Reference in New Issue
Block a user