From 17e9c2d1e7cf9bfb4c0481ff1ef224eeab414d21 Mon Sep 17 00:00:00 2001 From: CaoE Date: Tue, 27 Aug 2024 20:15:41 -0700 Subject: [PATCH] Add oneDNN support for Half LSTM on CPU (#132607) Pull Request resolved: https://github.com/pytorch/pytorch/pull/132607 Approved by: https://github.com/jgong5, https://github.com/peterbell10 --- aten/src/ATen/native/RNN.cpp | 8 +++++- test/test_mkldnn.py | 47 +++++++++++++++++++++++++++--------- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index d1bb95f8332..9db7b4cb7da 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -15,6 +15,9 @@ #include #include #include +#if AT_MKLDNN_ENABLED() +#include +#endif #ifndef AT_PER_OPERATOR_HEADERS #include @@ -97,7 +100,10 @@ bool use_mkldnn(const Tensor& input, TensorList params, TensorList hx) { }; return input.options().backend() == at::Backend::CPU && is_cpu_backend(params) && is_cpu_backend(hx) && - (input.scalar_type() == kFloat || input.scalar_type() == kBFloat16) && + (input.scalar_type() == kFloat || + (input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) || + (input.scalar_type() == kHalf && !at::GradMode::is_enabled() && + mkldnn_fp16_device_check())) && input.numel() != 0; #endif return false; diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index b61e50fb4f6..5f192d7c349 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -1469,24 +1469,30 @@ class TestMkldnn(TestCase): params_list = list(params_dict.values()) return params_list - def _cast_dtype(self, input, bf16): - if bf16: + def _cast_dtype(self, input, dtype): + if dtype == torch.bfloat16: input = input.to(torch.bfloat16) + elif dtype == torch.half: + input = input.to(torch.half) return input - @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") def test_lstm(self): seed = 2023 torch.manual_seed(seed) params_list = self._lstm_params_list() for dtype in types: - bf16 = True if dtype == torch.bfloat16 and torch.ops.mkldnn._is_mkldnn_bf16_supported() else False + bf16 = dtype == torch.bfloat16 + fp16 = dtype == torch.half rtol = 1.3e-6 atol = 1e-5 + if bf16: rtol = 0.02 atol = 0.02 + if fp16: + rtol = 1e-3 + atol = 1e-3 for input_size, hidden_size, num_layers, bidirectional, bias, batch_first, dropout, batch_size, seq_len, training \ in itertools.product(*params_list): num_directions = 2 if bidirectional else 1 @@ -1496,7 +1502,9 @@ class TestMkldnn(TestCase): input = torch.randn(seq_len, batch_size, input_size, dtype=torch.float32) h = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32) c = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32) - + if fp16: + # TODO add traing support when oneDNN support lstm FP16 training + training = False model = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional, bias=bias, dropout=dropout, batch_first=batch_first).float() model.train() if training else model.eval() @@ -1510,15 +1518,25 @@ class TestMkldnn(TestCase): model1 = copy.deepcopy(model) model2 = copy.deepcopy(model) - with torch.cpu.amp.autocast(enabled=bf16, dtype=torch.bfloat16), torch.no_grad() if not training else nullcontext(): + with torch.no_grad() if not training else nullcontext(): with torch.backends.mkldnn.flags(enabled=False): torch.manual_seed(seed) - output1, (hn1, cn1) = self._cast_dtype(model1, bf16)(self._cast_dtype(input1, bf16), - (self._cast_dtype(h1, bf16), - self._cast_dtype(c1, bf16))) + output1, (hn1, cn1) = self._cast_dtype(model1, dtype)( + self._cast_dtype(input1, dtype), + ( + self._cast_dtype(h1, dtype), + self._cast_dtype(c1, dtype), + ), + ) torch.manual_seed(seed) - output2, (hn2, cn2) = model2(input2, (h2, c2)) + output2, (hn2, cn2) = self._cast_dtype(model2, dtype)( + self._cast_dtype(input2, dtype), + ( + self._cast_dtype(h2, dtype), + self._cast_dtype(c2, dtype), + ), + ) self.assertEqual(output1, output2, rtol=rtol, atol=atol) self.assertEqual(hn1, hn2, rtol=rtol, atol=atol) self.assertEqual(cn1, cn2, rtol=rtol, atol=atol) @@ -1533,8 +1551,13 @@ class TestMkldnn(TestCase): self.assertEqual(input1.grad, input2.grad, rtol=rtol, atol=atol) for name, para in model1.named_parameters(): - self.assertEqual(para, self._cast_dtype(getattr(model2, name), bf16)) - self.assertEqual(para.grad, self._cast_dtype(getattr(model2, name).grad, bf16), rtol=rtol, atol=atol) + self.assertEqual(para, getattr(model2, name)) + self.assertEqual( + para.grad, + getattr(model2, name).grad, + rtol=rtol, + atol=atol, + ) with torch.backends.mkldnn.flags(enabled=False): torch.manual_seed(seed)