diff --git a/test/test_transformers.py b/test/test_transformers.py index 68fb89697a4..6066650e55c 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -5,10 +5,17 @@ import torch import torch.nn as nn import torch.nn.functional as F import unittest +from unittest.mock import patch from torch.testing._internal.common_nn import NNTestCase from torch.testing._internal.common_utils import ( - TEST_FAIRSEQ, run_tests, parametrize, instantiate_parametrized_tests, freeze_rng_state) + TEST_FAIRSEQ, + run_tests, + parametrize, + instantiate_parametrized_tests, + freeze_rng_state, + TEST_WITH_CROSSREF +) from torch.testing._internal.common_cuda import TEST_CUDA if TEST_FAIRSEQ: @@ -724,6 +731,59 @@ class TestTransformers(NNTestCase): if dropout_p == 0.0 or device == 'cpu': self.assertEqual(actual, expected) + @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref') + @torch.no_grad() + def test_mask_check_fastpath(self): + """ + Test that fastpath is executed independently of the mask that is passed. + If the passed mask is left aligned or mask_check=False, test that nested tensors are used (sparsity fastpath), + otherwise use fastpath with traditional tensors. + """ + + x = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]).to(torch.float) + + def _test_fastpath(model, mask, mock_return_value, nested_tensors=True): + with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock: + fastpath_mock.return_value = mock_return_value + model(x, src_key_padding_mask=mask) + + # If mock was called, fastpath was taken + self.assertTrue(fastpath_mock.called) + + # If mock was called with nested tensors, sparsity fastpath was taken + for call_args, _ in fastpath_mock.call_args_list: + self.assertEqual(call_args[0].is_nested, nested_tensors) + + encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True) + + model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True) + model.eval() + + aligned_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool) + not_aligned_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool) + nested_tensor_return_value = torch.nested_tensor([torch.ones((2, 2), dtype=torch.float)]) + tensor_return_value = torch.ones((1, 3, 2), dtype=torch.float) + + # Left aligned mask results in sparsity fastpath + _test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True) + + # Not aligned mask results in fastpath + _test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False) + + model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False, mask_check=True) + model.eval() + + # If nested tensor disabled, fastpath is always taken + _test_fastpath(model, aligned_mask, tensor_return_value, nested_tensors=False) + _test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False) + + + model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=False) + model.eval() + + # Mask check disabled results in sparisty fastpath, independently of the mask + _test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True) + _test_fastpath(model, not_aligned_mask, nested_tensor_return_value, nested_tensors=True) # TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for # cross device / dtype testing.