feat: knowledge pipeline (#25360)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: jyong <718720800@qq.com>
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: quicksand <quicksandzn@gmail.com>
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: Hanqing Zhao <sherry9277@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry <xh001x@hotmail.com>
This commit is contained in:
-LAN-
2025-09-18 12:49:10 +08:00
committed by GitHub
parent 7dadb33003
commit 85cda47c70
1772 changed files with 102407 additions and 31710 deletions

View File

@@ -1,4 +0,0 @@
from .answer_node import AnswerNode
from .entities import AnswerStreamGenerateRoute
__all__ = ["AnswerNode", "AnswerStreamGenerateRoute"]

View File

@@ -1,24 +1,19 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Any
from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.base import BaseNode
from core.variables import ArrayFileSegment, FileSegment, Segment
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.answer.entities import AnswerNodeData
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
class AnswerNode(BaseNode):
_node_type = NodeType.ANSWER
class AnswerNode(Node):
node_type = NodeType.ANSWER
execution_type = NodeExecutionType.RESPONSE
_node_data: AnswerNodeData
@@ -48,35 +43,29 @@ class AnswerNode(BaseNode):
return "1"
def _run(self) -> NodeRunResult:
"""
Run node
:return:
"""
# generate routes
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data)
answer = ""
files = []
for part in generate_routes:
if part.type == GenerateRouteChunk.ChunkType.VAR:
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
variable = self.graph_runtime_state.variable_pool.get(value_selector)
if variable:
if isinstance(variable, FileSegment):
files.append(variable.value)
elif isinstance(variable, ArrayFileSegment):
files.extend(variable.value)
answer += variable.markdown
else:
part = cast(TextGenerateRouteChunk, part)
answer += part.text
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer)
files = self._extract_files_from_segments(segments.value)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"answer": answer, "files": ArrayFileSegment(value=files)},
outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)},
)
def _extract_files_from_segments(self, segments: Sequence[Segment]):
"""Extract all files from segments containing FileSegment or ArrayFileSegment instances.
FileSegment contains a single file, while ArrayFileSegment contains multiple files.
This method flattens all files into a single list.
"""
files = []
for segment in segments:
if isinstance(segment, FileSegment):
# Single file - wrap in list for consistency
files.append(segment.value)
elif isinstance(segment, ArrayFileSegment):
# Multiple files - extend the list
files.extend(segment.value)
return files
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@@ -96,3 +85,12 @@ class AnswerNode(BaseNode):
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
return variable_mapping
def get_streaming_template(self) -> Template:
"""
Get the template for streaming.
Returns:
Template instance for this Answer node
"""
return Template.from_answer_template(self._node_data.answer)

View File

@@ -1,174 +0,0 @@
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
AnswerStreamGenerateRoute,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerStreamGeneratorRouter:
@classmethod
def init(
cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
) -> AnswerStreamGenerateRoute:
"""
Get stream generate routes.
:return:
"""
# parse stream output node value selectors of answer nodes
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
for answer_node_id, node_config in node_id_config_mapping.items():
if node_config.get("data", {}).get("type") != NodeType.ANSWER.value:
continue
# get generate route for stream output
generate_route = cls._extract_generate_route_selectors(node_config)
answer_generate_route[answer_node_id] = generate_route
# fetch answer dependencies
answer_node_ids = list(answer_generate_route.keys())
answer_dependencies = cls._fetch_answers_dependencies(
answer_node_ids=answer_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping,
)
return AnswerStreamGenerateRoute(
answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies
)
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω")
generate_routes: list[GenerateRouteChunk] = []
for part in template.split("Ω"):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "")
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector))
else:
generate_routes.append(TextGenerateRouteChunk(text=part))
return generate_routes
@classmethod
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = AnswerNodeData(**config.get("data", {}))
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace("{{", "").replace("}}", "")
return part.startswith("{{") and cleaned_part in variable_keys
@classmethod
def _fetch_answers_dependencies(
cls,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict],
) -> dict[str, list[str]]:
"""
Fetch answer dependencies
:param answer_node_ids: answer node ids
:param reverse_edge_mapping: reverse edge mapping
:param node_id_config_mapping: node id config mapping
:return:
"""
answer_dependencies: dict[str, list[str]] = {}
for answer_node_id in answer_node_ids:
if answer_dependencies.get(answer_node_id) is None:
answer_dependencies[answer_node_id] = []
cls._recursive_fetch_answer_dependencies(
current_node_id=answer_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies,
)
return answer_dependencies
@classmethod
def _recursive_fetch_answer_dependencies(
cls,
current_node_id: str,
answer_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]],
):
"""
Recursive fetch answer dependencies
:param current_node_id: current node id
:param answer_node_id: answer node id
:param node_id_config_mapping: node id config mapping
:param reverse_edge_mapping: reverse edge mapping
:param answer_dependencies: answer dependencies
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
if (
source_node_type
in {
NodeType.ANSWER,
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.LOOP,
NodeType.VARIABLE_ASSIGNER,
}
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
):
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._recursive_fetch_answer_dependencies(
current_node_id=source_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies,
)

View File

@@ -1,199 +0,0 @@
import logging
from collections.abc import Generator
from typing import cast
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunExceptionEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
logger = logging.getLogger(__name__)
class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool):
super().__init__(graph, variable_pool)
self.generate_routes = graph.answer_stream_generate_routes
self.route_position = {}
for answer_node_id in self.generate_routes.answer_generate_route:
self.route_position[answer_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
self.reset()
yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id or event.in_loop_id:
yield event
continue
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
]
else:
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
stream_out_answer_node_ids
)
for _ in stream_out_answer_node_ids:
yield event
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: # ty: ignore [unresolved-attribute]
# update self.route_position after all stream event finished
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: # ty: ignore [unresolved-attribute]
self.route_position[answer_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] # ty: ignore [unresolved-attribute]
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event))
else:
yield event
def reset(self):
self.route_position = {}
for answer_node_id, _ in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(
self, event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
:return:
"""
for answer_node_id in self.route_position:
# all depends on answer node id not in rest node ids
if event.route_node_state.node_id != answer_node_id and (
answer_node_id not in self.rest_node_ids
or not all(
dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
)
):
continue
route_position = self.route_position[answer_node_id]
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
for route_chunk in route_chunks:
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=route_chunk.text,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"],
node_version=event.node_version,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
break
value = self.variable_pool.get(value_selector)
if value is None:
break
text = value.markdown
if text:
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=list(value_selector),
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
)
self.route_position[answer_node_id] += 1
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.from_variable_selector:
return []
stream_output_value_selector = event.from_variable_selector
stream_out_answer_node_ids = []
for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:
continue
# Remove current node id from answer dependencies to support stream output if it is a success branch
answer_dependencies = self.generate_routes.answer_dependencies
edge_mapping = self.graph.edge_mapping.get(event.node_id)
success_edge = (
next(
(
edge
for edge in edge_mapping
if edge.run_condition
and edge.run_condition.type == "branch_identify"
and edge.run_condition.branch_identify == "success-branch"
),
None,
)
if edge_mapping
else None
)
if (
event.node_id in answer_dependencies[answer_node_id]
and success_edge
and success_edge.target_node_id == answer_node_id
):
answer_dependencies[answer_node_id].remove(event.node_id)
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
# all depends on answer node id not in rest node ids
if all(dep_id not in self.rest_node_ids for dep_id in answer_dependencies_ids):
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
continue
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
continue
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
continue
stream_out_answer_node_ids.append(answer_node_id)
return stream_out_answer_node_ids

View File

@@ -1,108 +0,0 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
logger = logging.getLogger(__name__)
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool):
self.graph = graph
self.variable_pool = variable_pool
self.rest_node_ids = graph.node_ids.copy()
@abstractmethod
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent):
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids: list[str] = []
unreachable_first_node_ids: list[str] = []
if finished_node_id not in self.graph.edge_mapping:
logger.warning("node %s has no edge mapping", finished_node_id)
return
for edge in self.graph.edge_mapping[finished_node_id]:
if (
edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify
):
# remove unreachable nodes
# FIXME: because of the code branch can combine directly, so for answer node
# we remove the node maybe shortcut the answer node, so comment this code for now
# there is not effect on the answer node and the workflow, when we have a better solution
# we can open this code. Issues: #11542 #9560 #10638 #10564
# ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
# if "answer" in ids:
# continue
# else:
# reachable_node_ids.extend(ids)
# The branch_identify parameter is added to ensure that
# only nodes in the correct logical branch are included.
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
reachable_node_ids.extend(ids)
else:
# if the condition edge in parallel, and the target node is not in parallel, we should not remove it
# Issues: #13626
if (
finished_node_id in self.graph.node_parallel_mapping
and edge.target_node_id not in self.graph.node_parallel_mapping
):
continue
unreachable_first_node_ids.append(edge.target_node_id)
unreachable_first_node_ids = list(set(unreachable_first_node_ids) - set(reachable_node_ids))
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: str | None = None) -> list[str]:
if node_id not in self.rest_node_ids:
self.rest_node_ids.append(node_id)
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
# Only follow edges that match the branch_identify or have no run_condition
if edge.run_condition and edge.run_condition.branch_identify:
if not branch_identify or edge.run_condition.branch_identify != branch_identify:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]):
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
if node_id in reachable_node_ids:
return
self.rest_node_ids.remove(node_id)
self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids))
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)