Fix: surface workflow container LLM usage (#27021)
This commit is contained in:
@@ -72,6 +72,19 @@ default_retrieval_model: dict[str, Any] = {
|
||||
class DatasetRetrieval:
|
||||
def __init__(self, application_generate_entity=None):
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self._llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
return self._llm_usage.model_copy()
|
||||
|
||||
def _record_usage(self, usage: LLMUsage | None) -> None:
|
||||
if usage is None or usage.total_tokens <= 0:
|
||||
return
|
||||
if self._llm_usage.total_tokens == 0:
|
||||
self._llm_usage = usage
|
||||
else:
|
||||
self._llm_usage = self._llm_usage.plus(usage)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
@@ -312,15 +325,18 @@ class DatasetRetrieval:
|
||||
)
|
||||
tools.append(message_tool)
|
||||
dataset_id = None
|
||||
router_usage = LLMUsage.empty_usage()
|
||||
if planning_strategy == PlanningStrategy.REACT_ROUTER:
|
||||
react_multi_dataset_router = ReactMultiDatasetRouter()
|
||||
dataset_id = react_multi_dataset_router.invoke(
|
||||
dataset_id, router_usage = react_multi_dataset_router.invoke(
|
||||
query, tools, model_config, model_instance, user_id, tenant_id
|
||||
)
|
||||
|
||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||
function_call_router = FunctionCallMultiDatasetRouter()
|
||||
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
|
||||
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
|
||||
|
||||
self._record_usage(router_usage)
|
||||
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
@@ -983,7 +999,8 @@ class DatasetRetrieval:
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
result_text, _ = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
self._record_usage(usage)
|
||||
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
automatic_metadata_filters = []
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
|
||||
|
||||
|
||||
@@ -13,15 +13,15 @@ class FunctionCallMultiDatasetRouter:
|
||||
dataset_tools: list[PromptMessageTool],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
) -> Union[str, None]:
|
||||
) -> tuple[Union[str, None], LLMUsage]:
|
||||
"""Given input, decided what to do.
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(dataset_tools) == 0:
|
||||
return None
|
||||
return None, LLMUsage.empty_usage()
|
||||
elif len(dataset_tools) == 1:
|
||||
return dataset_tools[0].name
|
||||
return dataset_tools[0].name, LLMUsage.empty_usage()
|
||||
|
||||
try:
|
||||
prompt_messages = [
|
||||
@@ -34,9 +34,10 @@ class FunctionCallMultiDatasetRouter:
|
||||
stream=False,
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
)
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
if result.message.tool_calls:
|
||||
# get retrieval model config
|
||||
return result.message.tool_calls[0].function.name
|
||||
return None
|
||||
return result.message.tool_calls[0].function.name, usage
|
||||
return None, usage
|
||||
except Exception:
|
||||
return None
|
||||
return None, LLMUsage.empty_usage()
|
||||
|
||||
@@ -58,15 +58,15 @@ class ReactMultiDatasetRouter:
|
||||
model_instance: ModelInstance,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
) -> Union[str, None]:
|
||||
) -> tuple[Union[str, None], LLMUsage]:
|
||||
"""Given input, decided what to do.
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(dataset_tools) == 0:
|
||||
return None
|
||||
return None, LLMUsage.empty_usage()
|
||||
elif len(dataset_tools) == 1:
|
||||
return dataset_tools[0].name
|
||||
return dataset_tools[0].name, LLMUsage.empty_usage()
|
||||
|
||||
try:
|
||||
return self._react_invoke(
|
||||
@@ -78,7 +78,7 @@ class ReactMultiDatasetRouter:
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
return None, LLMUsage.empty_usage()
|
||||
|
||||
def _react_invoke(
|
||||
self,
|
||||
@@ -91,7 +91,7 @@ class ReactMultiDatasetRouter:
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
) -> Union[str, None]:
|
||||
) -> tuple[Union[str, None], LLMUsage]:
|
||||
prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
|
||||
if model_config.mode == "chat":
|
||||
prompt = self.create_chat_prompt(
|
||||
@@ -120,7 +120,7 @@ class ReactMultiDatasetRouter:
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
result_text, _ = self._invoke_llm(
|
||||
result_text, usage = self._invoke_llm(
|
||||
completion_param=model_config.parameters,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
@@ -131,8 +131,8 @@ class ReactMultiDatasetRouter:
|
||||
output_parser = StructuredChatOutputParser()
|
||||
react_decision = output_parser.parse(result_text)
|
||||
if isinstance(react_decision, ReactAction):
|
||||
return react_decision.tool
|
||||
return None
|
||||
return react_decision.tool, usage
|
||||
return None, usage
|
||||
|
||||
def _invoke_llm(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user