feat(workflow): workflow as tool output schema (#26241)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Novice <novice12185727@gmail.com>
This commit is contained in:
CrabSAMA
2025-11-27 16:50:48 +08:00
committed by GitHub
parent 299bd351fd
commit 820925a866
21 changed files with 438 additions and 34 deletions

View File

@@ -257,7 +257,6 @@ class TestWorkflowToolManageService:
# Attempt to create second workflow tool with same name
second_tool_parameters = self._create_test_workflow_tool_parameters()
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
@@ -309,7 +308,6 @@ class TestWorkflowToolManageService:
# Attempt to create workflow tool with non-existent app
tool_parameters = self._create_test_workflow_tool_parameters()
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
@@ -365,7 +363,6 @@ class TestWorkflowToolManageService:
"required": True,
}
]
# Attempt to create workflow tool with invalid parameters
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
@@ -416,7 +413,6 @@ class TestWorkflowToolManageService:
# Create first workflow tool
first_tool_name = fake.word()
first_tool_parameters = self._create_test_workflow_tool_parameters()
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
@@ -431,7 +427,6 @@ class TestWorkflowToolManageService:
# Attempt to create second workflow tool with same app_id but different name
second_tool_name = fake.word()
second_tool_parameters = self._create_test_workflow_tool_parameters()
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
@@ -486,7 +481,6 @@ class TestWorkflowToolManageService:
# Attempt to create workflow tool for app without workflow
tool_parameters = self._create_test_workflow_tool_parameters()
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
@@ -534,7 +528,6 @@ class TestWorkflowToolManageService:
# Create initial workflow tool
initial_tool_name = fake.word()
initial_tool_parameters = self._create_test_workflow_tool_parameters()
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
@@ -621,7 +614,6 @@ class TestWorkflowToolManageService:
# Attempt to update non-existent workflow tool
tool_parameters = self._create_test_workflow_tool_parameters()
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.update_workflow_tool(
user_id=account.id,
@@ -671,7 +663,6 @@ class TestWorkflowToolManageService:
# Create first workflow tool
first_tool_name = fake.word()
first_tool_parameters = self._create_test_workflow_tool_parameters()
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,

View File

@@ -3,7 +3,7 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
from core.tools.errors import ToolInvokeError
from core.tools.workflow_as_tool.tool import WorkflowTool
@@ -51,3 +51,166 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
# actually `run` the tool.
list(tool.invoke("test_user", {}))
assert exc_info.value.args == ("oops",)
def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch):
"""Test that WorkflowTool should generate variable messages when there are outputs"""
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
# Mock workflow outputs
mock_outputs = {"result": "success", "count": 42, "data": {"key": "value"}}
# needs to patch those methods to avoid database access.
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
# Mock user resolution to avoid database access
from unittest.mock import Mock
mock_user = Mock()
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
# replace `WorkflowAppGenerator.generate` 's return value.
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
lambda *args, **kwargs: {"data": {"outputs": mock_outputs}},
)
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
# Execute tool invocation
messages = list(tool.invoke("test_user", {}))
# Verify generated messages
# Should contain: 3 variable messages + 1 text message + 1 JSON message = 5 messages
assert len(messages) == 5
# Verify variable messages
variable_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.VARIABLE]
assert len(variable_messages) == 3
# Verify content of each variable message
variable_dict = {msg.message.variable_name: msg.message.variable_value for msg in variable_messages}
assert variable_dict["result"] == "success"
assert variable_dict["count"] == 42
assert variable_dict["data"] == {"key": "value"}
# Verify text message
text_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.TEXT]
assert len(text_messages) == 1
assert '{"result": "success", "count": 42, "data": {"key": "value"}}' in text_messages[0].message.text
# Verify JSON message
json_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.JSON]
assert len(json_messages) == 1
assert json_messages[0].message.json_object == mock_outputs
def test_workflow_tool_should_handle_empty_outputs(monkeypatch: pytest.MonkeyPatch):
"""Test that WorkflowTool should handle empty outputs correctly"""
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
# needs to patch those methods to avoid database access.
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
# Mock user resolution to avoid database access
from unittest.mock import Mock
mock_user = Mock()
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
# replace `WorkflowAppGenerator.generate` 's return value.
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
lambda *args, **kwargs: {"data": {}},
)
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
# Execute tool invocation
messages = list(tool.invoke("test_user", {}))
# Verify generated messages
# Should contain: 0 variable messages + 1 text message + 1 JSON message = 2 messages
assert len(messages) == 2
# Verify no variable messages
variable_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.VARIABLE]
assert len(variable_messages) == 0
# Verify text message
text_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.TEXT]
assert len(text_messages) == 1
assert text_messages[0].message.text == "{}"
# Verify JSON message
json_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.JSON]
assert len(json_messages) == 1
assert json_messages[0].message.json_object == {}
def test_create_variable_message():
"""Test the functionality of creating variable messages"""
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
# Test different types of variable values
test_cases = [
("string_var", "test string"),
("int_var", 42),
("float_var", 3.14),
("bool_var", True),
("list_var", [1, 2, 3]),
("dict_var", {"key": "value"}),
]
for var_name, var_value in test_cases:
message = tool.create_variable_message(var_name, var_value)
assert message.type == ToolInvokeMessage.MessageType.VARIABLE
assert message.message.variable_name == var_name
assert message.message.variable_value == var_value
assert message.message.stream is False

View File

@@ -14,7 +14,7 @@ from core.workflow.graph_events import (
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input import HumanInputNode
@@ -110,8 +110,12 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
end_primary_data = EndNodeData(
title="End Primary",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="primary_text", value_selector=["llm_primary", "text"]),
OutputVariableEntity(
variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"]
),
OutputVariableEntity(
variable="primary_text", value_type=OutputVariableType.STRING, value_selector=["llm_primary", "text"]
),
],
desc=None,
)
@@ -126,8 +130,14 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
end_secondary_data = EndNodeData(
title="End Secondary",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="secondary_text", value_selector=["llm_secondary", "text"]),
OutputVariableEntity(
variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"]
),
OutputVariableEntity(
variable="secondary_text",
value_type=OutputVariableType.STRING,
value_selector=["llm_secondary", "text"],
),
],
desc=None,
)

View File

@@ -13,7 +13,7 @@ from core.workflow.graph_events import (
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input import HumanInputNode
@@ -108,8 +108,12 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
end_data = EndNodeData(
title="End",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="resume_text", value_selector=["llm_resume", "text"]),
OutputVariableEntity(
variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"]
),
OutputVariableEntity(
variable="resume_text", value_type=OutputVariableType.STRING, value_selector=["llm_resume", "text"]
),
],
desc=None,
)

View File

@@ -11,7 +11,7 @@ from core.workflow.graph_events import (
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.if_else.entities import IfElseNodeData
@@ -123,8 +123,12 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
end_primary_data = EndNodeData(
title="End Primary",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="primary_text", value_selector=["llm_primary", "text"]),
OutputVariableEntity(
variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"]
),
OutputVariableEntity(
variable="primary_text", value_type=OutputVariableType.STRING, value_selector=["llm_primary", "text"]
),
],
desc=None,
)
@@ -139,8 +143,14 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
end_secondary_data = EndNodeData(
title="End Secondary",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="secondary_text", value_selector=["llm_secondary", "text"]),
OutputVariableEntity(
variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"]
),
OutputVariableEntity(
variable="secondary_text",
value_type=OutputVariableType.STRING,
value_selector=["llm_secondary", "text"],
),
],
desc=None,
)