orm filter -> where (#22801)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -214,7 +214,7 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete(
|
||||
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@@ -44,7 +44,7 @@ class TestEncryptToken:
|
||||
"""Test successful token encryption"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
|
||||
result = encrypt_token("tenant-123", "test_token")
|
||||
@@ -55,7 +55,7 @@ class TestEncryptToken:
|
||||
@patch("models.engine.db.session.query")
|
||||
def test_tenant_not_found(self, mock_query):
|
||||
"""Test error when tenant doesn't exist"""
|
||||
mock_query.return_value.filter.return_value.first.return_value = None
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypt_token("invalid-tenant", "test_token")
|
||||
@@ -127,7 +127,7 @@ class TestEncryptDecryptIntegration:
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
|
||||
# Setup mock encryption/decryption
|
||||
original_token = "test_token_123"
|
||||
@@ -153,7 +153,7 @@ class TestSecurity:
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "tenant1_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_for_tenant1"
|
||||
|
||||
# Encrypt token for tenant1
|
||||
@@ -186,7 +186,7 @@ class TestSecurity:
|
||||
def test_encryption_randomness(self, mock_encrypt, mock_query):
|
||||
"""Ensure same plaintext produces different ciphertext"""
|
||||
mock_tenant = MagicMock(encrypt_public_key="key")
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
|
||||
# Different outputs for same input
|
||||
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
|
||||
@@ -211,7 +211,7 @@ class TestEdgeCases:
|
||||
"""Test encryption of empty token"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_empty"
|
||||
|
||||
result = encrypt_token("tenant-123", "")
|
||||
@@ -225,7 +225,7 @@ class TestEdgeCases:
|
||||
"""Test tokens containing special/unicode characters"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_special"
|
||||
|
||||
# Test various special characters
|
||||
@@ -248,7 +248,7 @@ class TestEdgeCases:
|
||||
"""Test behavior when token exceeds RSA encryption limits"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
|
||||
# RSA 2048-bit can only encrypt ~245 bytes
|
||||
# The actual limit depends on padding scheme
|
||||
|
||||
@@ -54,8 +54,7 @@ def mock_tool_file():
|
||||
mock.mimetype = "application/pdf"
|
||||
mock.original_url = "http://example.com/tool.pdf"
|
||||
mock.size = 2048
|
||||
with patch("factories.file_factory.db.session.query") as mock_query:
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=mock):
|
||||
yield mock
|
||||
|
||||
|
||||
@@ -153,8 +152,7 @@ def test_build_from_remote_url(mock_http_head):
|
||||
|
||||
def test_tool_file_not_found():
|
||||
"""Test ToolFile not found in database."""
|
||||
with patch("factories.file_factory.db.session.query") as mock_query:
|
||||
mock_query.return_value.filter.return_value.first.return_value = None
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=None):
|
||||
mapping = tool_file_mapping()
|
||||
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
|
||||
@@ -114,12 +114,12 @@ class TestEnumText:
|
||||
session.commit()
|
||||
|
||||
with Session(engine) as session:
|
||||
user = session.query(_User).filter(_User.id == admin_user_id).first()
|
||||
user = session.query(_User).where(_User.id == admin_user_id).first()
|
||||
assert user.user_type == _UserType.admin
|
||||
assert user.user_type_nullable is None
|
||||
|
||||
with Session(engine) as session:
|
||||
user = session.query(_User).filter(_User.id == normal_user_id).first()
|
||||
user = session.query(_User).where(_User.id == normal_user_id).first()
|
||||
assert user.user_type == _UserType.normal
|
||||
assert user.user_type_nullable == _UserType.normal
|
||||
|
||||
@@ -188,4 +188,4 @@ class TestEnumText:
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
with Session(engine) as session:
|
||||
_user = session.query(_User).filter(_User.id == 1).first()
|
||||
_user = session.query(_User).where(_User.id == 1).first()
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestApiKeyAuthService:
|
||||
mock_binding.provider = self.provider
|
||||
mock_binding.disabled = False
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [mock_binding]
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [mock_binding]
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
@@ -39,7 +39,7 @@ class TestApiKeyAuthService:
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_empty(self, mock_session):
|
||||
"""Test get provider auth list - empty result"""
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = []
|
||||
mock_session.query.return_value.where.return_value.all.return_value = []
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
@@ -48,13 +48,13 @@ class TestApiKeyAuthService:
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_filters_disabled(self, mock_session):
|
||||
"""Test get provider auth list - filters disabled items"""
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = []
|
||||
mock_session.query.return_value.where.return_value.all.return_value = []
|
||||
|
||||
ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
# Verify filter conditions include disabled.is_(False)
|
||||
filter_call = mock_session.query.return_value.filter.call_args[0]
|
||||
assert len(filter_call) == 2 # tenant_id and disabled filter conditions
|
||||
# Verify where conditions include disabled.is_(False)
|
||||
where_call = mock_session.query.return_value.where.call_args[0]
|
||||
assert len(where_call) == 2 # tenant_id and disabled filter conditions
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
@@ -138,7 +138,8 @@ class TestApiKeyAuthService:
|
||||
# Mock database query result
|
||||
mock_binding = Mock()
|
||||
mock_binding.credentials = json.dumps(self.mock_credentials)
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
@@ -148,7 +149,7 @@ class TestApiKeyAuthService:
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_auth_credentials_not_found(self, mock_session):
|
||||
"""Test get auth credentials - not found"""
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = None
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
@@ -157,13 +158,13 @@ class TestApiKeyAuthService:
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_auth_credentials_filters_correctly(self, mock_session):
|
||||
"""Test get auth credentials - applies correct filters"""
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = None
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
# Verify filter conditions are correct
|
||||
filter_call = mock_session.query.return_value.filter.call_args[0]
|
||||
assert len(filter_call) == 4 # tenant_id, category, provider, disabled
|
||||
# Verify where conditions are correct
|
||||
where_call = mock_session.query.return_value.where.call_args[0]
|
||||
assert len(where_call) == 4 # tenant_id, category, provider, disabled
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_auth_credentials_json_parsing(self, mock_session):
|
||||
@@ -173,7 +174,7 @@ class TestApiKeyAuthService:
|
||||
|
||||
mock_binding = Mock()
|
||||
mock_binding.credentials = json.dumps(special_credentials, ensure_ascii=False)
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
@@ -185,7 +186,7 @@ class TestApiKeyAuthService:
|
||||
"""Test delete provider auth - success scenario"""
|
||||
# Mock database query result
|
||||
mock_binding = Mock()
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
|
||||
|
||||
@@ -196,7 +197,7 @@ class TestApiKeyAuthService:
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_delete_provider_auth_not_found(self, mock_session):
|
||||
"""Test delete provider auth - not found"""
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = None
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
|
||||
|
||||
@@ -207,13 +208,13 @@ class TestApiKeyAuthService:
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_delete_provider_auth_filters_by_tenant(self, mock_session):
|
||||
"""Test delete provider auth - filters by tenant"""
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = None
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
|
||||
|
||||
# Verify filter conditions include tenant_id and binding_id
|
||||
filter_call = mock_session.query.return_value.filter.call_args[0]
|
||||
assert len(filter_call) == 2
|
||||
# Verify where conditions include tenant_id and binding_id
|
||||
where_call = mock_session.query.return_value.where.call_args[0]
|
||||
assert len(where_call) == 2
|
||||
|
||||
def test_validate_api_key_auth_args_success(self):
|
||||
"""Test API key auth args validation - success scenario"""
|
||||
@@ -336,7 +337,7 @@ class TestApiKeyAuthService:
|
||||
# Mock database returning invalid JSON
|
||||
mock_binding = Mock()
|
||||
mock_binding.credentials = "invalid json content"
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
@@ -63,10 +63,10 @@ class TestAuthIntegration:
|
||||
tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
|
||||
tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [tenant1_binding]
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding]
|
||||
result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [tenant2_binding]
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding]
|
||||
result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)
|
||||
|
||||
assert len(result1) == 1
|
||||
@@ -77,7 +77,7 @@ class TestAuthIntegration:
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_cross_tenant_access_prevention(self, mock_session):
|
||||
"""Test prevention of cross-tenant credential access"""
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = None
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL)
|
||||
|
||||
|
||||
@@ -708,9 +708,9 @@ class TestTenantService:
|
||||
with patch("services.account_service.db") as mock_db:
|
||||
# Mock the join query that returns the tenant_account_join
|
||||
mock_query = MagicMock()
|
||||
mock_filter = MagicMock()
|
||||
mock_filter.first.return_value = mock_tenant_join
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_where = MagicMock()
|
||||
mock_where.first.return_value = mock_tenant_join
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_db.session.query.return_value = mock_query
|
||||
|
||||
@@ -1381,10 +1381,10 @@ class TestRegisterService:
|
||||
|
||||
# Mock database queries - complex query mocking
|
||||
mock_query1 = MagicMock()
|
||||
mock_query1.filter.return_value.first.return_value = mock_tenant
|
||||
mock_query1.where.return_value.first.return_value = mock_tenant
|
||||
|
||||
mock_query2 = MagicMock()
|
||||
mock_query2.join.return_value.filter.return_value.first.return_value = (mock_account, "normal")
|
||||
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
|
||||
|
||||
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
|
||||
|
||||
@@ -1449,7 +1449,7 @@ class TestRegisterService:
|
||||
mock_query1.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
mock_query2 = MagicMock()
|
||||
mock_query2.join.return_value.filter.return_value.first.return_value = None # No account found
|
||||
mock_query2.join.return_value.where.return_value.first.return_value = None # No account found
|
||||
|
||||
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
|
||||
|
||||
@@ -1482,7 +1482,7 @@ class TestRegisterService:
|
||||
mock_query1.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
mock_query2 = MagicMock()
|
||||
mock_query2.join.return_value.filter.return_value.first.return_value = (mock_account, "normal")
|
||||
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
|
||||
|
||||
mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ def test_delete_workflow_success(workflow_setup):
|
||||
# Setup mocks
|
||||
|
||||
# Mock the tool provider query to return None (not published as a tool)
|
||||
workflow_setup["session"].query.return_value.filter.return_value.first.return_value = None
|
||||
workflow_setup["session"].query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
workflow_setup["session"].scalar = MagicMock(
|
||||
side_effect=[workflow_setup["workflow"], None]
|
||||
@@ -106,7 +106,7 @@ def test_delete_workflow_published_as_tool_error(workflow_setup):
|
||||
|
||||
# Mock the tool provider query
|
||||
mock_tool_provider = MagicMock(spec=WorkflowToolProvider)
|
||||
workflow_setup["session"].query.return_value.filter.return_value.first.return_value = mock_tool_provider
|
||||
workflow_setup["session"].query.return_value.where.return_value.first.return_value = mock_tool_provider
|
||||
|
||||
workflow_setup["session"].scalar = MagicMock(
|
||||
side_effect=[workflow_setup["workflow"], None]
|
||||
|
||||
Reference in New Issue
Block a user