mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
enable BFloat16 mkldnn_convolution on both contiguous and channels last memory format (#55864)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55864 Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D27941367 Pulled By: VitalyFedyunin fbshipit-source-id: c6bcb73c41652cc0aca11c1d1e0697a8a2fa43ad (cherry picked from commit 3fc0b992a7dccbc31042dc35afec9ae3dc59a05a)
This commit is contained in:
parent
d23619b030
commit
dbfb9a823d
4 changed files with 67 additions and 50 deletions
|
|
@ -19,6 +19,10 @@
|
|||
#include <nnpack.h>
|
||||
#endif
|
||||
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
#include <ATen/native/mkldnn/Utils.h>
|
||||
#endif
|
||||
|
||||
constexpr int MIOPEN_DIM_MAX = 5;
|
||||
|
||||
namespace at { namespace native {
|
||||
|
|
@ -228,6 +232,9 @@ auto ConvParams::use_mkldnn(const at::Tensor& input, const at::Tensor& weight) c
|
|||
if (!at::globalContext().userEnabledMkldnn()) {
|
||||
return false;
|
||||
}
|
||||
if (input.device().is_cpu() && input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) {
|
||||
return true;
|
||||
}
|
||||
return (input.is_mkldnn()) || // input is mkldnn Tensor
|
||||
(input.device().is_cpu() &&
|
||||
input.scalar_type() == kFloat && // only on CPU Float Tensors
|
||||
|
|
|
|||
|
|
@ -200,7 +200,8 @@ void slow_conv_dilated_all_cpu_template(
|
|||
std::vector<int64_t> dims(dim);
|
||||
std::iota(dims.begin(), dims.end(), 1);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Long, input.scalar_type(), "slow_conv_dilated<>", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Long, at::ScalarType::BFloat16, input.scalar_type(), "slow_conv_dilated<>", [&] {
|
||||
// For each elt in batch, do:
|
||||
for (const auto elt : c10::irange(batchSize)) {
|
||||
// Matrix multiply per output:
|
||||
|
|
@ -275,12 +276,12 @@ void slow_conv_dilated_all_cpu_template(
|
|||
/* m=*/columns.size(1),
|
||||
/* n=*/nOutputPlane,
|
||||
/* k=*/columns.size(0),
|
||||
/* alpha=*/1,
|
||||
/* alpha=*/static_cast<scalar_t>(1),
|
||||
/* A=*/columns.data_ptr<scalar_t>(),
|
||||
/* lda=*/columns.size(1),
|
||||
/* B=*/weight.data_ptr<scalar_t>(),
|
||||
/* ldb=*/columns.size(0),
|
||||
/* beta=*/1,
|
||||
/* beta=*/static_cast<scalar_t>(1),
|
||||
/* C=*/output_n.data_ptr<scalar_t>(),
|
||||
/* ldc=*/columns.size(1));
|
||||
|
||||
|
|
@ -319,12 +320,12 @@ void slow_conv_dilated_all_cpu_template(
|
|||
/* m=*/columns.size(1),
|
||||
/* n=*/columns.size(0),
|
||||
/* k=*/nOutputPlane,
|
||||
/* alpha=*/1,
|
||||
/* alpha=*/static_cast<scalar_t>(1),
|
||||
/* A=*/grad_output_n.data_ptr<scalar_t>(),
|
||||
/* lda=*/columns.size(1),
|
||||
/* B=*/weight.data_ptr<scalar_t>(),
|
||||
/* ldb=*/columns.size(0),
|
||||
/* beta=*/0,
|
||||
/* beta=*/static_cast<scalar_t>(0),
|
||||
/* C=*/columns.data_ptr<scalar_t>(),
|
||||
/* ldc=*/columns.size(1));
|
||||
// Unpack columns back into input:
|
||||
|
|
@ -384,12 +385,12 @@ void slow_conv_dilated_all_cpu_template(
|
|||
/* m=*/columns.size(0),
|
||||
/* n=*/nOutputPlane,
|
||||
/* k=*/columns.size(1),
|
||||
/* alpha=*/scale,
|
||||
/* alpha=*/static_cast<scalar_t>(scale),
|
||||
/* A=*/columns.data_ptr<scalar_t>(),
|
||||
/* lda=*/columns.size(1),
|
||||
/* B=*/grad_output_n.data_ptr<scalar_t>(),
|
||||
/* ldb=*/columns.size(1),
|
||||
/* beta=*/1,
|
||||
/* beta=*/static_cast<scalar_t>(1),
|
||||
/* C=*/grad_weight.data_ptr<scalar_t>(),
|
||||
/* ldc=*/columns.size(0));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -241,47 +241,6 @@ class TestMkldnn(TestCase):
|
|||
def test_conv3d(self):
|
||||
self._test_conv_base(dim=3)
|
||||
|
||||
def test_conv2d_nhwc(self):
|
||||
conv_module = torch.nn.Conv2d
|
||||
input_shapes = (224, 224)
|
||||
options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
|
||||
for train, bias, dilation, groups in options:
|
||||
N = torch.randint(3, 10, (1,)).item()
|
||||
M = torch.randint(1, 3, (1,)).item() * groups
|
||||
C = torch.randint(1, 3, (1,)).item() * groups
|
||||
x_shape = (N, C) + input_shapes
|
||||
x = torch.randn(x_shape, dtype=torch.float32)
|
||||
# conv1: mkldnn conv2d in contiguous memory format (nchw)
|
||||
# conv2: mkldnn conv2d in channels last memory format (nhwc)
|
||||
conv1 = conv_module(in_channels=C,
|
||||
out_channels=M,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
groups=groups).float()
|
||||
conv2 = copy.deepcopy(conv1).to(memory_format=torch.channels_last)
|
||||
x1 = x.clone()
|
||||
x2 = x.clone().to(memory_format=torch.channels_last)
|
||||
if train:
|
||||
x1.requires_grad_()
|
||||
x2.requires_grad_()
|
||||
y1 = conv1(x1)
|
||||
y2 = conv2(x2)
|
||||
self.assertEqual(y1, y2)
|
||||
if train:
|
||||
y1.sum().backward()
|
||||
y2.sum().backward()
|
||||
self.assertTrue(x2.grad.is_contiguous(memory_format=torch.channels_last))
|
||||
self.assertEqual(conv1.weight.grad,
|
||||
conv2.weight.grad,
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
if bias:
|
||||
self.assertEqual(conv1.bias.grad, conv2.bias.grad)
|
||||
self.assertEqual(x1.grad, x2.grad)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
|
||||
def _test_conv_bf16_base(self, dim):
|
||||
conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
|
||||
|
|
@ -324,6 +283,56 @@ class TestMkldnn(TestCase):
|
|||
def test_conv3d_bf16(self):
|
||||
self._test_conv_bf16_base(dim=3)
|
||||
|
||||
def _test_conv2d_nhwc_base(self, dtype):
|
||||
conv_module = torch.nn.Conv2d
|
||||
input_shapes = (224, 224)
|
||||
options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
|
||||
for train, bias, dilation, groups in options:
|
||||
N = torch.randint(3, 10, (1,)).item()
|
||||
M = torch.randint(1, 3, (1,)).item() * groups
|
||||
C = torch.randint(1, 3, (1,)).item() * groups
|
||||
x_shape = (N, C) + input_shapes
|
||||
x = torch.randn(x_shape, dtype=dtype)
|
||||
# conv1: mkldnn conv2d in contiguous memory format (nchw)
|
||||
# conv2: mkldnn conv2d in channels last memory format (nhwc)
|
||||
conv1 = conv_module(in_channels=C,
|
||||
out_channels=M,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
groups=groups).to(dtype=dtype)
|
||||
conv2 = copy.deepcopy(conv1).to(memory_format=torch.channels_last)
|
||||
x1 = x.clone()
|
||||
x2 = x.clone().to(memory_format=torch.channels_last)
|
||||
if train:
|
||||
x1.requires_grad_()
|
||||
x2.requires_grad_()
|
||||
y1 = conv1(x1)
|
||||
y2 = conv2(x2)
|
||||
self.assertEqual(y1, y2)
|
||||
if train:
|
||||
y1.sum().backward()
|
||||
y2.sum().backward()
|
||||
self.assertTrue(x2.grad.is_contiguous(memory_format=torch.channels_last))
|
||||
self.assertEqual(conv1.weight.grad,
|
||||
conv2.weight.grad,
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
if bias:
|
||||
self.assertEqual(conv1.bias.grad, conv2.bias.grad)
|
||||
self.assertEqual(x1.grad, x2.grad)
|
||||
|
||||
def test_conv2d_nhwc(self):
|
||||
self._test_conv2d_nhwc_base(dtype=torch.float32)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
|
||||
def test_conv2d_nhwc_bf16(self):
|
||||
# when has_bf16_support() returns false, bf16 CPU conv will fall back to thnn impl
|
||||
if has_bf16_support():
|
||||
self._test_conv2d_nhwc_base(dtype=torch.bfloat16)
|
||||
|
||||
def test_conv2d_legacy_jit_model(self):
|
||||
"""
|
||||
MKLDNN integration used to serialize models with 5d weight for grouped
|
||||
|
|
|
|||
|
|
@ -12232,7 +12232,7 @@ op_db: List[OpInfo] = [
|
|||
OpInfo('nn.functional.conv1d',
|
||||
aliases=('conv1d',),
|
||||
aten_name='conv1d',
|
||||
dtypes=floating_and_complex_types_and(torch.int64),
|
||||
dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
|
||||
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
|
||||
sample_inputs_func=sample_inputs_conv1d,
|
||||
|
|
@ -12257,7 +12257,7 @@ op_db: List[OpInfo] = [
|
|||
OpInfo('nn.functional.conv2d',
|
||||
aliases=('conv2d',),
|
||||
aten_name='conv2d',
|
||||
dtypes=floating_and_complex_types_and(torch.int64),
|
||||
dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
|
||||
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
|
||||
sample_inputs_func=partial(sample_inputs_conv2d),
|
||||
|
|
|
|||
Loading…
Reference in a new issue