orm filter -> where (#22801)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Asuka Minato
2025-07-24 01:57:45 +09:00
committed by GitHub
parent e64e7563f6
commit ef51678c73
161 changed files with 828 additions and 857 deletions

View File

@@ -28,7 +28,7 @@ class TestApiKeyAuthService:
mock_binding.provider = self.provider
mock_binding.disabled = False
mock_session.query.return_value.filter.return_value.all.return_value = [mock_binding]
mock_session.query.return_value.where.return_value.all.return_value = [mock_binding]
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
@@ -39,7 +39,7 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_empty(self, mock_session):
"""Test get provider auth list - empty result"""
mock_session.query.return_value.filter.return_value.all.return_value = []
mock_session.query.return_value.where.return_value.all.return_value = []
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
@@ -48,13 +48,13 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_filters_disabled(self, mock_session):
"""Test get provider auth list - filters disabled items"""
mock_session.query.return_value.filter.return_value.all.return_value = []
mock_session.query.return_value.where.return_value.all.return_value = []
ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
# Verify filter conditions include disabled.is_(False)
filter_call = mock_session.query.return_value.filter.call_args[0]
assert len(filter_call) == 2 # tenant_id and disabled filter conditions
# Verify where conditions include disabled.is_(False)
where_call = mock_session.query.return_value.where.call_args[0]
assert len(where_call) == 2 # tenant_id and disabled filter conditions
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@@ -138,7 +138,8 @@ class TestApiKeyAuthService:
# Mock database query result
mock_binding = Mock()
mock_binding.credentials = json.dumps(self.mock_credentials)
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
@@ -148,7 +149,7 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_not_found(self, mock_session):
"""Test get auth credentials - not found"""
mock_session.query.return_value.filter.return_value.first.return_value = None
mock_session.query.return_value.where.return_value.first.return_value = None
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
@@ -157,13 +158,13 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_filters_correctly(self, mock_session):
"""Test get auth credentials - applies correct filters"""
mock_session.query.return_value.filter.return_value.first.return_value = None
mock_session.query.return_value.where.return_value.first.return_value = None
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
# Verify filter conditions are correct
filter_call = mock_session.query.return_value.filter.call_args[0]
assert len(filter_call) == 4 # tenant_id, category, provider, disabled
# Verify where conditions are correct
where_call = mock_session.query.return_value.where.call_args[0]
assert len(where_call) == 4 # tenant_id, category, provider, disabled
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_json_parsing(self, mock_session):
@@ -173,7 +174,7 @@ class TestApiKeyAuthService:
mock_binding = Mock()
mock_binding.credentials = json.dumps(special_credentials, ensure_ascii=False)
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
@@ -185,7 +186,7 @@ class TestApiKeyAuthService:
"""Test delete provider auth - success scenario"""
# Mock database query result
mock_binding = Mock()
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
@@ -196,7 +197,7 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.db.session")
def test_delete_provider_auth_not_found(self, mock_session):
"""Test delete provider auth - not found"""
mock_session.query.return_value.filter.return_value.first.return_value = None
mock_session.query.return_value.where.return_value.first.return_value = None
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
@@ -207,13 +208,13 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.db.session")
def test_delete_provider_auth_filters_by_tenant(self, mock_session):
"""Test delete provider auth - filters by tenant"""
mock_session.query.return_value.filter.return_value.first.return_value = None
mock_session.query.return_value.where.return_value.first.return_value = None
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
# Verify filter conditions include tenant_id and binding_id
filter_call = mock_session.query.return_value.filter.call_args[0]
assert len(filter_call) == 2
# Verify where conditions include tenant_id and binding_id
where_call = mock_session.query.return_value.where.call_args[0]
assert len(where_call) == 2
def test_validate_api_key_auth_args_success(self):
"""Test API key auth args validation - success scenario"""
@@ -336,7 +337,7 @@ class TestApiKeyAuthService:
# Mock database returning invalid JSON
mock_binding = Mock()
mock_binding.credentials = "invalid json content"
mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
with pytest.raises(json.JSONDecodeError):
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)

View File

@@ -63,10 +63,10 @@ class TestAuthIntegration:
tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)
mock_session.query.return_value.filter.return_value.all.return_value = [tenant1_binding]
mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding]
result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)
mock_session.query.return_value.filter.return_value.all.return_value = [tenant2_binding]
mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding]
result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)
assert len(result1) == 1
@@ -77,7 +77,7 @@ class TestAuthIntegration:
@patch("services.auth.api_key_auth_service.db.session")
def test_cross_tenant_access_prevention(self, mock_session):
"""Test prevention of cross-tenant credential access"""
mock_session.query.return_value.filter.return_value.first.return_value = None
mock_session.query.return_value.where.return_value.first.return_value = None
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL)