feat: introduce trigger functionality (#27644)
Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: Stream <Stream_2@qq.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do> Co-authored-by: Harry <xh001x@hotmail.com> Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: yessenia <yessenia.contact@gmail.com> Co-authored-by: hjlarry <hjlarry@163.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WTW0313 <twwu@dify.ai> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -37,6 +37,7 @@ from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
@@ -303,6 +304,11 @@ class WorkflowResponseConverter:
|
||||
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
|
||||
self._application_generate_entity.app_config.tenant_id
|
||||
)
|
||||
elif event.node_type == NodeType.TRIGGER_PLUGIN:
|
||||
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
@@ -38,10 +39,16 @@ from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTrigger
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@staticmethod
|
||||
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
|
||||
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@@ -53,7 +60,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@@ -66,6 +76,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@@ -79,7 +92,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@@ -91,7 +107,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
# parse files
|
||||
@@ -126,17 +145,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# for trigger debug run, not prepare user inputs
|
||||
if self._should_prepare_user_inputs(args):
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
),
|
||||
inputs=inputs,
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@@ -155,7 +177,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
if triggered_from is not None:
|
||||
# Use explicitly provided triggered_from (for async triggers)
|
||||
workflow_triggered_from = triggered_from
|
||||
elif invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
@@ -182,8 +207,16 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
)
|
||||
|
||||
def resume(self, *, workflow_run_id: str) -> None:
|
||||
"""
|
||||
@TBD
|
||||
"""
|
||||
pass
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
@@ -196,6 +229,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -231,8 +266,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"root_node_id": root_node_id,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": graph_engine_layers,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -426,6 +463,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
variable_loader: VariableLoader,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@@ -469,6 +508,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
system_user_id=system_user_id,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@@ -8,6 +9,7 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
@@ -16,6 +18,7 @@ from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
|
||||
@@ -35,17 +38,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
variable_loader: VariableLoader,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
root_node_id: str | None = None,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
self._root_node_id = root_node_id
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
@@ -60,6 +67,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
files=self.application_generate_entity.files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
timestamp=int(naive_utc_now().timestamp()),
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
@@ -92,6 +100,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
root_node_id=self._root_node_id,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
|
||||
@@ -84,6 +84,7 @@ class WorkflowBasedAppRunner:
|
||||
workflow_id: str = "",
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
root_node_id: str | None = None,
|
||||
) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
@@ -117,7 +118,7 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
@@ -32,6 +32,10 @@ class InvokeFrom(StrEnum):
|
||||
# https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README
|
||||
WEB_APP = "web-app"
|
||||
|
||||
# TRIGGER indicates that this invocation is from a trigger.
|
||||
# this is used for plugin trigger and webhook trigger.
|
||||
TRIGGER = "trigger"
|
||||
|
||||
# EXPLORE indicates that this invocation is from
|
||||
# the workflow (or chatflow) explore page.
|
||||
EXPLORE = "explore"
|
||||
@@ -40,6 +44,9 @@ class InvokeFrom(StrEnum):
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED = "published"
|
||||
|
||||
# VALIDATION indicates that this invocation is from validation.
|
||||
VALIDATION = "validation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
@@ -65,6 +72,8 @@ class InvokeFrom(StrEnum):
|
||||
return "dev"
|
||||
elif self == InvokeFrom.EXPLORE:
|
||||
return "explore_app"
|
||||
elif self == InvokeFrom.TRIGGER:
|
||||
return "trigger"
|
||||
elif self == InvokeFrom.SERVICE_API:
|
||||
return "api"
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Annotated, Literal, Self, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
@@ -55,7 +55,7 @@ class WorkflowResumptionContext(BaseModel):
|
||||
class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: Engine | sessionmaker,
|
||||
session_factory: Engine | sessionmaker[Session],
|
||||
generate_entity: WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity,
|
||||
state_owner_user_id: str,
|
||||
):
|
||||
@@ -103,10 +103,8 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
entity_wrapper: _GenerateEntityUnion
|
||||
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
|
||||
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
|
||||
elif isinstance(self._generate_entity, AdvancedChatAppGenerateEntity):
|
||||
entity_wrapper = _AdvancedChatAppGenerateEntityWrapper(entity=self._generate_entity)
|
||||
else:
|
||||
raise AssertionError(f"unknown entity type: type={type(self._generate_entity)}")
|
||||
entity_wrapper = _AdvancedChatAppGenerateEntityWrapper(entity=self._generate_entity)
|
||||
|
||||
state = WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=self.graph_runtime_state.dumps(),
|
||||
|
||||
21
api/core/app/layers/suspend_layer.py
Normal file
21
api/core/app/layers/suspend_layer.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
|
||||
|
||||
class SuspendLayer(GraphEngineLayer):
|
||||
""" """
|
||||
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle the paused event, stash runtime state into storage and wait for resume.
|
||||
"""
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
pass
|
||||
|
||||
def on_graph_end(self, error: Exception | None):
|
||||
""" """
|
||||
pass
|
||||
88
api/core/app/layers/timeslice_layer.py
Normal file
88
api/core/app/layers/timeslice_layer.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import ClassVar
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
|
||||
|
||||
from core.workflow.graph_engine.entities.commands import CommandType, GraphEngineCommand
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
|
||||
from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimeSliceLayer(GraphEngineLayer):
|
||||
"""
|
||||
CFS plan scheduler to control the timeslice of the workflow.
|
||||
"""
|
||||
|
||||
scheduler: ClassVar[BackgroundScheduler] = BackgroundScheduler()
|
||||
|
||||
def __init__(self, cfs_plan_scheduler: CFSPlanScheduler) -> None:
|
||||
"""
|
||||
CFS plan scheduler allows to control the timeslice of the workflow.
|
||||
"""
|
||||
|
||||
if not TimeSliceLayer.scheduler.running:
|
||||
TimeSliceLayer.scheduler.start()
|
||||
|
||||
super().__init__()
|
||||
self.cfs_plan_scheduler = cfs_plan_scheduler
|
||||
self.stopped = False
|
||||
self.schedule_id = ""
|
||||
|
||||
def _checker_job(self, schedule_id: str):
|
||||
"""
|
||||
Check if the workflow need to be suspended.
|
||||
"""
|
||||
try:
|
||||
if self.stopped:
|
||||
self.scheduler.remove_job(schedule_id)
|
||||
return
|
||||
|
||||
if self.cfs_plan_scheduler.can_schedule() == SchedulerCommand.RESOURCE_LIMIT_REACHED:
|
||||
# remove the job
|
||||
self.scheduler.remove_job(schedule_id)
|
||||
|
||||
if not self.command_channel:
|
||||
logger.exception("No command channel to stop the workflow")
|
||||
return
|
||||
|
||||
# send command to pause the workflow
|
||||
self.command_channel.send_command(
|
||||
GraphEngineCommand(
|
||||
command_type=CommandType.PAUSE,
|
||||
payload={
|
||||
"reason": SchedulerCommand.RESOURCE_LIMIT_REACHED,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("scheduler error during check if the workflow need to be suspended")
|
||||
|
||||
def on_graph_start(self):
|
||||
"""
|
||||
Start timer to check if the workflow need to be suspended.
|
||||
"""
|
||||
|
||||
if self.cfs_plan_scheduler.plan.schedule_strategy == WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice:
|
||||
self.schedule_id = uuid.uuid4().hex
|
||||
|
||||
self.scheduler.add_job(
|
||||
lambda: self._checker_job(self.schedule_id),
|
||||
"interval",
|
||||
seconds=self.cfs_plan_scheduler.plan.granularity,
|
||||
id=self.schedule_id,
|
||||
)
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
pass
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
self.stopped = True
|
||||
# remove the scheduler
|
||||
if self.schedule_id:
|
||||
self.scheduler.remove_job(self.schedule_id)
|
||||
88
api/core/app/layers/trigger_post_layer.py
Normal file
88
api/core/app/layers/trigger_post_layer.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
|
||||
from models.enums import WorkflowTriggerStatus
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerPostLayer(GraphEngineLayer):
|
||||
"""
|
||||
Trigger post layer.
|
||||
"""
|
||||
|
||||
_STATUS_MAP: ClassVar[dict[type[GraphEngineEvent], WorkflowTriggerStatus]] = {
|
||||
GraphRunSucceededEvent: WorkflowTriggerStatus.SUCCEEDED,
|
||||
GraphRunFailedEvent: WorkflowTriggerStatus.FAILED,
|
||||
GraphRunPausedEvent: WorkflowTriggerStatus.PAUSED,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity,
|
||||
start_time: datetime,
|
||||
trigger_log_id: str,
|
||||
session_maker: sessionmaker[Session],
|
||||
):
|
||||
self.trigger_log_id = trigger_log_id
|
||||
self.start_time = start_time
|
||||
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
|
||||
self.session_maker = session_maker
|
||||
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Update trigger log with success or failure.
|
||||
"""
|
||||
if isinstance(event, tuple(self._STATUS_MAP.keys())):
|
||||
with self.session_maker() as session:
|
||||
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = repo.get_by_id(self.trigger_log_id)
|
||||
if not trigger_log:
|
||||
logger.exception("Trigger log not found: %s", self.trigger_log_id)
|
||||
return
|
||||
|
||||
# Calculate elapsed time
|
||||
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
|
||||
|
||||
# Extract relevant data from result
|
||||
if not self.graph_runtime_state:
|
||||
logger.exception("Graph runtime state is not set")
|
||||
return
|
||||
|
||||
outputs = self.graph_runtime_state.outputs
|
||||
|
||||
# BASICLY, workflow_execution_id is the same as workflow_run_id
|
||||
workflow_run_id = self.graph_runtime_state.system_variable.workflow_execution_id
|
||||
assert workflow_run_id, "Workflow run id is not set"
|
||||
|
||||
total_tokens = self.graph_runtime_state.total_tokens
|
||||
|
||||
# Update trigger log with success
|
||||
trigger_log.status = self._STATUS_MAP[type(event)]
|
||||
trigger_log.workflow_run_id = workflow_run_id
|
||||
trigger_log.outputs = TypeAdapter(dict[str, Any]).dump_json(outputs).decode()
|
||||
|
||||
if trigger_log.elapsed_time is None:
|
||||
trigger_log.elapsed_time = elapsed_time
|
||||
else:
|
||||
trigger_log.elapsed_time += elapsed_time
|
||||
|
||||
trigger_log.total_tokens = total_tokens
|
||||
trigger_log.finished_at = datetime.now(UTC)
|
||||
repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
Reference in New Issue
Block a user