update sql in batch (#24801)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Asuka Minato
2025-09-10 14:00:17 +09:00
committed by GitHub
parent b51c724a94
commit cbc0e639e4
49 changed files with 281 additions and 277 deletions

View File

@@ -32,11 +32,16 @@ class TokenBufferMemory:
self.model_instance = model_instance
def _build_prompt_message_with_files(
self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
self,
message_files: Sequence[MessageFile],
text_content: str,
message: Message,
app_record,
is_user_message: bool,
) -> PromptMessage:
"""
Build prompt message with files.
:param message_files: list of MessageFile objects
:param message_files: Sequence of MessageFile objects
:param text_content: text content of the message
:param message: Message object
:param app_record: app record
@@ -128,14 +133,12 @@ class TokenBufferMemory:
prompt_messages: list[PromptMessage] = []
for message in messages:
# Process user message with files
user_files = (
db.session.query(MessageFile)
.where(
user_files = db.session.scalars(
select(MessageFile).where(
MessageFile.message_id == message.id,
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
)
.all()
)
).all()
if user_files:
user_prompt_message = self._build_prompt_message_with_files(
@@ -150,11 +153,9 @@ class TokenBufferMemory:
prompt_messages.append(UserPromptMessage(content=message.query))
# Process assistant message with files
assistant_files = (
db.session.query(MessageFile)
.where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
.all()
)
assistant_files = db.session.scalars(
select(MessageFile).where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
).all()
if assistant_files:
assistant_prompt_message = self._build_prompt_message_with_files(

View File

@@ -15,6 +15,7 @@ from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
from sqlalchemy import select
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
@@ -699,8 +700,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def _get_workflow_nodes(self, workflow_run_id: str):
"""Helper method to get workflow nodes"""
workflow_nodes = (
db.session.query(
workflow_nodes = db.session.scalars(
select(
WorkflowNodeExecutionModel.id,
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
@@ -713,10 +714,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
WorkflowNodeExecutionModel.elapsed_time,
WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata,
)
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.all()
)
).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
).all()
return workflow_nodes
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:

View File

@@ -1,5 +1,6 @@
import time
import uuid
from collections.abc import Sequence
import requests
from requests.auth import HTTPDigestAuth
@@ -139,7 +140,7 @@ class TidbService:
@staticmethod
def batch_update_tidb_serverless_cluster_status(
tidb_serverless_list: list[TidbAuthBinding],
tidb_serverless_list: Sequence[TidbAuthBinding],
project_id: str,
api_url: str,
iam_url: str,

View File

@@ -1,4 +1,5 @@
from pydantic import Field
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
from core.tools.__base.tool_provider import ToolProviderController
@@ -176,11 +177,11 @@ class ApiToolProviderController(ToolProviderController):
tools: list[ApiTool] = []
# get tenant api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.all()
)
db_providers = db.session.scalars(
select(ApiToolProvider).where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name
)
).all()
if db_providers and len(db_providers) != 0:
for db_provider in db_providers:

View File

@@ -87,9 +87,7 @@ class ToolLabelManager:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
labels: list[ToolLabelBinding] = (
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all()
)
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}

View File

@@ -667,9 +667,9 @@ class ToolManager:
# get db api providers
if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all()
)
db_api_providers = db.session.scalars(
select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)
).all()
api_provider_controllers: list[dict[str, Any]] = [
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
@@ -690,9 +690,9 @@ class ToolManager:
if "workflow" in filters:
# get workflow providers
workflow_providers: list[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
)
workflow_providers = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()
workflow_provider_controllers: list[WorkflowToolProviderController] = []
for workflow_provider in workflow_providers: