mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add oneDNN support for Half LSTM on CPU (#132607)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132607 Approved by: https://github.com/jgong5, https://github.com/peterbell10
This commit is contained in:
parent
41e36e2b46
commit
17e9c2d1e7
2 changed files with 42 additions and 13 deletions
|
|
@ -15,6 +15,9 @@
|
|||
#include <torch/custom_class.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/Config.h>
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
#include <ATen/native/mkldnn/Utils.h>
|
||||
#endif
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
|
|
@ -97,7 +100,10 @@ bool use_mkldnn(const Tensor& input, TensorList params, TensorList hx) {
|
|||
};
|
||||
return input.options().backend() == at::Backend::CPU &&
|
||||
is_cpu_backend(params) && is_cpu_backend(hx) &&
|
||||
(input.scalar_type() == kFloat || input.scalar_type() == kBFloat16) &&
|
||||
(input.scalar_type() == kFloat ||
|
||||
(input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) ||
|
||||
(input.scalar_type() == kHalf && !at::GradMode::is_enabled() &&
|
||||
mkldnn_fp16_device_check())) &&
|
||||
input.numel() != 0;
|
||||
#endif
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -1469,24 +1469,30 @@ class TestMkldnn(TestCase):
|
|||
params_list = list(params_dict.values())
|
||||
return params_list
|
||||
|
||||
def _cast_dtype(self, input, bf16):
|
||||
if bf16:
|
||||
def _cast_dtype(self, input, dtype):
|
||||
if dtype == torch.bfloat16:
|
||||
input = input.to(torch.bfloat16)
|
||||
elif dtype == torch.half:
|
||||
input = input.to(torch.half)
|
||||
return input
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
|
||||
def test_lstm(self):
|
||||
seed = 2023
|
||||
torch.manual_seed(seed)
|
||||
|
||||
params_list = self._lstm_params_list()
|
||||
for dtype in types:
|
||||
bf16 = True if dtype == torch.bfloat16 and torch.ops.mkldnn._is_mkldnn_bf16_supported() else False
|
||||
bf16 = dtype == torch.bfloat16
|
||||
fp16 = dtype == torch.half
|
||||
rtol = 1.3e-6
|
||||
atol = 1e-5
|
||||
|
||||
if bf16:
|
||||
rtol = 0.02
|
||||
atol = 0.02
|
||||
if fp16:
|
||||
rtol = 1e-3
|
||||
atol = 1e-3
|
||||
for input_size, hidden_size, num_layers, bidirectional, bias, batch_first, dropout, batch_size, seq_len, training \
|
||||
in itertools.product(*params_list):
|
||||
num_directions = 2 if bidirectional else 1
|
||||
|
|
@ -1496,7 +1502,9 @@ class TestMkldnn(TestCase):
|
|||
input = torch.randn(seq_len, batch_size, input_size, dtype=torch.float32)
|
||||
h = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32)
|
||||
c = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32)
|
||||
|
||||
if fp16:
|
||||
# TODO add traing support when oneDNN support lstm FP16 training
|
||||
training = False
|
||||
model = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional,
|
||||
bias=bias, dropout=dropout, batch_first=batch_first).float()
|
||||
model.train() if training else model.eval()
|
||||
|
|
@ -1510,15 +1518,25 @@ class TestMkldnn(TestCase):
|
|||
|
||||
model1 = copy.deepcopy(model)
|
||||
model2 = copy.deepcopy(model)
|
||||
with torch.cpu.amp.autocast(enabled=bf16, dtype=torch.bfloat16), torch.no_grad() if not training else nullcontext():
|
||||
with torch.no_grad() if not training else nullcontext():
|
||||
with torch.backends.mkldnn.flags(enabled=False):
|
||||
torch.manual_seed(seed)
|
||||
output1, (hn1, cn1) = self._cast_dtype(model1, bf16)(self._cast_dtype(input1, bf16),
|
||||
(self._cast_dtype(h1, bf16),
|
||||
self._cast_dtype(c1, bf16)))
|
||||
output1, (hn1, cn1) = self._cast_dtype(model1, dtype)(
|
||||
self._cast_dtype(input1, dtype),
|
||||
(
|
||||
self._cast_dtype(h1, dtype),
|
||||
self._cast_dtype(c1, dtype),
|
||||
),
|
||||
)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
output2, (hn2, cn2) = model2(input2, (h2, c2))
|
||||
output2, (hn2, cn2) = self._cast_dtype(model2, dtype)(
|
||||
self._cast_dtype(input2, dtype),
|
||||
(
|
||||
self._cast_dtype(h2, dtype),
|
||||
self._cast_dtype(c2, dtype),
|
||||
),
|
||||
)
|
||||
self.assertEqual(output1, output2, rtol=rtol, atol=atol)
|
||||
self.assertEqual(hn1, hn2, rtol=rtol, atol=atol)
|
||||
self.assertEqual(cn1, cn2, rtol=rtol, atol=atol)
|
||||
|
|
@ -1533,8 +1551,13 @@ class TestMkldnn(TestCase):
|
|||
|
||||
self.assertEqual(input1.grad, input2.grad, rtol=rtol, atol=atol)
|
||||
for name, para in model1.named_parameters():
|
||||
self.assertEqual(para, self._cast_dtype(getattr(model2, name), bf16))
|
||||
self.assertEqual(para.grad, self._cast_dtype(getattr(model2, name).grad, bf16), rtol=rtol, atol=atol)
|
||||
self.assertEqual(para, getattr(model2, name))
|
||||
self.assertEqual(
|
||||
para.grad,
|
||||
getattr(model2, name).grad,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
)
|
||||
|
||||
with torch.backends.mkldnn.flags(enabled=False):
|
||||
torch.manual_seed(seed)
|
||||
|
|
|
|||
Loading…
Reference in a new issue