Add fastpath test for mask check flag (#82999)

Summary: Check that fastpath is taken, which type (sparsity fastpath or normal) for mask that is aligned and one that is not.

Test Plan: buck test caffe2/test:test_transformers

Differential Revision: D38259928

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82999
Approved by: https://github.com/jbschlosser
This commit is contained in:
Yoav Navon 2022-08-12 00:04:43 +00:00 committed by PyTorch MergeBot
parent b60dc2eb43
commit dfc97df64d

View file

@ -5,10 +5,17 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import unittest
from unittest.mock import patch
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
TEST_FAIRSEQ, run_tests, parametrize, instantiate_parametrized_tests, freeze_rng_state)
TEST_FAIRSEQ,
run_tests,
parametrize,
instantiate_parametrized_tests,
freeze_rng_state,
TEST_WITH_CROSSREF
)
from torch.testing._internal.common_cuda import TEST_CUDA
if TEST_FAIRSEQ:
@ -724,6 +731,59 @@ class TestTransformers(NNTestCase):
if dropout_p == 0.0 or device == 'cpu':
self.assertEqual(actual, expected)
@unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref')
@torch.no_grad()
def test_mask_check_fastpath(self):
"""
Test that fastpath is executed independently of the mask that is passed.
If the passed mask is left aligned or mask_check=False, test that nested tensors are used (sparsity fastpath),
otherwise use fastpath with traditional tensors.
"""
x = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]).to(torch.float)
def _test_fastpath(model, mask, mock_return_value, nested_tensors=True):
with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock:
fastpath_mock.return_value = mock_return_value
model(x, src_key_padding_mask=mask)
# If mock was called, fastpath was taken
self.assertTrue(fastpath_mock.called)
# If mock was called with nested tensors, sparsity fastpath was taken
for call_args, _ in fastpath_mock.call_args_list:
self.assertEqual(call_args[0].is_nested, nested_tensors)
encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True)
model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True)
model.eval()
aligned_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool)
not_aligned_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool)
nested_tensor_return_value = torch.nested_tensor([torch.ones((2, 2), dtype=torch.float)])
tensor_return_value = torch.ones((1, 3, 2), dtype=torch.float)
# Left aligned mask results in sparsity fastpath
_test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True)
# Not aligned mask results in fastpath
_test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False)
model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False, mask_check=True)
model.eval()
# If nested tensor disabled, fastpath is always taken
_test_fastpath(model, aligned_mask, tensor_return_value, nested_tensors=False)
_test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False)
model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=False)
model.eval()
# Mask check disabled results in sparisty fastpath, independently of the mask
_test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True)
_test_fastpath(model, not_aligned_mask, nested_tensor_return_value, nested_tensors=True)
# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
# cross device / dtype testing.