mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Context: For a while slow gradcheck CI was skipping nearly all tests and this hid the fact that it should've been failing and timing out (10+h runtime for TestGradients). The CI configuration has since been fixed to correct this, revealing the test failures. This PR reenables slow gradcheck CI and makes it pass again. This PR: - makes slow and failing tests run in fast gradcheck mode only - reduce the input size for slow gradcheck only for unary/binary ufuncs (alternatively, skip the test entirely) - skip entire test files on slow gradcheck runner if they don't use gradcheck (test_ops, test_meta, test_decomp, test_ops_jit) - reduces the input size for some ops Follow ups: 1. Investigate slow mode failures https://github.com/pytorch/pytorch/issues/80411 2. See if we can re-enable slow gradcheck tests for some of the slow tests by reducing the sizes of their inputs The following are failing in slow mode, they are now running in fast mode only. ``` test_fn_fwgrad_bwgrad___rmod___cuda_float64 test_fn_fwgrad_bwgrad_linalg_householder_product_cuda_complex128 test_fn_fwgrad_bwgrad__masked_prod_cuda_complex128 test_fn_fwgrad_bwgrad__masked_prod_cuda_float64 test_fn_fwgrad_bwgrad_linalg_matrix_power_cuda_complex128 test_fn_fwgrad_bwgrad_cat_cuda_complex128 test_fn_fwgrad_bwgrad_linalg_lu_factor_ex_cuda_float64 test_fn_fwgrad_bwgrad_copysign_cuda_float64 test_fn_fwgrad_bwgrad_cholesky_inverse_cuda_complex128 test_fn_fwgrad_bwgrad_float_power_cuda_complex128 test_fn_fwgrad_bwgrad_fmod_cuda_float64 test_fn_fwgrad_bwgrad_float_power_cuda_float64 test_fn_fwgrad_bwgrad_linalg_lu_cuda_float64 test_fn_fwgrad_bwgrad_remainder_cuda_float64 test_fn_fwgrad_bwgrad_repeat_cuda_complex128 test_fn_fwgrad_bwgrad_prod_cuda_complex128 test_fn_fwgrad_bwgrad_slice_scatter_cuda_float64 test_fn_fwgrad_bwgrad_tile_cuda_complex128 test_fn_fwgrad_bwgrad_pow_cuda_float64 test_fn_fwgrad_bwgrad_pow_cuda_complex128 test_fn_fwgrad_bwgrad_fft_* test_fn_fwgrad_bwgrad_zero__cuda_complex128 test_fn_gradgrad_linalg_lu_factor_cuda_float64 test_fn_grad_div_trunc_rounding_cuda_float64 test_fn_grad_div_floor_rounding_cuda_float64 ``` Marks the OpInfos for the following ops that run slowly in slow gradcheck as `fast_gradcheck` only (the left column represents runtime in seconds): ``` 0 918.722 test_fn_fwgrad_bwgrad_nn_functional_conv_transpose3d_cuda_float64 1 795.042 test_fn_fwgrad_bwgrad_nn_functional_unfold_cuda_complex128 2 583.63 test_fn_fwgrad_bwgrad_nn_functional_max_pool3d_cuda_float64 3 516.946 test_fn_fwgrad_bwgrad_svd_cuda_complex128 4 503.179 test_fn_fwgrad_bwgrad_linalg_svd_cuda_complex128 5 460.985 test_fn_fwgrad_bwgrad_linalg_lu_cuda_complex128 6 401.04 test_fn_fwgrad_bwgrad_linalg_lstsq_grad_oriented_cuda_complex128 7 353.671 test_fn_fwgrad_bwgrad_nn_functional_max_pool2d_cuda_float64 8 321.903 test_fn_fwgrad_bwgrad_nn_functional_gaussian_nll_loss_cuda_float64 9 307.951 test_fn_fwgrad_bwgrad_stft_cuda_complex128 10 266.104 test_fn_fwgrad_bwgrad_svd_lowrank_cuda_float64 11 221.032 test_fn_fwgrad_bwgrad_istft_cuda_complex128 12 183.741 test_fn_fwgrad_bwgrad_lu_unpack_cuda_complex128 13 132.019 test_fn_fwgrad_bwgrad_nn_functional_unfold_cuda_float64 14 125.343 test_fn_fwgrad_bwgrad_nn_functional_pad_constant_cuda_complex128 15 124.2 test_fn_fwgrad_bwgrad_kron_cuda_complex128 16 123.721 test_fn_fwgrad_bwgrad_pca_lowrank_cuda_float64 17 121.074 test_fn_fwgrad_bwgrad_nn_functional_max_unpool3d_cuda_float64 18 119.387 test_fn_fwgrad_bwgrad_rot90_cuda_complex128 19 112.889 test_fn_fwgrad_bwgrad__masked_normalize_cuda_complex128 20 107.541 test_fn_fwgrad_bwgrad_dist_cuda_complex128 21 106.727 test_fn_fwgrad_bwgrad_diff_cuda_complex128 22 104.588 test_fn_fwgrad_bwgrad__masked_cumprod_cuda_complex128 23 100.135 test_fn_fwgrad_bwgrad_nn_functional_feature_alpha_dropout_with_train_cuda_float64 24 88.359 test_fn_fwgrad_bwgrad_mH_cuda_complex128 25 86.214 test_fn_fwgrad_bwgrad_nn_functional_max_unpool2d_cuda_float64 26 83.037 test_fn_fwgrad_bwgrad_nn_functional_bilinear_cuda_float64 27 79.987 test_fn_fwgrad_bwgrad__masked_cumsum_cuda_complex128 28 77.822 test_fn_fwgrad_bwgrad_diag_embed_cuda_complex128 29 76.256 test_fn_fwgrad_bwgrad_mT_cuda_complex128 30 74.039 test_fn_fwgrad_bwgrad_linalg_lu_solve_cuda_complex128 ``` ``` 0 334.142 test_fn_fwgrad_bwgrad_unfold_cuda_complex128 1 312.791 test_fn_fwgrad_bwgrad_linalg_lu_factor_cuda_complex128 2 121.963 test_fn_fwgrad_bwgrad_nn_functional_max_unpool3d_cuda_float64 3 108.085 test_fn_fwgrad_bwgrad_diff_cuda_complex128 4 89.418 test_fn_fwgrad_bwgrad_nn_functional_max_unpool2d_cuda_float64 5 72.231 test_fn_fwgrad_bwgrad___rdiv___cuda_complex128 6 69.433 test_fn_fwgrad_bwgrad___getitem___cuda_complex128 7 68.582 test_fn_fwgrad_bwgrad_ldexp_cuda_complex128 8 68.572 test_fn_fwgrad_bwgrad_linalg_pinv_cuda_complex128 9 67.585 test_fn_fwgrad_bwgrad_nn_functional_glu_cuda_float64 10 66.567 test_fn_fwgrad_bwgrad_lu_cuda_float64 ``` ``` 0 630.13 test_fn_gradgrad_nn_functional_conv2d_cuda_complex128 1 81.086 test_fn_gradgrad_linalg_solve_triangular_cuda_complex128 2 71.332 test_fn_gradgrad_norm_cuda_complex128 3 64.308 test_fn_gradgrad__masked_std_cuda_complex128 4 59.519 test_fn_gradgrad_div_no_rounding_mode_cuda_complex128 5 58.836 test_fn_gradgrad_nn_functional_adaptive_avg_pool3 ``` Reduces the sizes of the inputs for: - diff - diag_embed Pull Request resolved: https://github.com/pytorch/pytorch/pull/80514 Approved by: https://github.com/albanD
266 lines
13 KiB
Python
266 lines
13 KiB
Python
# Owner(s): ["module: unknown"]
|
|
|
|
from functools import partial, wraps
|
|
from itertools import chain
|
|
import torch
|
|
|
|
from torch.testing._internal.common_utils import \
|
|
(TestCase, is_iterable_of_tensors, run_tests, gradcheck, gradgradcheck, is_slow_gradcheck_env)
|
|
from torch.testing._internal.common_methods_invocations import op_db
|
|
from torch.testing._internal.common_device_type import \
|
|
(instantiate_device_type_tests, ops, OpDTypes)
|
|
|
|
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
|
|
torch.set_default_dtype(torch.float32)
|
|
|
|
# gradcheck requires double precision
|
|
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
|
allowed_dtypes=[torch.double, torch.cdouble])
|
|
|
|
class TestGradients(TestCase):
|
|
exact_dtype = True
|
|
|
|
# Copies inputs to inplace operations to avoid inplace modifications
|
|
# to leaves requiring gradient
|
|
def _get_safe_inplace(self, inplace_variant):
|
|
@wraps(inplace_variant)
|
|
def _fn(t, *args, **kwargs):
|
|
return inplace_variant(t.clone(), *args, **kwargs)
|
|
|
|
return _fn
|
|
|
|
def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True,
|
|
check_batched_grad=None, check_batched_forward_grad=False):
|
|
assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad')
|
|
# NB: check_backward_ad does not affect gradgradcheck (always True)
|
|
if variant is None:
|
|
self.skipTest("Skipped! Variant not implemented.")
|
|
if not op.supports_dtype(dtype, torch.device(device).type):
|
|
self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
|
|
|
|
def is_inplace(variant):
|
|
if hasattr(variant, "__wrapped__"):
|
|
return variant.__wrapped__ is op.get_inplace()
|
|
return variant is op.get_inplace()
|
|
|
|
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs,
|
|
small_inputs_only=is_slow_gradcheck_env())
|
|
|
|
for sample in samples:
|
|
if sample.broadcasts_input and is_inplace(variant):
|
|
continue
|
|
|
|
# Gradcheck expects tensors as its input, but autograd actually supports tensorlists
|
|
# and tensors passed as kwargs. The following creates a function that accepts just
|
|
# the tensors that require grad as varargs, and then recomposes them back into the
|
|
# original input.
|
|
|
|
# Creates gradcheck inputs by identifying tensors requiring grad
|
|
all_args = None
|
|
if is_iterable_of_tensors(sample.input):
|
|
all_args = chain(sample.input, sample.args, sample.kwargs.values())
|
|
else:
|
|
all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values()))
|
|
gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad))
|
|
|
|
def _input_recomposition_helper(inputs, inp, input_idx):
|
|
if is_iterable_of_tensors(inp):
|
|
tensor_list = []
|
|
for x in inp:
|
|
if isinstance(x, torch.Tensor) and x.requires_grad:
|
|
tensor_list.append(inputs[input_idx])
|
|
input_idx = input_idx + 1
|
|
else:
|
|
tensor_list.append(x)
|
|
return tensor_list, input_idx
|
|
elif isinstance(inp, torch.Tensor) and inp.requires_grad:
|
|
return inputs[input_idx], input_idx + 1
|
|
else:
|
|
return inp, input_idx
|
|
|
|
def fn(*inputs):
|
|
# Puts inputs back into sample properly
|
|
positional_args = []
|
|
input_idx = 0
|
|
inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx)
|
|
positional_args.append(inp)
|
|
|
|
for x in sample.args:
|
|
inp, input_idx = _input_recomposition_helper(inputs, x, input_idx)
|
|
positional_args.append(inp)
|
|
|
|
# Recreates kwargs
|
|
kwargs = {}
|
|
for k, v in sample.kwargs.items():
|
|
inp, input_idx = _input_recomposition_helper(inputs, v, input_idx)
|
|
kwargs[k] = inp
|
|
|
|
output = op.gradcheck_wrapper(variant, *positional_args, **kwargs)
|
|
if sample.output_process_fn_grad is not None:
|
|
return sample.output_process_fn_grad(output)
|
|
return output
|
|
|
|
if check == 'gradcheck':
|
|
if check_batched_grad is None:
|
|
check_batched_grad = op.check_batched_grad
|
|
self.assertTrue(gradcheck(fn, gradcheck_args,
|
|
check_batched_grad=check_batched_grad,
|
|
check_grad_dtypes=True,
|
|
nondet_tol=op.gradcheck_nondet_tol,
|
|
fast_mode=op.gradcheck_fast_mode,
|
|
check_forward_ad=check_forward_ad,
|
|
check_backward_ad=check_backward_ad,
|
|
check_undefined_grad=True,
|
|
check_batched_forward_grad=check_batched_forward_grad))
|
|
elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'): # gradgrad check
|
|
self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
|
|
for gen_non_contig_grad_outputs in (False, True):
|
|
kwargs = {
|
|
"gen_non_contig_grad_outputs": gen_non_contig_grad_outputs,
|
|
"check_batched_grad": op.check_batched_gradgrad,
|
|
"check_grad_dtypes": True,
|
|
"nondet_tol": op.gradcheck_nondet_tol,
|
|
"fast_mode": op.gradcheck_fast_mode
|
|
}
|
|
if check == "fwgrad_bwgrad":
|
|
kwargs["check_fwd_over_rev"] = True
|
|
kwargs["check_rev_over_rev"] = False
|
|
kwargs["check_batched_grad"] = False
|
|
kwargs["check_undefined_grad"] = False
|
|
|
|
self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs))
|
|
else:
|
|
self.assertTrue(False, msg="Unknown check requested!")
|
|
|
|
def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False, check_backward_ad=True,
|
|
check_batched_grad=None, check_batched_forward_grad=False):
|
|
return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad,
|
|
check_backward_ad=check_backward_ad, check_batched_grad=check_batched_grad,
|
|
check_batched_forward_grad=check_batched_forward_grad)
|
|
|
|
def _skip_helper(self, op, device, dtype):
|
|
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
|
|
self.skipTest("Skipped! Op doesn't support autograd for this dtype.")
|
|
if not op.supports_autograd and not op.supports_forward_ad:
|
|
self.skipTest("Skipped! autograd not supported.")
|
|
|
|
# Tests that gradients are computed correctly
|
|
@_gradcheck_ops(op_db)
|
|
def test_fn_grad(self, device, dtype, op):
|
|
# This is verified by test_dtypes in test_ops.py
|
|
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
|
|
self.skipTest("Skipped! Dtype is not in supported backward dtypes!")
|
|
else:
|
|
self._grad_test_helper(device, dtype, op, op.get_op())
|
|
|
|
# Method grad (and gradgrad, see below) tests are disabled since they're
|
|
# costly and redundant with function grad (and gradgad) tests
|
|
# @_gradcheck_ops(op_db)
|
|
# def test_method_grad(self, device, dtype, op):
|
|
# self._skip_helper(op, device, dtype)
|
|
# self._grad_test_helper(device, dtype, op, op.get_method())
|
|
|
|
@_gradcheck_ops(op_db)
|
|
def test_inplace_grad(self, device, dtype, op):
|
|
self._skip_helper(op, device, dtype)
|
|
if not op.inplace_variant:
|
|
self.skipTest("Op has no inplace variant!")
|
|
|
|
# Verifies an operation doesn't support inplace autograd if it claims not to
|
|
if not op.supports_inplace_autograd:
|
|
inplace = self._get_safe_inplace(op.get_inplace())
|
|
for sample in op.sample_inputs(device, dtype, requires_grad=True):
|
|
if sample.broadcasts_input:
|
|
continue
|
|
with self.assertRaises(Exception):
|
|
result = inplace(sample)
|
|
result.sum().backward()
|
|
else:
|
|
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
|
|
|
# Test that gradients of gradients are computed correctly
|
|
@_gradcheck_ops(op_db)
|
|
def test_fn_gradgrad(self, device, dtype, op):
|
|
self._skip_helper(op, device, dtype)
|
|
if not op.supports_gradgrad:
|
|
self.skipTest("Op claims it doesn't support gradgrad. This is not verified.")
|
|
else:
|
|
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
|
|
|
|
# Test that forward-over-reverse gradgrad is computed correctly
|
|
@_gradcheck_ops(op_db)
|
|
def test_fn_fwgrad_bwgrad(self, device, dtype, op):
|
|
self._skip_helper(op, device, dtype)
|
|
|
|
if op.supports_fwgrad_bwgrad:
|
|
self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
|
|
else:
|
|
err_msg = r"Trying to use forward AD with .* that does not support it"
|
|
hint_msg = ("Running forward-over-backward gradgrad for an OP that has does not support it did not "
|
|
"raise any error. If your op supports forward AD, you should set supports_fwgrad_bwgrad=True.")
|
|
with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
|
|
self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
|
|
|
|
# Test that gradients of gradients are properly raising
|
|
@_gradcheck_ops(op_db)
|
|
def test_fn_fail_gradgrad(self, device, dtype, op):
|
|
self._skip_helper(op, device, dtype)
|
|
if op.supports_gradgrad:
|
|
self.skipTest("Skipped! Operation does support gradgrad")
|
|
|
|
err_msg = r"derivative for .* is not implemented"
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
|
|
|
|
# Method gradgrad (and grad, see above) tests are disabled since they're
|
|
# costly and redundant with function gradgrad (and grad) tests
|
|
# @_gradcheck_ops(op_db)
|
|
# def test_method_gradgrad(self, device, dtype, op):
|
|
# self._skip_helper(op, device, dtype)
|
|
# self._gradgrad_test_helper(device, dtype, op, op.get_method())
|
|
|
|
@_gradcheck_ops(op_db)
|
|
def test_inplace_gradgrad(self, device, dtype, op):
|
|
self._skip_helper(op, device, dtype)
|
|
if not op.inplace_variant or not op.supports_inplace_autograd:
|
|
self.skipTest("Skipped! Operation does not support inplace autograd.")
|
|
self._check_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), "bwgrad_bwgrad")
|
|
|
|
def _forward_grad_helper(self, device, dtype, op, variant, is_inplace):
|
|
# TODO: clean up how attributes are passed to gradcheck from OpInfos
|
|
def call_grad_test_helper():
|
|
check_batched_forward_grad = ((op.check_batched_forward_grad and not is_inplace) or
|
|
(op.check_inplace_batched_forward_grad and is_inplace))
|
|
self._grad_test_helper(device, dtype, op, variant, check_forward_ad=True, check_backward_ad=False,
|
|
check_batched_grad=False, check_batched_forward_grad=check_batched_forward_grad)
|
|
if op.supports_forward_ad:
|
|
call_grad_test_helper()
|
|
else:
|
|
err_msg = r"Trying to use forward AD with .* that does not support it"
|
|
hint_msg = ("Running forward AD for an OP that has does not support it did not "
|
|
"raise any error. If your op supports forward AD, you should set supports_forward_ad=True")
|
|
with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
|
|
call_grad_test_helper()
|
|
|
|
@_gradcheck_ops(op_db)
|
|
def test_forward_mode_AD(self, device, dtype, op):
|
|
self._skip_helper(op, device, dtype)
|
|
|
|
self._forward_grad_helper(device, dtype, op, op.get_op(), is_inplace=False)
|
|
|
|
@_gradcheck_ops(op_db)
|
|
def test_inplace_forward_mode_AD(self, device, dtype, op):
|
|
self._skip_helper(op, device, dtype)
|
|
|
|
if not op.inplace_variant or not op.supports_inplace_autograd:
|
|
self.skipTest("Skipped! Operation does not support inplace autograd.")
|
|
|
|
self._forward_grad_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), is_inplace=True)
|
|
|
|
|
|
instantiate_device_type_tests(TestGradients, globals())
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|