From dbfb9a823d338312e1c441e2f88d20171d920c71 Mon Sep 17 00:00:00 2001 From: mingfeima Date: Mon, 2 May 2022 15:17:49 -0700 Subject: [PATCH] enable BFloat16 mkldnn_convolution on both contiguous and channels last memory format (#55864) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55864 Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D27941367 Pulled By: VitalyFedyunin fbshipit-source-id: c6bcb73c41652cc0aca11c1d1e0697a8a2fa43ad (cherry picked from commit 3fc0b992a7dccbc31042dc35afec9ae3dc59a05a) --- aten/src/ATen/native/Convolution.cpp | 7 ++ .../ATen/native/NaiveDilatedConvolution.cpp | 15 +-- test/test_mkldnn.py | 91 ++++++++++--------- .../_internal/common_methods_invocations.py | 4 +- 4 files changed, 67 insertions(+), 50 deletions(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index ea040e241d6..7c6cccea353 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -19,6 +19,10 @@ #include #endif +#if AT_MKLDNN_ENABLED() +#include +#endif + constexpr int MIOPEN_DIM_MAX = 5; namespace at { namespace native { @@ -228,6 +232,9 @@ auto ConvParams::use_mkldnn(const at::Tensor& input, const at::Tensor& weight) c if (!at::globalContext().userEnabledMkldnn()) { return false; } + if (input.device().is_cpu() && input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) { + return true; + } return (input.is_mkldnn()) || // input is mkldnn Tensor (input.device().is_cpu() && input.scalar_type() == kFloat && // only on CPU Float Tensors diff --git a/aten/src/ATen/native/NaiveDilatedConvolution.cpp b/aten/src/ATen/native/NaiveDilatedConvolution.cpp index 68eaa372b7e..e6b7b96ea71 100644 --- a/aten/src/ATen/native/NaiveDilatedConvolution.cpp +++ b/aten/src/ATen/native/NaiveDilatedConvolution.cpp @@ -200,7 +200,8 @@ void slow_conv_dilated_all_cpu_template( std::vector dims(dim); std::iota(dims.begin(), dims.end(), 1); - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Long, input.scalar_type(), "slow_conv_dilated<>", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Long, at::ScalarType::BFloat16, input.scalar_type(), "slow_conv_dilated<>", [&] { // For each elt in batch, do: for (const auto elt : c10::irange(batchSize)) { // Matrix multiply per output: @@ -275,12 +276,12 @@ void slow_conv_dilated_all_cpu_template( /* m=*/columns.size(1), /* n=*/nOutputPlane, /* k=*/columns.size(0), - /* alpha=*/1, + /* alpha=*/static_cast(1), /* A=*/columns.data_ptr(), /* lda=*/columns.size(1), /* B=*/weight.data_ptr(), /* ldb=*/columns.size(0), - /* beta=*/1, + /* beta=*/static_cast(1), /* C=*/output_n.data_ptr(), /* ldc=*/columns.size(1)); @@ -319,12 +320,12 @@ void slow_conv_dilated_all_cpu_template( /* m=*/columns.size(1), /* n=*/columns.size(0), /* k=*/nOutputPlane, - /* alpha=*/1, + /* alpha=*/static_cast(1), /* A=*/grad_output_n.data_ptr(), /* lda=*/columns.size(1), /* B=*/weight.data_ptr(), /* ldb=*/columns.size(0), - /* beta=*/0, + /* beta=*/static_cast(0), /* C=*/columns.data_ptr(), /* ldc=*/columns.size(1)); // Unpack columns back into input: @@ -384,12 +385,12 @@ void slow_conv_dilated_all_cpu_template( /* m=*/columns.size(0), /* n=*/nOutputPlane, /* k=*/columns.size(1), - /* alpha=*/scale, + /* alpha=*/static_cast(scale), /* A=*/columns.data_ptr(), /* lda=*/columns.size(1), /* B=*/grad_output_n.data_ptr(), /* ldb=*/columns.size(1), - /* beta=*/1, + /* beta=*/static_cast(1), /* C=*/grad_weight.data_ptr(), /* ldc=*/columns.size(0)); } diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index 529481654bf..0f466ec6007 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -241,47 +241,6 @@ class TestMkldnn(TestCase): def test_conv3d(self): self._test_conv_base(dim=3) - def test_conv2d_nhwc(self): - conv_module = torch.nn.Conv2d - input_shapes = (224, 224) - options = itertools.product([True, False], [True, False], [1, 2], [1, 4]) - for train, bias, dilation, groups in options: - N = torch.randint(3, 10, (1,)).item() - M = torch.randint(1, 3, (1,)).item() * groups - C = torch.randint(1, 3, (1,)).item() * groups - x_shape = (N, C) + input_shapes - x = torch.randn(x_shape, dtype=torch.float32) - # conv1: mkldnn conv2d in contiguous memory format (nchw) - # conv2: mkldnn conv2d in channels last memory format (nhwc) - conv1 = conv_module(in_channels=C, - out_channels=M, - kernel_size=3, - stride=2, - padding=1, - dilation=dilation, - bias=bias, - groups=groups).float() - conv2 = copy.deepcopy(conv1).to(memory_format=torch.channels_last) - x1 = x.clone() - x2 = x.clone().to(memory_format=torch.channels_last) - if train: - x1.requires_grad_() - x2.requires_grad_() - y1 = conv1(x1) - y2 = conv2(x2) - self.assertEqual(y1, y2) - if train: - y1.sum().backward() - y2.sum().backward() - self.assertTrue(x2.grad.is_contiguous(memory_format=torch.channels_last)) - self.assertEqual(conv1.weight.grad, - conv2.weight.grad, - atol=1e-3, - rtol=1e-3) - if bias: - self.assertEqual(conv1.bias.grad, conv2.bias.grad) - self.assertEqual(x1.grad, x2.grad) - @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") def _test_conv_bf16_base(self, dim): conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} @@ -324,6 +283,56 @@ class TestMkldnn(TestCase): def test_conv3d_bf16(self): self._test_conv_bf16_base(dim=3) + def _test_conv2d_nhwc_base(self, dtype): + conv_module = torch.nn.Conv2d + input_shapes = (224, 224) + options = itertools.product([True, False], [True, False], [1, 2], [1, 4]) + for train, bias, dilation, groups in options: + N = torch.randint(3, 10, (1,)).item() + M = torch.randint(1, 3, (1,)).item() * groups + C = torch.randint(1, 3, (1,)).item() * groups + x_shape = (N, C) + input_shapes + x = torch.randn(x_shape, dtype=dtype) + # conv1: mkldnn conv2d in contiguous memory format (nchw) + # conv2: mkldnn conv2d in channels last memory format (nhwc) + conv1 = conv_module(in_channels=C, + out_channels=M, + kernel_size=3, + stride=2, + padding=1, + dilation=dilation, + bias=bias, + groups=groups).to(dtype=dtype) + conv2 = copy.deepcopy(conv1).to(memory_format=torch.channels_last) + x1 = x.clone() + x2 = x.clone().to(memory_format=torch.channels_last) + if train: + x1.requires_grad_() + x2.requires_grad_() + y1 = conv1(x1) + y2 = conv2(x2) + self.assertEqual(y1, y2) + if train: + y1.sum().backward() + y2.sum().backward() + self.assertTrue(x2.grad.is_contiguous(memory_format=torch.channels_last)) + self.assertEqual(conv1.weight.grad, + conv2.weight.grad, + atol=1e-3, + rtol=1e-3) + if bias: + self.assertEqual(conv1.bias.grad, conv2.bias.grad) + self.assertEqual(x1.grad, x2.grad) + + def test_conv2d_nhwc(self): + self._test_conv2d_nhwc_base(dtype=torch.float32) + + @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") + def test_conv2d_nhwc_bf16(self): + # when has_bf16_support() returns false, bf16 CPU conv will fall back to thnn impl + if has_bf16_support(): + self._test_conv2d_nhwc_base(dtype=torch.bfloat16) + def test_conv2d_legacy_jit_model(self): """ MKLDNN integration used to serialize models with 5d weight for grouped diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b472b0c80ba..8b7c65907bf 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -12232,7 +12232,7 @@ op_db: List[OpInfo] = [ OpInfo('nn.functional.conv1d', aliases=('conv1d',), aten_name='conv1d', - dtypes=floating_and_complex_types_and(torch.int64), + dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []), sample_inputs_func=sample_inputs_conv1d, @@ -12257,7 +12257,7 @@ op_db: List[OpInfo] = [ OpInfo('nn.functional.conv2d', aliases=('conv2d',), aten_name='conv2d', - dtypes=floating_and_complex_types_and(torch.int64), + dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []), sample_inputs_func=partial(sample_inputs_conv2d),