Disable slow gradcheck for nn.Transformer ModuleInfo (#145531)

Fixes https://github.com/pytorch/pytorch/issues/117140

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145531
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #145520
This commit is contained in:
soulitzer 2025-01-24 13:36:24 -05:00 committed by PyTorch MergeBot
parent 9e0ee152e5
commit c7ca1df37e
2 changed files with 19 additions and 4 deletions

View file

@ -482,11 +482,19 @@ class TestModule(TestCase):
output_flattened = torch.utils._pytree.tree_leaves(output)
return output_flattened
def do_check(flat_input):
self.assertTrue(
check(
fn_to_gradcheck,
flat_input,
nondet_tol=gradcheck_nondet_tol,
fast_mode=module_info.gradcheck_fast_mode
))
# check total derivative
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
do_check(flat_input)
# check partial derivatives
old_params_requires_grad = [p.requires_grad for p in params]
@ -501,14 +509,14 @@ class TestModule(TestCase):
p.requires_grad = old
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
do_check(flat_input)
p.requires_grad = False
for (_, obj), old in zip(kwarg_tensors, old_kwargs_requires_grad):
obj.requires_grad = old
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
do_check(flat_input)
obj.requires_grad = False
@modules(module_db, allowed_dtypes=[torch.double])

View file

@ -222,6 +222,9 @@ class ModuleInfo:
# channels last output
train_and_eval_differ=False, # whether the module has differing behavior between train and eval
module_error_inputs_func=None, # Function to generate module inputs that error
gradcheck_fast_mode=None, # Whether to use the fast implmentation for gradcheck/gradgradcheck.
# When set to None, defers to the default value provided by the wrapper
# function around gradcheck (testing._internal.common_utils.gradcheck)
):
self.module_cls = module_cls
self.module_inputs_func = module_inputs_func
@ -234,6 +237,7 @@ class ModuleInfo:
self.module_memformat_affects_out = module_memformat_affects_out
self.train_and_eval_differ = train_and_eval_differ
self.module_error_inputs_func = module_error_inputs_func
self.gradcheck_fast_mode = gradcheck_fast_mode
self.is_lazy = issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin)
def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
@ -4179,6 +4183,9 @@ module_db: list[ModuleInfo] = [
),
ModuleInfo(torch.nn.Transformer,
module_inputs_func=module_inputs_torch_nn_Transformer,
# Inputs are too large to run with slow gradcheck
# https://github.com/pytorch/pytorch/issues/117140
gradcheck_fast_mode=True,
decorators=[
# Not implemented for SDPA backward derivative
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',