From 0fcdf936e703db544f5285dfa5fa2cb2e267f8a4 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Mon, 25 Jul 2022 11:47:44 -0400 Subject: [PATCH] Skip tests that don't call gradcheck in slow gradcheck CI (#82117) Pull Request resolved: https://github.com/pytorch/pytorch/pull/82117 Approved by: https://github.com/kit1980, https://github.com/albanD --- test/quantization/core/test_quantized_op.py | 3 ++- test/quantization/jit/test_quantize_jit.py | 2 ++ test/test_cpp_extensions_jit.py | 8 +++++--- test/test_fx.py | 2 ++ 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 2e34f8a5138..41b735afc75 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -21,7 +21,7 @@ from hypothesis import strategies as st import torch.testing._internal.hypothesis_utils as hu hu.assert_deadline_disabled() -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import TestCase, skipIfSlowGradcheckEnv from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, BUILD_WITH_CAFFE2 from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \ @@ -130,6 +130,7 @@ def _get_random_tensor_and_q_params(shapes, rand_scale, torch_type): X_scale = 1e-10 return X, X_scale, X_zero_point +@skipIfSlowGradcheckEnv class TestQuantizedOps(TestCase): """Helper function to test quantized activation functions.""" diff --git a/test/quantization/jit/test_quantize_jit.py b/test/quantization/jit/test_quantize_jit.py index 6648bcaa9af..84ab3a723b7 100644 --- a/test/quantization/jit/test_quantize_jit.py +++ b/test/quantization/jit/test_quantize_jit.py @@ -73,6 +73,7 @@ from torch.testing import FileCheck from torch.testing._internal.jit_utils import attrs_with_prefix from torch.testing._internal.jit_utils import get_forward from torch.testing._internal.jit_utils import get_forward_graph +from torch.testing._internal.common_utils import skipIfSlowGradcheckEnv from torch.jit._recursive import wrap_cpp_module @@ -1625,6 +1626,7 @@ class TestQuantizeJitPasses(QuantizationTestCase): torch.jit.save(model, b) +@skipIfSlowGradcheckEnv class TestQuantizeJitOps(QuantizationTestCase): """Test graph mode post training static quantization works for individual ops end to end. diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index 9875f4ee356..e4b1e9e5508 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -15,7 +15,7 @@ import torch import torch.backends.cudnn import torch.utils.cpp_extension from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME -from torch.testing._internal.common_utils import gradcheck +from torch.testing._internal.common_utils import gradcheck, skipIfSlowGradcheckEnv TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None @@ -37,7 +37,8 @@ def remove_build_path(): if os.path.exists(default_build_root): shutil.rmtree(default_build_root) - +# There's only one test that runs gracheck, run slow mode manually +@skipIfSlowGradcheckEnv class TestCppExtensionJIT(common.TestCase): """Tests just-in-time cpp extensions. Don't confuse this with the PyTorch JIT (aka TorchScript). @@ -864,7 +865,8 @@ class TestCppExtensionJIT(common.TestCase): a = torch.randn(5, 5, requires_grad=True) b = torch.randn(5, 5, requires_grad=True) - gradcheck(torch.ops.my.add, [a, b], eps=1e-2) + for fast_mode in (True, False): + gradcheck(torch.ops.my.add, [a, b], eps=1e-2, fast_mode=fast_mode) if __name__ == "__main__": diff --git a/test/test_fx.py b/test/test_fx.py index 76ee2ad5a14..eb26dfe471f 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -57,6 +57,7 @@ from torch.testing._internal.common_utils import ( IS_WINDOWS, find_library_location, run_tests, + skipIfSlowGradcheckEnv, ) from torch.testing._internal.jit_utils import JitTestCase @@ -4072,6 +4073,7 @@ TestFunctionalTracing.generate_tests() instantiate_device_type_tests(TestOperatorSignatures, globals()) @skipIfNoTorchVision +@skipIfSlowGradcheckEnv class TestVisionTracing(JitTestCase): def setUp(self): # Checking for mutable operations while tracing is feature flagged