enable BFloat16 mkldnn_convolution on both contiguous and channels last memory format (#55864)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55864

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D27941367

Pulled By: VitalyFedyunin

fbshipit-source-id: c6bcb73c41652cc0aca11c1d1e0697a8a2fa43ad
(cherry picked from commit 3fc0b992a7dccbc31042dc35afec9ae3dc59a05a)
This commit is contained in:
mingfeima 2022-05-02 15:17:49 -07:00 committed by PyTorch MergeBot
parent d23619b030
commit dbfb9a823d
4 changed files with 67 additions and 50 deletions

View file

@ -19,6 +19,10 @@
#include <nnpack.h>
#endif
#if AT_MKLDNN_ENABLED()
#include <ATen/native/mkldnn/Utils.h>
#endif
constexpr int MIOPEN_DIM_MAX = 5;
namespace at { namespace native {
@ -228,6 +232,9 @@ auto ConvParams::use_mkldnn(const at::Tensor& input, const at::Tensor& weight) c
if (!at::globalContext().userEnabledMkldnn()) {
return false;
}
if (input.device().is_cpu() && input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) {
return true;
}
return (input.is_mkldnn()) || // input is mkldnn Tensor
(input.device().is_cpu() &&
input.scalar_type() == kFloat && // only on CPU Float Tensors

View file

@ -200,7 +200,8 @@ void slow_conv_dilated_all_cpu_template(
std::vector<int64_t> dims(dim);
std::iota(dims.begin(), dims.end(), 1);
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Long, input.scalar_type(), "slow_conv_dilated<>", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Long, at::ScalarType::BFloat16, input.scalar_type(), "slow_conv_dilated<>", [&] {
// For each elt in batch, do:
for (const auto elt : c10::irange(batchSize)) {
// Matrix multiply per output:
@ -275,12 +276,12 @@ void slow_conv_dilated_all_cpu_template(
/* m=*/columns.size(1),
/* n=*/nOutputPlane,
/* k=*/columns.size(0),
/* alpha=*/1,
/* alpha=*/static_cast<scalar_t>(1),
/* A=*/columns.data_ptr<scalar_t>(),
/* lda=*/columns.size(1),
/* B=*/weight.data_ptr<scalar_t>(),
/* ldb=*/columns.size(0),
/* beta=*/1,
/* beta=*/static_cast<scalar_t>(1),
/* C=*/output_n.data_ptr<scalar_t>(),
/* ldc=*/columns.size(1));
@ -319,12 +320,12 @@ void slow_conv_dilated_all_cpu_template(
/* m=*/columns.size(1),
/* n=*/columns.size(0),
/* k=*/nOutputPlane,
/* alpha=*/1,
/* alpha=*/static_cast<scalar_t>(1),
/* A=*/grad_output_n.data_ptr<scalar_t>(),
/* lda=*/columns.size(1),
/* B=*/weight.data_ptr<scalar_t>(),
/* ldb=*/columns.size(0),
/* beta=*/0,
/* beta=*/static_cast<scalar_t>(0),
/* C=*/columns.data_ptr<scalar_t>(),
/* ldc=*/columns.size(1));
// Unpack columns back into input:
@ -384,12 +385,12 @@ void slow_conv_dilated_all_cpu_template(
/* m=*/columns.size(0),
/* n=*/nOutputPlane,
/* k=*/columns.size(1),
/* alpha=*/scale,
/* alpha=*/static_cast<scalar_t>(scale),
/* A=*/columns.data_ptr<scalar_t>(),
/* lda=*/columns.size(1),
/* B=*/grad_output_n.data_ptr<scalar_t>(),
/* ldb=*/columns.size(1),
/* beta=*/1,
/* beta=*/static_cast<scalar_t>(1),
/* C=*/grad_weight.data_ptr<scalar_t>(),
/* ldc=*/columns.size(0));
}

View file

@ -241,47 +241,6 @@ class TestMkldnn(TestCase):
def test_conv3d(self):
self._test_conv_base(dim=3)
def test_conv2d_nhwc(self):
conv_module = torch.nn.Conv2d
input_shapes = (224, 224)
options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
for train, bias, dilation, groups in options:
N = torch.randint(3, 10, (1,)).item()
M = torch.randint(1, 3, (1,)).item() * groups
C = torch.randint(1, 3, (1,)).item() * groups
x_shape = (N, C) + input_shapes
x = torch.randn(x_shape, dtype=torch.float32)
# conv1: mkldnn conv2d in contiguous memory format (nchw)
# conv2: mkldnn conv2d in channels last memory format (nhwc)
conv1 = conv_module(in_channels=C,
out_channels=M,
kernel_size=3,
stride=2,
padding=1,
dilation=dilation,
bias=bias,
groups=groups).float()
conv2 = copy.deepcopy(conv1).to(memory_format=torch.channels_last)
x1 = x.clone()
x2 = x.clone().to(memory_format=torch.channels_last)
if train:
x1.requires_grad_()
x2.requires_grad_()
y1 = conv1(x1)
y2 = conv2(x2)
self.assertEqual(y1, y2)
if train:
y1.sum().backward()
y2.sum().backward()
self.assertTrue(x2.grad.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(conv1.weight.grad,
conv2.weight.grad,
atol=1e-3,
rtol=1e-3)
if bias:
self.assertEqual(conv1.bias.grad, conv2.bias.grad)
self.assertEqual(x1.grad, x2.grad)
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
def _test_conv_bf16_base(self, dim):
conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
@ -324,6 +283,56 @@ class TestMkldnn(TestCase):
def test_conv3d_bf16(self):
self._test_conv_bf16_base(dim=3)
def _test_conv2d_nhwc_base(self, dtype):
conv_module = torch.nn.Conv2d
input_shapes = (224, 224)
options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
for train, bias, dilation, groups in options:
N = torch.randint(3, 10, (1,)).item()
M = torch.randint(1, 3, (1,)).item() * groups
C = torch.randint(1, 3, (1,)).item() * groups
x_shape = (N, C) + input_shapes
x = torch.randn(x_shape, dtype=dtype)
# conv1: mkldnn conv2d in contiguous memory format (nchw)
# conv2: mkldnn conv2d in channels last memory format (nhwc)
conv1 = conv_module(in_channels=C,
out_channels=M,
kernel_size=3,
stride=2,
padding=1,
dilation=dilation,
bias=bias,
groups=groups).to(dtype=dtype)
conv2 = copy.deepcopy(conv1).to(memory_format=torch.channels_last)
x1 = x.clone()
x2 = x.clone().to(memory_format=torch.channels_last)
if train:
x1.requires_grad_()
x2.requires_grad_()
y1 = conv1(x1)
y2 = conv2(x2)
self.assertEqual(y1, y2)
if train:
y1.sum().backward()
y2.sum().backward()
self.assertTrue(x2.grad.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(conv1.weight.grad,
conv2.weight.grad,
atol=1e-3,
rtol=1e-3)
if bias:
self.assertEqual(conv1.bias.grad, conv2.bias.grad)
self.assertEqual(x1.grad, x2.grad)
def test_conv2d_nhwc(self):
self._test_conv2d_nhwc_base(dtype=torch.float32)
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
def test_conv2d_nhwc_bf16(self):
# when has_bf16_support() returns false, bf16 CPU conv will fall back to thnn impl
if has_bf16_support():
self._test_conv2d_nhwc_base(dtype=torch.bfloat16)
def test_conv2d_legacy_jit_model(self):
"""
MKLDNN integration used to serialize models with 5d weight for grouped

View file

@ -12232,7 +12232,7 @@ op_db: List[OpInfo] = [
OpInfo('nn.functional.conv1d',
aliases=('conv1d',),
aten_name='conv1d',
dtypes=floating_and_complex_types_and(torch.int64),
dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
sample_inputs_func=sample_inputs_conv1d,
@ -12257,7 +12257,7 @@ op_db: List[OpInfo] = [
OpInfo('nn.functional.conv2d',
aliases=('conv2d',),
aten_name='conv2d',
dtypes=floating_and_complex_types_and(torch.int64),
dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
sample_inputs_func=partial(sample_inputs_conv2d),