feat(workflow): workflow as tool output schema (#26241)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Novice <novice12185727@gmail.com>
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
|
||||
@@ -25,3 +27,5 @@ class ApiToolBundle(BaseModel):
|
||||
icon: str | None = None
|
||||
# openapi operation
|
||||
openapi: dict
|
||||
# output schema
|
||||
output_schema: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@@ -24,6 +25,31 @@ class WorkflowToolConfigurationUtils:
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_output(cls, graph: Mapping[str, Any]) -> Sequence[OutputVariableEntity]:
|
||||
"""
|
||||
get workflow graph output
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
outputs_by_variable: dict[str, OutputVariableEntity] = {}
|
||||
variable_order: list[str] = []
|
||||
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") != "end":
|
||||
continue
|
||||
|
||||
for output in node.get("data", {}).get("outputs", []):
|
||||
entity = OutputVariableEntity.model_validate(output)
|
||||
variable = entity.variable
|
||||
|
||||
if variable not in variable_order:
|
||||
variable_order.append(variable)
|
||||
|
||||
# Later end nodes override duplicated variable definitions.
|
||||
outputs_by_variable[variable] = entity
|
||||
|
||||
return [outputs_by_variable[variable] for variable in variable_order]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
|
||||
@@ -162,6 +162,20 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
else:
|
||||
raise ValueError("variable not found")
|
||||
|
||||
# get output schema from workflow
|
||||
outputs = WorkflowToolConfigurationUtils.get_workflow_graph_output(graph)
|
||||
|
||||
reserved_keys = {"json", "text", "files"}
|
||||
|
||||
properties = {}
|
||||
for output in outputs:
|
||||
if output.variable not in reserved_keys:
|
||||
properties[output.variable] = {
|
||||
"type": output.value_type,
|
||||
"description": "",
|
||||
}
|
||||
output_schema = {"type": "object", "properties": properties}
|
||||
|
||||
return WorkflowTool(
|
||||
workflow_as_tool_id=db_provider.id,
|
||||
entity=ToolEntity(
|
||||
@@ -177,6 +191,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
llm=db_provider.description,
|
||||
),
|
||||
parameters=workflow_tool_parameters,
|
||||
output_schema=output_schema,
|
||||
),
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
|
||||
@@ -114,6 +114,11 @@ class WorkflowTool(Tool):
|
||||
for file in files:
|
||||
yield self.create_file_message(file) # type: ignore
|
||||
|
||||
# traverse `outputs` field and create variable messages
|
||||
for key, value in outputs.items():
|
||||
if key not in {"text", "json", "files"}:
|
||||
yield self.create_variable_message(variable_name=key, variable_value=value)
|
||||
|
||||
self._latest_usage = self._derive_usage_from_result(data)
|
||||
|
||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||
|
||||
Reference in New Issue
Block a user