Fix bad use of channels last kernel in sync batch norm backward (#64100)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/64039

There are two distinct problems here.
1. If `grad_output` is channels last but not input, then input would be read as-if it were channels last. So reading the wrong values.
2. `use_channels_last_kernels` doesn't guarunte that `suggest_memory_format` will actually return channels last, so use `empty_like` instead so the strides always match.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64100

Reviewed By: mruberry

Differential Revision: D30622127

Pulled By: ngimel

fbshipit-source-id: e28cc57215596817f1432fcdd6c49d69acfedcf2
This commit is contained in:
Peter Bell 2021-08-30 12:14:09 -07:00 committed by Facebook GitHub Bot
parent ac99d63f83
commit 5b0dfd0f8a
3 changed files with 49 additions and 3 deletions

View file

@ -648,7 +648,9 @@ Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, c
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
if (at::cuda::detail::canUse32BitIndexMath(self) && batch_norm_use_channels_last_kernels(self)){
if (at::cuda::detail::canUse32BitIndexMath(self) &&
batch_norm_use_channels_last_kernels(self) &&
batch_norm_use_channels_last_kernels(input)) {
return batch_norm_backward_elemt_channels_last_cuda_template(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
}

View file

@ -1649,7 +1649,8 @@ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
const auto stride = input.sizes()[1];
const auto reduction_size = input.numel() / stride;
at::Tensor grad_input = at::empty_like(input, input.suggest_memory_format());
// Input is guarunteed to be channels-last compatible
at::Tensor grad_input = at::empty_like(input);
dim3 block;
dim3 grid;
@ -1716,7 +1717,8 @@ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
const auto reduction_size = input.numel() / stride;
auto norm_fct = 1.0 / reduction_size;
at::Tensor grad_input = at::empty_like(input, input.suggest_memory_format());
// Input is guarunteed to be channels-last compatible
at::Tensor grad_input = at::empty_like(input);
dim3 block;
dim3 grid;

View file

@ -11192,6 +11192,48 @@ class TestNN(NNTestCase):
self.assertEqual(layer.state_dict()[key].device, converted_layer.state_dict()[key].device)
self.assertEqual(layer.state_dict()[key], converted_layer.state_dict()[key])
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_sync_batchnorm_backward_elemt(self):
device = 'cuda'
saved_input = torch.rand(2, 3, 2, 1, device=device)
grad_output = torch.rand(2, 3, 2, 1, device=device)
mean = torch.rand(3, device=device)
invstd = torch.rand(3, device=device)
weight = torch.rand(3, device=device)
sum_dy = torch.rand(3, device=device)
sum_dy_xmu = torch.rand(3, device=device)
count_tensor = torch.tensor([5, 5, 5], dtype=torch.int32, device=device)
gI_contiguous = torch.batch_norm_backward_elemt(
grad_output,
saved_input,
mean,
invstd,
weight,
sum_dy,
sum_dy_xmu,
count_tensor
)
# Test batch_norm_backward_elemt gives the same answer for all
# combinations of contiguous as channels_last input
for a, b in [
(torch.channels_last, torch.contiguous_format),
(torch.contiguous_format, torch.channels_last),
(torch.channels_last, torch.channels_last),
]:
gI_actual = torch.batch_norm_backward_elemt(
grad_output.contiguous(memory_format=a),
saved_input.contiguous(memory_format=b),
mean,
invstd,
weight,
sum_dy,
sum_dy_xmu,
count_tensor
)
self.assertEqual(gI_actual, gI_contiguous)
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_sync_batchnorm_accuracy_cuda(self):
# The target of this test is to test the functionality and accuracy of