feat: add ops trace (#5483)
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
@@ -15,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool.tool import Tool
|
||||
@@ -42,6 +43,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
self._init_react_state(query)
|
||||
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
# check model mode
|
||||
if 'Observation' not in app_generate_entity.model_conf.stop:
|
||||
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
|
||||
@@ -211,7 +214,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
|
||||
action=scratchpad.action,
|
||||
tool_instances=tool_instances,
|
||||
message_file_ids=message_file_ids
|
||||
message_file_ids=message_file_ids,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
scratchpad.observation = tool_invoke_response
|
||||
scratchpad.agent_response = tool_invoke_response
|
||||
@@ -237,8 +241,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in self._prompt_messages_tools:
|
||||
self.update_prompt_message_tool(
|
||||
tool_instances[prompt_tool.name], prompt_tool)
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
@@ -275,14 +278,15 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
message_file_ids: list[str]) -> tuple[str, ToolInvokeMeta]:
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
"""
|
||||
handle invoke action
|
||||
:param action: action
|
||||
@@ -312,7 +316,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# publish files
|
||||
|
||||
@@ -50,6 +50,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
@@ -243,6 +246,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
|
||||
@@ -183,6 +183,14 @@ class TextToSpeechEntity(BaseModel):
|
||||
language: Optional[str] = None
|
||||
|
||||
|
||||
class TracingConfigEntity(BaseModel):
|
||||
"""
|
||||
Tracing Config Entity.
|
||||
"""
|
||||
enabled: bool
|
||||
tracing_provider: str
|
||||
|
||||
|
||||
class FileExtraConfig(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
@@ -199,7 +207,7 @@ class AppAdditionalFeatures(BaseModel):
|
||||
more_like_this: bool = False
|
||||
speech_to_text: bool = False
|
||||
text_to_speech: Optional[TextToSpeechEntity] = None
|
||||
|
||||
trace_config: Optional[TracingConfigEntity] = None
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""
|
||||
|
||||
@@ -20,6 +20,7 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity,
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
@@ -29,13 +30,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -84,6 +86,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow=workflow
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_id=app_model.id)
|
||||
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True
|
||||
@@ -99,7 +104,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
|
||||
@@ -70,7 +70,8 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
app_record=app_record,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
query=query,
|
||||
message_id=message.id
|
||||
):
|
||||
return
|
||||
|
||||
@@ -156,11 +157,14 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: dict,
|
||||
query: str) -> bool:
|
||||
def handle_input_moderation(
|
||||
self, queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
message_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Handle input moderation
|
||||
:param queue_manager: application queue manager
|
||||
@@ -168,6 +172,7 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
:param app_generate_entity: application generate entity
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
@@ -178,6 +183,7 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
app_generate_entity=app_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self._stream_output(
|
||||
|
||||
@@ -42,6 +42,7 @@ from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
@@ -69,13 +70,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool) -> None:
|
||||
def __init__(
|
||||
self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -126,14 +129,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._process_stream_response()
|
||||
generator = self._process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> ChatbotAppBlockingResponse:
|
||||
-> ChatbotAppBlockingResponse:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
@@ -164,7 +169,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
-> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
@@ -177,7 +182,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _process_stream_response(self) -> Generator[StreamResponse, None, None]:
|
||||
def _process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
@@ -249,7 +256,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
|
||||
self._handle_iteration_operation(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
workflow_run = self._handle_workflow_finished(event)
|
||||
workflow_run = self._handle_workflow_finished(
|
||||
event, conversation_id=self._conversation.id, trace_manager=trace_manager
|
||||
)
|
||||
if workflow_run:
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -292,7 +301,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
continue
|
||||
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
|
||||
@@ -361,7 +370,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
id=self._message.id,
|
||||
**extras
|
||||
)
|
||||
|
||||
|
||||
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
@@ -391,9 +400,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
|
||||
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
-> list[str]:
|
||||
"""
|
||||
Get answer start at node id.
|
||||
:param graph: graph
|
||||
@@ -414,14 +423,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
|
||||
if not target_node:
|
||||
return []
|
||||
|
||||
|
||||
node_iteration_id = target_node.get('data', {}).get('iteration_id')
|
||||
# get iteration start node id
|
||||
for node in nodes:
|
||||
if node.get('id') == node_iteration_id:
|
||||
if node.get('data', {}).get('start_node_id') == target_node_id:
|
||||
return [target_node_id]
|
||||
|
||||
|
||||
return []
|
||||
|
||||
start_node_ids = []
|
||||
@@ -457,7 +466,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
start_node_ids.extend(sub_start_node_ids)
|
||||
|
||||
return start_node_ids
|
||||
|
||||
|
||||
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get iteration nested relations.
|
||||
@@ -466,18 +475,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
if node.get('data', {}).get('type') in [
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value,
|
||||
]]
|
||||
]]
|
||||
|
||||
return {
|
||||
iteration_id: [
|
||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
||||
|
||||
def _generate_stream_outputs_when_node_started(self) -> Generator:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
@@ -485,8 +494,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
"""
|
||||
if self._task_state.current_stream_generate_state:
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:
|
||||
]
|
||||
self._task_state.current_stream_generate_state.current_route_position:
|
||||
]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
@@ -506,7 +515,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route):
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:
|
||||
@@ -519,7 +529,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:]
|
||||
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
@@ -551,7 +561,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
value = iteration_state.current_index
|
||||
elif value_selector[1] == 'item':
|
||||
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
|
||||
iterator_selector) else None
|
||||
iterator_selector
|
||||
) else None
|
||||
else:
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
|
||||
@@ -562,14 +573,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
# get route chunk node execution info
|
||||
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
|
||||
if (route_chunk_node_execution_info.node_type == NodeType.LLM
|
||||
and latest_node_execution_info.node_type == NodeType.LLM):
|
||||
and latest_node_execution_info.node_type == NodeType.LLM):
|
||||
# only LLM support chunk stream output
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
|
||||
# get route chunk node execution
|
||||
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id).first()
|
||||
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id
|
||||
).first()
|
||||
|
||||
outputs = route_chunk_node_execution.outputs_dict
|
||||
|
||||
@@ -631,7 +643,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route):
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
|
||||
|
||||
@@ -19,6 +19,7 @@ from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueMa
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
@@ -108,6 +109,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
override_config_dict=override_model_config_dict
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_model.id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AgentChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
@@ -121,7 +125,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
call_depth=0
|
||||
call_depth=0,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
# init generate records
|
||||
@@ -158,7 +163,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return AgentChatAppGenerateResponseConverter.convert(
|
||||
@@ -166,11 +171,13 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str) -> None:
|
||||
def _generate_worker(
|
||||
self, flask_app: Flask,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
@@ -192,7 +199,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
|
||||
@@ -28,10 +28,13 @@ class AgentChatAppRunner(AppRunner):
|
||||
"""
|
||||
Agent Application Runner
|
||||
"""
|
||||
def run(self, application_generate_entity: AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
|
||||
def run(
|
||||
self, application_generate_entity: AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""
|
||||
Run assistant application
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -100,6 +103,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
@@ -219,7 +223,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
runner_cls = FunctionCallAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
|
||||
|
||||
|
||||
runner = runner_cls(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
|
||||
@@ -338,11 +338,14 @@ class AppRunner:
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def moderation_for_inputs(self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: dict,
|
||||
query: str) -> tuple[bool, dict, str]:
|
||||
def moderation_for_inputs(
|
||||
self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
@@ -350,6 +353,7 @@ class AppRunner:
|
||||
:param app_generate_entity: app generate entity
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
moderation_feature = InputModeration()
|
||||
@@ -358,7 +362,9 @@ class AppRunner:
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_generate_entity.app_config,
|
||||
inputs=inputs,
|
||||
query=query if query else ''
|
||||
query=query if query else '',
|
||||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager
|
||||
)
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
|
||||
@@ -19,6 +19,7 @@ from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueMa
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
@@ -27,12 +28,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -105,6 +107,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
override_config_dict=override_model_config_dict
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_model.id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = ChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
@@ -117,7 +122,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
# init generate records
|
||||
@@ -154,7 +160,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return ChatAppGenerateResponseConverter.convert(
|
||||
|
||||
@@ -96,6 +96,7 @@ class ChatAppRunner(AppRunner):
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
@@ -154,7 +155,7 @@ class ChatAppRunner(AppRunner):
|
||||
application_generate_entity.invoke_from
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
||||
context = dataset_retrieval.retrieve(
|
||||
app_id=app_record.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
@@ -165,7 +166,8 @@ class ChatAppRunner(AppRunner):
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
hit_callback=hit_callback,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
|
||||
@@ -19,6 +19,7 @@ from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueMa
|
||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser, Message
|
||||
@@ -94,6 +95,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
override_config_dict=override_model_config_dict
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_model.id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = CompletionAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
@@ -105,7 +109,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
# init generate records
|
||||
@@ -141,7 +146,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(
|
||||
@@ -158,7 +163,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param conversation_id: conversation ID
|
||||
:param message_id: message ID
|
||||
:return:
|
||||
"""
|
||||
@@ -300,7 +304,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(
|
||||
|
||||
@@ -77,6 +77,7 @@ class CompletionAppRunner(AppRunner):
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
@@ -114,7 +115,7 @@ class CompletionAppRunner(AppRunner):
|
||||
if dataset_config and dataset_config.retrieve_config.query_variable:
|
||||
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
||||
context = dataset_retrieval.retrieve(
|
||||
app_id=app_record.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
@@ -124,7 +125,8 @@ class CompletionAppRunner(AppRunner):
|
||||
query=query,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
hit_callback=hit_callback
|
||||
hit_callback=hit_callback,
|
||||
message_id=message.id
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
|
||||
@@ -35,22 +35,23 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
def _handle_response(self, application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False) \
|
||||
-> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
|
||||
]:
|
||||
def _handle_response(
|
||||
self, application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
|
||||
]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@@ -20,6 +20,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
@@ -29,14 +30,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
def generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0,
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -46,6 +48,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param args: request args
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param call_depth: call depth
|
||||
"""
|
||||
inputs = args['inputs']
|
||||
|
||||
@@ -68,6 +71,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow=workflow
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_model.id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
@@ -77,7 +83,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
@@ -87,17 +94,18 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
stream=stream,
|
||||
call_depth=call_depth
|
||||
call_depth=call_depth,
|
||||
)
|
||||
|
||||
def _generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
def _generate(
|
||||
self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -131,7 +139,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
@@ -36,6 +36,7 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from extensions.ext_database import db
|
||||
@@ -104,7 +105,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
db.session.refresh(self._user)
|
||||
db.session.close()
|
||||
|
||||
generator = self._process_stream_response()
|
||||
generator = self._process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
@@ -158,7 +161,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _process_stream_response(self) -> Generator[StreamResponse, None, None]:
|
||||
def _process_stream_response(
|
||||
self,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
@@ -215,7 +221,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
|
||||
self._handle_iteration_operation(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
workflow_run = self._handle_workflow_finished(event)
|
||||
workflow_run = self._handle_workflow_finished(
|
||||
event, trace_manager=trace_manager
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
@@ -7,6 +7,7 @@ from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, Workfl
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
class InvokeFrom(Enum):
|
||||
@@ -89,6 +90,12 @@ class AppGenerateEntity(BaseModel):
|
||||
# extra parameters, like: auto_generate_conversation_name
|
||||
extras: dict[str, Any] = {}
|
||||
|
||||
# tracing instance
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
||||
@@ -44,6 +44,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
@@ -100,7 +101,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(self) -> Union[
|
||||
def process(
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
|
||||
@@ -120,7 +123,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._process_stream_response()
|
||||
generator = self._process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
@@ -197,7 +202,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _process_stream_response(self) -> Generator[StreamResponse, None, None]:
|
||||
def _process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
@@ -224,7 +231,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
|
||||
# Save message
|
||||
self._save_message()
|
||||
self._save_message(trace_manager)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
@@ -269,7 +276,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self) -> None:
|
||||
def _save_message(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
@@ -300,6 +309,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE,
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id
|
||||
)
|
||||
)
|
||||
|
||||
message_was_created.send(
|
||||
self._message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
|
||||
@@ -22,6 +22,7 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
@@ -94,11 +95,15 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _workflow_run_success(self, workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Optional[str] = None) -> WorkflowRun:
|
||||
def _workflow_run_success(
|
||||
self, workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run success
|
||||
:param workflow_run: workflow run
|
||||
@@ -106,6 +111,7 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param outputs: outputs
|
||||
:param conversation_id: conversation id
|
||||
:return:
|
||||
"""
|
||||
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
||||
@@ -119,14 +125,27 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_run=workflow_run,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _workflow_run_failed(self, workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
status: WorkflowRunStatus,
|
||||
error: str) -> WorkflowRun:
|
||||
def _workflow_run_failed(
|
||||
self, workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
status: WorkflowRunStatus,
|
||||
error: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run failed
|
||||
:param workflow_run: workflow run
|
||||
@@ -148,6 +167,14 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_run=workflow_run,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun,
|
||||
@@ -180,7 +207,8 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
title=node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=workflow_run.created_by_role,
|
||||
created_by=workflow_run.created_by
|
||||
created_by=workflow_run.created_by,
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
@@ -440,9 +468,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
|
||||
|
||||
|
||||
execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
|
||||
|
||||
|
||||
if self._iteration_state and self._iteration_state.current_iterations:
|
||||
if not execution_metadata:
|
||||
execution_metadata = {}
|
||||
@@ -470,7 +498,7 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
self._task_state.total_tokens += (
|
||||
int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
|
||||
|
||||
|
||||
if self._iteration_state:
|
||||
for iteration_node_id in self._iteration_state.current_iterations:
|
||||
data = self._iteration_state.current_iterations[iteration_node_id]
|
||||
@@ -496,13 +524,18 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \
|
||||
-> Optional[WorkflowRun]:
|
||||
def _handle_workflow_finished(
|
||||
self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Optional[WorkflowRun]:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
if not workflow_run:
|
||||
return None
|
||||
|
||||
if conversation_id is None:
|
||||
conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id')
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
@@ -510,7 +543,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
error='Workflow stopped.'
|
||||
error='Workflow stopped.',
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
latest_node_execution_info = self._task_state.latest_node_execution_info
|
||||
@@ -531,7 +566,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
error=event.error
|
||||
error=event.error,
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
else:
|
||||
if self._task_state.latest_node_execution_info:
|
||||
@@ -546,7 +583,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
start_at=self._task_state.start_at,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
outputs=outputs
|
||||
outputs=outputs,
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
self._task_state.workflow_run_id = workflow_run.id
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
@@ -131,7 +132,8 @@ class WorkflowIterationCycleManage(WorkflowCycleStateManager):
|
||||
'started_run_index': node_run_index + 1,
|
||||
'current_index': 0,
|
||||
'steps_boundary': [],
|
||||
})
|
||||
}),
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Any, Optional, TextIO, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
@@ -51,6 +53,9 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
tool_name: str,
|
||||
tool_inputs: dict[str, Any],
|
||||
tool_outputs: str,
|
||||
message_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
print_text("\n[on_tool_end]\n", color=self.color)
|
||||
@@ -59,6 +64,18 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
print_text("Outputs: " + str(tool_outputs)[:1000] + "\n", color=self.color)
|
||||
print_text("\n")
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.TOOL_TRACE,
|
||||
message_id=message_id,
|
||||
tool_name=tool_name,
|
||||
tool_inputs=tool_inputs,
|
||||
tool_outputs=tool_outputs,
|
||||
timer=timer,
|
||||
)
|
||||
)
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserException
|
||||
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
@@ -9,12 +11,16 @@ from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
|
||||
|
||||
class LLMGenerator:
|
||||
@classmethod
|
||||
def generate_conversation_name(cls, tenant_id: str, query):
|
||||
def generate_conversation_name(
|
||||
cls, tenant_id: str, query, conversation_id: Optional[str] = None, app_id: Optional[str] = None
|
||||
):
|
||||
prompt = CONVERSATION_TITLE_PROMPT
|
||||
|
||||
if len(query) > 2000:
|
||||
@@ -29,25 +35,39 @@ class LLMGenerator:
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
prompts = [UserPromptMessage(content=prompt)]
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompts,
|
||||
model_parameters={
|
||||
"max_tokens": 100,
|
||||
"temperature": 1
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
answer = response.message.content
|
||||
|
||||
result_dict = json.loads(answer)
|
||||
with measure_time() as timer:
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=prompts,
|
||||
model_parameters={
|
||||
"max_tokens": 100,
|
||||
"temperature": 1
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
answer = response.message.content
|
||||
cleaned_answer = re.sub(r'^.*(\{.*\}).*$', r'\1', answer, flags=re.DOTALL)
|
||||
result_dict = json.loads(cleaned_answer)
|
||||
answer = result_dict['Your Output']
|
||||
name = answer.strip()
|
||||
|
||||
if len(name) > 75:
|
||||
name = name[:75] + '...'
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_id=app_id)
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.GENERATE_NAME_TRACE,
|
||||
conversation_id=conversation_id,
|
||||
generate_conversation_name=name,
|
||||
inputs=prompt,
|
||||
timer=timer,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,4 +1,20 @@
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from dashscope.common.error import (
|
||||
AuthenticationError,
|
||||
InvalidParameter,
|
||||
RequestFailure,
|
||||
ServiceUnavailableError,
|
||||
UnsupportedHTTPMethod,
|
||||
UnsupportedModel,
|
||||
)
|
||||
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
|
||||
|
||||
class _CommonTongyi:
|
||||
@@ -20,4 +36,20 @@ class _CommonTongyi:
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
pass
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
RequestFailure,
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
ServiceUnavailableError,
|
||||
],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [
|
||||
AuthenticationError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvalidParameter,
|
||||
UnsupportedModel,
|
||||
UnsupportedHTTPMethod,
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1,18 +1,25 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import AppConfig
|
||||
from core.moderation.base import ModerationAction, ModerationException
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
|
||||
from core.ops.utils import measure_time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InputModeration:
|
||||
def check(self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_config: AppConfig,
|
||||
inputs: dict,
|
||||
query: str) -> tuple[bool, dict, str]:
|
||||
def check(
|
||||
self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_config: AppConfig,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
message_id: str,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
@@ -20,6 +27,8 @@ class InputModeration:
|
||||
:param app_config: app config
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:param message_id: message id
|
||||
:param trace_manager: trace manager
|
||||
:return:
|
||||
"""
|
||||
if not app_config.sensitive_word_avoidance:
|
||||
@@ -35,8 +44,20 @@ class InputModeration:
|
||||
config=sensitive_word_avoidance_config.config
|
||||
)
|
||||
|
||||
moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
|
||||
with measure_time() as timer:
|
||||
moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MODERATION_TRACE,
|
||||
message_id=message_id,
|
||||
moderation_result=moderation_result,
|
||||
inputs=inputs,
|
||||
timer=timer
|
||||
)
|
||||
)
|
||||
|
||||
if not moderation_result.flagged:
|
||||
return False, inputs, query
|
||||
|
||||
|
||||
0
api/core/ops/__init__.py
Normal file
0
api/core/ops/__init__.py
Normal file
26
api/core/ops/base_trace_instance.py
Normal file
26
api/core/ops/base_trace_instance.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.entities.trace_entity import BaseTraceInfo
|
||||
|
||||
|
||||
class BaseTraceInstance(ABC):
|
||||
"""
|
||||
Base trace instance for ops trace services
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, trace_config: BaseTracingConfig):
|
||||
"""
|
||||
Abstract initializer for the trace instance.
|
||||
Distribute trace tasks by matching entities
|
||||
"""
|
||||
self.trace_config = trace_config
|
||||
|
||||
@abstractmethod
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
"""
|
||||
Abstract method to trace activities.
|
||||
Subclasses must implement specific tracing logic for activities.
|
||||
"""
|
||||
...
|
||||
0
api/core/ops/entities/__init__.py
Normal file
0
api/core/ops/entities/__init__.py
Normal file
51
api/core/ops/entities/config_entity.py
Normal file
51
api/core/ops/entities/config_entity.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
|
||||
|
||||
class TracingProviderEnum(Enum):
|
||||
LANGFUSE = 'langfuse'
|
||||
LANGSMITH = 'langsmith'
|
||||
|
||||
|
||||
class BaseTracingConfig(BaseModel):
|
||||
"""
|
||||
Base model class for tracing
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class LangfuseConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langfuse tracing config.
|
||||
"""
|
||||
public_key: str
|
||||
secret_key: str
|
||||
host: str = 'https://api.langfuse.com'
|
||||
|
||||
@field_validator("host")
|
||||
def set_value(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = 'https://api.langfuse.com'
|
||||
if not v.startswith('https://'):
|
||||
raise ValueError('host must start with https://')
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class LangSmithConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langsmith tracing config.
|
||||
"""
|
||||
api_key: str
|
||||
project: str
|
||||
endpoint: str = 'https://api.smith.langchain.com'
|
||||
|
||||
@field_validator("endpoint")
|
||||
def set_value(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = 'https://api.smith.langchain.com'
|
||||
if not v.startswith('https://'):
|
||||
raise ValueError('endpoint must start with https://')
|
||||
|
||||
return v
|
||||
98
api/core/ops/entities/trace_entity.py
Normal file
98
api/core/ops/entities/trace_entity.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
|
||||
class BaseTraceInfo(BaseModel):
|
||||
message_id: Optional[str] = None
|
||||
message_data: Optional[Any] = None
|
||||
inputs: Optional[Union[str, dict[str, Any], list]] = None
|
||||
outputs: Optional[Union[str, dict[str, Any], list]] = None
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
metadata: dict[str, Any]
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
def ensure_type(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str | dict | list):
|
||||
return v
|
||||
else:
|
||||
return ""
|
||||
|
||||
class WorkflowTraceInfo(BaseTraceInfo):
|
||||
workflow_data: Any
|
||||
conversation_id: Optional[str] = None
|
||||
workflow_app_log_id: Optional[str] = None
|
||||
workflow_id: str
|
||||
tenant_id: str
|
||||
workflow_run_id: str
|
||||
workflow_run_elapsed_time: Union[int, float]
|
||||
workflow_run_status: str
|
||||
workflow_run_inputs: dict[str, Any]
|
||||
workflow_run_outputs: dict[str, Any]
|
||||
workflow_run_version: str
|
||||
error: Optional[str] = None
|
||||
total_tokens: int
|
||||
file_list: list[str]
|
||||
query: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class MessageTraceInfo(BaseTraceInfo):
|
||||
conversation_model: str
|
||||
message_tokens: int
|
||||
answer_tokens: int
|
||||
total_tokens: int
|
||||
error: Optional[str] = None
|
||||
file_list: Optional[Union[str, dict[str, Any], list]] = None
|
||||
message_file_data: Optional[Any] = None
|
||||
conversation_mode: str
|
||||
|
||||
|
||||
class ModerationTraceInfo(BaseTraceInfo):
|
||||
flagged: bool
|
||||
action: str
|
||||
preset_response: str
|
||||
query: str
|
||||
|
||||
|
||||
class SuggestedQuestionTraceInfo(BaseTraceInfo):
|
||||
total_tokens: int
|
||||
status: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
from_account_id: Optional[str] = None
|
||||
agent_based: Optional[bool] = None
|
||||
from_source: Optional[str] = None
|
||||
model_provider: Optional[str] = None
|
||||
model_id: Optional[str] = None
|
||||
suggested_question: list[str]
|
||||
level: str
|
||||
status_message: Optional[str] = None
|
||||
workflow_run_id: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class DatasetRetrievalTraceInfo(BaseTraceInfo):
|
||||
documents: Any
|
||||
|
||||
|
||||
class ToolTraceInfo(BaseTraceInfo):
|
||||
tool_name: str
|
||||
tool_inputs: dict[str, Any]
|
||||
tool_outputs: str
|
||||
metadata: dict[str, Any]
|
||||
message_file_data: Any
|
||||
error: Optional[str] = None
|
||||
tool_config: dict[str, Any]
|
||||
time_cost: Union[int, float]
|
||||
tool_parameters: dict[str, Any]
|
||||
file_url: Union[str, None, list]
|
||||
|
||||
|
||||
class GenerateNameTraceInfo(BaseTraceInfo):
|
||||
conversation_id: str
|
||||
tenant_id: str
|
||||
0
api/core/ops/langfuse_trace/__init__.py
Normal file
0
api/core/ops/langfuse_trace/__init__.py
Normal file
0
api/core/ops/langfuse_trace/entities/__init__.py
Normal file
0
api/core/ops/langfuse_trace/entities/__init__.py
Normal file
280
api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py
Normal file
280
api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py
Normal file
@@ -0,0 +1,280 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.ops.utils import replace_text_with_content
|
||||
|
||||
|
||||
def validate_input_output(v, field_name):
|
||||
"""
|
||||
Validate input output
|
||||
:param v:
|
||||
:param field_name:
|
||||
:return:
|
||||
"""
|
||||
if v == {} or v is None:
|
||||
return v
|
||||
if isinstance(v, str):
|
||||
return [
|
||||
{
|
||||
"role": "assistant" if field_name == "output" else "user",
|
||||
"content": v,
|
||||
}
|
||||
]
|
||||
elif isinstance(v, list):
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
v = replace_text_with_content(data=v)
|
||||
return v
|
||||
else:
|
||||
return [
|
||||
{
|
||||
"role": "assistant" if field_name == "output" else "user",
|
||||
"content": str(v),
|
||||
}
|
||||
]
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class LevelEnum(str, Enum):
|
||||
DEBUG = "DEBUG"
|
||||
WARNING = "WARNING"
|
||||
ERROR = "ERROR"
|
||||
DEFAULT = "DEFAULT"
|
||||
|
||||
|
||||
class LangfuseTrace(BaseModel):
|
||||
"""
|
||||
Langfuse trace model
|
||||
"""
|
||||
id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The id of the trace can be set, defaults to a random id. Used to link traces to external systems "
|
||||
"or when creating a distributed trace. Traces are upserted on id.",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Identifier of the trace. Useful for sorting/filtering in the UI.",
|
||||
)
|
||||
input: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
||||
default=None, description="The input of the trace. Can be any JSON object."
|
||||
)
|
||||
output: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
||||
default=None, description="The output of the trace. Can be any JSON object."
|
||||
)
|
||||
metadata: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Additional metadata of the trace. Can be any JSON object. Metadata is merged when being updated "
|
||||
"via the API.",
|
||||
)
|
||||
user_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The id of the user that triggered the execution. Used to provide user-level analytics.",
|
||||
)
|
||||
session_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Used to group multiple traces into a session in Langfuse. Use your own session/thread identifier.",
|
||||
)
|
||||
version: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The version of the trace type. Used to understand how changes to the trace type affect metrics. "
|
||||
"Useful in debugging.",
|
||||
)
|
||||
release: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The release identifier of the current deployment. Used to understand how changes of different "
|
||||
"deployments affect metrics. Useful in debugging.",
|
||||
)
|
||||
tags: Optional[list[str]] = Field(
|
||||
default=None,
|
||||
description="Tags are used to categorize or label traces. Traces can be filtered by tags in the UI and GET "
|
||||
"API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.",
|
||||
)
|
||||
public: Optional[bool] = Field(
|
||||
default=None,
|
||||
description="You can make a trace public to share it via a public link. This allows others to view the trace "
|
||||
"without needing to log in or be members of your Langfuse project.",
|
||||
)
|
||||
|
||||
@field_validator("input", "output")
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
|
||||
|
||||
class LangfuseSpan(BaseModel):
|
||||
"""
|
||||
Langfuse span model
|
||||
"""
|
||||
id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The id of the span can be set, otherwise a random id is generated. Spans are upserted on id.",
|
||||
)
|
||||
session_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Used to group multiple spans into a session in Langfuse. Use your own session/thread identifier.",
|
||||
)
|
||||
trace_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The id of the trace the span belongs to. Used to link spans to traces.",
|
||||
)
|
||||
user_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The id of the user that triggered the execution. Used to provide user-level analytics.",
|
||||
)
|
||||
start_time: Optional[datetime | str] = Field(
|
||||
default_factory=datetime.now,
|
||||
description="The time at which the span started, defaults to the current time.",
|
||||
)
|
||||
end_time: Optional[datetime | str] = Field(
|
||||
default=None,
|
||||
description="The time at which the span ended. Automatically set by span.end().",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Identifier of the span. Useful for sorting/filtering in the UI.",
|
||||
)
|
||||
metadata: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated "
|
||||
"via the API.",
|
||||
)
|
||||
level: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of "
|
||||
"traces with elevated error levels and for highlighting in the UI.",
|
||||
)
|
||||
status_message: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The status message of the span. Additional field for context of the event. E.g. the error "
|
||||
"message of an error event.",
|
||||
)
|
||||
input: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
||||
default=None, description="The input of the span. Can be any JSON object."
|
||||
)
|
||||
output: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
||||
default=None, description="The output of the span. Can be any JSON object."
|
||||
)
|
||||
version: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The version of the span type. Used to understand how changes to the span type affect metrics. "
|
||||
"Useful in debugging.",
|
||||
)
|
||||
parent_observation_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The id of the observation the span belongs to. Used to link spans to observations.",
|
||||
)
|
||||
|
||||
@field_validator("input", "output")
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
|
||||
|
||||
class UnitEnum(str, Enum):
|
||||
CHARACTERS = "CHARACTERS"
|
||||
TOKENS = "TOKENS"
|
||||
SECONDS = "SECONDS"
|
||||
MILLISECONDS = "MILLISECONDS"
|
||||
IMAGES = "IMAGES"
|
||||
|
||||
|
||||
class GenerationUsage(BaseModel):
|
||||
promptTokens: Optional[int] = None
|
||||
completionTokens: Optional[int] = None
|
||||
totalTokens: Optional[int] = None
|
||||
input: Optional[int] = None
|
||||
output: Optional[int] = None
|
||||
total: Optional[int] = None
|
||||
unit: Optional[UnitEnum] = None
|
||||
inputCost: Optional[float] = None
|
||||
outputCost: Optional[float] = None
|
||||
totalCost: Optional[float] = None
|
||||
|
||||
@field_validator("input", "output")
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
|
||||
|
||||
class LangfuseGeneration(BaseModel):
|
||||
id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The id of the generation can be set, defaults to random id.",
|
||||
)
|
||||
trace_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The id of the trace the generation belongs to. Used to link generations to traces.",
|
||||
)
|
||||
parent_observation_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The id of the observation the generation belongs to. Used to link generations to observations.",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Identifier of the generation. Useful for sorting/filtering in the UI.",
|
||||
)
|
||||
start_time: Optional[datetime | str] = Field(
|
||||
default_factory=datetime.now,
|
||||
description="The time at which the generation started, defaults to the current time.",
|
||||
)
|
||||
completion_start_time: Optional[datetime | str] = Field(
|
||||
default=None,
|
||||
description="The time at which the completion started (streaming). Set it to get latency analytics broken "
|
||||
"down into time until completion started and completion duration.",
|
||||
)
|
||||
end_time: Optional[datetime | str] = Field(
|
||||
default=None,
|
||||
description="The time at which the generation ended. Automatically set by generation.end().",
|
||||
)
|
||||
model: Optional[str] = Field(
|
||||
default=None, description="The name of the model used for the generation."
|
||||
)
|
||||
model_parameters: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="The parameters of the model used for the generation; can be any key-value pairs.",
|
||||
)
|
||||
input: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="The prompt used for the generation. Can be any string or JSON object.",
|
||||
)
|
||||
output: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="The completion generated by the model. Can be any string or JSON object.",
|
||||
)
|
||||
usage: Optional[GenerationUsage] = Field(
|
||||
default=None,
|
||||
description="The usage object supports the OpenAi structure with tokens and a more generic version with "
|
||||
"detailed costs and units.",
|
||||
)
|
||||
metadata: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Additional metadata of the generation. Can be any JSON object. Metadata is merged when being "
|
||||
"updated via the API.",
|
||||
)
|
||||
level: Optional[LevelEnum] = Field(
|
||||
default=None,
|
||||
description="The level of the generation. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering "
|
||||
"of traces with elevated error levels and for highlighting in the UI.",
|
||||
)
|
||||
status_message: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The status message of the generation. Additional field for context of the event. E.g. the error "
|
||||
"message of an error event.",
|
||||
)
|
||||
version: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The version of the generation type. Used to understand how changes to the span type affect "
|
||||
"metrics. Useful in debugging.",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("input", "output")
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
|
||||
392
api/core/ops/langfuse_trace/langfuse_trace.py
Normal file
392
api/core/ops/langfuse_trace/langfuse_trace.py
Normal file
@@ -0,0 +1,392 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from langfuse import Langfuse
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
GenerationUsage,
|
||||
LangfuseGeneration,
|
||||
LangfuseSpan,
|
||||
LangfuseTrace,
|
||||
LevelEnum,
|
||||
UnitEnum,
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LangFuseDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
langfuse_config: LangfuseConfig,
|
||||
):
|
||||
super().__init__(langfuse_config)
|
||||
self.langfuse_client = Langfuse(
|
||||
public_key=langfuse_config.public_key,
|
||||
secret_key=langfuse_config.secret_key,
|
||||
host=langfuse_config.host,
|
||||
)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id
|
||||
if trace_info.message_id:
|
||||
trace_id = trace_info.message_id
|
||||
name = f"message_{trace_info.message_id}"
|
||||
trace_data = LangfuseTrace(
|
||||
id=trace_info.message_id,
|
||||
user_id=trace_info.tenant_id,
|
||||
name=name,
|
||||
input=trace_info.workflow_run_inputs,
|
||||
output=trace_info.workflow_run_outputs,
|
||||
metadata=trace_info.metadata,
|
||||
session_id=trace_info.conversation_id,
|
||||
tags=["message", "workflow"],
|
||||
)
|
||||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
workflow_span_data = LangfuseSpan(
|
||||
id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
|
||||
name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}",
|
||||
input=trace_info.workflow_run_inputs,
|
||||
output=trace_info.workflow_run_outputs,
|
||||
trace_id=trace_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
|
||||
status_message=trace_info.error if trace_info.error else "",
|
||||
)
|
||||
self.add_span(langfuse_span_data=workflow_span_data)
|
||||
else:
|
||||
trace_data = LangfuseTrace(
|
||||
id=trace_id,
|
||||
user_id=trace_info.tenant_id,
|
||||
name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}",
|
||||
input=trace_info.workflow_run_inputs,
|
||||
output=trace_info.workflow_run_outputs,
|
||||
metadata=trace_info.metadata,
|
||||
session_id=trace_info.conversation_id,
|
||||
tags=["workflow"],
|
||||
)
|
||||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution
|
||||
workflow_nodes_executions = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
||||
.order_by(WorkflowNodeExecution.index.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
for node_execution in workflow_nodes_executions:
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = node_execution.tenant_id
|
||||
app_id = node_execution.app_id
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == "llm":
|
||||
inputs = json.loads(node_execution.process_data).get("prompts", {})
|
||||
else:
|
||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||
outputs = (
|
||||
json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||
)
|
||||
created_at = node_execution.created_at if node_execution.created_at else datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||
metadata.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
"node_execution_id": node_execution_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"node_name": node_name,
|
||||
"node_type": node_type,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
||||
# add span
|
||||
if trace_info.message_id:
|
||||
span_data = LangfuseSpan(
|
||||
name=f"{node_name}_{node_execution_id}",
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
trace_id=trace_id,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
metadata=metadata,
|
||||
level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR,
|
||||
status_message=trace_info.error if trace_info.error else "",
|
||||
parent_observation_id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
|
||||
)
|
||||
else:
|
||||
span_data = LangfuseSpan(
|
||||
name=f"{node_name}_{node_execution_id}",
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
trace_id=trace_id,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
metadata=metadata,
|
||||
level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR,
|
||||
status_message=trace_info.error if trace_info.error else "",
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=span_data)
|
||||
|
||||
def message_trace(
|
||||
self, trace_info: MessageTraceInfo, **kwargs
|
||||
):
|
||||
# get message file data
|
||||
file_list = trace_info.file_list
|
||||
metadata = trace_info.metadata
|
||||
message_data = trace_info.message_data
|
||||
message_id = message_data.id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser = db.session.query(EndUser).filter(
|
||||
EndUser.id == message_data.from_end_user_id
|
||||
).first().session_id
|
||||
user_id = end_user_data.session_id
|
||||
|
||||
trace_data = LangfuseTrace(
|
||||
id=message_id,
|
||||
user_id=user_id,
|
||||
name=f"message_{message_id}",
|
||||
input={
|
||||
"message": trace_info.inputs,
|
||||
"files": file_list,
|
||||
"message_tokens": trace_info.message_tokens,
|
||||
"answer_tokens": trace_info.answer_tokens,
|
||||
"total_tokens": trace_info.total_tokens,
|
||||
"error": trace_info.error,
|
||||
"provider_response_latency": message_data.provider_response_latency,
|
||||
"created_at": trace_info.start_time,
|
||||
},
|
||||
output=trace_info.outputs,
|
||||
metadata=metadata,
|
||||
session_id=message_data.conversation_id,
|
||||
tags=["message", str(trace_info.conversation_mode)],
|
||||
version=None,
|
||||
release=None,
|
||||
public=None,
|
||||
)
|
||||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
|
||||
# start add span
|
||||
generation_usage = GenerationUsage(
|
||||
totalTokens=trace_info.total_tokens,
|
||||
input=trace_info.message_tokens,
|
||||
output=trace_info.answer_tokens,
|
||||
total=trace_info.total_tokens,
|
||||
unit=UnitEnum.TOKENS,
|
||||
)
|
||||
|
||||
langfuse_generation_data = LangfuseGeneration(
|
||||
name=f"generation_{message_id}",
|
||||
trace_id=message_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
model=message_data.model_id,
|
||||
input=trace_info.inputs,
|
||||
output=message_data.answer,
|
||||
metadata=metadata,
|
||||
level=LevelEnum.DEFAULT if message_data.status != 'error' else LevelEnum.ERROR,
|
||||
status_message=message_data.error if message_data.error else "",
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
self.add_generation(langfuse_generation_data)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
span_data = LangfuseSpan(
|
||||
name="moderation",
|
||||
input=trace_info.inputs,
|
||||
output={
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
trace_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
||||
end_time=trace_info.end_time or trace_info.message_data.created_at,
|
||||
metadata=trace_info.metadata,
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=span_data)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
generation_usage = GenerationUsage(
|
||||
totalTokens=len(str(trace_info.suggested_question)),
|
||||
input=len(trace_info.inputs),
|
||||
output=len(trace_info.suggested_question),
|
||||
total=len(trace_info.suggested_question),
|
||||
unit=UnitEnum.CHARACTERS,
|
||||
)
|
||||
|
||||
generation_data = LangfuseGeneration(
|
||||
name="suggested_question",
|
||||
input=trace_info.inputs,
|
||||
output=str(trace_info.suggested_question),
|
||||
trace_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
level=LevelEnum.DEFAULT if message_data.status != 'error' else LevelEnum.ERROR,
|
||||
status_message=message_data.error if message_data.error else "",
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
self.add_generation(langfuse_generation_data=generation_data)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
dataset_retrieval_span_data = LangfuseSpan(
|
||||
name="dataset_retrieval",
|
||||
input=trace_info.inputs,
|
||||
output={"documents": trace_info.documents},
|
||||
trace_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
||||
end_time=trace_info.end_time or trace_info.message_data.updated_at,
|
||||
metadata=trace_info.metadata,
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=dataset_retrieval_span_data)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
tool_span_data = LangfuseSpan(
|
||||
name=trace_info.tool_name,
|
||||
input=trace_info.tool_inputs,
|
||||
output=trace_info.tool_outputs,
|
||||
trace_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
|
||||
status_message=trace_info.error,
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=tool_span_data)
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
name_generation_trace_data = LangfuseTrace(
|
||||
name="generate_name",
|
||||
input=trace_info.inputs,
|
||||
output=trace_info.outputs,
|
||||
user_id=trace_info.tenant_id,
|
||||
metadata=trace_info.metadata,
|
||||
session_id=trace_info.conversation_id,
|
||||
)
|
||||
|
||||
self.add_trace(langfuse_trace_data=name_generation_trace_data)
|
||||
|
||||
name_generation_span_data = LangfuseSpan(
|
||||
name="generate_name",
|
||||
input=trace_info.inputs,
|
||||
output=trace_info.outputs,
|
||||
trace_id=trace_info.conversation_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
)
|
||||
self.add_span(langfuse_span_data=name_generation_span_data)
|
||||
|
||||
def add_trace(self, langfuse_trace_data: Optional[LangfuseTrace] = None):
|
||||
format_trace_data = (
|
||||
filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
|
||||
)
|
||||
try:
|
||||
self.langfuse_client.trace(**format_trace_data)
|
||||
logger.debug("LangFuse Trace created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create trace: {str(e)}")
|
||||
|
||||
def add_span(self, langfuse_span_data: Optional[LangfuseSpan] = None):
|
||||
format_span_data = (
|
||||
filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
|
||||
)
|
||||
try:
|
||||
self.langfuse_client.span(**format_span_data)
|
||||
logger.debug("LangFuse Span created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create span: {str(e)}")
|
||||
|
||||
def update_span(self, span, langfuse_span_data: Optional[LangfuseSpan] = None):
|
||||
format_span_data = (
|
||||
filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
|
||||
)
|
||||
|
||||
span.end(**format_span_data)
|
||||
|
||||
def add_generation(
|
||||
self, langfuse_generation_data: Optional[LangfuseGeneration] = None
|
||||
):
|
||||
format_generation_data = (
|
||||
filter_none_values(langfuse_generation_data.model_dump())
|
||||
if langfuse_generation_data
|
||||
else {}
|
||||
)
|
||||
try:
|
||||
self.langfuse_client.generation(**format_generation_data)
|
||||
logger.debug("LangFuse Generation created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create generation: {str(e)}")
|
||||
|
||||
def update_generation(
|
||||
self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None
|
||||
):
|
||||
format_generation_data = (
|
||||
filter_none_values(langfuse_generation_data.model_dump())
|
||||
if langfuse_generation_data
|
||||
else {}
|
||||
)
|
||||
|
||||
generation.end(**format_generation_data)
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
return self.langfuse_client.auth_check()
|
||||
except Exception as e:
|
||||
logger.debug(f"LangFuse API check failed: {str(e)}")
|
||||
raise ValueError(f"LangFuse API check failed: {str(e)}")
|
||||
0
api/core/ops/langsmith_trace/__init__.py
Normal file
0
api/core/ops/langsmith_trace/__init__.py
Normal file
0
api/core/ops/langsmith_trace/entities/__init__.py
Normal file
0
api/core/ops/langsmith_trace/entities/__init__.py
Normal file
167
api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py
Normal file
167
api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.ops.utils import replace_text_with_content
|
||||
|
||||
|
||||
class LangSmithRunType(str, Enum):
|
||||
tool = "tool"
|
||||
chain = "chain"
|
||||
llm = "llm"
|
||||
retriever = "retriever"
|
||||
embedding = "embedding"
|
||||
prompt = "prompt"
|
||||
parser = "parser"
|
||||
|
||||
|
||||
class LangSmithTokenUsage(BaseModel):
|
||||
input_tokens: Optional[int] = None
|
||||
output_tokens: Optional[int] = None
|
||||
total_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class LangSmithMultiModel(BaseModel):
|
||||
file_list: Optional[list[str]] = Field(None, description="List of files")
|
||||
|
||||
|
||||
class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
name: Optional[str] = Field(..., description="Name of the run")
|
||||
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run")
|
||||
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run")
|
||||
run_type: LangSmithRunType = Field(..., description="Type of the run")
|
||||
start_time: Optional[datetime | str] = Field(None, description="Start time of the run")
|
||||
end_time: Optional[datetime | str] = Field(None, description="End time of the run")
|
||||
extra: Optional[dict[str, Any]] = Field(
|
||||
None, description="Extra information of the run"
|
||||
)
|
||||
error: Optional[str] = Field(None, description="Error message of the run")
|
||||
serialized: Optional[dict[str, Any]] = Field(
|
||||
None, description="Serialized data of the run"
|
||||
)
|
||||
parent_run_id: Optional[str] = Field(None, description="Parent run ID")
|
||||
events: Optional[list[dict[str, Any]]] = Field(
|
||||
None, description="Events associated with the run"
|
||||
)
|
||||
tags: Optional[list[str]] = Field(None, description="Tags associated with the run")
|
||||
trace_id: Optional[str] = Field(
|
||||
None, description="Trace ID associated with the run"
|
||||
)
|
||||
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
|
||||
id: Optional[str] = Field(None, description="ID of the run")
|
||||
session_id: Optional[str] = Field(
|
||||
None, description="Session ID associated with the run"
|
||||
)
|
||||
session_name: Optional[str] = Field(
|
||||
None, description="Session name associated with the run"
|
||||
)
|
||||
reference_example_id: Optional[str] = Field(
|
||||
None, description="Reference example ID associated with the run"
|
||||
)
|
||||
input_attachments: Optional[dict[str, Any]] = Field(
|
||||
None, description="Input attachments of the run"
|
||||
)
|
||||
output_attachments: Optional[dict[str, Any]] = Field(
|
||||
None, description="Output attachments of the run"
|
||||
)
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
values = info.data
|
||||
if v == {} or v is None:
|
||||
return v
|
||||
usage_metadata = {
|
||||
"input_tokens": values.get('input_tokens', 0),
|
||||
"output_tokens": values.get('output_tokens', 0),
|
||||
"total_tokens": values.get('total_tokens', 0),
|
||||
}
|
||||
file_list = values.get("file_list", [])
|
||||
if isinstance(v, str):
|
||||
if field_name == "inputs":
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif isinstance(v, list):
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": v,
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
data = {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai" if field_name == "outputs" else "user",
|
||||
"content": str(v),
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
if isinstance(v, dict):
|
||||
v["usage_metadata"] = usage_metadata
|
||||
v["file_list"] = file_list
|
||||
return v
|
||||
return v
|
||||
|
||||
@field_validator("start_time", "end_time")
|
||||
def format_time(cls, v, info: ValidationInfo):
|
||||
if not isinstance(v, datetime):
|
||||
raise ValueError(f"{info.field_name} must be a datetime object")
|
||||
else:
|
||||
return v.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
|
||||
|
||||
class LangSmithRunUpdateModel(BaseModel):
|
||||
run_id: str = Field(..., description="ID of the run")
|
||||
trace_id: Optional[str] = Field(
|
||||
None, description="Trace ID associated with the run"
|
||||
)
|
||||
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
|
||||
parent_run_id: Optional[str] = Field(None, description="Parent run ID")
|
||||
end_time: Optional[datetime | str] = Field(None, description="End time of the run")
|
||||
error: Optional[str] = Field(None, description="Error message of the run")
|
||||
inputs: Optional[dict[str, Any]] = Field(None, description="Inputs of the run")
|
||||
outputs: Optional[dict[str, Any]] = Field(None, description="Outputs of the run")
|
||||
events: Optional[list[dict[str, Any]]] = Field(
|
||||
None, description="Events associated with the run"
|
||||
)
|
||||
tags: Optional[list[str]] = Field(None, description="Tags associated with the run")
|
||||
extra: Optional[dict[str, Any]] = Field(
|
||||
None, description="Extra information of the run"
|
||||
)
|
||||
input_attachments: Optional[dict[str, Any]] = Field(
|
||||
None, description="Input attachments of the run"
|
||||
)
|
||||
output_attachments: Optional[dict[str, Any]] = Field(
|
||||
None, description="Output attachments of the run"
|
||||
)
|
||||
355
api/core/ops/langsmith_trace/langsmith_trace.py
Normal file
355
api/core/ops/langsmith_trace/langsmith_trace.py
Normal file
@@ -0,0 +1,355 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from langsmith import Client
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
LangSmithRunModel,
|
||||
LangSmithRunType,
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LangSmithDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
langsmith_config: LangSmithConfig,
|
||||
):
|
||||
super().__init__(langsmith_config)
|
||||
self.langsmith_key = langsmith_config.api_key
|
||||
self.project_name = langsmith_config.project
|
||||
self.project_id = None
|
||||
self.langsmith_client = Client(
|
||||
api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint
|
||||
)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
if trace_info.message_id:
|
||||
message_run = LangSmithRunModel(
|
||||
id=trace_info.message_id,
|
||||
name=f"message_{trace_info.message_id}",
|
||||
inputs=trace_info.workflow_run_inputs,
|
||||
outputs=trace_info.workflow_run_outputs,
|
||||
run_type=LangSmithRunType.chain,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
extra={
|
||||
"metadata": trace_info.metadata,
|
||||
},
|
||||
tags=["message"],
|
||||
error=trace_info.error
|
||||
)
|
||||
self.add_run(message_run)
|
||||
|
||||
langsmith_run = LangSmithRunModel(
|
||||
file_list=trace_info.file_list,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
|
||||
name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}",
|
||||
inputs=trace_info.workflow_run_inputs,
|
||||
run_type=LangSmithRunType.tool,
|
||||
start_time=trace_info.workflow_data.created_at,
|
||||
end_time=trace_info.workflow_data.finished_at,
|
||||
outputs=trace_info.workflow_run_outputs,
|
||||
extra={
|
||||
"metadata": trace_info.metadata,
|
||||
},
|
||||
error=trace_info.error,
|
||||
tags=["workflow"],
|
||||
parent_run_id=trace_info.message_id if trace_info.message_id else None,
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution
|
||||
workflow_nodes_executions = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
||||
.order_by(WorkflowNodeExecution.index.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
for node_execution in workflow_nodes_executions:
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = node_execution.tenant_id
|
||||
app_id = node_execution.app_id
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == "llm":
|
||||
inputs = json.loads(node_execution.process_data).get("prompts", {})
|
||||
else:
|
||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||
outputs = (
|
||||
json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||
)
|
||||
created_at = node_execution.created_at if node_execution.created_at else datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = (
|
||||
json.loads(node_execution.execution_metadata)
|
||||
if node_execution.execution_metadata
|
||||
else {}
|
||||
)
|
||||
node_total_tokens = execution_metadata.get("total_tokens", 0)
|
||||
|
||||
metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||
metadata.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
"node_execution_id": node_execution_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_name,
|
||||
"node_type": node_type,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
run_type = LangSmithRunType.llm
|
||||
elif node_type == "knowledge-retrieval":
|
||||
run_type = LangSmithRunType.retriever
|
||||
else:
|
||||
run_type = LangSmithRunType.tool
|
||||
|
||||
langsmith_run = LangSmithRunModel(
|
||||
total_tokens=node_total_tokens,
|
||||
name=f"{node_name}_{node_execution_id}",
|
||||
inputs=inputs,
|
||||
run_type=run_type,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
outputs=outputs,
|
||||
file_list=trace_info.file_list,
|
||||
extra={
|
||||
"metadata": metadata,
|
||||
},
|
||||
parent_run_id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
|
||||
tags=["node_execution"],
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = trace_info.file_list
|
||||
message_file_data: MessageFile = trace_info.message_file_data
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
metadata = trace_info.metadata
|
||||
message_data = trace_info.message_data
|
||||
message_id = message_data.id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser = db.session.query(EndUser).filter(
|
||||
EndUser.id == message_data.from_end_user_id
|
||||
).first().session_id
|
||||
end_user_id = end_user_data.session_id
|
||||
metadata["end_user_id"] = end_user_id
|
||||
metadata["user_id"] = user_id
|
||||
|
||||
message_run = LangSmithRunModel(
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=message_id,
|
||||
name=f"message_{message_id}",
|
||||
inputs=trace_info.inputs,
|
||||
run_type=LangSmithRunType.chain,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
outputs=message_data.answer,
|
||||
extra={
|
||||
"metadata": metadata,
|
||||
},
|
||||
tags=["message", str(trace_info.conversation_mode)],
|
||||
error=trace_info.error,
|
||||
file_list=file_list,
|
||||
)
|
||||
self.add_run(message_run)
|
||||
|
||||
# create llm run parented to message run
|
||||
llm_run = LangSmithRunModel(
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
name=f"llm_{message_id}",
|
||||
inputs=trace_info.inputs,
|
||||
run_type=LangSmithRunType.llm,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
outputs=message_data.answer,
|
||||
extra={
|
||||
"metadata": metadata,
|
||||
},
|
||||
parent_run_id=message_id,
|
||||
tags=["llm", str(trace_info.conversation_mode)],
|
||||
error=trace_info.error,
|
||||
file_list=file_list,
|
||||
)
|
||||
self.add_run(llm_run)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
langsmith_run = LangSmithRunModel(
|
||||
name="moderation",
|
||||
inputs=trace_info.inputs,
|
||||
outputs={
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
run_type=LangSmithRunType.tool,
|
||||
extra={
|
||||
"metadata": trace_info.metadata,
|
||||
},
|
||||
tags=["moderation"],
|
||||
parent_run_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
||||
end_time=trace_info.end_time or trace_info.message_data.updated_at,
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
suggested_question_run = LangSmithRunModel(
|
||||
name="suggested_question",
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.suggested_question,
|
||||
run_type=LangSmithRunType.tool,
|
||||
extra={
|
||||
"metadata": trace_info.metadata,
|
||||
},
|
||||
tags=["suggested_question"],
|
||||
parent_run_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time or message_data.created_at,
|
||||
end_time=trace_info.end_time or message_data.updated_at,
|
||||
)
|
||||
|
||||
self.add_run(suggested_question_run)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
dataset_retrieval_run = LangSmithRunModel(
|
||||
name="dataset_retrieval",
|
||||
inputs=trace_info.inputs,
|
||||
outputs={"documents": trace_info.documents},
|
||||
run_type=LangSmithRunType.retriever,
|
||||
extra={
|
||||
"metadata": trace_info.metadata,
|
||||
},
|
||||
tags=["dataset_retrieval"],
|
||||
parent_run_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
||||
end_time=trace_info.end_time or trace_info.message_data.updated_at,
|
||||
)
|
||||
|
||||
self.add_run(dataset_retrieval_run)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
tool_run = LangSmithRunModel(
|
||||
name=trace_info.tool_name,
|
||||
inputs=trace_info.tool_inputs,
|
||||
outputs=trace_info.tool_outputs,
|
||||
run_type=LangSmithRunType.tool,
|
||||
extra={
|
||||
"metadata": trace_info.metadata,
|
||||
},
|
||||
tags=["tool", trace_info.tool_name],
|
||||
parent_run_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
file_list=[trace_info.file_url],
|
||||
)
|
||||
|
||||
self.add_run(tool_run)
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
name_run = LangSmithRunModel(
|
||||
name="generate_name",
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
run_type=LangSmithRunType.tool,
|
||||
extra={
|
||||
"metadata": trace_info.metadata,
|
||||
},
|
||||
tags=["generate_name"],
|
||||
start_time=trace_info.start_time or datetime.now(),
|
||||
end_time=trace_info.end_time or datetime.now(),
|
||||
)
|
||||
|
||||
self.add_run(name_run)
|
||||
|
||||
def add_run(self, run_data: LangSmithRunModel):
|
||||
data = run_data.model_dump()
|
||||
if self.project_id:
|
||||
data["session_id"] = self.project_id
|
||||
elif self.project_name:
|
||||
data["session_name"] = self.project_name
|
||||
|
||||
data = filter_none_values(data)
|
||||
try:
|
||||
self.langsmith_client.create_run(**data)
|
||||
logger.debug("LangSmith Run created successfully.")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangSmith Failed to create run: {str(e)}")
|
||||
|
||||
def update_run(self, update_run_data: LangSmithRunUpdateModel):
|
||||
data = update_run_data.model_dump()
|
||||
data = filter_none_values(data)
|
||||
try:
|
||||
self.langsmith_client.update_run(**data)
|
||||
logger.debug("LangSmith Run updated successfully.")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangSmith Failed to update run: {str(e)}")
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
random_project_name = f"test_project_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
self.langsmith_client.create_project(project_name=random_project_name)
|
||||
self.langsmith_client.delete_project(project_name=random_project_name)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"LangSmith API check failed: {str(e)}")
|
||||
raise ValueError(f"LangSmith API check failed: {str(e)}")
|
||||
687
api/core/ops/ops_trace_manager.py
Normal file
687
api/core/ops/ops_trace_manager.py
Normal file
@@ -0,0 +1,687 @@
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import (
|
||||
LangfuseConfig,
|
||||
LangSmithConfig,
|
||||
TracingProviderEnum,
|
||||
)
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig
|
||||
from models.workflow import WorkflowAppLog, WorkflowRun
|
||||
|
||||
provider_config_map = {
|
||||
TracingProviderEnum.LANGFUSE.value: {
|
||||
'config_class': LangfuseConfig,
|
||||
'secret_keys': ['public_key', 'secret_key'],
|
||||
'other_keys': ['host'],
|
||||
'trace_instance': LangFuseDataTrace
|
||||
},
|
||||
TracingProviderEnum.LANGSMITH.value: {
|
||||
'config_class': LangSmithConfig,
|
||||
'secret_keys': ['api_key'],
|
||||
'other_keys': ['project', 'endpoint'],
|
||||
'trace_instance': LangSmithDataTrace
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class OpsTraceManager:
|
||||
@classmethod
|
||||
def encrypt_tracing_config(
|
||||
cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
|
||||
):
|
||||
"""
|
||||
Encrypt tracing config.
|
||||
:param tenant_id: tenant id
|
||||
:param tracing_provider: tracing provider
|
||||
:param tracing_config: tracing config dictionary to be encrypted
|
||||
:param current_trace_config: current tracing configuration for keeping existing values
|
||||
:return: encrypted tracing configuration
|
||||
"""
|
||||
# Get the configuration class and the keys that require encryption
|
||||
config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \
|
||||
provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys']
|
||||
|
||||
new_config = {}
|
||||
# Encrypt necessary keys
|
||||
for key in secret_keys:
|
||||
if key in tracing_config:
|
||||
if '*' in tracing_config[key]:
|
||||
# If the key contains '*', retain the original value from the current config
|
||||
new_config[key] = current_trace_config.get(key, tracing_config[key])
|
||||
else:
|
||||
# Otherwise, encrypt the key
|
||||
new_config[key] = encrypt_token(tenant_id, tracing_config[key])
|
||||
|
||||
for key in other_keys:
|
||||
new_config[key] = tracing_config.get(key, "")
|
||||
|
||||
# Create a new instance of the config class with the new configuration
|
||||
encrypted_config = config_class(**new_config)
|
||||
return encrypted_config.model_dump()
|
||||
|
||||
@classmethod
|
||||
def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
|
||||
"""
|
||||
Decrypt tracing config
|
||||
:param tenant_id: tenant id
|
||||
:param tracing_provider: tracing provider
|
||||
:param tracing_config: tracing config
|
||||
:return:
|
||||
"""
|
||||
config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \
|
||||
provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys']
|
||||
new_config = {}
|
||||
for key in secret_keys:
|
||||
if key in tracing_config:
|
||||
new_config[key] = decrypt_token(tenant_id, tracing_config[key])
|
||||
|
||||
for key in other_keys:
|
||||
new_config[key] = tracing_config.get(key, "")
|
||||
|
||||
return config_class(**new_config).model_dump()
|
||||
|
||||
@classmethod
|
||||
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config:dict):
|
||||
"""
|
||||
Decrypt tracing config
|
||||
:param tracing_provider: tracing provider
|
||||
:param decrypt_tracing_config: tracing config
|
||||
:return:
|
||||
"""
|
||||
config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \
|
||||
provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys']
|
||||
new_config = {}
|
||||
for key in secret_keys:
|
||||
if key in decrypt_tracing_config:
|
||||
new_config[key] = obfuscated_token(decrypt_tracing_config[key])
|
||||
|
||||
for key in other_keys:
|
||||
new_config[key] = decrypt_tracing_config.get(key, "")
|
||||
|
||||
return config_class(**new_config).model_dump()
|
||||
|
||||
@classmethod
|
||||
def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str):
|
||||
"""
|
||||
Get decrypted tracing config
|
||||
:param app_id: app id
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter(
|
||||
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
|
||||
).first()
|
||||
|
||||
if not trace_config_data:
|
||||
return None
|
||||
|
||||
# decrypt_token
|
||||
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
|
||||
decrypt_tracing_config = cls.decrypt_tracing_config(
|
||||
tenant_id, tracing_provider, trace_config_data.tracing_config
|
||||
)
|
||||
|
||||
return decrypt_tracing_config
|
||||
|
||||
@classmethod
|
||||
def get_ops_trace_instance(
|
||||
cls,
|
||||
app_id: Optional[Union[UUID, str]] = None,
|
||||
message_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Get ops trace through model config
|
||||
:param app_id: app_id
|
||||
:param message_id: message_id
|
||||
:param conversation_id: conversation_id
|
||||
:return:
|
||||
"""
|
||||
if conversation_id is not None:
|
||||
conversation_data: Conversation = db.session.query(Conversation).filter(
|
||||
Conversation.id == conversation_id
|
||||
).first()
|
||||
if conversation_data:
|
||||
app_id = conversation_data.app_id
|
||||
|
||||
if message_id is not None:
|
||||
record: Message = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
app_id = record.app_id
|
||||
|
||||
if isinstance(app_id, UUID):
|
||||
app_id = str(app_id)
|
||||
|
||||
if app_id is None:
|
||||
return None
|
||||
|
||||
app: App = db.session.query(App).filter(
|
||||
App.id == app_id
|
||||
).first()
|
||||
app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
|
||||
|
||||
if app_ops_trace_config is not None:
|
||||
tracing_provider = app_ops_trace_config.get('tracing_provider')
|
||||
else:
|
||||
return None
|
||||
|
||||
# decrypt_token
|
||||
decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
|
||||
if app_ops_trace_config.get('enabled'):
|
||||
trace_instance, config_class = provider_config_map[tracing_provider]['trace_instance'], \
|
||||
provider_config_map[tracing_provider]['config_class']
|
||||
tracing_instance = trace_instance(config_class(**decrypt_trace_config))
|
||||
return tracing_instance
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_app_config_through_message_id(cls, message_id: str):
|
||||
app_model_config = None
|
||||
message_data = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
conversation_id = message_data.conversation_id
|
||||
conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
|
||||
if conversation_data.app_model_config_id:
|
||||
app_model_config = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == conversation_data.app_model_config_id
|
||||
).first()
|
||||
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
|
||||
app_model_config = conversation_data.override_model_configs
|
||||
|
||||
return app_model_config
|
||||
|
||||
@classmethod
|
||||
def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
|
||||
"""
|
||||
Update app tracing config
|
||||
:param app_id: app id
|
||||
:param enabled: enabled
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
# auth check
|
||||
if tracing_provider not in provider_config_map.keys() and tracing_provider is not None:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
app_config: App = db.session.query(App).filter(App.id == app_id).first()
|
||||
app_config.tracing = json.dumps(
|
||||
{
|
||||
"enabled": enabled,
|
||||
"tracing_provider": tracing_provider,
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_app_tracing_config(cls, app_id: str):
|
||||
"""
|
||||
Get app tracing config
|
||||
:param app_id: app id
|
||||
:return:
|
||||
"""
|
||||
app: App = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app.tracing:
|
||||
return {
|
||||
"enabled": False,
|
||||
"tracing_provider": None
|
||||
}
|
||||
app_trace_config = json.loads(app.tracing)
|
||||
return app_trace_config
|
||||
|
||||
@staticmethod
|
||||
def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
|
||||
"""
|
||||
Check trace config is effective
|
||||
:param tracing_config: tracing config
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \
|
||||
provider_config_map[tracing_provider]['trace_instance']
|
||||
tracing_config = config_type(**tracing_config)
|
||||
return trace_instance(tracing_config).api_check()
|
||||
|
||||
|
||||
class TraceTaskName(str, Enum):
|
||||
CONVERSATION_TRACE = 'conversation_trace'
|
||||
WORKFLOW_TRACE = 'workflow_trace'
|
||||
MESSAGE_TRACE = 'message_trace'
|
||||
MODERATION_TRACE = 'moderation_trace'
|
||||
SUGGESTED_QUESTION_TRACE = 'suggested_question_trace'
|
||||
DATASET_RETRIEVAL_TRACE = 'dataset_retrieval_trace'
|
||||
TOOL_TRACE = 'tool_trace'
|
||||
GENERATE_NAME_TRACE = 'generate_name_trace'
|
||||
|
||||
|
||||
class TraceTask:
|
||||
def __init__(
|
||||
self,
|
||||
trace_type: Any,
|
||||
message_id: Optional[str] = None,
|
||||
workflow_run: Optional[WorkflowRun] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
**kwargs
|
||||
):
|
||||
self.trace_type = trace_type
|
||||
self.message_id = message_id
|
||||
self.workflow_run = workflow_run
|
||||
self.conversation_id = conversation_id
|
||||
self.timer = timer
|
||||
self.kwargs = kwargs
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def execute(self, trace_instance: BaseTraceInstance):
|
||||
method_name, trace_info = self.preprocess()
|
||||
if trace_instance:
|
||||
method = trace_instance.trace
|
||||
method(trace_info)
|
||||
|
||||
def preprocess(self):
|
||||
if self.trace_type == TraceTaskName.CONVERSATION_TRACE:
|
||||
return TraceTaskName.CONVERSATION_TRACE, self.conversation_trace(**self.kwargs)
|
||||
if self.trace_type == TraceTaskName.WORKFLOW_TRACE:
|
||||
return TraceTaskName.WORKFLOW_TRACE, self.workflow_trace(self.workflow_run, self.conversation_id)
|
||||
elif self.trace_type == TraceTaskName.MESSAGE_TRACE:
|
||||
return TraceTaskName.MESSAGE_TRACE, self.message_trace(self.message_id)
|
||||
elif self.trace_type == TraceTaskName.MODERATION_TRACE:
|
||||
return TraceTaskName.MODERATION_TRACE, self.moderation_trace(self.message_id, self.timer, **self.kwargs)
|
||||
elif self.trace_type == TraceTaskName.SUGGESTED_QUESTION_TRACE:
|
||||
return TraceTaskName.SUGGESTED_QUESTION_TRACE, self.suggested_question_trace(
|
||||
self.message_id, self.timer, **self.kwargs
|
||||
)
|
||||
elif self.trace_type == TraceTaskName.DATASET_RETRIEVAL_TRACE:
|
||||
return TraceTaskName.DATASET_RETRIEVAL_TRACE, self.dataset_retrieval_trace(
|
||||
self.message_id, self.timer, **self.kwargs
|
||||
)
|
||||
elif self.trace_type == TraceTaskName.TOOL_TRACE:
|
||||
return TraceTaskName.TOOL_TRACE, self.tool_trace(self.message_id, self.timer, **self.kwargs)
|
||||
elif self.trace_type == TraceTaskName.GENERATE_NAME_TRACE:
|
||||
return TraceTaskName.GENERATE_NAME_TRACE, self.generate_name_trace(
|
||||
self.conversation_id, self.timer, **self.kwargs
|
||||
)
|
||||
else:
|
||||
return '', {}
|
||||
|
||||
# process methods for different trace types
|
||||
def conversation_trace(self, **kwargs):
|
||||
return kwargs
|
||||
|
||||
def workflow_trace(self, workflow_run: WorkflowRun, conversation_id):
|
||||
workflow_id = workflow_run.workflow_id
|
||||
tenant_id = workflow_run.tenant_id
|
||||
workflow_run_id = workflow_run.id
|
||||
workflow_run_elapsed_time = workflow_run.elapsed_time
|
||||
workflow_run_status = workflow_run.status
|
||||
workflow_run_inputs = (
|
||||
json.loads(workflow_run.inputs) if workflow_run.inputs else {}
|
||||
)
|
||||
workflow_run_outputs = (
|
||||
json.loads(workflow_run.outputs) if workflow_run.outputs else {}
|
||||
)
|
||||
workflow_run_version = workflow_run.version
|
||||
error = workflow_run.error if workflow_run.error else ""
|
||||
|
||||
total_tokens = workflow_run.total_tokens
|
||||
|
||||
file_list = workflow_run_inputs.get("sys.file") if workflow_run_inputs.get("sys.file") else []
|
||||
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
||||
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by(workflow_run_id=workflow_run.id).first()
|
||||
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
||||
# get message_id
|
||||
message_data = db.session.query(Message.id).filter_by(workflow_run_id=workflow_run_id).first()
|
||||
message_id = str(message_data.id) if message_data else None
|
||||
|
||||
metadata = {
|
||||
"workflow_id": workflow_id,
|
||||
"conversation_id": conversation_id,
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"tenant_id": tenant_id,
|
||||
"elapsed_time": workflow_run_elapsed_time,
|
||||
"status": workflow_run_status,
|
||||
"version": workflow_run_version,
|
||||
"total_tokens": total_tokens,
|
||||
"file_list": file_list,
|
||||
"triggered_form": workflow_run.triggered_from,
|
||||
}
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
workflow_data=workflow_run,
|
||||
conversation_id=conversation_id,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_elapsed_time=workflow_run_elapsed_time,
|
||||
workflow_run_status=workflow_run_status,
|
||||
workflow_run_inputs=workflow_run_inputs,
|
||||
workflow_run_outputs=workflow_run_outputs,
|
||||
workflow_run_version=workflow_run_version,
|
||||
error=error,
|
||||
total_tokens=total_tokens,
|
||||
file_list=file_list,
|
||||
query=query,
|
||||
metadata=metadata,
|
||||
workflow_app_log_id=workflow_app_log_id,
|
||||
message_id=message_id,
|
||||
start_time=workflow_run.created_at,
|
||||
end_time=workflow_run.finished_at,
|
||||
)
|
||||
|
||||
return workflow_trace_info
|
||||
|
||||
def message_trace(self, message_id):
|
||||
message_data = get_message_data(message_id)
|
||||
if not message_data:
|
||||
return {}
|
||||
conversation_mode = db.session.query(Conversation.mode).filter_by(id=message_data.conversation_id).first()
|
||||
conversation_mode = conversation_mode[0]
|
||||
created_at = message_data.created_at
|
||||
inputs = message_data.message
|
||||
|
||||
# get message file data
|
||||
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
|
||||
file_list = []
|
||||
if message_file_data and message_file_data.url is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
metadata = {
|
||||
"conversation_id": message_data.conversation_id,
|
||||
"ls_provider": message_data.model_provider,
|
||||
"ls_model_name": message_data.model_id,
|
||||
"status": message_data.status,
|
||||
"from_end_user_id": message_data.from_account_id,
|
||||
"from_account_id": message_data.from_account_id,
|
||||
"agent_based": message_data.agent_based,
|
||||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
message_tokens = message_data.message_tokens
|
||||
|
||||
message_trace_info = MessageTraceInfo(
|
||||
message_data=message_data,
|
||||
conversation_model=conversation_mode,
|
||||
message_tokens=message_tokens,
|
||||
answer_tokens=message_data.answer_tokens,
|
||||
total_tokens=message_tokens + message_data.answer_tokens,
|
||||
error=message_data.error if message_data.error else "",
|
||||
inputs=inputs,
|
||||
outputs=message_data.answer,
|
||||
file_list=file_list,
|
||||
start_time=created_at,
|
||||
end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
|
||||
metadata=metadata,
|
||||
message_file_data=message_file_data,
|
||||
conversation_mode=conversation_mode,
|
||||
)
|
||||
|
||||
return message_trace_info
|
||||
|
||||
def moderation_trace(self, message_id, timer, **kwargs):
|
||||
moderation_result = kwargs.get("moderation_result")
|
||||
inputs = kwargs.get("inputs")
|
||||
message_data = get_message_data(message_id)
|
||||
if not message_data:
|
||||
return {}
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"action": moderation_result.action,
|
||||
"preset_response": moderation_result.preset_response,
|
||||
"query": moderation_result.query,
|
||||
}
|
||||
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_id = None
|
||||
if message_data.workflow_run_id:
|
||||
workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by(
|
||||
workflow_run_id=message_data.workflow_run_id
|
||||
).first()
|
||||
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
||||
|
||||
moderation_trace_info = ModerationTraceInfo(
|
||||
message_id=workflow_app_log_id if workflow_app_log_id else message_id,
|
||||
inputs=inputs,
|
||||
message_data=message_data,
|
||||
flagged=moderation_result.flagged,
|
||||
action=moderation_result.action,
|
||||
preset_response=moderation_result.preset_response,
|
||||
query=moderation_result.query,
|
||||
start_time=timer.get("start"),
|
||||
end_time=timer.get("end"),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return moderation_trace_info
|
||||
|
||||
def suggested_question_trace(self, message_id, timer, **kwargs):
|
||||
suggested_question = kwargs.get("suggested_question")
|
||||
message_data = get_message_data(message_id)
|
||||
if not message_data:
|
||||
return {}
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"ls_provider": message_data.model_provider,
|
||||
"ls_model_name": message_data.model_id,
|
||||
"status": message_data.status,
|
||||
"from_end_user_id": message_data.from_account_id,
|
||||
"from_account_id": message_data.from_account_id,
|
||||
"agent_based": message_data.agent_based,
|
||||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
}
|
||||
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_id = None
|
||||
if message_data.workflow_run_id:
|
||||
workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by(
|
||||
workflow_run_id=message_data.workflow_run_id
|
||||
).first()
|
||||
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
||||
|
||||
suggested_question_trace_info = SuggestedQuestionTraceInfo(
|
||||
message_id=workflow_app_log_id if workflow_app_log_id else message_id,
|
||||
message_data=message_data,
|
||||
inputs=message_data.message,
|
||||
outputs=message_data.answer,
|
||||
start_time=timer.get("start"),
|
||||
end_time=timer.get("end"),
|
||||
metadata=metadata,
|
||||
total_tokens=message_data.message_tokens + message_data.answer_tokens,
|
||||
status=message_data.status,
|
||||
error=message_data.error,
|
||||
from_account_id=message_data.from_account_id,
|
||||
agent_based=message_data.agent_based,
|
||||
from_source=message_data.from_source,
|
||||
model_provider=message_data.model_provider,
|
||||
model_id=message_data.model_id,
|
||||
suggested_question=suggested_question,
|
||||
level=message_data.status,
|
||||
status_message=message_data.error,
|
||||
)
|
||||
|
||||
return suggested_question_trace_info
|
||||
|
||||
def dataset_retrieval_trace(self, message_id, timer, **kwargs):
|
||||
documents = kwargs.get("documents")
|
||||
message_data = get_message_data(message_id)
|
||||
if not message_data:
|
||||
return {}
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"ls_provider": message_data.model_provider,
|
||||
"ls_model_name": message_data.model_id,
|
||||
"status": message_data.status,
|
||||
"from_end_user_id": message_data.from_account_id,
|
||||
"from_account_id": message_data.from_account_id,
|
||||
"agent_based": message_data.agent_based,
|
||||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
}
|
||||
|
||||
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
|
||||
message_id=message_id,
|
||||
inputs=message_data.query if message_data.query else message_data.inputs,
|
||||
documents=documents,
|
||||
start_time=timer.get("start"),
|
||||
end_time=timer.get("end"),
|
||||
metadata=metadata,
|
||||
message_data=message_data,
|
||||
)
|
||||
|
||||
return dataset_retrieval_trace_info
|
||||
|
||||
def tool_trace(self, message_id, timer, **kwargs):
|
||||
tool_name = kwargs.get('tool_name')
|
||||
tool_inputs = kwargs.get('tool_inputs')
|
||||
tool_outputs = kwargs.get('tool_outputs')
|
||||
message_data = get_message_data(message_id)
|
||||
if not message_data:
|
||||
return {}
|
||||
tool_config = {}
|
||||
time_cost = 0
|
||||
error = None
|
||||
tool_parameters = {}
|
||||
created_time = message_data.created_at
|
||||
end_time = message_data.updated_at
|
||||
agent_thoughts: list[MessageAgentThought] = message_data.agent_thoughts
|
||||
for agent_thought in agent_thoughts:
|
||||
if tool_name in agent_thought.tools:
|
||||
created_time = agent_thought.created_at
|
||||
tool_meta_data = agent_thought.tool_meta.get(tool_name, {})
|
||||
tool_config = tool_meta_data.get('tool_config', {})
|
||||
time_cost = tool_meta_data.get('time_cost', 0)
|
||||
end_time = created_time + timedelta(seconds=time_cost)
|
||||
error = tool_meta_data.get('error', "")
|
||||
tool_parameters = tool_meta_data.get('tool_parameters', {})
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_inputs": tool_inputs,
|
||||
"tool_outputs": tool_outputs,
|
||||
"tool_config": tool_config,
|
||||
"time_cost": time_cost,
|
||||
"error": error,
|
||||
"tool_parameters": tool_parameters,
|
||||
}
|
||||
|
||||
file_url = ""
|
||||
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
|
||||
if message_file_data:
|
||||
message_file_id = message_file_data.id if message_file_data else None
|
||||
type = message_file_data.type
|
||||
created_by_role = message_file_data.created_by_role
|
||||
created_user_id = message_file_data.created_by
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}"
|
||||
|
||||
metadata.update(
|
||||
{
|
||||
"message_file_id": message_file_id,
|
||||
"created_by_role": created_by_role,
|
||||
"created_user_id": created_user_id,
|
||||
"type": type,
|
||||
}
|
||||
)
|
||||
|
||||
tool_trace_info = ToolTraceInfo(
|
||||
message_id=message_id,
|
||||
message_data=message_data,
|
||||
tool_name=tool_name,
|
||||
start_time=timer.get("start") if timer else created_time,
|
||||
end_time=timer.get("end") if timer else end_time,
|
||||
tool_inputs=tool_inputs,
|
||||
tool_outputs=tool_outputs,
|
||||
metadata=metadata,
|
||||
message_file_data=message_file_data,
|
||||
error=error,
|
||||
inputs=message_data.message,
|
||||
outputs=message_data.answer,
|
||||
tool_config=tool_config,
|
||||
time_cost=time_cost,
|
||||
tool_parameters=tool_parameters,
|
||||
file_url=file_url,
|
||||
)
|
||||
|
||||
return tool_trace_info
|
||||
|
||||
def generate_name_trace(self, conversation_id, timer, **kwargs):
|
||||
generate_conversation_name = kwargs.get("generate_conversation_name")
|
||||
inputs = kwargs.get("inputs")
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
start_time = timer.get("start")
|
||||
end_time = timer.get("end")
|
||||
|
||||
metadata = {
|
||||
"conversation_id": conversation_id,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
|
||||
generate_name_trace_info = GenerateNameTraceInfo(
|
||||
conversation_id=conversation_id,
|
||||
inputs=inputs,
|
||||
outputs=generate_conversation_name,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
metadata=metadata,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return generate_name_trace_info
|
||||
|
||||
|
||||
class TraceQueueManager:
|
||||
def __init__(self, app_id=None, conversation_id=None, message_id=None):
|
||||
tracing_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id)
|
||||
self.queue = queue.Queue()
|
||||
self.is_running = True
|
||||
self.thread = threading.Thread(
|
||||
target=self.process_queue, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'trace_instance': tracing_instance
|
||||
}
|
||||
)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self):
|
||||
self.is_running = False
|
||||
|
||||
def process_queue(self, flask_app: Flask, trace_instance: BaseTraceInstance):
|
||||
with flask_app.app_context():
|
||||
while self.is_running:
|
||||
try:
|
||||
task = self.queue.get(timeout=60)
|
||||
task.execute(trace_instance)
|
||||
self.queue.task_done()
|
||||
except queue.Empty:
|
||||
self.stop()
|
||||
|
||||
def add_trace_task(self, trace_task: TraceTask):
|
||||
self.queue.put(trace_task)
|
||||
43
api/core/ops/utils.py
Normal file
43
api/core/ops/utils.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
|
||||
|
||||
def filter_none_values(data: dict):
|
||||
for key, value in data.items():
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, datetime):
|
||||
data[key] = value.isoformat()
|
||||
return {key: value for key, value in data.items() if value is not None}
|
||||
|
||||
|
||||
def get_message_data(message_id):
|
||||
return db.session.query(Message).filter(Message.id == message_id).first()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def measure_time():
|
||||
timing_info = {'start': datetime.now(), 'end': None}
|
||||
try:
|
||||
yield timing_info
|
||||
finally:
|
||||
timing_info['end'] = datetime.now()
|
||||
print(f"Execution time: {timing_info['end'] - timing_info['start']}")
|
||||
|
||||
|
||||
def replace_text_with_content(data):
|
||||
if isinstance(data, dict):
|
||||
new_data = {}
|
||||
for key, value in data.items():
|
||||
if key == 'text':
|
||||
new_data['content'] = value
|
||||
else:
|
||||
new_data[key] = replace_text_with_content(value)
|
||||
return new_data
|
||||
elif isinstance(data, list):
|
||||
return [replace_text_with_content(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
@@ -12,6 +12,8 @@ from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.ops.ops_trace_manager import TraceTask, TraceTaskName
|
||||
from core.ops.utils import measure_time
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank import RerankRunner
|
||||
@@ -38,14 +40,20 @@ default_retrieval_model = {
|
||||
|
||||
|
||||
class DatasetRetrieval:
|
||||
def retrieve(self, app_id: str, user_id: str, tenant_id: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: DatasetEntity,
|
||||
query: str,
|
||||
invoke_from: InvokeFrom,
|
||||
show_retrieve_source: bool,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
|
||||
def __init__(self, application_generate_entity=None):
|
||||
self.application_generate_entity = application_generate_entity
|
||||
|
||||
def retrieve(
|
||||
self, app_id: str, user_id: str, tenant_id: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: DatasetEntity,
|
||||
query: str,
|
||||
invoke_from: InvokeFrom,
|
||||
show_retrieve_source: bool,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
message_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Retrieve dataset.
|
||||
:param app_id: app_id
|
||||
@@ -57,6 +65,7 @@ class DatasetRetrieval:
|
||||
:param invoke_from: invoke from
|
||||
:param show_retrieve_source: show retrieve source
|
||||
:param hit_callback: hit callback
|
||||
:param message_id: message id
|
||||
:param memory: memory
|
||||
:return:
|
||||
"""
|
||||
@@ -113,15 +122,20 @@ class DatasetRetrieval:
|
||||
all_documents = []
|
||||
user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
all_documents = self.single_retrieve(app_id, tenant_id, user_id, user_from, available_datasets, query,
|
||||
model_instance,
|
||||
model_config, planning_strategy)
|
||||
all_documents = self.single_retrieve(
|
||||
app_id, tenant_id, user_id, user_from, available_datasets, query,
|
||||
model_instance,
|
||||
model_config, planning_strategy, message_id
|
||||
)
|
||||
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
all_documents = self.multiple_retrieve(app_id, tenant_id, user_id, user_from,
|
||||
available_datasets, query, retrieve_config.top_k,
|
||||
retrieve_config.score_threshold,
|
||||
retrieve_config.reranking_model.get('reranking_provider_name'),
|
||||
retrieve_config.reranking_model.get('reranking_model_name'))
|
||||
all_documents = self.multiple_retrieve(
|
||||
app_id, tenant_id, user_id, user_from,
|
||||
available_datasets, query, retrieve_config.top_k,
|
||||
retrieve_config.score_threshold,
|
||||
retrieve_config.reranking_model.get('reranking_provider_name'),
|
||||
retrieve_config.reranking_model.get('reranking_model_name'),
|
||||
message_id,
|
||||
)
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
@@ -189,16 +203,18 @@ class DatasetRetrieval:
|
||||
return str("\n".join(document_context_list))
|
||||
return ''
|
||||
|
||||
def single_retrieve(self, app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
):
|
||||
def single_retrieve(
|
||||
self, app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
description = dataset.description
|
||||
@@ -251,27 +267,35 @@ class DatasetRetrieval:
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
with measure_time() as timer:
|
||||
results = RetrievalService.retrieve(
|
||||
retrival_method=retrival_method, dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model
|
||||
)
|
||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||
|
||||
if results:
|
||||
self._on_retrival_end(results)
|
||||
self._on_retrival_end(results, message_id, timer)
|
||||
|
||||
return results
|
||||
return []
|
||||
|
||||
def multiple_retrieve(self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_provider_name: str,
|
||||
reranking_model_name: str):
|
||||
def multiple_retrieve(
|
||||
self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_provider_name: str,
|
||||
reranking_model_name: str,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
threads = []
|
||||
all_documents = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
@@ -297,15 +321,23 @@ class DatasetRetrieval:
|
||||
)
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(query, all_documents,
|
||||
score_threshold,
|
||||
top_k)
|
||||
|
||||
with measure_time() as timer:
|
||||
all_documents = rerank_runner.run(
|
||||
query, all_documents,
|
||||
score_threshold,
|
||||
top_k
|
||||
)
|
||||
self._on_query(query, dataset_ids, app_id, user_from, user_id)
|
||||
|
||||
if all_documents:
|
||||
self._on_retrival_end(all_documents)
|
||||
self._on_retrival_end(all_documents, message_id, timer)
|
||||
|
||||
return all_documents
|
||||
|
||||
def _on_retrival_end(self, documents: list[Document]) -> None:
|
||||
def _on_retrival_end(
|
||||
self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None
|
||||
) -> None:
|
||||
"""Handle retrival end."""
|
||||
for document in documents:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
@@ -324,6 +356,18 @@ class DatasetRetrieval:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
message_id=message_id,
|
||||
documents=documents,
|
||||
timer=timer
|
||||
)
|
||||
)
|
||||
|
||||
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:
|
||||
"""
|
||||
Handle query.
|
||||
|
||||
@@ -31,9 +31,10 @@ class WorkflowTool(Tool):
|
||||
:return: the tool provider type
|
||||
"""
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke the tool
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from mimetypes import guess_type
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from yarl import URL
|
||||
|
||||
@@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
|
||||
from core.tools.errors import (
|
||||
ToolEngineInvokeError,
|
||||
@@ -32,10 +33,12 @@ class ToolEngine:
|
||||
Tool runtime engine take care of the tool executions.
|
||||
"""
|
||||
@staticmethod
|
||||
def agent_invoke(tool: Tool, tool_parameters: Union[str, dict],
|
||||
user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
|
||||
agent_tool_callback: DifyAgentCallbackHandler) \
|
||||
-> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
|
||||
def agent_invoke(
|
||||
tool: Tool, tool_parameters: Union[str, dict],
|
||||
user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
|
||||
agent_tool_callback: DifyAgentCallbackHandler,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
|
||||
"""
|
||||
Agent invokes the tool with the given arguments.
|
||||
"""
|
||||
@@ -83,9 +86,11 @@ class ToolEngine:
|
||||
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_end(
|
||||
tool_name=tool.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=plain_text
|
||||
tool_name=tool.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=plain_text,
|
||||
message_id=message.id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
# transform tool invoke message to get LLM friendly message
|
||||
@@ -121,8 +126,8 @@ class ToolEngine:
|
||||
def workflow_invoke(tool: Tool, tool_parameters: dict,
|
||||
user_id: str, workflow_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int) \
|
||||
-> list[ToolInvokeMessage]:
|
||||
workflow_call_depth: int,
|
||||
) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
"""
|
||||
@@ -140,9 +145,9 @@ class ToolEngine:
|
||||
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_tool_end(
|
||||
tool_name=tool.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=response
|
||||
tool_name=tool.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=response,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -66,44 +66,43 @@ class ParameterExtractorNode(LLMNode):
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
|
||||
node_data = cast(ParameterExtractorNodeData, self.node_data)
|
||||
query = variable_pool.get_variable_value(node_data.query)
|
||||
if not query:
|
||||
raise ValueError("Query not found")
|
||||
|
||||
inputs={
|
||||
raise ValueError("Input variable content not found or is empty")
|
||||
|
||||
inputs = {
|
||||
'query': query,
|
||||
'parameters': jsonable_encoder(node_data.parameters),
|
||||
'instruction': jsonable_encoder(node_data.instruction),
|
||||
}
|
||||
|
||||
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise ValueError("Model is not a Large Language Model")
|
||||
|
||||
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
|
||||
|
||||
if set(model_schema.features or []) & set([ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]) \
|
||||
and node_data.reasoning_mode == 'function_call':
|
||||
and node_data.reasoning_mode == 'function_call':
|
||||
# use function call
|
||||
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
|
||||
node_data, query, variable_pool, model_config, memory
|
||||
)
|
||||
else:
|
||||
# use prompt engineering
|
||||
prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config, memory)
|
||||
prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config,
|
||||
memory)
|
||||
prompt_message_tools = []
|
||||
|
||||
process_data = {
|
||||
@@ -202,7 +201,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
# handle invoke result
|
||||
if not isinstance(invoke_result, LLMResult):
|
||||
raise ValueError(f"Invalid invoke result: {invoke_result}")
|
||||
|
||||
|
||||
text = invoke_result.message.content
|
||||
usage = invoke_result.usage
|
||||
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
||||
@@ -212,21 +211,23 @@ class ParameterExtractorNode(LLMNode):
|
||||
|
||||
return text, usage, tool_call
|
||||
|
||||
def _generate_function_call_prompt(self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
def _generate_function_call_prompt(self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
"""
|
||||
Generate function call prompt.
|
||||
"""
|
||||
query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps(node_data.get_parameter_json_schema()))
|
||||
query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps(
|
||||
node_data.get_parameter_json_schema()))
|
||||
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory, rest_token)
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory,
|
||||
rest_token)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
@@ -259,8 +260,8 @@ class ParameterExtractorNode(LLMNode):
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=example['assistant']['function_call']['name'],
|
||||
arguments=json.dumps(example['assistant']['function_call']['parameters']
|
||||
)
|
||||
))
|
||||
)
|
||||
))
|
||||
]
|
||||
),
|
||||
ToolPromptMessage(
|
||||
@@ -273,8 +274,8 @@ class ParameterExtractorNode(LLMNode):
|
||||
])
|
||||
|
||||
prompt_messages = prompt_messages[:last_user_message_idx] + \
|
||||
example_messages + prompt_messages[last_user_message_idx:]
|
||||
|
||||
example_messages + prompt_messages[last_user_message_idx:]
|
||||
|
||||
# generate tool
|
||||
tool = PromptMessageTool(
|
||||
name=FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
@@ -284,13 +285,13 @@ class ParameterExtractorNode(LLMNode):
|
||||
|
||||
return prompt_messages, [tool]
|
||||
|
||||
def _generate_prompt_engineering_prompt(self,
|
||||
data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
def _generate_prompt_engineering_prompt(self,
|
||||
data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate prompt engineering prompt.
|
||||
"""
|
||||
@@ -308,18 +309,19 @@ class ParameterExtractorNode(LLMNode):
|
||||
raise ValueError(f"Invalid model mode: {model_mode}")
|
||||
|
||||
def _generate_prompt_engineering_completion_prompt(self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate completion prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory, rest_token)
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory,
|
||||
rest_token)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={
|
||||
@@ -336,23 +338,23 @@ class ParameterExtractorNode(LLMNode):
|
||||
return prompt_messages
|
||||
|
||||
def _generate_prompt_engineering_chat_prompt(self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate chat prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data,
|
||||
node_data,
|
||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
|
||||
structure=json.dumps(node_data.get_parameter_json_schema()),
|
||||
text=query
|
||||
),
|
||||
),
|
||||
variable_pool, memory, rest_token
|
||||
)
|
||||
|
||||
@@ -387,7 +389,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
])
|
||||
|
||||
prompt_messages = prompt_messages[:last_user_message_idx] + \
|
||||
example_messages + prompt_messages[last_user_message_idx:]
|
||||
example_messages + prompt_messages[last_user_message_idx:]
|
||||
|
||||
return prompt_messages
|
||||
|
||||
@@ -397,23 +399,23 @@ class ParameterExtractorNode(LLMNode):
|
||||
"""
|
||||
if len(data.parameters) != len(result):
|
||||
raise ValueError("Invalid number of parameters")
|
||||
|
||||
|
||||
for parameter in data.parameters:
|
||||
if parameter.required and parameter.name not in result:
|
||||
raise ValueError(f"Parameter {parameter.name} is required")
|
||||
|
||||
|
||||
if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options:
|
||||
raise ValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
|
||||
|
||||
if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float):
|
||||
raise ValueError(f"Invalid `number` value for parameter {parameter.name}")
|
||||
|
||||
|
||||
if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool):
|
||||
raise ValueError(f"Invalid `bool` value for parameter {parameter.name}")
|
||||
|
||||
|
||||
if parameter.type == 'string' and not isinstance(result.get(parameter.name), str):
|
||||
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
|
||||
|
||||
|
||||
if parameter.type.startswith('array'):
|
||||
if not isinstance(result.get(parameter.name), list):
|
||||
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
|
||||
@@ -499,6 +501,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
"""
|
||||
Extract complete json response.
|
||||
"""
|
||||
|
||||
def extract_json(text):
|
||||
"""
|
||||
From a given JSON started from '{' or '[' extract the complete JSON object.
|
||||
@@ -515,11 +518,11 @@ class ParameterExtractorNode(LLMNode):
|
||||
if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['):
|
||||
stack.pop()
|
||||
if not stack:
|
||||
return text[:i+1]
|
||||
return text[:i + 1]
|
||||
else:
|
||||
return text[:i]
|
||||
return None
|
||||
|
||||
|
||||
# extract json from the text
|
||||
for idx in range(len(result)):
|
||||
if result[idx] == '{' or result[idx] == '[':
|
||||
@@ -536,9 +539,9 @@ class ParameterExtractorNode(LLMNode):
|
||||
"""
|
||||
if not tool_call or not tool_call.function.arguments:
|
||||
return None
|
||||
|
||||
|
||||
return json.loads(tool_call.function.arguments)
|
||||
|
||||
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
|
||||
"""
|
||||
Generate default result.
|
||||
@@ -551,7 +554,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
result[parameter.name] = False
|
||||
elif parameter.type in ['string', 'select']:
|
||||
result[parameter.name] = ''
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
|
||||
@@ -562,13 +565,13 @@ class ParameterExtractorNode(LLMNode):
|
||||
inputs = {}
|
||||
for selector in variable_template_parser.extract_variable_selectors():
|
||||
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
|
||||
|
||||
|
||||
return variable_template_parser.format(inputs)
|
||||
|
||||
def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000) \
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000) \
|
||||
-> list[ChatModelMessage]:
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
input_text = query
|
||||
@@ -590,12 +593,12 @@ class ParameterExtractorNode(LLMNode):
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
else:
|
||||
raise ValueError(f"Model mode {model_mode} not support.")
|
||||
|
||||
|
||||
def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000) \
|
||||
-> list[ChatModelMessage]:
|
||||
-> list[ChatModelMessage]:
|
||||
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
input_text = query
|
||||
@@ -620,8 +623,8 @@ class ParameterExtractorNode(LLMNode):
|
||||
text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str,
|
||||
text=input_text,
|
||||
instruction=instruction)
|
||||
.replace('{γγγ', '')
|
||||
.replace('}γγγ', '')
|
||||
.replace('{γγγ', '')
|
||||
.replace('}γγγ', '')
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model mode {model_mode} not support.")
|
||||
@@ -635,7 +638,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise ValueError("Model is not a Large Language Model")
|
||||
|
||||
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
|
||||
if not model_schema:
|
||||
@@ -667,7 +670,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
prompt_messages
|
||||
) + 1000 # add 1000 to ensure tool call messages
|
||||
) + 1000 # add 1000 to ensure tool call messages
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
@@ -680,8 +683,9 @@ class ParameterExtractorNode(LLMNode):
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
|
||||
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
|
||||
ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config.
|
||||
"""
|
||||
@@ -689,9 +693,10 @@ class ParameterExtractorNode(LLMNode):
|
||||
self._model_instance, self._model_config = super()._fetch_model_config(node_data_model)
|
||||
|
||||
return self._model_instance, self._model_config
|
||||
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[
|
||||
str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
@@ -708,4 +713,4 @@ class ParameterExtractorNode(LLMNode):
|
||||
for selector in variable_template_parser.extract_variable_selectors():
|
||||
variable_mapping[selector.variable] = selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
return variable_mapping
|
||||
|
||||
Reference in New Issue
Block a user