diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index 0238b1b6828..1d4d1cc4bda 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -648,7 +648,9 @@ Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, c c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - if (at::cuda::detail::canUse32BitIndexMath(self) && batch_norm_use_channels_last_kernels(self)){ + if (at::cuda::detail::canUse32BitIndexMath(self) && + batch_norm_use_channels_last_kernels(self) && + batch_norm_use_channels_last_kernels(input)) { return batch_norm_backward_elemt_channels_last_cuda_template(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); } diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index af074f5d2c6..6daa2b08580 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -1649,7 +1649,8 @@ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( const auto stride = input.sizes()[1]; const auto reduction_size = input.numel() / stride; - at::Tensor grad_input = at::empty_like(input, input.suggest_memory_format()); + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); dim3 block; dim3 grid; @@ -1716,7 +1717,8 @@ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( const auto reduction_size = input.numel() / stride; auto norm_fct = 1.0 / reduction_size; - at::Tensor grad_input = at::empty_like(input, input.suggest_memory_format()); + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); dim3 block; dim3 grid; diff --git a/test/test_nn.py b/test/test_nn.py index bb4dd59be52..c9815dbf2ee 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -11192,6 +11192,48 @@ class TestNN(NNTestCase): self.assertEqual(layer.state_dict()[key].device, converted_layer.state_dict()[key].device) self.assertEqual(layer.state_dict()[key], converted_layer.state_dict()[key]) + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_sync_batchnorm_backward_elemt(self): + device = 'cuda' + saved_input = torch.rand(2, 3, 2, 1, device=device) + grad_output = torch.rand(2, 3, 2, 1, device=device) + mean = torch.rand(3, device=device) + invstd = torch.rand(3, device=device) + weight = torch.rand(3, device=device) + sum_dy = torch.rand(3, device=device) + sum_dy_xmu = torch.rand(3, device=device) + count_tensor = torch.tensor([5, 5, 5], dtype=torch.int32, device=device) + + gI_contiguous = torch.batch_norm_backward_elemt( + grad_output, + saved_input, + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + count_tensor + ) + + # Test batch_norm_backward_elemt gives the same answer for all + # combinations of contiguous as channels_last input + for a, b in [ + (torch.channels_last, torch.contiguous_format), + (torch.contiguous_format, torch.channels_last), + (torch.channels_last, torch.channels_last), + ]: + gI_actual = torch.batch_norm_backward_elemt( + grad_output.contiguous(memory_format=a), + saved_input.contiguous(memory_format=b), + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + count_tensor + ) + self.assertEqual(gI_actual, gI_contiguous) + @unittest.skipIf(not TEST_CUDA, "CUDA not available") def test_sync_batchnorm_accuracy_cuda(self): # The target of this test is to test the functionality and accuracy of