mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Disable slow gradcheck for nn.Transformer ModuleInfo (#145531)
Fixes https://github.com/pytorch/pytorch/issues/117140 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145531 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #145520
This commit is contained in:
parent
9e0ee152e5
commit
c7ca1df37e
2 changed files with 19 additions and 4 deletions
|
|
@ -482,11 +482,19 @@ class TestModule(TestCase):
|
|||
output_flattened = torch.utils._pytree.tree_leaves(output)
|
||||
return output_flattened
|
||||
|
||||
def do_check(flat_input):
|
||||
self.assertTrue(
|
||||
check(
|
||||
fn_to_gradcheck,
|
||||
flat_input,
|
||||
nondet_tol=gradcheck_nondet_tol,
|
||||
fast_mode=module_info.gradcheck_fast_mode
|
||||
))
|
||||
|
||||
# check total derivative
|
||||
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
|
||||
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
|
||||
|
||||
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
|
||||
do_check(flat_input)
|
||||
|
||||
# check partial derivatives
|
||||
old_params_requires_grad = [p.requires_grad for p in params]
|
||||
|
|
@ -501,14 +509,14 @@ class TestModule(TestCase):
|
|||
p.requires_grad = old
|
||||
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
|
||||
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
|
||||
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
|
||||
do_check(flat_input)
|
||||
p.requires_grad = False
|
||||
|
||||
for (_, obj), old in zip(kwarg_tensors, old_kwargs_requires_grad):
|
||||
obj.requires_grad = old
|
||||
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
|
||||
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
|
||||
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
|
||||
do_check(flat_input)
|
||||
obj.requires_grad = False
|
||||
|
||||
@modules(module_db, allowed_dtypes=[torch.double])
|
||||
|
|
|
|||
|
|
@ -222,6 +222,9 @@ class ModuleInfo:
|
|||
# channels last output
|
||||
train_and_eval_differ=False, # whether the module has differing behavior between train and eval
|
||||
module_error_inputs_func=None, # Function to generate module inputs that error
|
||||
gradcheck_fast_mode=None, # Whether to use the fast implmentation for gradcheck/gradgradcheck.
|
||||
# When set to None, defers to the default value provided by the wrapper
|
||||
# function around gradcheck (testing._internal.common_utils.gradcheck)
|
||||
):
|
||||
self.module_cls = module_cls
|
||||
self.module_inputs_func = module_inputs_func
|
||||
|
|
@ -234,6 +237,7 @@ class ModuleInfo:
|
|||
self.module_memformat_affects_out = module_memformat_affects_out
|
||||
self.train_and_eval_differ = train_and_eval_differ
|
||||
self.module_error_inputs_func = module_error_inputs_func
|
||||
self.gradcheck_fast_mode = gradcheck_fast_mode
|
||||
self.is_lazy = issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin)
|
||||
|
||||
def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
|
||||
|
|
@ -4179,6 +4183,9 @@ module_db: list[ModuleInfo] = [
|
|||
),
|
||||
ModuleInfo(torch.nn.Transformer,
|
||||
module_inputs_func=module_inputs_torch_nn_Transformer,
|
||||
# Inputs are too large to run with slow gradcheck
|
||||
# https://github.com/pytorch/pytorch/issues/117140
|
||||
gradcheck_fast_mode=True,
|
||||
decorators=[
|
||||
# Not implemented for SDPA backward derivative
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
|
||||
|
|
|
|||
Loading…
Reference in a new issue