refactor: decouple Node and NodeData (#22581)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
-LAN-
2025-07-18 10:08:51 +08:00
committed by GitHub
parent 54c56f2d05
commit 460a825ef1
65 changed files with 2305 additions and 1146 deletions

View File

@@ -1,5 +1,4 @@
import json
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
@@ -11,8 +10,10 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.agent.strategy.plugin import PluginAgentStrategy
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.exc import PluginDaemonClientSideError
@@ -25,45 +26,75 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.tool_manager import ToolManager
from core.variables.segments import StringSegment
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exc import (
AgentInputTypeError,
AgentInvocationError,
AgentMessageTransformError,
AgentVariableNotFoundError,
AgentVariableTypeError,
ToolFileNotFoundError,
)
class AgentNode(ToolNode):
class AgentNode(BaseNode):
"""
Agent Node
"""
_node_data_cls = AgentNodeData # type: ignore
_node_type = NodeType.AGENT
_node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = AgentNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
"""
Run the agent node
"""
node_data = cast(AgentNodeData, self.node_data)
try:
strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id,
agent_strategy_provider_name=node_data.agent_strategy_provider_name,
agent_strategy_name=node_data.agent_strategy_name,
agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
agent_strategy_name=self._node_data.agent_strategy_name,
)
except Exception as e:
yield RunCompletedEvent(
@@ -81,13 +112,13 @@ class AgentNode(ToolNode):
parameters = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
node_data=self._node_data,
strategy=strategy,
)
parameters_for_log = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
node_data=self._node_data,
for_log=True,
strategy=strategy,
)
@@ -105,59 +136,39 @@ class AgentNode(ToolNode):
credentials=credentials,
)
except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=f"Failed to invoke agent: {str(e)}",
error=str(error),
)
)
return
try:
# convert tool messages
agent_thoughts: list = []
thought_log_message = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LOG,
message=ToolInvokeMessage.LogMessage(
id=str(uuid.uuid4()),
label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}",
parent_id=None,
error=None,
status=ToolInvokeMessage.LogMessage.LogStatus.START,
data={
"strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
"parameters": parameters_for_log,
"thought_process": "Agent strategy execution started",
},
metadata={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
},
),
)
def enhanced_message_stream():
yield thought_log_message
yield from message_stream
yield from self._transform_message(
message_stream,
{
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
},
parameters_for_log,
agent_thoughts,
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_type=self.type_,
node_id=self.node_id,
node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=f"Failed to transform agent message: {str(e)}",
error=str(transform_error),
)
)
@@ -194,7 +205,7 @@ class AgentNode(ToolNode):
if agent_input.type == "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise ValueError(f"Variable {agent_input.value} does not exist")
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
elif agent_input.type in {"mixed", "constant"}:
# variable_pool.convert_template expects a string template,
@@ -216,7 +227,7 @@ class AgentNode(ToolNode):
except json.JSONDecodeError:
parameter_value = parameter_value
else:
raise ValueError(f"Unknown agent input type '{agent_input.type}'")
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
@@ -259,7 +270,7 @@ class AgentNode(ToolNode):
)
extra = tool.get("extra", {})
runtime_variable_pool = variable_pool if self.node_data.version != "1" else None
runtime_variable_pool = variable_pool if self._node_data.version != "1" else None
tool_runtime = ToolManager.get_agent_tool_runtime(
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
)
@@ -343,19 +354,14 @@ class AgentNode(ToolNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: BaseNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
node_data = cast(AgentNodeData, node_data)
# Create typed NodeData from dict
typed_node_data = AgentNodeData.model_validate(node_data)
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
input = node_data.agent_parameters[parameter_name]
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
if input.type in ["mixed", "constant"]:
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
@@ -380,7 +386,7 @@ class AgentNode(ToolNode):
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}"
== cast(AgentNodeData, self.node_data).agent_strategy_provider_name
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:
@@ -448,3 +454,236 @@ class AgentNode(ToolNode):
return tools
else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value]
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage: LLMUsage | None = None
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(msg_metadata)
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
if message.message.json_object is not None:
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise AgentVariableTypeError(
"When 'stream' is True, 'variable_value' must be a string.",
variable_name=variable_name,
expected_type="str",
actual_type=type(variable_value).__name__,
)
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, File)
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.id == agent_log.id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json:
json_output.extend(json)
else:
json_output.append({"data": []})
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)

View File

@@ -0,0 +1,124 @@
from typing import Optional
class AgentNodeError(Exception):
"""Base exception for all agent node errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class AgentStrategyError(AgentNodeError):
"""Exception raised when there's an error with the agent strategy."""
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
self.strategy_name = strategy_name
self.provider_name = provider_name
super().__init__(message)
class AgentStrategyNotFoundError(AgentStrategyError):
"""Exception raised when the specified agent strategy is not found."""
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
super().__init__(
f"Agent strategy '{strategy_name}' not found"
+ (f" for provider '{provider_name}'" if provider_name else ""),
strategy_name,
provider_name,
)
class AgentInvocationError(AgentNodeError):
"""Exception raised when there's an error invoking the agent."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
self.original_error = original_error
super().__init__(message)
class AgentParameterError(AgentNodeError):
"""Exception raised when there's an error with agent parameters."""
def __init__(self, message: str, parameter_name: Optional[str] = None):
self.parameter_name = parameter_name
super().__init__(message)
class AgentVariableError(AgentNodeError):
"""Exception raised when there's an error with variables in the agent node."""
def __init__(self, message: str, variable_name: Optional[str] = None):
self.variable_name = variable_name
super().__init__(message)
class AgentVariableNotFoundError(AgentVariableError):
"""Exception raised when a variable is not found in the variable pool."""
def __init__(self, variable_name: str):
super().__init__(f"Variable '{variable_name}' does not exist", variable_name)
class AgentInputTypeError(AgentNodeError):
"""Exception raised when an unknown agent input type is encountered."""
def __init__(self, input_type: str):
super().__init__(f"Unknown agent input type '{input_type}'")
class ToolFileError(AgentNodeError):
"""Exception raised when there's an error with a tool file."""
def __init__(self, message: str, file_id: Optional[str] = None):
self.file_id = file_id
super().__init__(message)
class ToolFileNotFoundError(ToolFileError):
"""Exception raised when a tool file is not found."""
def __init__(self, file_id: str):
super().__init__(f"Tool file '{file_id}' does not exist", file_id)
class AgentMessageTransformError(AgentNodeError):
"""Exception raised when there's an error transforming agent messages."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
self.original_error = original_error
super().__init__(message)
class AgentModelError(AgentNodeError):
"""Exception raised when there's an error with the model used by the agent."""
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
self.model_name = model_name
self.provider = provider
super().__init__(message)
class AgentMemoryError(AgentNodeError):
"""Exception raised when there's an error with the agent's memory."""
def __init__(self, message: str, conversation_id: Optional[str] = None):
self.conversation_id = conversation_id
super().__init__(message)
class AgentVariableTypeError(AgentNodeError):
"""Exception raised when a variable has an unexpected type."""
def __init__(
self,
message: str,
variable_name: Optional[str] = None,
expected_type: Optional[str] = None,
actual_type: Optional[str] = None,
):
self.variable_name = variable_name
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)