Fix: surface workflow container LLM usage (#27021)
This commit is contained in:
@@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import Segment, SegmentType
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
@@ -27,6 +28,7 @@ from core.workflow.node_events import (
|
||||
NodeRunResult,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
||||
@@ -40,7 +42,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopNode(Node):
|
||||
class LoopNode(LLMUsageTrackingMixin, Node):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
@@ -117,6 +119,7 @@ class LoopNode(Node):
|
||||
|
||||
loop_duration_map: dict[str, float] = {}
|
||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
||||
loop_usage = LLMUsage.empty_usage()
|
||||
|
||||
# Start Loop event
|
||||
yield LoopStartedEvent(
|
||||
@@ -163,6 +166,9 @@ class LoopNode(Node):
|
||||
# Update the total tokens from this iteration
|
||||
cost_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
|
||||
# Accumulate usage from the sub-graph execution
|
||||
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
||||
|
||||
# Collect loop variable values after iteration
|
||||
single_loop_variable = {}
|
||||
for key, selector in loop_variable_selectors.items():
|
||||
@@ -189,6 +195,7 @@ class LoopNode(Node):
|
||||
)
|
||||
|
||||
self.graph_runtime_state.total_tokens += cost_tokens
|
||||
self._accumulate_usage(loop_usage)
|
||||
# Loop completed successfully
|
||||
yield LoopSucceededEvent(
|
||||
start_at=start_at,
|
||||
@@ -196,7 +203,9 @@ class LoopNode(Node):
|
||||
outputs=self._node_data.outputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
@@ -207,22 +216,28 @@ class LoopNode(Node):
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
outputs=self._node_data.outputs,
|
||||
inputs=inputs,
|
||||
llm_usage=loop_usage,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._accumulate_usage(loop_usage)
|
||||
yield LoopFailedEvent(
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
"completed_reason": "error",
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
@@ -235,10 +250,13 @@ class LoopNode(Node):
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
llm_usage=loop_usage,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user