update sql in batch (#24801)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
from models.model import Account, Tenant
|
||||
@@ -468,7 +469,7 @@ class TestModelLoadBalancingService:
|
||||
assert load_balancing_config.id is not None
|
||||
|
||||
# Verify inherit config was created in database
|
||||
inherit_configs = (
|
||||
db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all()
|
||||
)
|
||||
inherit_configs = db.session.scalars(
|
||||
select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__")
|
||||
).all()
|
||||
assert len(inherit_configs) == 1
|
||||
|
||||
@@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
@@ -954,7 +955,9 @@ class TestTagService:
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Verify only one binding exists
|
||||
bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all()
|
||||
bindings = db.session.scalars(
|
||||
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
|
||||
).all()
|
||||
assert len(bindings) == 1
|
||||
|
||||
def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
@@ -1064,7 +1067,9 @@ class TestTagService:
|
||||
# No error should be raised, and database state should remain unchanged
|
||||
from extensions.ext_database import db
|
||||
|
||||
bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all()
|
||||
bindings = db.session.scalars(
|
||||
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
|
||||
).all()
|
||||
assert len(bindings) == 0
|
||||
|
||||
def test_check_target_exists_knowledge_success(
|
||||
|
||||
@@ -2,6 +2,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.account import Account
|
||||
@@ -354,16 +355,14 @@ class TestWebConversationService:
|
||||
# Verify only one pinned conversation record exists
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversations = (
|
||||
db.session.query(PinnedConversation)
|
||||
.where(
|
||||
pinned_conversations = db.session.scalars(
|
||||
select(PinnedConversation).where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
PinnedConversation.created_by_role == "account",
|
||||
PinnedConversation.created_by == account.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
assert len(pinned_conversations) == 1
|
||||
|
||||
|
||||
@@ -28,18 +28,20 @@ class TestApiKeyAuthService:
|
||||
mock_binding.provider = self.provider
|
||||
mock_binding.disabled = False
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [mock_binding]
|
||||
mock_session.scalars.return_value.all.return_value = [mock_binding]
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].tenant_id == self.tenant_id
|
||||
mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
|
||||
assert mock_session.scalars.call_count == 1
|
||||
select_arg = mock_session.scalars.call_args[0][0]
|
||||
assert "data_source_api_key_auth_binding" in str(select_arg).lower()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_empty(self, mock_session):
|
||||
"""Test get provider auth list - empty result"""
|
||||
mock_session.query.return_value.where.return_value.all.return_value = []
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
@@ -48,13 +50,15 @@ class TestApiKeyAuthService:
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_filters_disabled(self, mock_session):
|
||||
"""Test get provider auth list - filters disabled items"""
|
||||
mock_session.query.return_value.where.return_value.all.return_value = []
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
# Verify where conditions include disabled.is_(False)
|
||||
where_call = mock_session.query.return_value.where.call_args[0]
|
||||
assert len(where_call) == 2 # tenant_id and disabled filter conditions
|
||||
select_stmt = mock_session.scalars.call_args[0][0]
|
||||
where_clauses = list(getattr(select_stmt, "_where_criteria", []) or [])
|
||||
# Ensure both tenant filter and disabled filter exist
|
||||
where_strs = [str(c).lower() for c in where_clauses]
|
||||
assert any("tenant_id" in s for s in where_strs)
|
||||
assert any("disabled" in s for s in where_strs)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
|
||||
@@ -63,10 +63,10 @@ class TestAuthIntegration:
|
||||
tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
|
||||
tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding]
|
||||
mock_session.scalars.return_value.all.return_value = [tenant1_binding]
|
||||
result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding]
|
||||
mock_session.scalars.return_value.all.return_value = [tenant2_binding]
|
||||
result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)
|
||||
|
||||
assert len(result1) == 1
|
||||
|
||||
Reference in New Issue
Block a user