feat(graph_engine): Support pausing workflow graph executions (#26585)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-10-19 21:33:41 +08:00
committed by GitHub
parent 9a5f214623
commit 578247ffbc
112 changed files with 3766 additions and 2415 deletions

View File

@@ -447,6 +447,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"message_id": message.id,
"context": context,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
},
)
@@ -466,8 +468,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
)
@@ -483,6 +483,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id: str,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
):
"""
Generate worker in a new thread.
@@ -538,6 +540,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow=workflow,
system_user_id=system_user_id,
app=app,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
try:
@@ -570,8 +574,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
@@ -584,7 +586,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param message: message
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
@@ -596,8 +597,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message=message,
user=user,
dialogue_count=self._dialogue_count,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
draft_var_saver_factory=draft_var_saver_factory,
)

View File

@@ -23,8 +23,12 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
@@ -55,6 +59,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow: Workflow,
system_user_id: str,
app: App,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
):
super().__init__(
queue_manager=queue_manager,
@@ -68,11 +74,24 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._workflow = workflow
self.system_user_id = system_user_id
self._app = app
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
def run(self):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
system_inputs = SystemVariable(
query=self.application_generate_entity.query,
files=self.application_generate_entity.files,
conversation_id=self.conversation.id,
user_id=self.system_user_id,
dialogue_count=self._dialogue_count,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_run_id,
)
with Session(db.engine, expire_on_commit=False) as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
@@ -89,7 +108,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
else:
inputs = self.application_generate_entity.inputs
query = self.application_generate_entity.query
files = self.application_generate_entity.files
# moderation
if self.handle_input_moderation(
@@ -114,17 +132,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = self._initialize_conversation_variables()
# Create a variable pool.
system_inputs = SystemVariable(
query=query,
files=files,
conversation_id=self.conversation.id,
user_id=self.system_user_id,
dialogue_count=self._dialogue_count,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_run_id,
)
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
@@ -172,6 +179,23 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
command_channel=command_channel,
)
self._queue_manager.graph_runtime_state = graph_runtime_state
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=self._workflow.id,
workflow_type=WorkflowType(self._workflow.type),
version=self._workflow.version,
graph_data=self._workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
generator = workflow_entry.run()
for event in generator:

View File

@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
@@ -60,14 +61,11 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
@@ -77,7 +75,7 @@ from models.workflow import Workflow
logger = logging.getLogger(__name__)
class AdvancedChatAppGenerateTaskPipeline:
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
@@ -92,8 +90,6 @@ class AdvancedChatAppGenerateTaskPipeline:
user: Union[Account, EndUser],
stream: bool,
dialogue_count: int,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
):
self._base_task_pipeline = BasedGenerateTaskPipeline(
@@ -113,31 +109,20 @@ class AdvancedChatAppGenerateTaskPipeline:
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
workflow_system_variables=SystemVariable(
query=message.query,
files=application_generate_entity.files,
conversation_id=conversation.id,
user_id=user_session_id,
dialogue_count=dialogue_count,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_run_id,
),
workflow_info=CycleManagerWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
version=workflow.version,
graph_data=workflow.graph_dict,
),
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
self._workflow_system_variables = SystemVariable(
query=message.query,
files=application_generate_entity.files,
conversation_id=conversation.id,
user_id=user_session_id,
dialogue_count=dialogue_count,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_run_id,
)
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=self._workflow_system_variables,
)
self._task_state = WorkflowTaskState()
@@ -156,6 +141,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id: str = ""
self._draft_var_saver_factory = draft_var_saver_factory
self._graph_runtime_state: GraphRuntimeState | None = None
self._seed_graph_runtime_state_from_queue_manager()
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
@@ -288,12 +275,6 @@ class AdvancedChatAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
"""Fluent validation for graph runtime state."""
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
return graph_runtime_state
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline.ping_stream_response()
@@ -304,21 +285,28 @@ class AdvancedChatAppGenerateTaskPipeline:
err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
def _handle_workflow_started_event(
self,
event: QueueWorkflowStartedEvent,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
runtime_state = self._resolve_graph_runtime_state()
run_id = self._extract_workflow_run_id(runtime_state)
self._workflow_run_id = run_id
with self._database_session() as session:
message = self._get_message(session=session)
if not message:
raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_execution.id_
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
message.workflow_run_id = run_id
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow_id,
)
yield workflow_start_resp
@@ -326,13 +314,9 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_retry_resp:
@@ -344,14 +328,9 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle node started events."""
self._ensure_workflow_initialized()
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=self._workflow_run_id, event=event
)
node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_resp:
@@ -367,14 +346,12 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
self._save_output_for_event(event, event.node_execution_id)
if node_finish_resp:
yield node_finish_resp
@@ -385,16 +362,13 @@ class AdvancedChatAppGenerateTaskPipeline:
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, workflow_node_execution.id)
self._save_output_for_event(event, event.node_execution_id)
if node_finish_resp:
yield node_finish_resp
@@ -504,29 +478,19 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowSucceededEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.SUCCEEDED,
graph_runtime_state=validated_state,
)
yield workflow_finish_resp
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
@@ -535,30 +499,20 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowPartialSuccessEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
graph_runtime_state=validated_state,
exceptions_count=event.exceptions_count,
)
yield workflow_finish_resp
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
@@ -567,32 +521,25 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowFailedEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.FAILED,
graph_runtime_state=validated_state,
error=event.error,
exceptions_count=event.exceptions_count,
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
status=WorkflowExecutionStatus.FAILED,
error_message=event.error,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {event.error}"))
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
yield workflow_finish_resp
@@ -607,25 +554,23 @@ class AdvancedChatAppGenerateTaskPipeline:
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle stop events."""
if self._workflow_run_id and graph_runtime_state:
_ = trace_manager
resolved_state = None
if self._workflow_run_id:
resolved_state = self._resolve_graph_runtime_state(graph_runtime_state)
if self._workflow_run_id and resolved_state:
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.STOPPED,
graph_runtime_state=resolved_state,
error=event.get_stop_reason(),
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowExecutionStatus.STOPPED,
error_message=event.get_stop_reason(),
conversation_id=self._conversation_id,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
# Save message
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
self._save_message(session=session, graph_runtime_state=resolved_state)
yield workflow_finish_resp
elif event.stopped_by in (
@@ -647,7 +592,7 @@ class AdvancedChatAppGenerateTaskPipeline:
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
self._ensure_graph_runtime_initialized(graph_runtime_state)
resolved_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
self._task_state.answer
@@ -661,7 +606,7 @@ class AdvancedChatAppGenerateTaskPipeline:
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
self._save_message(session=session, graph_runtime_state=resolved_state)
yield self._message_end_to_stream_response()
@@ -670,10 +615,6 @@ class AdvancedChatAppGenerateTaskPipeline:
) -> Generator[StreamResponse, None, None]:
"""Handle retriever resources events."""
self._message_cycle_manager.handle_retriever_resources(event)
with self._database_session() as session:
message = self._get_message(session=session)
message.message_metadata = self._task_state.metadata.model_dump_json()
return
yield # Make this a generator
@@ -682,10 +623,6 @@ class AdvancedChatAppGenerateTaskPipeline:
) -> Generator[StreamResponse, None, None]:
"""Handle annotation reply events."""
self._message_cycle_manager.handle_annotation_reply(event)
with self._database_session() as session:
message = self._get_message(session=session)
message.message_metadata = self._task_state.metadata.model_dump_json()
return
yield # Make this a generator
@@ -739,7 +676,6 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: Any,
*,
graph_runtime_state: GraphRuntimeState | None = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
@@ -752,7 +688,6 @@ class AdvancedChatAppGenerateTaskPipeline:
if handler := handlers.get(event_type):
yield from handler(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@@ -769,7 +704,6 @@ class AdvancedChatAppGenerateTaskPipeline:
):
yield from self._handle_node_failed_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@@ -788,15 +722,12 @@ class AdvancedChatAppGenerateTaskPipeline:
Process stream response using elegant Fluent Python patterns.
Maintains exact same functionality as original 57-if-statement version.
"""
# Initialize graph runtime state
graph_runtime_state: GraphRuntimeState | None = None
for queue_message in self._base_task_pipeline.queue_manager.listen():
event = queue_message.event
match event:
case QueueWorkflowStartedEvent():
graph_runtime_state = event.graph_runtime_state
self._resolve_graph_runtime_state()
yield from self._handle_workflow_started_event(event)
case QueueErrorEvent():
@@ -804,15 +735,11 @@ class AdvancedChatAppGenerateTaskPipeline:
break
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_event(
event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager
)
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
break
case QueueStopEvent():
yield from self._handle_stop_event(
event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager
)
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
break
# Handle all other events through elegant dispatch
@@ -820,7 +747,6 @@ class AdvancedChatAppGenerateTaskPipeline:
if responses := list(
self._dispatch_event(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@@ -878,6 +804,12 @@ class AdvancedChatAppGenerateTaskPipeline:
else:
self._task_state.metadata.usage = LLMUsage.empty_usage()
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
if candidate is not None:
self._graph_runtime_state = candidate
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
"""
Message end to stream response.

View File

@@ -20,6 +20,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
WorkflowQueueMessage,
)
from core.workflow.runtime import GraphRuntimeState
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@@ -47,6 +48,7 @@ class AppQueueManager:
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self._q = q
self._graph_runtime_state: GraphRuntimeState | None = None
self._stopped_cache: TTLCache[tuple, bool] = TTLCache(maxsize=1, ttl=1)
self._cache_lock = threading.Lock()
@@ -109,6 +111,16 @@ class AppQueueManager:
"""
self.publish(QueueErrorEvent(error=e), pub_from)
@property
def graph_runtime_state(self) -> GraphRuntimeState | None:
"""Retrieve the attached graph runtime state, if available."""
return self._graph_runtime_state
@graph_runtime_state.setter
def graph_runtime_state(self, graph_runtime_state: GraphRuntimeState | None) -> None:
"""Attach the live graph runtime state reference for downstream consumers."""
self._graph_runtime_state = graph_runtime_state
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue

View File

@@ -0,0 +1,55 @@
"""Shared helpers for managing GraphRuntimeState across task pipelines."""
from __future__ import annotations
from typing import TYPE_CHECKING
from core.workflow.runtime import GraphRuntimeState
if TYPE_CHECKING:
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
class GraphRuntimeStateSupport:
"""
Mixin that centralises common GraphRuntimeState access patterns used by task pipelines.
Subclasses are expected to provide:
* `_base_task_pipeline` exposing the queue manager with an optional cached runtime state.
* `_graph_runtime_state` attribute used as the local cache for the runtime state.
"""
_base_task_pipeline: BasedGenerateTaskPipeline
_graph_runtime_state: GraphRuntimeState | None = None
def _ensure_graph_runtime_initialized(
self,
graph_runtime_state: GraphRuntimeState | None = None,
) -> GraphRuntimeState:
"""Validate and return the active graph runtime state."""
return self._resolve_graph_runtime_state(graph_runtime_state)
def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str:
system_variables = graph_runtime_state.variable_pool.system_variables
if not system_variables or not system_variables.workflow_execution_id:
raise ValueError("workflow_execution_id missing from runtime state")
return str(system_variables.workflow_execution_id)
def _resolve_graph_runtime_state(
self,
graph_runtime_state: GraphRuntimeState | None = None,
) -> GraphRuntimeState:
"""Return the cached runtime state or bootstrap it from the queue manager."""
if graph_runtime_state is not None:
self._graph_runtime_state = graph_runtime_state
return graph_runtime_state
if self._graph_runtime_state is None:
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
if candidate is not None:
self._graph_runtime_state = candidate
if self._graph_runtime_state is None:
raise ValueError("graph runtime state not initialized.")
return self._graph_runtime_state

View File

@@ -1,9 +1,8 @@
import time
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Union
from sqlalchemy.orm import Session
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NewType, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
@@ -39,16 +38,36 @@ from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import (
NodeType,
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now
from models import (
Account,
EndUser,
)
from models import Account, EndUser
from services.variable_truncator import VariableTruncator
NodeExecutionId = NewType("NodeExecutionId", str)
@dataclass(slots=True)
class _NodeSnapshot:
"""In-memory cache for node metadata between start and completion events."""
title: str
index: int
start_at: datetime
iteration_id: str = ""
"""Empty string means the node is not executing inside an iteration."""
loop_id: str = ""
"""Empty string means the node is not executing inside a loop."""
class WorkflowResponseConverter:
def __init__(
@@ -56,37 +75,151 @@ class WorkflowResponseConverter:
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
user: Union[Account, EndUser],
system_variables: SystemVariable,
):
self._application_generate_entity = application_generate_entity
self._user = user
self._system_variables = system_variables
self._workflow_inputs = self._prepare_workflow_inputs()
self._truncator = VariableTruncator.default()
self._node_snapshots: dict[NodeExecutionId, _NodeSnapshot] = {}
self._workflow_execution_id: str | None = None
self._workflow_started_at: datetime | None = None
# ------------------------------------------------------------------
# Workflow lifecycle helpers
# ------------------------------------------------------------------
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
inputs = dict(self._application_generate_entity.inputs)
for field_name, value in self._system_variables.to_dict().items():
# TODO(@future-refactor): store system variables separately from user inputs so we don't
# need to flatten `sys.*` entries into the input payload just for rerun/export tooling.
if field_name == SystemVariableKey.CONVERSATION_ID:
# Conversation IDs are session-scoped; omitting them keeps workflow inputs
# reusable without pinning new runs to a prior conversation.
continue
inputs[f"sys.{field_name}"] = value
handled = WorkflowEntry.handle_special_values(inputs)
return dict(handled or {})
def _ensure_workflow_run_id(self, workflow_run_id: str | None = None) -> str:
"""Return the memoized workflow run id, optionally seeding it during start events."""
if workflow_run_id is not None:
self._workflow_execution_id = workflow_run_id
if not self._workflow_execution_id:
raise ValueError("workflow_run_id missing before streaming workflow events")
return self._workflow_execution_id
# ------------------------------------------------------------------
# Node snapshot helpers
# ------------------------------------------------------------------
def _store_snapshot(self, event: QueueNodeStartedEvent) -> _NodeSnapshot:
snapshot = _NodeSnapshot(
title=event.node_title,
index=event.node_run_index,
start_at=event.start_at,
iteration_id=event.in_iteration_id or "",
loop_id=event.in_loop_id or "",
)
node_execution_id = NodeExecutionId(event.node_execution_id)
self._node_snapshots[node_execution_id] = snapshot
return snapshot
def _get_snapshot(self, node_execution_id: str) -> _NodeSnapshot | None:
return self._node_snapshots.get(NodeExecutionId(node_execution_id))
def _pop_snapshot(self, node_execution_id: str) -> _NodeSnapshot | None:
return self._node_snapshots.pop(NodeExecutionId(node_execution_id), None)
@staticmethod
def _merge_metadata(
base_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None,
snapshot: _NodeSnapshot | None,
) -> Mapping[WorkflowNodeExecutionMetadataKey, Any] | None:
if not base_metadata and not snapshot:
return base_metadata
merged: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
if base_metadata:
merged.update(base_metadata)
if snapshot:
if snapshot.iteration_id:
merged[WorkflowNodeExecutionMetadataKey.ITERATION_ID] = snapshot.iteration_id
if snapshot.loop_id:
merged[WorkflowNodeExecutionMetadataKey.LOOP_ID] = snapshot.loop_id
return merged or None
def _truncate_mapping(
self,
mapping: Mapping[str, Any] | None,
) -> tuple[Mapping[str, Any] | None, bool]:
if mapping is None:
return None, False
if not mapping:
return {}, False
normalized = WorkflowEntry.handle_special_values(dict(mapping))
if normalized is None:
return None, False
truncated, is_truncated = self._truncator.truncate_variable_mapping(dict(normalized))
return truncated, is_truncated
@staticmethod
def _encode_outputs(outputs: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
if outputs is None:
return None
converter = WorkflowRuntimeTypeConverter()
return converter.to_json_encodable(outputs)
def workflow_start_to_stream_response(
self,
*,
task_id: str,
workflow_execution: WorkflowExecution,
workflow_run_id: str,
workflow_id: str,
) -> WorkflowStartStreamResponse:
run_id = self._ensure_workflow_run_id(workflow_run_id)
started_at = naive_utc_now()
self._workflow_started_at = started_at
return WorkflowStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution.id_,
workflow_run_id=run_id,
data=WorkflowStartStreamResponse.Data(
id=workflow_execution.id_,
workflow_id=workflow_execution.workflow_id,
inputs=workflow_execution.inputs,
created_at=int(workflow_execution.started_at.timestamp()),
id=run_id,
workflow_id=workflow_id,
inputs=self._workflow_inputs,
created_at=int(started_at.timestamp()),
),
)
def workflow_finish_to_stream_response(
self,
*,
session: Session,
task_id: str,
workflow_execution: WorkflowExecution,
workflow_id: str,
status: WorkflowExecutionStatus,
graph_runtime_state: GraphRuntimeState,
error: str | None = None,
exceptions_count: int = 0,
) -> WorkflowFinishStreamResponse:
created_by = None
run_id = self._ensure_workflow_run_id()
started_at = self._workflow_started_at
if started_at is None:
raise ValueError(
"workflow_finish_to_stream_response called before workflow_start_to_stream_response",
)
finished_at = naive_utc_now()
elapsed_time = (finished_at - started_at).total_seconds()
outputs_mapping = graph_runtime_state.outputs or {}
encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping)
created_by: Mapping[str, object] | None
user = self._user
if isinstance(user, Account):
created_by = {
@@ -94,38 +227,29 @@ class WorkflowResponseConverter:
"name": user.name,
"email": user.email,
}
elif isinstance(user, EndUser):
else:
created_by = {
"id": user.id,
"user": user.session_id,
}
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
# Handle the case where finished_at is None by using current time as default
finished_at_timestamp = (
int(workflow_execution.finished_at.timestamp())
if workflow_execution.finished_at
else int(datetime.now(UTC).timestamp())
)
return WorkflowFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution.id_,
workflow_run_id=run_id,
data=WorkflowFinishStreamResponse.Data(
id=workflow_execution.id_,
workflow_id=workflow_execution.workflow_id,
status=workflow_execution.status,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs),
error=workflow_execution.error_message,
elapsed_time=workflow_execution.elapsed_time,
total_tokens=workflow_execution.total_tokens,
total_steps=workflow_execution.total_steps,
id=run_id,
workflow_id=workflow_id,
status=status.value,
outputs=encoded_outputs,
error=error,
elapsed_time=elapsed_time,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
created_by=created_by,
created_at=int(workflow_execution.started_at.timestamp()),
finished_at=finished_at_timestamp,
files=self.fetch_files_from_node_outputs(workflow_execution.outputs),
exceptions_count=workflow_execution.exceptions_count,
created_at=int(started_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(outputs_mapping),
exceptions_count=exceptions_count,
),
)
@@ -134,38 +258,28 @@ class WorkflowResponseConverter:
*,
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> NodeStartStreamResponse | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._store_snapshot(event)
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
workflow_run_id=run_id,
data=NodeStartStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
title=workflow_node_execution.title,
index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=snapshot.title,
index=snapshot.index,
created_at=int(snapshot.start_at.timestamp()),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
parallel_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
),
)
# extras logic
if event.node_type == NodeType.TOOL:
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
@@ -189,41 +303,54 @@ class WorkflowResponseConverter:
*,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> NodeFinishStreamResponse | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
return None
if not workflow_node_execution.finished_at:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._pop_snapshot(event.node_execution_id)
json_converter = WorkflowRuntimeTypeConverter()
start_at = snapshot.start_at if snapshot else event.start_at
finished_at = naive_utc_now()
elapsed_time = (finished_at - start_at).total_seconds()
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
process_data, process_data_truncated = self._truncate_mapping(event.process_data)
encoded_outputs = self._encode_outputs(event.outputs)
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
metadata = self._merge_metadata(event.execution_metadata, snapshot)
if isinstance(event, QueueNodeSucceededEvent):
status = WorkflowNodeExecutionStatus.SUCCEEDED.value
error_message = event.error
elif isinstance(event, QueueNodeFailedEvent):
status = WorkflowNodeExecutionStatus.FAILED.value
error_message = event.error
else:
status = WorkflowNodeExecutionStatus.EXCEPTION.value
error_message = event.error
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
workflow_run_id=run_id,
data=NodeFinishStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
process_data=workflow_node_execution.get_response_process_data(),
process_data_truncated=workflow_node_execution.process_data_truncated,
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
outputs_truncated=workflow_node_execution.outputs_truncated,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.metadata,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
index=snapshot.index if snapshot else 0,
title=snapshot.title if snapshot else "",
inputs=inputs,
inputs_truncated=inputs_truncated,
process_data=process_data,
process_data_truncated=process_data_truncated,
outputs=outputs,
outputs_truncated=outputs_truncated,
status=status,
error=error_message,
elapsed_time=elapsed_time,
execution_metadata=metadata,
created_at=int(start_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(event.outputs or {}),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
),
@@ -234,44 +361,45 @@ class WorkflowResponseConverter:
*,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Union[NodeRetryStreamResponse, NodeFinishStreamResponse] | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
return None
if not workflow_node_execution.finished_at:
) -> NodeRetryStreamResponse | None:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
json_converter = WorkflowRuntimeTypeConverter()
snapshot = self._get_snapshot(event.node_execution_id)
if snapshot is None:
raise AssertionError("node retry event arrived without a stored snapshot")
finished_at = naive_utc_now()
elapsed_time = (finished_at - event.start_at).total_seconds()
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
process_data, process_data_truncated = self._truncate_mapping(event.process_data)
encoded_outputs = self._encode_outputs(event.outputs)
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
metadata = self._merge_metadata(event.execution_metadata, snapshot)
return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
workflow_run_id=run_id,
data=NodeRetryStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
process_data=workflow_node_execution.get_response_process_data(),
process_data_truncated=workflow_node_execution.process_data_truncated,
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
outputs_truncated=workflow_node_execution.outputs_truncated,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.metadata,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
index=snapshot.index,
title=snapshot.title,
inputs=inputs,
inputs_truncated=inputs_truncated,
process_data=process_data,
process_data_truncated=process_data_truncated,
outputs=outputs,
outputs_truncated=outputs_truncated,
status=WorkflowNodeExecutionStatus.RETRY.value,
error=event.error,
elapsed_time=elapsed_time,
execution_metadata=metadata,
created_at=int(snapshot.start_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(event.outputs or {}),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
retry_index=event.retry_index,
@@ -379,8 +507,6 @@ class WorkflowResponseConverter:
inputs=new_inputs,
inputs_truncated=truncated,
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
@@ -405,9 +531,6 @@ class WorkflowResponseConverter:
pre_loop_output={},
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
),
)
@@ -446,8 +569,6 @@ class WorkflowResponseConverter:
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)

View File

@@ -352,6 +352,8 @@ class PipelineGenerator(BaseAppGenerator):
"application_generate_entity": application_generate_entity,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
},
)
@@ -367,8 +369,6 @@ class PipelineGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
draft_var_saver_factory=draft_var_saver_factory,
)
@@ -573,6 +573,8 @@ class PipelineGenerator(BaseAppGenerator):
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
@@ -620,6 +622,8 @@ class PipelineGenerator(BaseAppGenerator):
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
runner.run()
@@ -648,8 +652,6 @@ class PipelineGenerator(BaseAppGenerator):
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@@ -660,7 +662,6 @@ class PipelineGenerator(BaseAppGenerator):
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
@@ -670,8 +671,6 @@ class PipelineGenerator(BaseAppGenerator):
queue_manager=queue_manager,
user=user,
stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository,
workflow_execution_repository=workflow_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
)

View File

@@ -11,11 +11,14 @@ from core.app.entities.app_invoke_entities import (
)
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import WorkflowType
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
@@ -40,6 +43,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
@@ -56,6 +61,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
@@ -163,6 +170,23 @@ class PipelineRunner(WorkflowBasedAppRunner):
variable_pool=variable_pool,
)
self._queue_manager.graph_runtime_state = graph_runtime_state
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
version=workflow.version,
graph_data=workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
generator = workflow_entry.run()
for event in generator:

View File

@@ -231,6 +231,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
"queue_manager": queue_manager,
"context": context,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
},
)
@@ -244,8 +246,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
stream=streaming,
)
@@ -424,6 +424,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
"""
Generate worker in a new thread.
@@ -465,6 +467,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
try:
@@ -493,8 +497,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@@ -514,8 +516,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
stream=stream,
)

View File

@@ -5,12 +5,13 @@ from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
@@ -34,6 +35,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
):
super().__init__(
queue_manager=queue_manager,
@@ -43,6 +46,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self.application_generate_entity = application_generate_entity
self._workflow = workflow
self._sys_user_id = system_user_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
def run(self):
"""
@@ -51,6 +56,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
system_inputs = SystemVariable(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
@@ -60,18 +73,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = SystemVariable(
files=files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
@@ -96,6 +100,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
self._queue_manager.graph_runtime_state = graph_runtime_state
workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
@@ -115,6 +121,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
command_channel=command_channel,
)
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=self._workflow.id,
workflow_type=WorkflowType(self._workflow.type),
version=self._workflow.version,
graph_data=self._workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
generator = workflow_entry.run()
for event in generator:

View File

@@ -8,11 +8,9 @@ from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
@@ -53,27 +51,20 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from extensions.ext_database import db
from models import Account
from models.enums import CreatorUserRole
from models.model import EndUser
from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
)
from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom
logger = logging.getLogger(__name__)
class WorkflowAppGenerateTaskPipeline:
class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
@@ -85,8 +76,6 @@ class WorkflowAppGenerateTaskPipeline:
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
):
self._base_task_pipeline = BasedGenerateTaskPipeline(
@@ -99,42 +88,30 @@ class WorkflowAppGenerateTaskPipeline:
self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatorUserRole.END_USER
elif isinstance(user, Account):
else:
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatorUserRole.ACCOUNT
else:
raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
workflow_system_variables=SystemVariable(
files=application_generate_entity.files,
user_id=user_session_id,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_execution_id,
),
workflow_info=CycleManagerWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
version=workflow.version,
graph_data=workflow.graph_dict,
),
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
)
self._application_generate_entity = application_generate_entity
self._workflow_features_dict = workflow.features_dict
self._workflow_run_id = ""
self._workflow_execution_id = ""
self._invoke_from = queue_manager.invoke_from
self._draft_var_saver_factory = draft_var_saver_factory
self._workflow = workflow
self._workflow_system_variables = SystemVariable(
files=application_generate_entity.files,
user_id=user_session_id,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_execution_id,
)
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=self._workflow_system_variables,
)
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@@ -261,15 +238,9 @@ class WorkflowAppGenerateTaskPipeline:
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""
if not self._workflow_run_id:
if not self._workflow_execution_id:
raise ValueError("workflow run not initialized.")
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
"""Fluent validation for graph runtime state."""
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
return graph_runtime_state
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline.ping_stream_response()
@@ -283,12 +254,14 @@ class WorkflowAppGenerateTaskPipeline:
self, event: QueueWorkflowStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
# init workflow run
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
runtime_state = self._resolve_graph_runtime_state()
run_id = self._extract_workflow_run_id(runtime_state)
self._workflow_execution_id = run_id
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
workflow_run_id=run_id,
workflow_id=self._workflow.id,
)
yield start_resp
@@ -296,14 +269,9 @@ class WorkflowAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
@@ -315,13 +283,9 @@ class WorkflowAppGenerateTaskPipeline:
"""Handle node started events."""
self._ensure_workflow_initialized()
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=self._workflow_run_id, event=event
)
node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_response:
@@ -331,14 +295,12 @@ class WorkflowAppGenerateTaskPipeline:
self, event: QueueNodeSucceededEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle node succeeded events."""
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
self._save_output_for_event(event, event.node_execution_id)
if node_success_response:
yield node_success_response
@@ -349,17 +311,13 @@ class WorkflowAppGenerateTaskPipeline:
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
event=event,
)
node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, workflow_node_execution.id)
self._save_output_for_event(event, event.node_execution_id)
if node_failed_response:
yield node_failed_response
@@ -372,7 +330,7 @@ class WorkflowAppGenerateTaskPipeline:
iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield iter_start_resp
@@ -385,7 +343,7 @@ class WorkflowAppGenerateTaskPipeline:
iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield iter_next_resp
@@ -398,7 +356,7 @@ class WorkflowAppGenerateTaskPipeline:
iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield iter_finish_resp
@@ -409,7 +367,7 @@ class WorkflowAppGenerateTaskPipeline:
loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield loop_start_resp
@@ -420,7 +378,7 @@ class WorkflowAppGenerateTaskPipeline:
loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield loop_next_resp
@@ -433,7 +391,7 @@ class WorkflowAppGenerateTaskPipeline:
loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield loop_finish_resp
@@ -442,33 +400,22 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueWorkflowSucceededEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.SUCCEEDED,
graph_runtime_state=validated_state,
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
@@ -476,34 +423,23 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueWorkflowPartialSuccessEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
graph_runtime_state=validated_state,
exceptions_count=event.exceptions_count,
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
@@ -511,37 +447,33 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed and stop events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
validated_state = self._ensure_graph_runtime_initialized()
if isinstance(event, QueueWorkflowFailedEvent):
status = WorkflowExecutionStatus.FAILED
error = event.error
exceptions_count = event.exceptions_count
else:
status = WorkflowExecutionStatus.STOPPED
error = event.get_stop_reason()
exceptions_count = 0
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow.id,
status=status,
graph_runtime_state=validated_state,
error=error,
exceptions_count=exceptions_count,
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
status=WorkflowExecutionStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowExecutionStatus.STOPPED,
error_message=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
@@ -601,7 +533,6 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: AppQueueEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
@@ -614,7 +545,6 @@ class WorkflowAppGenerateTaskPipeline:
if handler := handlers.get(event_type):
yield from handler(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@@ -631,7 +561,6 @@ class WorkflowAppGenerateTaskPipeline:
):
yield from self._handle_node_failed_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@@ -642,7 +571,6 @@ class WorkflowAppGenerateTaskPipeline:
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
yield from self._handle_workflow_failed_and_stop_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@@ -661,15 +589,12 @@ class WorkflowAppGenerateTaskPipeline:
Process stream response using elegant Fluent Python patterns.
Maintains exact same functionality as original 44-if-statement version.
"""
# Initialize graph runtime state
graph_runtime_state = None
for queue_message in self._base_task_pipeline.queue_manager.listen():
event = queue_message.event
match event:
case QueueWorkflowStartedEvent():
graph_runtime_state = event.graph_runtime_state
self._resolve_graph_runtime_state()
yield from self._handle_workflow_started_event(event)
case QueueTextChunkEvent():
@@ -681,12 +606,19 @@ class WorkflowAppGenerateTaskPipeline:
yield from self._handle_error_event(event)
break
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_and_stop_events(event)
break
case QueueStopEvent():
yield from self._handle_workflow_failed_and_stop_events(event)
break
# Handle all other events through elegant dispatch
case _:
if responses := list(
self._dispatch_event(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@@ -697,7 +629,7 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution):
def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None):
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
@@ -709,11 +641,14 @@ class WorkflowAppGenerateTaskPipeline:
# not save log for debugging
return
if not workflow_run_id:
return
workflow_app_log = WorkflowAppLog()
workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id
workflow_app_log.app_id = self._application_generate_entity.app_config.app_id
workflow_app_log.workflow_id = workflow_execution.workflow_id
workflow_app_log.workflow_run_id = workflow_execution.id_
workflow_app_log.workflow_id = self._workflow.id
workflow_app_log.workflow_run_id = workflow_run_id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = self._created_by_role
workflow_app_log.created_by = self._user_id

View File

@@ -25,7 +25,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphEngineEvent,
@@ -54,6 +54,7 @@ from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
@@ -346,9 +347,7 @@ class WorkflowBasedAppRunner:
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(
QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state)
)
self._publish_event(QueueWorkflowStartedEvent())
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
@@ -372,7 +371,6 @@ class WorkflowBasedAppRunner:
node_title=event.node_title,
node_type=event.node_type,
start_at=event.start_at,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
inputs=inputs,
@@ -393,7 +391,6 @@ class WorkflowBasedAppRunner:
node_title=event.node_title,
node_type=event.node_type,
start_at=event.start_at,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
agent_strategy=event.agent_strategy,
@@ -494,7 +491,6 @@ class WorkflowBasedAppRunner:
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata,
)
)
@@ -536,7 +532,6 @@ class WorkflowBasedAppRunner:
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata,
)
)

View File

@@ -3,11 +3,11 @@ from datetime import datetime
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes import NodeType
@@ -54,6 +54,7 @@ class AppQueueEvent(BaseModel):
"""
event: QueueEvent
model_config = ConfigDict(arbitrary_types_allowed=True)
class QueueLLMChunkEvent(AppQueueEvent):
@@ -80,7 +81,6 @@ class QueueIterationStartEvent(AppQueueEvent):
node_run_index: int
inputs: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
metadata: Mapping[str, object] = Field(default_factory=dict)
@@ -132,19 +132,10 @@ class QueueLoopStartEvent(AppQueueEvent):
node_id: str
node_type: NodeType
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
metadata: Mapping[str, object] = Field(default_factory=dict)
@@ -160,16 +151,6 @@ class QueueLoopNextEvent(AppQueueEvent):
node_id: str
node_type: NodeType
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: str | None = None
"""iteration run in parallel mode run id"""
node_run_index: int
output: Any = None # output for the current loop
@@ -185,14 +166,6 @@ class QueueLoopCompletedEvent(AppQueueEvent):
node_id: str
node_type: NodeType
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
@@ -285,12 +258,9 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
class QueueWorkflowStartedEvent(AppQueueEvent):
"""
QueueWorkflowStartedEvent entity
"""
"""QueueWorkflowStartedEvent entity."""
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
graph_runtime_state: GraphRuntimeState
class QueueWorkflowSucceededEvent(AppQueueEvent):
@@ -334,15 +304,9 @@ class QueueNodeStartedEvent(AppQueueEvent):
node_title: str
node_type: NodeType
node_run_index: int = 1 # FIXME(-LAN-): may not used
predecessor_node_id: str | None = None
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
in_iteration_id: str | None = None
in_loop_id: str | None = None
start_at: datetime
parallel_mode_run_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
# FIXME(-LAN-): only for ToolNode, need to refactor
@@ -360,14 +324,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
@@ -423,14 +379,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
@@ -455,7 +403,6 @@ class QueueNodeFailedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
parallel_id: str | None = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None

View File

@@ -257,13 +257,8 @@ class NodeStartStreamResponse(StreamResponse):
inputs_truncated: bool = False
created_at: int
extras: dict[str, object] = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
parallel_run_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
event: StreamEvent = StreamEvent.NODE_STARTED
@@ -285,10 +280,6 @@ class NodeStartStreamResponse(StreamResponse):
"inputs": None,
"created_at": self.data.created_at,
"extras": {},
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
},
@@ -324,10 +315,6 @@ class NodeFinishStreamResponse(StreamResponse):
created_at: int
finished_at: int
files: Sequence[Mapping[str, Any]] | None = []
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
@@ -357,10 +344,6 @@ class NodeFinishStreamResponse(StreamResponse):
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
},
@@ -396,10 +379,6 @@ class NodeRetryStreamResponse(StreamResponse):
created_at: int
finished_at: int
files: Sequence[Mapping[str, Any]] | None = []
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
retry_index: int = 0
@@ -430,10 +409,6 @@ class NodeRetryStreamResponse(StreamResponse):
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
"retry_index": self.data.retry_index,
@@ -541,8 +516,6 @@ class LoopNodeStartStreamResponse(StreamResponse):
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_STARTED
workflow_run_id: str
@@ -567,9 +540,6 @@ class LoopNodeNextStreamResponse(StreamResponse):
created_at: int
pre_loop_output: Any = None
extras: Mapping[str, object] = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parallel_mode_run_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_NEXT
workflow_run_id: str
@@ -603,8 +573,6 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
finished_at: int
steps: int
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_COMPLETED
workflow_run_id: str