mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Remove a number of fixed skips (#103162)
Also adds `PYTORCH_TEST_WITH_AOT_EAGER` to distinguish errors coming from aot_autograd and not inductor (not tested in ci, but useful for local debugging) Pull Request resolved: https://github.com/pytorch/pytorch/pull/103162 Approved by: https://github.com/desertfire
This commit is contained in:
parent
3c896a5adb
commit
40d70ba7ed
3 changed files with 7 additions and 29 deletions
|
|
@ -13,7 +13,7 @@ from torch.testing._internal.common_device_type import (
|
|||
from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
|
||||
gradgradcheck, skipIfTorchInductor)
|
||||
gradgradcheck)
|
||||
from unittest.mock import patch, call
|
||||
|
||||
|
||||
|
|
@ -324,7 +324,6 @@ class TestModule(TestCase):
|
|||
self._traverse_obj(obj, inner_zero_grad)
|
||||
|
||||
@modules(module_db)
|
||||
@skipIfTorchInductor("to be fixed")
|
||||
def test_non_contiguous_tensors(self, device, dtype, module_info, training):
|
||||
# Check modules work with non-contiguous tensors
|
||||
|
||||
|
|
@ -488,7 +487,6 @@ class TestModule(TestCase):
|
|||
@toleranceOverride({torch.float32: tol(5e-2, 0),
|
||||
torch.float64: tol(4e-4, 0)})
|
||||
@modules(module_db)
|
||||
@skipIfTorchInductor("to be fixed")
|
||||
def test_cpu_gpu_parity(self, device, dtype, module_info, training):
|
||||
# TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a
|
||||
# nicer way for eval mode only.
|
||||
|
|
@ -580,7 +578,6 @@ class TestModule(TestCase):
|
|||
|
||||
@with_tf32_off
|
||||
@modules(module_db)
|
||||
@skipIfTorchInductor("to be fixed")
|
||||
def test_memory_format(self, device, dtype, module_info, training):
|
||||
is_sm86or80 = device.startswith("cuda") and (torch.cuda.get_device_capability(0) == (8, 6)
|
||||
or torch.cuda.get_device_capability(0) == (8, 0))
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from torch.testing import make_tensor
|
|||
from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_JETSON,
|
||||
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
|
||||
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, skipIfTorchInductor, slowTest,
|
||||
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest,
|
||||
TEST_WITH_CROSSREF, skipIfTorchDynamo,
|
||||
skipCUDAMemoryLeakCheckIf, BytesIOContext,
|
||||
skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
|
||||
|
|
@ -519,7 +519,6 @@ class TestTorchDeviceType(TestCase):
|
|||
|
||||
# collected tests of ops that used scalar_check in Declarations.cwrap for
|
||||
# correctness
|
||||
@skipIfTorchInductor("segfaults")
|
||||
def test_scalar_check(self, device):
|
||||
zero_d = torch.randn((), device=device)
|
||||
one_d = torch.randn((1,), device=device)
|
||||
|
|
@ -981,7 +980,6 @@ class TestTorchDeviceType(TestCase):
|
|||
torch.set_default_tensor_type(default_type)
|
||||
|
||||
# TODO: this test should be in test_nn.py
|
||||
@skipIfTorchInductor("Please convert all Tensors to FakeTensors")
|
||||
def test_conv_transposed_backward_agnostic_to_memory_format(self, device):
|
||||
in_channels = 64
|
||||
out_channels = 128
|
||||
|
|
@ -3260,7 +3258,6 @@ else:
|
|||
|
||||
# FIXME: move to test indexing
|
||||
@onlyCPU
|
||||
@skipIfTorchInductor("FIXME")
|
||||
def test_errors_index_copy(self, device):
|
||||
# We do not test the GPU as the CUDA_ASSERT would break the CUDA context
|
||||
idx_dim = 8
|
||||
|
|
@ -3643,7 +3640,6 @@ else:
|
|||
@dtypes(*floating_and_complex_types())
|
||||
@dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
@dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
@skipIfTorchInductor("FIXME")
|
||||
def test_scatter_reduce_non_unique_index(self, device, dtype):
|
||||
height = 2
|
||||
width = 2
|
||||
|
|
@ -4401,7 +4397,6 @@ else:
|
|||
|
||||
# FIXME: move to test distributions
|
||||
@onlyCUDA
|
||||
@skipIfTorchInductor("out_wrapper does not check devices correctly")
|
||||
def test_multinomial_device_constrain(self, device):
|
||||
x = torch.empty(3, device="cpu")
|
||||
y = torch.empty(3, device=device)
|
||||
|
|
@ -4574,7 +4569,6 @@ else:
|
|||
y = ndhwc.permute(0, 1, 4, 3, 2).permute(0, 1, 4, 3, 2)
|
||||
self.assertTrue(y.is_contiguous(memory_format=torch.channels_last_3d))
|
||||
|
||||
@skipIfTorchInductor("To be supported")
|
||||
def test_memory_format_propagation_rules(self, device):
|
||||
|
||||
contiguous = torch.rand(10, 3, 5, 5, device=device)
|
||||
|
|
@ -4618,7 +4612,6 @@ else:
|
|||
self.assertEqual(ambiguous.stride(), result.stride())
|
||||
|
||||
@skipIfMps
|
||||
@skipIfTorchInductor("To be supported")
|
||||
def test_memory_format_empty_like(self, device):
|
||||
def test_helper(x, memory_format):
|
||||
xc = x.contiguous(memory_format=memory_format)
|
||||
|
|
@ -4661,7 +4654,6 @@ else:
|
|||
x.is_contiguous(memory_format=torch.channels_last_3d), x_rep.is_contiguous(memory_format=torch.channels_last_3d))
|
||||
|
||||
# FIXME: make this a elementwise unary and elementwise binary OpInfo test
|
||||
@skipIfTorchInductor("To be supported")
|
||||
def test_memory_format_operators(self, device):
|
||||
def _chunk_op(x, y):
|
||||
x1, x2 = x.chunk(2, dim=1)
|
||||
|
|
@ -4874,7 +4866,6 @@ else:
|
|||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
|
||||
@skipIfTorchInductor("pin_memory isn't yet supported in TorchInductor")
|
||||
def test_pin_memory_from_constructor(self, device):
|
||||
def _get_like(t, **kwargs):
|
||||
return [
|
||||
|
|
@ -5096,7 +5087,6 @@ else:
|
|||
x = x.permute(permutation)
|
||||
self.assertEqual(x.stride(), transformation_fn(x, memory_format=torch.preserve_format).stride())
|
||||
|
||||
@skipIfTorchInductor("To be supported")
|
||||
def test_memory_format_to(self, device):
|
||||
def get_generator(memory_format, shape):
|
||||
def input_generator_fn(device):
|
||||
|
|
@ -5114,7 +5104,6 @@ else:
|
|||
self._test_memory_format_transformations(
|
||||
device, get_generator(mf, shape), transformation_fn, mf, default_is_preserve=True)
|
||||
|
||||
@skipIfTorchInductor("To be supported")
|
||||
def test_memory_format_type(self, device):
|
||||
def get_generator(memory_format, shape):
|
||||
def input_generator_fn(device):
|
||||
|
|
@ -5132,7 +5121,6 @@ else:
|
|||
self._test_memory_format_transformations(
|
||||
device, get_generator(mf, shape), transformation_fn, mf, default_is_preserve=True)
|
||||
|
||||
@skipIfTorchInductor("To be supported")
|
||||
def test_memory_format_clone(self, device):
|
||||
def get_generator(memory_format, shape):
|
||||
def input_generator_fn(device):
|
||||
|
|
@ -5176,7 +5164,6 @@ else:
|
|||
self._test_memory_format_transformations(
|
||||
device, get_generator(mf, shape), transformation_fn, mf, compare_data=False, default_is_preserve=True)
|
||||
|
||||
@skipIfTorchInductor("To be supported")
|
||||
def test_memory_format_type_shortcuts(self, device):
|
||||
def get_generator(memory_format, shape, dtype):
|
||||
def input_generator_fn(device):
|
||||
|
|
@ -5210,7 +5197,6 @@ else:
|
|||
device, get_generator(mf, shape, torch.float64), get_fn('float'), mf, default_is_preserve=True)
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfTorchInductor("To be supported")
|
||||
def test_memory_format_cpu_and_cuda_ops(self, device):
|
||||
def get_generator(memory_format, shape):
|
||||
def input_generator_fn(device):
|
||||
|
|
@ -5411,7 +5397,6 @@ else:
|
|||
# FIXME: get PyTorch/XLA to run test_testing
|
||||
# This test should ideally be in test_testing.py,
|
||||
# but since pytorch/xla runs tests from test_torch.py, we have it here.
|
||||
@skipIfTorchInductor("random_.from needs to be renamed")
|
||||
def test_assertRaisesRegex_ignore_msg_non_native_device(self, device):
|
||||
# Verify that self.assertRaisesRegex only checks the Error and ignores
|
||||
# message for non-native devices.
|
||||
|
|
@ -5586,7 +5571,6 @@ class TestDevicePrecision(TestCase):
|
|||
|
||||
# FIXME: moved to indexing test suite
|
||||
@deviceCountAtLeast(1)
|
||||
@skipIfTorchInductor("FIXME")
|
||||
def test_advancedindex_mixed_cpu_devices(self, devices) -> None:
|
||||
def test(x: torch.Tensor, ia: torch.Tensor, ib: torch.Tensor) -> None:
|
||||
# test getitem
|
||||
|
|
@ -5620,7 +5604,6 @@ class TestDevicePrecision(TestCase):
|
|||
test(x, ia, ib)
|
||||
|
||||
@deviceCountAtLeast(1)
|
||||
@skipIfTorchInductor("FIXME")
|
||||
def test_advancedindex_mixed_devices_error(self, devices) -> None:
|
||||
def test(x: torch.Tensor, ia: torch.Tensor, ib: torch.Tensor) -> None:
|
||||
# test getitem
|
||||
|
|
@ -7423,7 +7406,6 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
self.assertRaises(RuntimeError, lambda: x.new(z.storage()))
|
||||
|
||||
@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
|
||||
@skipIfTorchInductor("pin_memory isn't yet supported in TorchInductor")
|
||||
def test_pin_memory(self):
|
||||
x = torch.randn(3, 5)
|
||||
self.assertFalse(x.is_pinned())
|
||||
|
|
@ -7643,7 +7625,6 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
weight), torch.tensor(bias), 1, epsilon, True)
|
||||
torch.testing.assert_close(expected_norm, actual_norm)
|
||||
|
||||
@skipIfTorchInductor("To be supported")
|
||||
def test_memory_format(self):
|
||||
def test_helper(x, memory_format):
|
||||
y = x.contiguous(memory_format=memory_format)
|
||||
|
|
@ -7834,7 +7815,6 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
|
||||
self.assertRaises(RuntimeError, lambda: z[0][0][0].item())
|
||||
|
||||
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/97414")
|
||||
def test_upsample_nearest2d_meta(self):
|
||||
# TODO: the out tests cannot be triggered by test_nn.py because
|
||||
# we don't actually do out= arguments for nn functions, so there
|
||||
|
|
@ -8180,7 +8160,6 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
self.assertEqual(y[:, 40], range(4000, 4100))
|
||||
|
||||
# FIXME: Port to a more appropriate test suite
|
||||
@skipIfTorchInductor("FIXME")
|
||||
def test_copy_broadcast(self):
|
||||
torch.zeros(5, 6).copy_(torch.zeros(6))
|
||||
self.assertRaises(RuntimeError, lambda: torch.zeros(5, 6).copy_(torch.zeros(30)))
|
||||
|
|
@ -8194,7 +8173,6 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
# storage to a single storage would cause RuntimeError to be thrown
|
||||
self.assertRaises(RuntimeError, lambda: torch.zeros(1, 6).expand(5, 6).copy_(torch.zeros(5, 6)))
|
||||
|
||||
@skipIfTorchInductor("FIXME")
|
||||
def test_copy_float16(self):
|
||||
# Check that fbgemm code no longer reads memory out of bounds, see
|
||||
# copy_impl and fbgemm::Float16ToFloat_ref.
|
||||
|
|
@ -8378,7 +8356,6 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
s0 = t0.as_subclass(BadSubTensor)
|
||||
|
||||
# FIXME: Port to a test suite that better fits slicing
|
||||
@skipIfTorchInductor("FIXME")
|
||||
def test_slice(self):
|
||||
empty = torch.empty(0, 4)
|
||||
x = torch.arange(0., 16).view(4, 4)
|
||||
|
|
|
|||
|
|
@ -1034,7 +1034,9 @@ class CrossRefMode(torch.overrides.TorchFunctionMode):
|
|||
|
||||
# Run PyTorch tests with TorchDynamo
|
||||
TEST_WITH_TORCHINDUCTOR = os.getenv('PYTORCH_TEST_WITH_INDUCTOR') == '1'
|
||||
TEST_WITH_TORCHDYNAMO = os.getenv('PYTORCH_TEST_WITH_DYNAMO') == '1' or TEST_WITH_TORCHINDUCTOR
|
||||
# AOT_EAGER not tested in ci, useful for debugging
|
||||
TEST_WITH_AOT_EAGER = os.getenv('PYTORCH_TEST_WITH_AOT_EAGER') == '1'
|
||||
TEST_WITH_TORCHDYNAMO = os.getenv('PYTORCH_TEST_WITH_DYNAMO') == '1' or TEST_WITH_TORCHINDUCTOR or TEST_WITH_AOT_EAGER
|
||||
|
||||
if TEST_WITH_TORCHDYNAMO:
|
||||
import torch._dynamo
|
||||
|
|
@ -2263,6 +2265,8 @@ class TestCase(expecttest.TestCase):
|
|||
super_run = super().run
|
||||
if TEST_WITH_TORCHINDUCTOR:
|
||||
super_run = torch._dynamo.optimize("inductor")(super_run)
|
||||
elif TEST_WITH_AOT_EAGER:
|
||||
super_run = torch._dynamo.optimize("aot_eager")(super_run)
|
||||
elif TEST_WITH_TORCHDYNAMO:
|
||||
# TorchDynamo optimize annotation
|
||||
super_run = torch._dynamo.optimize("eager")(super_run)
|
||||
|
|
|
|||
Loading…
Reference in a new issue