Refactor/remove db from cycle manager (#20455)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from constants import UUID_NIL
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
|
||||
|
||||
class TestMessage:
|
||||
class MockMessage:
|
||||
def __init__(self, id, parent_message_id):
|
||||
self.id = id
|
||||
self.parent_message_id = parent_message_id
|
||||
@@ -14,7 +14,7 @@ class TestMessage:
|
||||
|
||||
|
||||
def test_extract_thread_messages_single_message():
|
||||
messages = [TestMessage(str(uuid4()), UUID_NIL)]
|
||||
messages = [MockMessage(str(uuid4()), UUID_NIL)]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 1
|
||||
assert result[0] == messages[0]
|
||||
@@ -23,11 +23,11 @@ def test_extract_thread_messages_single_message():
|
||||
def test_extract_thread_messages_linear_thread():
|
||||
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id5, id4),
|
||||
TestMessage(id4, id3),
|
||||
TestMessage(id3, id2),
|
||||
TestMessage(id2, id1),
|
||||
TestMessage(id1, UUID_NIL),
|
||||
MockMessage(id5, id4),
|
||||
MockMessage(id4, id3),
|
||||
MockMessage(id3, id2),
|
||||
MockMessage(id2, id1),
|
||||
MockMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 5
|
||||
@@ -37,10 +37,10 @@ def test_extract_thread_messages_linear_thread():
|
||||
def test_extract_thread_messages_branched_thread():
|
||||
id1, id2, id3, id4 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id4, id2),
|
||||
TestMessage(id3, id2),
|
||||
TestMessage(id2, id1),
|
||||
TestMessage(id1, UUID_NIL),
|
||||
MockMessage(id4, id2),
|
||||
MockMessage(id3, id2),
|
||||
MockMessage(id2, id1),
|
||||
MockMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 3
|
||||
@@ -56,9 +56,9 @@ def test_extract_thread_messages_empty_list():
|
||||
def test_extract_thread_messages_partially_loaded():
|
||||
id0, id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id3, id2),
|
||||
TestMessage(id2, id1),
|
||||
TestMessage(id1, id0),
|
||||
MockMessage(id3, id2),
|
||||
MockMessage(id2, id1),
|
||||
MockMessage(id1, id0),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 3
|
||||
@@ -68,9 +68,9 @@ def test_extract_thread_messages_partially_loaded():
|
||||
def test_extract_thread_messages_legacy_messages():
|
||||
id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id3, UUID_NIL),
|
||||
TestMessage(id2, UUID_NIL),
|
||||
TestMessage(id1, UUID_NIL),
|
||||
MockMessage(id3, UUID_NIL),
|
||||
MockMessage(id2, UUID_NIL),
|
||||
MockMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 3
|
||||
@@ -80,11 +80,11 @@ def test_extract_thread_messages_legacy_messages():
|
||||
def test_extract_thread_messages_mixed_with_legacy_messages():
|
||||
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id5, id4),
|
||||
TestMessage(id4, id2),
|
||||
TestMessage(id3, id2),
|
||||
TestMessage(id2, UUID_NIL),
|
||||
TestMessage(id1, UUID_NIL),
|
||||
MockMessage(id5, id4),
|
||||
MockMessage(id4, id2),
|
||||
MockMessage(id3, id2),
|
||||
MockMessage(id2, UUID_NIL),
|
||||
MockMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 4
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from flask import Flask
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseNodeEvent,
|
||||
@@ -25,7 +26,7 @@ from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
@@ -11,7 +12,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
def test_execute_answer():
|
||||
|
||||
@@ -4,6 +4,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileVariable, FileVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end import EndStreamParam
|
||||
@@ -15,7 +16,7 @@ from core.workflow.nodes.http_request import (
|
||||
HttpRequestNodeData,
|
||||
)
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
def test_http_request_node_binary_file(monkeypatch):
|
||||
|
||||
@@ -5,6 +5,7 @@ from unittest.mock import patch
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
@@ -14,7 +15,7 @@ from core.workflow.nodes.iteration.entities import ErrorHandleMode
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
def test_run():
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
@@ -11,7 +12,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
def test_execute_answer():
|
||||
|
||||
@@ -2,6 +2,7 @@ from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphRunPartialSucceededEvent,
|
||||
@@ -14,7 +15,7 @@ from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class ContinueOnErrorTestHelper:
|
||||
|
||||
@@ -7,6 +7,7 @@ from core.file import File, FileTransferMethod
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.variables import StringVariable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData
|
||||
from core.workflow.nodes.document_extractor.node import (
|
||||
_extract_text_from_docx,
|
||||
@@ -15,7 +16,6 @@ from core.workflow.nodes.document_extractor.node import (
|
||||
_extract_text_from_plain_text,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -6,6 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
@@ -15,7 +16,7 @@ from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
def test_execute_if_else_result_true():
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.list_operator.entities import (
|
||||
ExtractConfig,
|
||||
FilterBy,
|
||||
@@ -14,7 +15,6 @@ from core.workflow.nodes.list_operator.entities import (
|
||||
)
|
||||
from core.workflow.nodes.list_operator.exc import InvalidKeyError
|
||||
from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -7,6 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderTyp
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end import EndStreamParam
|
||||
@@ -14,7 +15,7 @@ from core.workflow.nodes.enums import ErrorStrategy
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from models import UserFrom, WorkflowNodeExecutionStatus, WorkflowType
|
||||
from models import UserFrom, WorkflowType
|
||||
|
||||
|
||||
def _create_tool_node():
|
||||
|
||||
@@ -12,21 +12,20 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus
|
||||
from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
NodeExecution,
|
||||
NodeRunMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -93,16 +92,38 @@ def mock_workflow_execution_repository():
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_workflow_entity():
|
||||
return CycleManagerWorkflowInfo(
|
||||
workflow_id="test-workflow-id", # Matches ID used in other fixtures
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
version="1.0.0",
|
||||
graph_data={
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"type": "chat", # NodeType is a string enum
|
||||
"name": "Chat Node",
|
||||
"data": {"model": "gpt-3.5-turbo", "prompt": "test prompt"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_cycle_manager(
|
||||
real_app_generate_entity,
|
||||
real_workflow_system_variables,
|
||||
mock_workflow_execution_repository,
|
||||
mock_node_execution_repository,
|
||||
real_workflow_entity,
|
||||
):
|
||||
return WorkflowCycleManager(
|
||||
application_generate_entity=real_app_generate_entity,
|
||||
workflow_system_variables=real_workflow_system_variables,
|
||||
workflow_info=real_workflow_entity,
|
||||
workflow_execution_repository=mock_workflow_execution_repository,
|
||||
workflow_node_execution_repository=mock_node_execution_repository,
|
||||
)
|
||||
@@ -148,7 +169,7 @@ def real_workflow_run():
|
||||
workflow_run.version = "1.0"
|
||||
workflow_run.graph = json.dumps({"nodes": [], "edges": []})
|
||||
workflow_run.inputs = json.dumps({"query": "test query"})
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
workflow_run.outputs = json.dumps({"answer": "test answer"})
|
||||
workflow_run.created_by_role = CreatorUserRole.ACCOUNT
|
||||
workflow_run.created_by = "test-user-id"
|
||||
@@ -171,20 +192,13 @@ def test_init(
|
||||
assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository
|
||||
|
||||
|
||||
def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, real_workflow):
|
||||
def test_handle_workflow_run_start(workflow_cycle_manager):
|
||||
"""Test handle_workflow_run_start method"""
|
||||
# Mock session.scalar to return the workflow and max sequence
|
||||
mock_session.scalar.side_effect = [real_workflow, 5]
|
||||
|
||||
# Call the method
|
||||
workflow_execution = workflow_cycle_manager.handle_workflow_run_start(
|
||||
session=mock_session,
|
||||
workflow_id="test-workflow-id",
|
||||
)
|
||||
workflow_execution = workflow_cycle_manager.handle_workflow_run_start()
|
||||
|
||||
# Verify the result
|
||||
assert workflow_execution.workflow_id == real_workflow.id
|
||||
assert workflow_execution.sequence_number == 6 # max_sequence + 1
|
||||
assert workflow_execution.workflow_id == "test-workflow-id"
|
||||
|
||||
# Verify the workflow_execution_repository.save was called
|
||||
workflow_cycle_manager._workflow_execution_repository.save.assert_called_once_with(workflow_execution)
|
||||
@@ -195,11 +209,10 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu
|
||||
# Create a real WorkflowExecution
|
||||
|
||||
workflow_execution = WorkflowExecution(
|
||||
id="test-workflow-run-id",
|
||||
id_="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_version="1.0",
|
||||
sequence_number=1,
|
||||
type=WorkflowType.CHAT,
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
@@ -230,11 +243,10 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
|
||||
# Create a real WorkflowExecution
|
||||
|
||||
workflow_execution = WorkflowExecution(
|
||||
id="test-workflow-run-id",
|
||||
id_="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_version="1.0",
|
||||
sequence_number=1,
|
||||
type=WorkflowType.CHAT,
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
@@ -251,13 +263,13 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
|
||||
workflow_run_id="test-workflow-run-id",
|
||||
total_tokens=50,
|
||||
total_steps=3,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
status=WorkflowExecutionStatus.FAILED,
|
||||
error_message="Test error message",
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == workflow_execution
|
||||
assert result.status == WorkflowExecutionStatus(WorkflowRunStatus.FAILED.value)
|
||||
assert result.status == WorkflowExecutionStatus.FAILED
|
||||
assert result.error_message == "Test error message"
|
||||
assert result.total_tokens == 50
|
||||
assert result.total_steps == 3
|
||||
@@ -269,11 +281,10 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
|
||||
# Create a real WorkflowExecution
|
||||
|
||||
workflow_execution = WorkflowExecution(
|
||||
id="test-workflow-execution-id",
|
||||
id_="test-workflow-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_version="1.0",
|
||||
sequence_number=1,
|
||||
type=WorkflowType.CHAT,
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
@@ -301,18 +312,18 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_node_execution_start(
|
||||
workflow_execution_id=workflow_execution.id,
|
||||
workflow_execution_id=workflow_execution.id_,
|
||||
event=event,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result.workflow_id == workflow_execution.workflow_id
|
||||
assert result.workflow_run_id == workflow_execution.id
|
||||
assert result.workflow_run_id == workflow_execution.id_
|
||||
assert result.node_execution_id == event.node_execution_id
|
||||
assert result.node_id == event.node_id
|
||||
assert result.node_type == event.node_type
|
||||
assert result.title == event.node_data.title
|
||||
assert result.status == NodeExecutionStatus.RUNNING
|
||||
assert result.status == WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
# Verify save was called
|
||||
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result)
|
||||
@@ -323,11 +334,10 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
|
||||
# Create a real WorkflowExecution
|
||||
|
||||
workflow_execution = WorkflowExecution(
|
||||
id="test-workflow-run-id",
|
||||
id_="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_version="1.0",
|
||||
sequence_number=1,
|
||||
type=WorkflowType.CHAT,
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
@@ -385,7 +395,7 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
|
||||
|
||||
# Verify the result
|
||||
assert result == node_execution
|
||||
assert result.status == NodeExecutionStatus.SUCCEEDED
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
# Verify save was called
|
||||
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution)
|
||||
@@ -396,11 +406,10 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl
|
||||
# Create a real WorkflowExecution
|
||||
|
||||
workflow_execution = WorkflowExecution(
|
||||
id="test-workflow-run-id",
|
||||
id_="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_version="1.0",
|
||||
sequence_number=1,
|
||||
type=WorkflowType.CHAT,
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "test query"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
@@ -464,7 +473,7 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
|
||||
|
||||
# Verify the result
|
||||
assert result == node_execution
|
||||
assert result.status == NodeExecutionStatus.FAILED
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == "Test error message"
|
||||
|
||||
# Verify save was called
|
||||
|
||||
@@ -13,12 +13,15 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
NodeExecution,
|
||||
NodeRunMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from models.account import Account, Tenant
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
def configure_mock_execution(mock_execution):
|
||||
@@ -297,7 +300,7 @@ def test_to_db_model(repository):
|
||||
inputs={"input_key": "input_value"},
|
||||
process_data={"process_key": "process_value"},
|
||||
outputs={"output_key": "output_value"},
|
||||
status=NodeExecutionStatus.RUNNING,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
error=None,
|
||||
elapsed_time=1.5,
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: 100, NodeRunMetadataKey.TOTAL_PRICE: Decimal("0.0")},
|
||||
@@ -388,7 +391,7 @@ def test_to_domain_model(repository):
|
||||
assert domain_model.inputs == inputs_dict
|
||||
assert domain_model.process_data == process_data_dict
|
||||
assert domain_model.outputs == outputs_dict
|
||||
assert domain_model.status == NodeExecutionStatus(db_model.status)
|
||||
assert domain_model.status == WorkflowNodeExecutionStatus(db_model.status)
|
||||
assert domain_model.error == db_model.error
|
||||
assert domain_model.elapsed_time == db_model.elapsed_time
|
||||
assert domain_model.metadata == metadata_dict
|
||||
|
||||
Reference in New Issue
Block a user