feat(workflow): domain model for workflow node execution (#19430)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@@ -19,8 +20,8 @@ class WeaveMultiModel(BaseModel):
|
||||
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||
id: str = Field(..., description="ID of the trace")
|
||||
op: str = Field(..., description="Name of the operation")
|
||||
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace")
|
||||
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace")
|
||||
inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace")
|
||||
outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace")
|
||||
attributes: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
||||
None, description="Metadata and attributes associated with trace"
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
@@ -7,6 +6,7 @@ from typing import Any, Optional, cast
|
||||
|
||||
import wandb
|
||||
import weave
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
@@ -22,9 +22,11 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -128,58 +130,57 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
|
||||
self.start_call(workflow_run, parent_run_id=trace_info.message_id)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution
|
||||
workflow_nodes_execution_id_records = (
|
||||
db.session.query(WorkflowNodeExecution.id)
|
||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
||||
.all()
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
for node_execution_id_record in workflow_nodes_execution_id_records:
|
||||
node_execution = (
|
||||
db.session.query(
|
||||
WorkflowNodeExecution.id,
|
||||
WorkflowNodeExecution.tenant_id,
|
||||
WorkflowNodeExecution.app_id,
|
||||
WorkflowNodeExecution.title,
|
||||
WorkflowNodeExecution.node_type,
|
||||
WorkflowNodeExecution.status,
|
||||
WorkflowNodeExecution.inputs,
|
||||
WorkflowNodeExecution.outputs,
|
||||
WorkflowNodeExecution.created_at,
|
||||
WorkflowNodeExecution.elapsed_time,
|
||||
WorkflowNodeExecution.process_data,
|
||||
WorkflowNodeExecution.execution_metadata,
|
||||
)
|
||||
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not node_execution:
|
||||
continue
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = node_execution.tenant_id
|
||||
app_id = node_execution.app_id
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == "llm":
|
||||
inputs = (
|
||||
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
||||
)
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = (
|
||||
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||
)
|
||||
node_total_tokens = execution_metadata.get("total_tokens", 0)
|
||||
attributes = execution_metadata.copy()
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||
attributes = {str(k): v for k, v in execution_metadata.items()}
|
||||
attributes.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
@@ -192,7 +193,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
}
|
||||
)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
attributes.update(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user