diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index 7ba6b320ad7..160ef9d9a0c 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -708,7 +708,7 @@ Tensor _mkldnn_convolution_transpose( if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias); - ideep::convolution_transpose_forward::compute( + ideep::convolution_transpose_forward::compute_v3( x, w, b, @@ -719,9 +719,10 @@ Tensor _mkldnn_convolution_transpose( padding_r(padding, output_padding), dilation.vec(), groups, + use_channels_last, op_attr); } else { - ideep::convolution_transpose_forward::compute( + ideep::convolution_transpose_forward::compute_v3( x, w, output_sizes, @@ -731,6 +732,7 @@ Tensor _mkldnn_convolution_transpose( padding_r(padding, output_padding), dilation.vec(), groups, + use_channels_last, op_attr); } if (input.is_mkldnn()) { @@ -738,7 +740,6 @@ Tensor _mkldnn_convolution_transpose( } else if (!use_channels_last) { return mkldnn_to_dense(MKLDNNTensor(y, input.options())); } else { - TORCH_INTERNAL_ASSERT(y.get_desc().is_nhwc()); return output; } } @@ -923,65 +924,18 @@ Tensor mkldnn_convolution_transpose( IntArrayRef dilation, int64_t groups) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); - const Tensor& bias = *bias_maybe_owned; - - if (input.scalar_type() == ScalarType::BFloat16) { - TORCH_CHECK(mkldnn_bf16_device_check(), - "mkldnn_convolution_transpose: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"); - } - bool use_channels_last = mkldnn_conv_use_channels_last(input, weight); - auto memory_format = mkldnn_convolution_memory_format(input.ndimension(), use_channels_last); - - auto output_sizes = conv_input_size(input.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups); - auto output = at::empty({0}, input.options()); - - const ideep::tensor x = itensor_from_tensor(input); - ideep::tensor w = itensor_from_tensor(weight); - // mkldnn transposed convolution has weight in logical order of OIHW or OIDHW, - // while PyTorch has IOHW or IODHW, `._tranpose()` switches strides (no memory copy). - w.transpose_(0, 1); - - ideep::tensor y; - if (use_channels_last) { - output.resize_(output_sizes, memory_format); - y = itensor_from_tensor(output); - } - if (bias.defined()) { - const ideep::tensor b = itensor_from_tensor(bias); - ideep::convolution_transpose_forward::compute( - x, - w, - b, - output_sizes, - y, - stride.vec(), - padding.vec(), - padding_r(padding, output_padding), - dilation.vec(), - groups); - } else { - ideep::convolution_transpose_forward::compute( - x, - w, - output_sizes, - y, - stride.vec(), - padding.vec(), - padding_r(padding, output_padding), - dilation.vec(), - groups); - } - - if (input.is_mkldnn()) { - return MKLDNNTensor(y, input.options()); - } else if (!use_channels_last) { - return mkldnn_to_dense(MKLDNNTensor(y, input.options())); - } else { - return output; - } + return _mkldnn_convolution_transpose( + input, + weight, + bias_opt, + padding, + output_padding, + stride, + dilation, + groups, + use_channels_last + ); } Tensor mkldnn_convolution_transpose_backward_input( diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index 821eec3dcad..a0669967329 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -427,6 +427,16 @@ class TestMkldnn(TestCase): def test_conv_transpose3d(self): self._test_conv_transpose_base(dim=3) + def test_conv_transposed2d_ic1_nhwc(self): + x = torch.ones([1, 1, 2, 2]).contiguous(memory_format=torch.channels_last) + model = torch.nn.ConvTranspose2d(in_channels=1, out_channels=2, kernel_size=[5, 5]).eval() + model.weight.data = torch.ones([1, 2, 5, 5]).contiguous(memory_format=torch.channels_last) + with torch.no_grad(): + y = model(x) + with torch.backends.mkldnn.flags(enabled=False): + y_ref = model(x) + self.assertEqual(y, y_ref) + def test_conv2d_legacy_jit_model(self): """ MKLDNN integration used to serialize models with 5d weight for grouped