refactor(api): Decouple ParameterExtractorNode from LLMNode (#20843)

- Extract methods used by `ParameterExtractorNode` from `LLMNode` into a separate file.
- Convert `ParameterExtractorNode` into a subclass of `BaseNode`.
- Refactor code referencing the extracted methods to ensure functionality and clarity.
- Fixes the issue that `ParameterExtractorNode` returns error when executed.
- Fix relevant test cases.

Closes #20840.
This commit is contained in:
QuantumGhost
2025-06-10 11:47:50 +08:00
committed by GitHub
parent a97ff587d2
commit c439e82038
8 changed files with 226 additions and 171 deletions

View File

@@ -353,7 +353,7 @@ def test_extract_json_from_tool_call():
assert result["location"] == "kawaii"
def test_chat_parameter_extractor_with_memory(setup_model_mock):
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
"""
Test chat parameter extractor with memory.
"""
@@ -384,7 +384,8 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock):
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
node._fetch_memory = get_mocked_fetch_memory("customized memory")
# Test the mock before running the actual test
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
db.session.close = MagicMock()
result = node._run()

View File

@@ -25,6 +25,7 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
@@ -170,7 +171,7 @@ def model_config():
)
def test_fetch_files_with_file_segment(llm_node):
def test_fetch_files_with_file_segment():
file = File(
id="1",
tenant_id="test",
@@ -180,13 +181,14 @@ def test_fetch_files_with_file_segment(llm_node):
related_id="1",
storage_key="",
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
variable_pool = VariablePool()
variable_pool.add(["sys", "files"], file)
result = llm_node._fetch_files(selector=["sys", "files"])
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == [file]
def test_fetch_files_with_array_file_segment(llm_node):
def test_fetch_files_with_array_file_segment():
files = [
File(
id="1",
@@ -207,28 +209,32 @@ def test_fetch_files_with_array_file_segment(llm_node):
storage_key="",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
variable_pool = VariablePool()
variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_node._fetch_files(selector=["sys", "files"])
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == files
def test_fetch_files_with_none_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
def test_fetch_files_with_none_segment():
variable_pool = VariablePool()
variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"])
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == []
def test_fetch_files_with_array_any_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
def test_fetch_files_with_array_any_segment():
variable_pool = VariablePool()
variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_node._fetch_files(selector=["sys", "files"])
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == []
def test_fetch_files_with_non_existent_variable(llm_node):
result = llm_node._fetch_files(selector=["sys", "files"])
def test_fetch_files_with_non_existent_variable():
variable_pool = VariablePool()
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == []