feat: knowledge pipeline (#25360)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: jyong <718720800@qq.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: quicksand <quicksandzn@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Hanqing Zhao <sherry9277@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry <xh001x@hotmail.com>
This commit is contained in:
@@ -9,7 +9,8 @@ from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app_factory import create_app
|
||||
from models import Account, DifySetup, Tenant, TenantAccountJoin, db
|
||||
from extensions.ext_database import db
|
||||
from models import Account, DifySetup, Tenant, TenantAccountJoin
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
|
||||
|
||||
@@ -3,16 +3,27 @@ import unittest
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import StringVariable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes import NodeType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from factories.variable_factory import build_segment
|
||||
from libs import datetime_utils
|
||||
from models import db
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, VariableResetError, WorkflowDraftVariableService
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, WorkflowNodeExecutionModel
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
DraftVarLoader,
|
||||
VariableResetError,
|
||||
WorkflowDraftVariableService,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
@@ -175,6 +186,23 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
_node1_id = "test_loader_node_1"
|
||||
_node_exec_id = str(uuid.uuid4())
|
||||
|
||||
# @pytest.fixture
|
||||
# def test_app_id(self):
|
||||
# return str(uuid.uuid4())
|
||||
|
||||
# @pytest.fixture
|
||||
# def test_tenant_id(self):
|
||||
# return str(uuid.uuid4())
|
||||
|
||||
# @pytest.fixture
|
||||
# def session(self):
|
||||
# with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# yield session
|
||||
|
||||
# @pytest.fixture
|
||||
# def node_var(self, session):
|
||||
# pass
|
||||
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
self._test_tenant_id = str(uuid.uuid4())
|
||||
@@ -241,6 +269,246 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
node1_var = next(v for v in variables if v.selector[0] == self._node1_id)
|
||||
assert node1_var.id == self._node_var_id
|
||||
|
||||
@pytest.mark.usefixtures("setup_account")
|
||||
def test_load_offloaded_variable_string_type_integration(self, setup_account):
|
||||
"""Test _load_offloaded_variable with string type using DraftVariableSaver for data creation."""
|
||||
|
||||
# Create a large string that will be offloaded
|
||||
test_content = "x" * 15000 # Create a string larger than LARGE_VARIABLE_THRESHOLD (10KB)
|
||||
large_string_segment = StringSegment(value=test_content)
|
||||
|
||||
node_execution_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Use DraftVariableSaver to create offloaded variable (this mimics production)
|
||||
saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=self._test_app_id,
|
||||
node_id="test_offload_node",
|
||||
node_type=NodeType.LLM, # Use a real node type
|
||||
node_execution_id=node_execution_id,
|
||||
user=setup_account,
|
||||
)
|
||||
|
||||
# Save the variable - this will trigger offloading due to large size
|
||||
saver.save(outputs={"offloaded_string_var": large_string_segment})
|
||||
session.commit()
|
||||
|
||||
# Now test loading using DraftVarLoader
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
|
||||
# Load the variable using the standard workflow
|
||||
variables = var_loader.load_variables([["test_offload_node", "offloaded_string_var"]])
|
||||
|
||||
# Verify results
|
||||
assert len(variables) == 1
|
||||
loaded_variable = variables[0]
|
||||
assert loaded_variable.name == "offloaded_string_var"
|
||||
assert loaded_variable.selector == ["test_offload_node", "offloaded_string_var"]
|
||||
assert isinstance(loaded_variable.value, StringSegment)
|
||||
assert loaded_variable.value.value == test_content
|
||||
|
||||
finally:
|
||||
# Clean up - delete all draft variables for this app
|
||||
with Session(bind=db.engine) as session:
|
||||
service = WorkflowDraftVariableService(session)
|
||||
service.delete_workflow_variables(self._test_app_id)
|
||||
session.commit()
|
||||
|
||||
def test_load_offloaded_variable_object_type_integration(self):
|
||||
"""Test _load_offloaded_variable with object type using real storage and service."""
|
||||
|
||||
# Create a test object
|
||||
test_object = {"key1": "value1", "key2": 42, "nested": {"inner": "data"}}
|
||||
test_json = json.dumps(test_object, ensure_ascii=False, separators=(",", ":"))
|
||||
content_bytes = test_json.encode()
|
||||
|
||||
# Create an upload file record
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._test_tenant_id,
|
||||
storage_type="local",
|
||||
key=f"test_offload_{uuid.uuid4()}.json",
|
||||
name="test_offload.json",
|
||||
size=len(content_bytes),
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
used=True,
|
||||
used_by=str(uuid.uuid4()),
|
||||
used_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store the content in storage
|
||||
storage.save(upload_file.key, content_bytes)
|
||||
|
||||
# Create a variable file record
|
||||
variable_file = WorkflowDraftVariableFile(
|
||||
upload_file_id=upload_file.id,
|
||||
value_type=SegmentType.OBJECT,
|
||||
tenant_id=self._test_tenant_id,
|
||||
app_id=self._test_app_id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
size=len(content_bytes),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
try:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Add upload file and variable file first to get their IDs
|
||||
session.add_all([upload_file, variable_file])
|
||||
session.flush() # This generates the IDs
|
||||
|
||||
# Now create the offloaded draft variable with the correct file_id
|
||||
offloaded_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id="test_offload_node",
|
||||
name="offloaded_object_var",
|
||||
value=build_segment({"truncated": True}),
|
||||
visible=True,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
offloaded_var.file_id = variable_file.id
|
||||
|
||||
session.add(offloaded_var)
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
# Use the service method that properly preloads relationships
|
||||
service = WorkflowDraftVariableService(session)
|
||||
draft_vars = service.get_draft_variables_by_selectors(
|
||||
self._test_app_id, [["test_offload_node", "offloaded_object_var"]]
|
||||
)
|
||||
|
||||
assert len(draft_vars) == 1
|
||||
loaded_var = draft_vars[0]
|
||||
assert loaded_var.is_truncated()
|
||||
|
||||
# Create DraftVarLoader and test loading
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
|
||||
# Test the _load_offloaded_variable method
|
||||
selector_tuple, variable = var_loader._load_offloaded_variable(loaded_var)
|
||||
|
||||
# Verify the results
|
||||
assert selector_tuple == ("test_offload_node", "offloaded_object_var")
|
||||
assert variable.id == loaded_var.id
|
||||
assert variable.name == "offloaded_object_var"
|
||||
assert variable.value.value == test_object
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
with Session(bind=db.engine) as session:
|
||||
# Query and delete by ID to ensure they're tracked in this session
|
||||
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
|
||||
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
|
||||
session.query(UploadFile).filter_by(id=upload_file.id).delete()
|
||||
session.commit()
|
||||
# Clean up storage
|
||||
try:
|
||||
storage.delete(upload_file.key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup failures
|
||||
|
||||
def test_load_variables_with_offloaded_variables_integration(self):
|
||||
"""Test load_variables method with mix of regular and offloaded variables using real storage."""
|
||||
# Create a regular variable (already exists from setUp)
|
||||
# Create offloaded variable content
|
||||
test_content = "This is offloaded content for integration test"
|
||||
content_bytes = test_content.encode()
|
||||
|
||||
# Create upload file record
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._test_tenant_id,
|
||||
storage_type="local",
|
||||
key=f"test_integration_{uuid.uuid4()}.txt",
|
||||
name="test_integration.txt",
|
||||
size=len(content_bytes),
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
used=True,
|
||||
used_by=str(uuid.uuid4()),
|
||||
used_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store the content
|
||||
storage.save(upload_file.key, content_bytes)
|
||||
|
||||
# Create variable file
|
||||
variable_file = WorkflowDraftVariableFile(
|
||||
upload_file_id=upload_file.id,
|
||||
value_type=SegmentType.STRING,
|
||||
tenant_id=self._test_tenant_id,
|
||||
app_id=self._test_app_id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
size=len(content_bytes),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
try:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Add upload file and variable file first to get their IDs
|
||||
session.add_all([upload_file, variable_file])
|
||||
session.flush() # This generates the IDs
|
||||
|
||||
# Now create the offloaded draft variable with the correct file_id
|
||||
offloaded_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id="test_integration_node",
|
||||
name="offloaded_integration_var",
|
||||
value=build_segment("truncated"),
|
||||
visible=True,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
offloaded_var.file_id = variable_file.id
|
||||
|
||||
session.add(offloaded_var)
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
# Test load_variables with both regular and offloaded variables
|
||||
# This method should handle the relationship preloading internally
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
|
||||
variables = var_loader.load_variables(
|
||||
[
|
||||
[SYSTEM_VARIABLE_NODE_ID, "sys_var"], # Regular variable from setUp
|
||||
["test_integration_node", "offloaded_integration_var"], # Offloaded variable
|
||||
]
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert len(variables) == 2
|
||||
|
||||
# Find regular variable
|
||||
regular_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID)
|
||||
assert regular_var.id == self._sys_var_id
|
||||
assert regular_var.value == "sys_value"
|
||||
|
||||
# Find offloaded variable
|
||||
offloaded_loaded_var = next(v for v in variables if v.selector[0] == "test_integration_node")
|
||||
assert offloaded_loaded_var.id == offloaded_var.id
|
||||
assert offloaded_loaded_var.value == test_content
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
with Session(bind=db.engine) as session:
|
||||
# Query and delete by ID to ensure they're tracked in this session
|
||||
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
|
||||
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
|
||||
session.query(UploadFile).filter_by(id=upload_file.id).delete()
|
||||
session.commit()
|
||||
# Clean up storage
|
||||
try:
|
||||
storage.delete(upload_file.key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup failures
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
@@ -272,7 +540,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
index=1,
|
||||
node_execution_id=self._node_exec_id,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
node_id=self._node_id,
|
||||
node_type=NodeType.LLM.value,
|
||||
title="Test Node",
|
||||
@@ -281,7 +549,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
outputs='{"test_var": "output_value", "other_var": "other_output"}',
|
||||
status="succeeded",
|
||||
elapsed_time=1.5,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
@@ -336,10 +604,14 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
)
|
||||
self._conv_var.last_edited_at = datetime_utils.naive_utc_now()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as persistent_session, persistent_session.begin():
|
||||
persistent_session.add(
|
||||
self._workflow_node_execution,
|
||||
)
|
||||
|
||||
# Add all to database
|
||||
db.session.add_all(
|
||||
[
|
||||
self._workflow_node_execution,
|
||||
self._node_var_with_exec,
|
||||
self._node_var_without_exec,
|
||||
self._node_var_missing_exec,
|
||||
@@ -354,6 +626,14 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
self._node_var_missing_exec_id = self._node_var_missing_exec.id
|
||||
self._conv_var_id = self._conv_var.id
|
||||
|
||||
def tearDown(self):
|
||||
self._session.rollback()
|
||||
with Session(db.engine) as session, session.begin():
|
||||
stmt = delete(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.id == self._workflow_node_execution.id
|
||||
)
|
||||
session.execute(stmt)
|
||||
|
||||
def _get_test_srv(self) -> WorkflowDraftVariableService:
|
||||
return WorkflowDraftVariableService(session=self._session)
|
||||
|
||||
@@ -377,12 +657,10 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
created_by=str(uuid.uuid4()),
|
||||
environment_variables=[],
|
||||
conversation_variables=conversation_vars,
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
return workflow
|
||||
|
||||
def tearDown(self):
|
||||
self._session.rollback()
|
||||
|
||||
def test_reset_node_variable_with_valid_execution_record(self):
|
||||
"""Test resetting a node variable with valid execution record - should restore from execution"""
|
||||
srv = self._get_test_srv()
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from models import Tenant, db
|
||||
from models.model import App
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, UploadFile
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
|
||||
|
||||
|
||||
@@ -212,3 +215,256 @@ class TestDeleteDraftVariablesIntegration:
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(query)
|
||||
|
||||
|
||||
class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
"""Integration tests for draft variable deletion with Offload data."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_offload_test_data(self, app_and_tenant):
|
||||
"""Create test data with draft variables that have associated Offload files."""
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
# Create UploadFile records
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
db.session.add(upload_file1)
|
||||
db.session.add(upload_file2)
|
||||
db.session.flush()
|
||||
|
||||
# Create WorkflowDraftVariableFile records
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
var_file1 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file1.id,
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.STRING,
|
||||
)
|
||||
var_file2 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file2.id,
|
||||
size=2048,
|
||||
length=20,
|
||||
value_type=SegmentType.OBJECT,
|
||||
)
|
||||
db.session.add(var_file1)
|
||||
db.session.add(var_file2)
|
||||
db.session.flush()
|
||||
|
||||
# Create WorkflowDraftVariable records with file associations
|
||||
draft_var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_1",
|
||||
name="large_var_1",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file1.id,
|
||||
)
|
||||
draft_var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_2",
|
||||
name="large_var_2",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file2.id,
|
||||
)
|
||||
# Create a regular variable without Offload data
|
||||
draft_var3 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_3",
|
||||
name="regular_var",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
db.session.add(draft_var1)
|
||||
db.session.add(draft_var2)
|
||||
db.session.add(draft_var3)
|
||||
db.session.commit()
|
||||
|
||||
yield {
|
||||
"app": app,
|
||||
"tenant": tenant,
|
||||
"upload_files": [upload_file1, upload_file2],
|
||||
"variable_files": [var_file1, var_file2],
|
||||
"draft_variables": [draft_var1, draft_var2, draft_var3],
|
||||
}
|
||||
|
||||
# Cleanup
|
||||
db.session.rollback()
|
||||
|
||||
# Clean up any remaining records
|
||||
for table, ids in [
|
||||
(WorkflowDraftVariable, [v.id for v in [draft_var1, draft_var2, draft_var3]]),
|
||||
(WorkflowDraftVariableFile, [vf.id for vf in [var_file1, var_file2]]),
|
||||
(UploadFile, [uf.id for uf in [upload_file1, upload_file2]]),
|
||||
]:
|
||||
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
|
||||
db.session.execute(cleanup_query)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that deleting draft variables also cleans up associated Offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Mock storage deletion to succeed
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Verify initial state
|
||||
draft_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_before = db.session.query(UploadFile).count()
|
||||
|
||||
assert draft_vars_before == 3 # 2 with files + 1 regular
|
||||
assert var_files_before == 2
|
||||
assert upload_files_before == 2
|
||||
|
||||
# Delete draft variables
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
|
||||
# Verify results
|
||||
assert deleted_count == 3
|
||||
|
||||
# Check that all draft variables are deleted
|
||||
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert draft_vars_after == 0
|
||||
|
||||
# Check that associated Offload data is cleaned up
|
||||
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = db.session.query(UploadFile).count()
|
||||
|
||||
assert var_files_after == 0 # All variable files should be deleted
|
||||
assert upload_files_after == 0 # All upload files should be deleted
|
||||
|
||||
# Verify storage deletion was called for both files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
storage_keys_deleted = [call.args[0] for call in mock_storage.delete.call_args_list]
|
||||
assert "test/file1.json" in storage_keys_deleted
|
||||
assert "test/file2.json" in storage_keys_deleted
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that database cleanup continues even when storage deletion fails."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Mock storage deletion to fail for first file, succeed for second
|
||||
mock_storage.delete.side_effect = [Exception("Storage error"), None]
|
||||
|
||||
# Delete draft variables
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
|
||||
# Verify that all draft variables are still deleted
|
||||
assert deleted_count == 3
|
||||
|
||||
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert draft_vars_after == 0
|
||||
|
||||
# Database cleanup should still succeed even with storage errors
|
||||
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = db.session.query(UploadFile).count()
|
||||
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
# Verify storage deletion was attempted for both files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_partial_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
"""Test deletion with mix of variables with and without Offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Create additional app with only regular variables (no offload data)
|
||||
tenant = data["tenant"]
|
||||
app2 = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App 2",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app2)
|
||||
db.session.flush()
|
||||
|
||||
# Add regular variables to app2
|
||||
regular_vars = []
|
||||
for i in range(3):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app2.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(var)
|
||||
regular_vars.append(var)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Mock storage deletion
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Delete variables for app2 (no offload data)
|
||||
deleted_count_app2 = delete_draft_variables_batch(app2.id, batch_size=10)
|
||||
assert deleted_count_app2 == 3
|
||||
|
||||
# Verify storage wasn't called for app2 (no offload files)
|
||||
mock_storage.delete.assert_not_called()
|
||||
|
||||
# Delete variables for original app (with offload data)
|
||||
deleted_count_app1 = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
assert deleted_count_app1 == 3
|
||||
|
||||
# Now storage should be called for the offload files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
finally:
|
||||
# Cleanup app2 and its variables
|
||||
cleanup_vars_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.app_id == app2.id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(cleanup_vars_query)
|
||||
|
||||
app2_obj = db.session.get(App, app2.id)
|
||||
if app2_obj:
|
||||
db.session.delete(app2_obj)
|
||||
db.session.commit()
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import environs
|
||||
import os
|
||||
|
||||
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
|
||||
|
||||
env = environs.Env()
|
||||
|
||||
|
||||
class Config:
|
||||
SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070")
|
||||
SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN")
|
||||
SEARCH_PWD = env.str("SEARCH_PWD", "ADMIN")
|
||||
USING_UGC = env.bool("USING_UGC", True)
|
||||
SEARCH_ENDPOINT = os.environ.get(
|
||||
"SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070"
|
||||
)
|
||||
SEARCH_USERNAME = os.environ.get("SEARCH_USERNAME", "ADMIN")
|
||||
SEARCH_PWD = os.environ.get("SEARCH_PWD", "ADMIN")
|
||||
USING_UGC = os.environ.get("USING_UGC", "True").lower() == "true"
|
||||
|
||||
|
||||
class TestLindormVectorStore(AbstractVectorTest):
|
||||
|
||||
@@ -5,16 +5,14 @@ from os import getenv
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000"))
|
||||
@@ -29,15 +27,12 @@ def init_code_node(code_config: dict):
|
||||
"target": "code",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, code_config],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@@ -56,12 +51,21 @@ def init_code_node(code_config: dict):
|
||||
variable_pool.add(["code", "args1"], 1)
|
||||
variable_pool.add(["code", "args2"], 2)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = CodeNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=code_config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
@@ -85,6 +89,7 @@ def test_execute_code(setup_code_executor_mock):
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "number",
|
||||
@@ -114,7 +119,7 @@ def test_execute_code(setup_code_executor_mock):
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] == 3
|
||||
assert result.error is None
|
||||
assert result.error == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
@@ -131,6 +136,7 @@ def test_execute_code_output_validator(setup_code_executor_mock):
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "string",
|
||||
@@ -158,7 +164,7 @@ def test_execute_code_output_validator(setup_code_executor_mock):
|
||||
result = node._run()
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == "Output variable `result` must be a string"
|
||||
assert result.error == "Output result must be a string, got int instead"
|
||||
|
||||
|
||||
def test_execute_code_output_validator_depth():
|
||||
@@ -176,6 +182,7 @@ def test_execute_code_output_validator_depth():
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"string_validator": {
|
||||
"type": "string",
|
||||
@@ -294,6 +301,7 @@ def test_execute_code_output_object_list():
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"object_list": {
|
||||
"type": "array[object]",
|
||||
@@ -354,7 +362,8 @@ def test_execute_code_output_object_list():
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
|
||||
def test_execute_code_scientific_notation():
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_code_scientific_notation(setup_code_executor_mock):
|
||||
code = """
|
||||
def main():
|
||||
return {
|
||||
@@ -366,6 +375,7 @@ def test_execute_code_scientific_notation():
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "number",
|
||||
|
||||
@@ -5,14 +5,12 @@ from urllib.parse import urlencode
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
|
||||
|
||||
|
||||
@@ -25,15 +23,12 @@ def init_http_node(config: dict):
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@@ -52,12 +47,21 @@ def init_http_node(config: dict):
|
||||
variable_pool.add(["a", "args1"], 1)
|
||||
variable_pool.add(["a", "args2"], 2)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
@@ -73,6 +77,7 @@ def test_get(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@@ -106,6 +111,7 @@ def test_no_auth(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@@ -135,6 +141,7 @@ def test_custom_authorization_header(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@@ -227,6 +234,7 @@ def test_bearer_authorization_with_custom_header_ignored(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@@ -267,6 +275,7 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@@ -306,6 +315,7 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@@ -339,6 +349,7 @@ def test_template(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@@ -374,6 +385,7 @@ def test_json(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "post",
|
||||
@@ -416,6 +428,7 @@ def test_x_www_form_urlencoded(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "post",
|
||||
@@ -463,6 +476,7 @@ def test_form_data(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "post",
|
||||
@@ -513,6 +527,7 @@ def test_none_data(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "post",
|
||||
@@ -546,6 +561,7 @@ def test_mock_404(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@@ -575,6 +591,7 @@ def test_multi_colons_parse(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@@ -627,10 +644,11 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{"data": {"type": "start", "title": "Start"}, "id": "start"},
|
||||
{
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@@ -651,12 +669,9 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@@ -676,12 +691,21 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
variable_pool.add(["a", "args2"], 2)
|
||||
variable_pool.add(["a", "args3"], {"nested": "nested_value"}) # Only for this test
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=graph_config["nodes"][1],
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
|
||||
@@ -6,17 +6,15 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
|
||||
@@ -30,11 +28,9 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
"target": "llm",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
# Use proper UUIDs for database compatibility
|
||||
tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||
app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
|
||||
@@ -44,7 +40,6 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
@@ -69,12 +64,21 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
)
|
||||
variable_pool.add(["abc", "output"], "sunny")
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = LLMNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
@@ -173,15 +177,15 @@ def test_execute_llm():
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
if item.run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
print(f"Error: {item.run_result.error}")
|
||||
print(f"Error type: {item.run_result.error_type}")
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert item.run_result.outputs is not None
|
||||
assert item.run_result.outputs.get("text") is not None
|
||||
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
if item.node_run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
print(f"Error: {item.node_run_result.error}")
|
||||
print(f"Error type: {item.node_run_result.error_type}")
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.process_data is not None
|
||||
assert item.node_run_result.outputs is not None
|
||||
assert item.node_run_result.outputs.get("text") is not None
|
||||
assert item.node_run_result.outputs.get("usage", {})["total_tokens"] > 0
|
||||
|
||||
|
||||
def test_execute_llm_with_jinja2():
|
||||
@@ -284,11 +288,11 @@ def test_execute_llm_with_jinja2():
|
||||
result = node._run()
|
||||
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert "sunny" in json.dumps(item.run_result.process_data)
|
||||
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.process_data is not None
|
||||
assert "sunny" in json.dumps(item.node_run_result.process_data)
|
||||
assert "what's the weather today?" in json.dumps(item.node_run_result.process_data)
|
||||
|
||||
|
||||
def test_extract_json():
|
||||
|
||||
@@ -5,11 +5,10 @@ from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.model_runtime.entities import AssistantPromptMessage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
@@ -17,7 +16,6 @@ from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from models.workflow import WorkflowType
|
||||
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
|
||||
|
||||
|
||||
@@ -44,15 +42,12 @@ def init_parameter_extractor_node(config: dict):
|
||||
"target": "llm",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@@ -73,12 +68,21 @@ def init_parameter_extractor_node(config: dict):
|
||||
variable_pool.add(["a", "args1"], 1)
|
||||
variable_pool.add(["a", "args2"], 2)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = ParameterExtractorNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
return node
|
||||
|
||||
@@ -4,15 +4,13 @@ import uuid
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
|
||||
@@ -22,6 +20,7 @@ def test_execute_code(setup_code_executor_mock):
|
||||
config = {
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "template-transform",
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
@@ -42,15 +41,12 @@ def test_execute_code(setup_code_executor_mock):
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@@ -69,12 +65,21 @@ def test_execute_code(setup_code_executor_mock):
|
||||
variable_pool.add(["1", "args1"], 1)
|
||||
variable_pool.add(["1", "args2"], 3)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
|
||||
|
||||
@@ -4,16 +4,14 @@ from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
def init_tool_node(config: dict):
|
||||
@@ -25,15 +23,12 @@ def init_tool_node(config: dict):
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@@ -50,12 +45,21 @@ def init_tool_node(config: dict):
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = ToolNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
return node
|
||||
@@ -66,6 +70,7 @@ def test_tool_variable_invoke():
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "tool",
|
||||
"title": "a",
|
||||
"desc": "a",
|
||||
"provider_id": "time",
|
||||
@@ -86,10 +91,10 @@ def test_tool_variable_invoke():
|
||||
# execute node
|
||||
result = node._run()
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs is not None
|
||||
assert item.run_result.outputs.get("text") is not None
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.outputs is not None
|
||||
assert item.node_run_result.outputs.get("text") is not None
|
||||
|
||||
|
||||
def test_tool_mixed_invoke():
|
||||
@@ -97,6 +102,7 @@ def test_tool_mixed_invoke():
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "tool",
|
||||
"title": "a",
|
||||
"desc": "a",
|
||||
"provider_id": "time",
|
||||
@@ -117,7 +123,7 @@ def test_tool_mixed_invoke():
|
||||
# execute node
|
||||
result = node._run()
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs is not None
|
||||
assert item.run_result.outputs.get("text") is not None
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.outputs is not None
|
||||
assert item.node_run_result.outputs.get("text") is not None
|
||||
|
||||
Reference in New Issue
Block a user