Refactor/remove db from cycle manager (#20455)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-05-30 04:34:13 +08:00
committed by GitHub
parent cd0a05f114
commit 482e50aae9
81 changed files with 345 additions and 362 deletions

View File

@@ -4,7 +4,7 @@ from constants import UUID_NIL
from core.prompt.utils.extract_thread_messages import extract_thread_messages
class TestMessage:
class MockMessage:
def __init__(self, id, parent_message_id):
self.id = id
self.parent_message_id = parent_message_id
@@ -14,7 +14,7 @@ class TestMessage:
def test_extract_thread_messages_single_message():
messages = [TestMessage(str(uuid4()), UUID_NIL)]
messages = [MockMessage(str(uuid4()), UUID_NIL)]
result = extract_thread_messages(messages)
assert len(result) == 1
assert result[0] == messages[0]
@@ -23,11 +23,11 @@ def test_extract_thread_messages_single_message():
def test_extract_thread_messages_linear_thread():
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id5, id4),
TestMessage(id4, id3),
TestMessage(id3, id2),
TestMessage(id2, id1),
TestMessage(id1, UUID_NIL),
MockMessage(id5, id4),
MockMessage(id4, id3),
MockMessage(id3, id2),
MockMessage(id2, id1),
MockMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 5
@@ -37,10 +37,10 @@ def test_extract_thread_messages_linear_thread():
def test_extract_thread_messages_branched_thread():
id1, id2, id3, id4 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id4, id2),
TestMessage(id3, id2),
TestMessage(id2, id1),
TestMessage(id1, UUID_NIL),
MockMessage(id4, id2),
MockMessage(id3, id2),
MockMessage(id2, id1),
MockMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 3
@@ -56,9 +56,9 @@ def test_extract_thread_messages_empty_list():
def test_extract_thread_messages_partially_loaded():
id0, id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id3, id2),
TestMessage(id2, id1),
TestMessage(id1, id0),
MockMessage(id3, id2),
MockMessage(id2, id1),
MockMessage(id1, id0),
]
result = extract_thread_messages(messages)
assert len(result) == 3
@@ -68,9 +68,9 @@ def test_extract_thread_messages_partially_loaded():
def test_extract_thread_messages_legacy_messages():
id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id3, UUID_NIL),
TestMessage(id2, UUID_NIL),
TestMessage(id1, UUID_NIL),
MockMessage(id3, UUID_NIL),
MockMessage(id2, UUID_NIL),
MockMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 3
@@ -80,11 +80,11 @@ def test_extract_thread_messages_legacy_messages():
def test_extract_thread_messages_mixed_with_legacy_messages():
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id5, id4),
TestMessage(id4, id2),
TestMessage(id3, id2),
TestMessage(id2, UUID_NIL),
TestMessage(id1, UUID_NIL),
MockMessage(id5, id4),
MockMessage(id4, id2),
MockMessage(id3, id2),
MockMessage(id2, UUID_NIL),
MockMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 4

View File

@@ -1,5 +1,5 @@
import pytest
from pydantic.error_wrappers import ValidationError
from pydantic import ValidationError
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig

View File

@@ -6,6 +6,7 @@ from flask import Flask
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
BaseNodeEvent,
@@ -25,7 +26,7 @@ from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
from models.workflow import WorkflowType
@pytest.fixture

View File

@@ -4,6 +4,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
@@ -11,7 +12,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
from core.workflow.nodes.answer.answer_node import AnswerNode
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
from models.workflow import WorkflowType
def test_execute_answer():

View File

@@ -4,6 +4,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileVariable, FileVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
@@ -15,7 +16,7 @@ from core.workflow.nodes.http_request import (
HttpRequestNodeData,
)
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
from models.workflow import WorkflowType
def test_http_request_node_binary_file(monkeypatch):

View File

@@ -5,6 +5,7 @@ from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
@@ -14,7 +15,7 @@ from core.workflow.nodes.iteration.entities import ErrorHandleMode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
from models.workflow import WorkflowType
def test_run():

View File

@@ -4,6 +4,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
@@ -11,7 +12,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
from core.workflow.nodes.answer.answer_node import AnswerNode
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
from models.workflow import WorkflowType
def test_execute_answer():

View File

@@ -2,6 +2,7 @@ from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
@@ -14,7 +15,7 @@ from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
from models.workflow import WorkflowType
class ContinueOnErrorTestHelper:

View File

@@ -7,6 +7,7 @@ from core.file import File, FileTransferMethod
from core.variables import ArrayFileSegment
from core.variables.variables import StringVariable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData
from core.workflow.nodes.document_extractor.node import (
_extract_text_from_docx,
@@ -15,7 +16,6 @@ from core.workflow.nodes.document_extractor.node import (
_extract_text_from_plain_text,
)
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
@pytest.fixture

View File

@@ -6,6 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
@@ -15,7 +16,7 @@ from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
from models.workflow import WorkflowType
def test_execute_if_else_result_true():

View File

@@ -4,6 +4,7 @@ import pytest
from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileSegment
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.list_operator.entities import (
ExtractConfig,
FilterBy,
@@ -14,7 +15,6 @@ from core.workflow.nodes.list_operator.entities import (
)
from core.workflow.nodes.list_operator.exc import InvalidKeyError
from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func
from models.workflow import WorkflowNodeExecutionStatus
@pytest.fixture

View File

@@ -7,6 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderTyp
from core.tools.errors import ToolInvokeError
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
@@ -14,7 +15,7 @@ from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.tool.entities import ToolNodeData
from models import UserFrom, WorkflowNodeExecutionStatus, WorkflowType
from models import UserFrom, WorkflowType
def _create_tool_node():

View File

@@ -12,21 +12,20 @@ from core.app.entities.queue_entities import (
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus
from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.entities.workflow_node_execution import (
NodeExecution,
NodeRunMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import (
Workflow,
WorkflowRun,
WorkflowRunStatus,
)
from models.workflow import Workflow, WorkflowRun
@pytest.fixture
@@ -93,16 +92,38 @@ def mock_workflow_execution_repository():
return repo
@pytest.fixture
def real_workflow_entity():
return CycleManagerWorkflowInfo(
workflow_id="test-workflow-id", # Matches ID used in other fixtures
workflow_type=WorkflowType.CHAT,
version="1.0.0",
graph_data={
"nodes": [
{
"id": "node1",
"type": "chat", # NodeType is a string enum
"name": "Chat Node",
"data": {"model": "gpt-3.5-turbo", "prompt": "test prompt"},
}
],
"edges": [],
},
)
@pytest.fixture
def workflow_cycle_manager(
real_app_generate_entity,
real_workflow_system_variables,
mock_workflow_execution_repository,
mock_node_execution_repository,
real_workflow_entity,
):
return WorkflowCycleManager(
application_generate_entity=real_app_generate_entity,
workflow_system_variables=real_workflow_system_variables,
workflow_info=real_workflow_entity,
workflow_execution_repository=mock_workflow_execution_repository,
workflow_node_execution_repository=mock_node_execution_repository,
)
@@ -148,7 +169,7 @@ def real_workflow_run():
workflow_run.version = "1.0"
workflow_run.graph = json.dumps({"nodes": [], "edges": []})
workflow_run.inputs = json.dumps({"query": "test query"})
workflow_run.status = WorkflowRunStatus.RUNNING
workflow_run.status = WorkflowExecutionStatus.RUNNING
workflow_run.outputs = json.dumps({"answer": "test answer"})
workflow_run.created_by_role = CreatorUserRole.ACCOUNT
workflow_run.created_by = "test-user-id"
@@ -171,20 +192,13 @@ def test_init(
assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository
def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, real_workflow):
def test_handle_workflow_run_start(workflow_cycle_manager):
"""Test handle_workflow_run_start method"""
# Mock session.scalar to return the workflow and max sequence
mock_session.scalar.side_effect = [real_workflow, 5]
# Call the method
workflow_execution = workflow_cycle_manager.handle_workflow_run_start(
session=mock_session,
workflow_id="test-workflow-id",
)
workflow_execution = workflow_cycle_manager.handle_workflow_run_start()
# Verify the result
assert workflow_execution.workflow_id == real_workflow.id
assert workflow_execution.sequence_number == 6 # max_sequence + 1
assert workflow_execution.workflow_id == "test-workflow-id"
# Verify the workflow_execution_repository.save was called
workflow_cycle_manager._workflow_execution_repository.save.assert_called_once_with(workflow_execution)
@@ -195,11 +209,10 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu
# Create a real WorkflowExecution
workflow_execution = WorkflowExecution(
id="test-workflow-run-id",
id_="test-workflow-run-id",
workflow_id="test-workflow-id",
workflow_version="1.0",
sequence_number=1,
type=WorkflowType.CHAT,
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
@@ -230,11 +243,10 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
# Create a real WorkflowExecution
workflow_execution = WorkflowExecution(
id="test-workflow-run-id",
id_="test-workflow-run-id",
workflow_id="test-workflow-id",
workflow_version="1.0",
sequence_number=1,
type=WorkflowType.CHAT,
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
@@ -251,13 +263,13 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
workflow_run_id="test-workflow-run-id",
total_tokens=50,
total_steps=3,
status=WorkflowRunStatus.FAILED,
status=WorkflowExecutionStatus.FAILED,
error_message="Test error message",
)
# Verify the result
assert result == workflow_execution
assert result.status == WorkflowExecutionStatus(WorkflowRunStatus.FAILED.value)
assert result.status == WorkflowExecutionStatus.FAILED
assert result.error_message == "Test error message"
assert result.total_tokens == 50
assert result.total_steps == 3
@@ -269,11 +281,10 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
# Create a real WorkflowExecution
workflow_execution = WorkflowExecution(
id="test-workflow-execution-id",
id_="test-workflow-execution-id",
workflow_id="test-workflow-id",
workflow_version="1.0",
sequence_number=1,
type=WorkflowType.CHAT,
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
@@ -301,18 +312,18 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
# Call the method
result = workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=workflow_execution.id,
workflow_execution_id=workflow_execution.id_,
event=event,
)
# Verify the result
assert result.workflow_id == workflow_execution.workflow_id
assert result.workflow_run_id == workflow_execution.id
assert result.workflow_run_id == workflow_execution.id_
assert result.node_execution_id == event.node_execution_id
assert result.node_id == event.node_id
assert result.node_type == event.node_type
assert result.title == event.node_data.title
assert result.status == NodeExecutionStatus.RUNNING
assert result.status == WorkflowNodeExecutionStatus.RUNNING
# Verify save was called
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result)
@@ -323,11 +334,10 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
# Create a real WorkflowExecution
workflow_execution = WorkflowExecution(
id="test-workflow-run-id",
id_="test-workflow-run-id",
workflow_id="test-workflow-id",
workflow_version="1.0",
sequence_number=1,
type=WorkflowType.CHAT,
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
@@ -385,7 +395,7 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
# Verify the result
assert result == node_execution
assert result.status == NodeExecutionStatus.SUCCEEDED
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
# Verify save was called
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution)
@@ -396,11 +406,10 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl
# Create a real WorkflowExecution
workflow_execution = WorkflowExecution(
id="test-workflow-run-id",
id_="test-workflow-run-id",
workflow_id="test-workflow-id",
workflow_version="1.0",
sequence_number=1,
type=WorkflowType.CHAT,
workflow_type=WorkflowType.CHAT,
graph={"nodes": [], "edges": []},
inputs={"query": "test query"},
started_at=datetime.now(UTC).replace(tzinfo=None),
@@ -464,7 +473,7 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
# Verify the result
assert result == node_execution
assert result.status == NodeExecutionStatus.FAILED
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error == "Test error message"
# Verify save was called

View File

@@ -13,12 +13,15 @@ from sqlalchemy.orm import Session, sessionmaker
from core.model_runtime.utils.encoders import jsonable_encoder
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus
from core.workflow.entities.workflow_node_execution import (
NodeExecution,
NodeRunMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from models.account import Account, Tenant
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionTriggeredFrom
def configure_mock_execution(mock_execution):
@@ -297,7 +300,7 @@ def test_to_db_model(repository):
inputs={"input_key": "input_value"},
process_data={"process_key": "process_value"},
outputs={"output_key": "output_value"},
status=NodeExecutionStatus.RUNNING,
status=WorkflowNodeExecutionStatus.RUNNING,
error=None,
elapsed_time=1.5,
metadata={NodeRunMetadataKey.TOTAL_TOKENS: 100, NodeRunMetadataKey.TOTAL_PRICE: Decimal("0.0")},
@@ -388,7 +391,7 @@ def test_to_domain_model(repository):
assert domain_model.inputs == inputs_dict
assert domain_model.process_data == process_data_dict
assert domain_model.outputs == outputs_dict
assert domain_model.status == NodeExecutionStatus(db_model.status)
assert domain_model.status == WorkflowNodeExecutionStatus(db_model.status)
assert domain_model.error == db_model.error
assert domain_model.elapsed_time == db_model.elapsed_time
assert domain_model.metadata == metadata_dict