mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add fastpath test for mask check flag (#82999)
Summary: Check that fastpath is taken, which type (sparsity fastpath or normal) for mask that is aligned and one that is not. Test Plan: buck test caffe2/test:test_transformers Differential Revision: D38259928 Pull Request resolved: https://github.com/pytorch/pytorch/pull/82999 Approved by: https://github.com/jbschlosser
This commit is contained in:
parent
b60dc2eb43
commit
dfc97df64d
1 changed files with 61 additions and 1 deletions
|
|
@ -5,10 +5,17 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_FAIRSEQ, run_tests, parametrize, instantiate_parametrized_tests, freeze_rng_state)
|
||||
TEST_FAIRSEQ,
|
||||
run_tests,
|
||||
parametrize,
|
||||
instantiate_parametrized_tests,
|
||||
freeze_rng_state,
|
||||
TEST_WITH_CROSSREF
|
||||
)
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
|
||||
if TEST_FAIRSEQ:
|
||||
|
|
@ -724,6 +731,59 @@ class TestTransformers(NNTestCase):
|
|||
if dropout_p == 0.0 or device == 'cpu':
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref')
|
||||
@torch.no_grad()
|
||||
def test_mask_check_fastpath(self):
|
||||
"""
|
||||
Test that fastpath is executed independently of the mask that is passed.
|
||||
If the passed mask is left aligned or mask_check=False, test that nested tensors are used (sparsity fastpath),
|
||||
otherwise use fastpath with traditional tensors.
|
||||
"""
|
||||
|
||||
x = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]).to(torch.float)
|
||||
|
||||
def _test_fastpath(model, mask, mock_return_value, nested_tensors=True):
|
||||
with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock:
|
||||
fastpath_mock.return_value = mock_return_value
|
||||
model(x, src_key_padding_mask=mask)
|
||||
|
||||
# If mock was called, fastpath was taken
|
||||
self.assertTrue(fastpath_mock.called)
|
||||
|
||||
# If mock was called with nested tensors, sparsity fastpath was taken
|
||||
for call_args, _ in fastpath_mock.call_args_list:
|
||||
self.assertEqual(call_args[0].is_nested, nested_tensors)
|
||||
|
||||
encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True)
|
||||
|
||||
model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True)
|
||||
model.eval()
|
||||
|
||||
aligned_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool)
|
||||
not_aligned_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool)
|
||||
nested_tensor_return_value = torch.nested_tensor([torch.ones((2, 2), dtype=torch.float)])
|
||||
tensor_return_value = torch.ones((1, 3, 2), dtype=torch.float)
|
||||
|
||||
# Left aligned mask results in sparsity fastpath
|
||||
_test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True)
|
||||
|
||||
# Not aligned mask results in fastpath
|
||||
_test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False)
|
||||
|
||||
model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False, mask_check=True)
|
||||
model.eval()
|
||||
|
||||
# If nested tensor disabled, fastpath is always taken
|
||||
_test_fastpath(model, aligned_mask, tensor_return_value, nested_tensors=False)
|
||||
_test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False)
|
||||
|
||||
|
||||
model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=False)
|
||||
model.eval()
|
||||
|
||||
# Mask check disabled results in sparisty fastpath, independently of the mask
|
||||
_test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True)
|
||||
_test_fastpath(model, not_aligned_mask, nested_tensor_return_value, nested_tensors=True)
|
||||
|
||||
# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
|
||||
# cross device / dtype testing.
|
||||
|
|
|
|||
Loading…
Reference in a new issue