Add oneDNN support for Half LSTM on CPU (#132607)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132607
Approved by: https://github.com/jgong5, https://github.com/peterbell10
This commit is contained in:
CaoE 2024-08-27 20:15:41 -07:00 committed by PyTorch MergeBot
parent 41e36e2b46
commit 17e9c2d1e7
2 changed files with 42 additions and 13 deletions

View file

@ -15,6 +15,9 @@
#include <torch/custom_class.h>
#include <torch/library.h>
#include <ATen/Config.h>
#if AT_MKLDNN_ENABLED()
#include <ATen/native/mkldnn/Utils.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -97,7 +100,10 @@ bool use_mkldnn(const Tensor& input, TensorList params, TensorList hx) {
};
return input.options().backend() == at::Backend::CPU &&
is_cpu_backend(params) && is_cpu_backend(hx) &&
(input.scalar_type() == kFloat || input.scalar_type() == kBFloat16) &&
(input.scalar_type() == kFloat ||
(input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) ||
(input.scalar_type() == kHalf && !at::GradMode::is_enabled() &&
mkldnn_fp16_device_check())) &&
input.numel() != 0;
#endif
return false;

View file

@ -1469,24 +1469,30 @@ class TestMkldnn(TestCase):
params_list = list(params_dict.values())
return params_list
def _cast_dtype(self, input, bf16):
if bf16:
def _cast_dtype(self, input, dtype):
if dtype == torch.bfloat16:
input = input.to(torch.bfloat16)
elif dtype == torch.half:
input = input.to(torch.half)
return input
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
def test_lstm(self):
seed = 2023
torch.manual_seed(seed)
params_list = self._lstm_params_list()
for dtype in types:
bf16 = True if dtype == torch.bfloat16 and torch.ops.mkldnn._is_mkldnn_bf16_supported() else False
bf16 = dtype == torch.bfloat16
fp16 = dtype == torch.half
rtol = 1.3e-6
atol = 1e-5
if bf16:
rtol = 0.02
atol = 0.02
if fp16:
rtol = 1e-3
atol = 1e-3
for input_size, hidden_size, num_layers, bidirectional, bias, batch_first, dropout, batch_size, seq_len, training \
in itertools.product(*params_list):
num_directions = 2 if bidirectional else 1
@ -1496,7 +1502,9 @@ class TestMkldnn(TestCase):
input = torch.randn(seq_len, batch_size, input_size, dtype=torch.float32)
h = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32)
c = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32)
if fp16:
# TODO add traing support when oneDNN support lstm FP16 training
training = False
model = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional,
bias=bias, dropout=dropout, batch_first=batch_first).float()
model.train() if training else model.eval()
@ -1510,15 +1518,25 @@ class TestMkldnn(TestCase):
model1 = copy.deepcopy(model)
model2 = copy.deepcopy(model)
with torch.cpu.amp.autocast(enabled=bf16, dtype=torch.bfloat16), torch.no_grad() if not training else nullcontext():
with torch.no_grad() if not training else nullcontext():
with torch.backends.mkldnn.flags(enabled=False):
torch.manual_seed(seed)
output1, (hn1, cn1) = self._cast_dtype(model1, bf16)(self._cast_dtype(input1, bf16),
(self._cast_dtype(h1, bf16),
self._cast_dtype(c1, bf16)))
output1, (hn1, cn1) = self._cast_dtype(model1, dtype)(
self._cast_dtype(input1, dtype),
(
self._cast_dtype(h1, dtype),
self._cast_dtype(c1, dtype),
),
)
torch.manual_seed(seed)
output2, (hn2, cn2) = model2(input2, (h2, c2))
output2, (hn2, cn2) = self._cast_dtype(model2, dtype)(
self._cast_dtype(input2, dtype),
(
self._cast_dtype(h2, dtype),
self._cast_dtype(c2, dtype),
),
)
self.assertEqual(output1, output2, rtol=rtol, atol=atol)
self.assertEqual(hn1, hn2, rtol=rtol, atol=atol)
self.assertEqual(cn1, cn2, rtol=rtol, atol=atol)
@ -1533,8 +1551,13 @@ class TestMkldnn(TestCase):
self.assertEqual(input1.grad, input2.grad, rtol=rtol, atol=atol)
for name, para in model1.named_parameters():
self.assertEqual(para, self._cast_dtype(getattr(model2, name), bf16))
self.assertEqual(para.grad, self._cast_dtype(getattr(model2, name).grad, bf16), rtol=rtol, atol=atol)
self.assertEqual(para, getattr(model2, name))
self.assertEqual(
para.grad,
getattr(model2, name).grad,
rtol=rtol,
atol=atol,
)
with torch.backends.mkldnn.flags(enabled=False):
torch.manual_seed(seed)