feat: Parallel Execution of Nodes in Workflows (#8192)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
from typing import cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
@@ -12,10 +11,9 @@ class EndNode(BaseNode):
|
||||
_node_data_cls = EndNodeData
|
||||
_node_type = NodeType.END
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
@@ -24,7 +22,7 @@ class EndNode(BaseNode):
|
||||
|
||||
outputs = {}
|
||||
for variable_selector in output_variables:
|
||||
value = variable_pool.get_any(variable_selector.value_selector)
|
||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
outputs[variable_selector.variable] = value
|
||||
|
||||
return NodeRunResult(
|
||||
@@ -34,52 +32,16 @@ class EndNode(BaseNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]:
|
||||
"""
|
||||
Extract generate nodes
|
||||
:param graph: graph
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
node_data = cast(EndNodeData, node_data)
|
||||
|
||||
return cls.extract_generate_nodes_from_node_data(graph, node_data)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]:
|
||||
"""
|
||||
Extract generate nodes from node data
|
||||
:param graph: graph
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes', [])
|
||||
node_mapping = {node.get('id'): node for node in nodes}
|
||||
|
||||
variable_selectors = node_data.outputs
|
||||
|
||||
generate_nodes = []
|
||||
for variable_selector in variable_selectors:
|
||||
if not variable_selector.value_selector:
|
||||
continue
|
||||
|
||||
node_id = variable_selector.value_selector[0]
|
||||
if node_id != 'sys' and node_id in node_mapping:
|
||||
node = node_mapping[node_id]
|
||||
node_type = node.get('data', {}).get('type')
|
||||
if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text':
|
||||
generate_nodes.append(node_id)
|
||||
|
||||
# remove duplicates
|
||||
generate_nodes = list(set(generate_nodes))
|
||||
|
||||
return generate_nodes
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: EndNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
||||
148
api/core/workflow/nodes/end/end_stream_generate_router.py
Normal file
148
api/core/workflow/nodes/end/end_stream_generate_router.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
|
||||
|
||||
|
||||
class EndStreamGeneratorRouter:
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_parallel_mapping: dict[str, str]
|
||||
) -> EndStreamParam:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# parse stream output node value selector of end nodes
|
||||
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
|
||||
for end_node_id, node_config in node_id_config_mapping.items():
|
||||
if not node_config.get('data', {}).get('type') == NodeType.END.value:
|
||||
continue
|
||||
|
||||
# skip end node in parallel
|
||||
if end_node_id in node_parallel_mapping:
|
||||
continue
|
||||
|
||||
# get generate route for stream output
|
||||
stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config)
|
||||
end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors
|
||||
|
||||
# fetch end dependencies
|
||||
end_node_ids = list(end_stream_variable_selectors_mapping.keys())
|
||||
end_dependencies = cls._fetch_ends_dependencies(
|
||||
end_node_ids=end_node_ids,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_id_config_mapping=node_id_config_mapping
|
||||
)
|
||||
|
||||
return EndStreamParam(
|
||||
end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_stream_variable_selector_from_node_data(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
node_data: EndNodeData) -> list[list[str]]:
|
||||
"""
|
||||
Extract stream variable selector from node data
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_selectors = node_data.outputs
|
||||
|
||||
value_selectors = []
|
||||
for variable_selector in variable_selectors:
|
||||
if not variable_selector.value_selector:
|
||||
continue
|
||||
|
||||
node_id = variable_selector.value_selector[0]
|
||||
if node_id != 'sys' and node_id in node_id_config_mapping:
|
||||
node = node_id_config_mapping[node_id]
|
||||
node_type = node.get('data', {}).get('type')
|
||||
if (
|
||||
variable_selector.value_selector not in value_selectors
|
||||
and node_type == NodeType.LLM.value
|
||||
and variable_selector.value_selector[1] == 'text'
|
||||
):
|
||||
value_selectors.append(variable_selector.value_selector)
|
||||
|
||||
return value_selectors
|
||||
|
||||
@classmethod
|
||||
def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \
|
||||
-> list[list[str]]:
|
||||
"""
|
||||
Extract stream variable selector from node config
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = EndNodeData(**config.get("data", {}))
|
||||
return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data)
|
||||
|
||||
@classmethod
|
||||
def _fetch_ends_dependencies(cls,
|
||||
end_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict]
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch end dependencies
|
||||
:param end_node_ids: end node ids
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:return:
|
||||
"""
|
||||
end_dependencies: dict[str, list[str]] = {}
|
||||
for end_node_id in end_node_ids:
|
||||
if end_dependencies.get(end_node_id) is None:
|
||||
end_dependencies[end_node_id] = []
|
||||
|
||||
cls._recursive_fetch_end_dependencies(
|
||||
current_node_id=end_node_id,
|
||||
end_node_id=end_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
)
|
||||
|
||||
return end_dependencies
|
||||
|
||||
@classmethod
|
||||
def _recursive_fetch_end_dependencies(cls,
|
||||
current_node_id: str,
|
||||
end_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]],
|
||||
# type: ignore[name-defined]
|
||||
end_dependencies: dict[str, list[str]]
|
||||
) -> None:
|
||||
"""
|
||||
Recursive fetch end dependencies
|
||||
:param current_node_id: current node id
|
||||
:param end_node_id: end node id
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param end_dependencies: end dependencies
|
||||
:return:
|
||||
"""
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
|
||||
if source_node_type in (
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
):
|
||||
end_dependencies[end_node_id].append(source_node_id)
|
||||
else:
|
||||
cls._recursive_fetch_end_dependencies(
|
||||
current_node_id=source_node_id,
|
||||
end_node_id=end_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
)
|
||||
191
api/core/workflow/nodes/end/end_stream_processor.py
Normal file
191
api/core/workflow/nodes/end/end_stream_processor.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndStreamProcessor(StreamProcessor):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
super().__init__(graph, variable_pool)
|
||||
self.end_stream_param = graph.end_stream_param
|
||||
self.route_position = {}
|
||||
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
|
||||
self.route_position[end_node_id] = 0
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
self.has_outputed = False
|
||||
self.outputed_node_ids = set()
|
||||
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
|
||||
self.reset()
|
||||
|
||||
yield event
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.in_iteration_id:
|
||||
if self.has_outputed and event.node_id not in self.outputed_node_ids:
|
||||
event.chunk_content = '\n' + event.chunk_content
|
||||
|
||||
self.outputed_node_ids.add(event.node_id)
|
||||
self.has_outputed = True
|
||||
yield event
|
||||
continue
|
||||
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
]
|
||||
else:
|
||||
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
|
||||
self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
] = stream_out_end_node_ids
|
||||
|
||||
if stream_out_end_node_ids:
|
||||
if self.has_outputed and event.node_id not in self.outputed_node_ids:
|
||||
event.chunk_content = '\n' + event.chunk_content
|
||||
|
||||
self.outputed_node_ids.add(event.node_id)
|
||||
self.has_outputed = True
|
||||
yield event
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
yield event
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
# update self.route_position after all stream event finished
|
||||
for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
|
||||
self.route_position[end_node_id] += 1
|
||||
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||
|
||||
# remove unreachable nodes
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
# generate stream outputs
|
||||
yield from self._generate_stream_outputs_when_node_finished(event)
|
||||
else:
|
||||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
self.route_position = {}
|
||||
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
|
||||
self.route_position[end_node_id] = 0
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:param event: node run succeeded event
|
||||
:return:
|
||||
"""
|
||||
for end_node_id, position in self.route_position.items():
|
||||
# all depends on end node id not in rest node ids
|
||||
if (event.route_node_state.node_id != end_node_id
|
||||
and (end_node_id not in self.rest_node_ids
|
||||
or not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.end_stream_param.end_dependencies[end_node_id]))):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[end_node_id]
|
||||
|
||||
position = 0
|
||||
value_selectors = []
|
||||
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
|
||||
if position >= route_position:
|
||||
value_selectors.append(current_value_selectors)
|
||||
|
||||
position += 1
|
||||
|
||||
for value_selector in value_selectors:
|
||||
if not value_selector:
|
||||
continue
|
||||
|
||||
value = self.variable_pool.get(
|
||||
value_selector
|
||||
)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
|
||||
text = value.markdown
|
||||
|
||||
if text:
|
||||
current_node_id = value_selector[0]
|
||||
if self.has_outputed and current_node_id not in self.outputed_node_ids:
|
||||
text = '\n' + text
|
||||
|
||||
self.outputed_node_ids.add(current_node_id)
|
||||
self.has_outputed = True
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=text,
|
||||
from_variable_selector=value_selector,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
|
||||
self.route_position[end_node_id] += 1
|
||||
|
||||
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.from_variable_selector:
|
||||
return []
|
||||
|
||||
stream_output_value_selector = event.from_variable_selector
|
||||
if not stream_output_value_selector:
|
||||
return []
|
||||
|
||||
stream_out_end_node_ids = []
|
||||
for end_node_id, route_position in self.route_position.items():
|
||||
if end_node_id not in self.rest_node_ids:
|
||||
continue
|
||||
|
||||
# all depends on end node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
|
||||
if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
|
||||
continue
|
||||
|
||||
position = 0
|
||||
value_selector = None
|
||||
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
|
||||
if position == route_position:
|
||||
value_selector = current_value_selectors
|
||||
break
|
||||
|
||||
position += 1
|
||||
|
||||
if not value_selector:
|
||||
continue
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
continue
|
||||
|
||||
stream_out_end_node_ids.append(end_node_id)
|
||||
|
||||
return stream_out_end_node_ids
|
||||
@@ -1,3 +1,5 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
@@ -7,3 +9,17 @@ class EndNodeData(BaseNodeData):
|
||||
END Node Data.
|
||||
"""
|
||||
outputs: list[VariableSelector]
|
||||
|
||||
|
||||
class EndStreamParam(BaseModel):
|
||||
"""
|
||||
EndStreamParam entity
|
||||
"""
|
||||
end_dependencies: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description="end dependencies (end node id -> dependent node ids)"
|
||||
)
|
||||
end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field(
|
||||
...,
|
||||
description="end stream variable selector mapping (end node id -> stream variable selectors)"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user