From d7f45fc575fc77fa4829038fda9cad8a90b79c64 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 15 Jan 2025 11:11:32 -0800 Subject: [PATCH] dynamic shape support for interpolate(antialias=True) backward (#141198) Fixes https://github.com/pytorch/pytorch/issues/141187 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141198 Approved by: https://github.com/ezyang, https://github.com/Chillee ghstack dependencies: #141161 --- test/functorch/test_aotdispatch.py | 3 -- test/test_fake_tensor.py | 24 +++++++++++++++- torch/_meta_registrations.py | 28 +++++++++++++++++++ .../_internal/common_methods_invocations.py | 1 - 4 files changed, 51 insertions(+), 5 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 57b15d41587..50ef291417b 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -6501,9 +6501,6 @@ symbolic_aot_autograd_failures = { "nn.functional.nll_loss", "" ), # Cannot call sizes() on tensor with symbolic sizes/strides xfail("trace", ""), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail( - "_upsample_bilinear2d_aa" - ), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList decorate( "linalg.householder_product", decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"), diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index f6304380197..e94bf27a978 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1954,7 +1954,6 @@ class FakeTensorDispatchCache(TestCase): extract_tensor_metadata(res4), ) - @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_wrapper_tensor_subclass_different_device(self): class DifferentDeviceTensor(torch.Tensor): @@ -2007,6 +2006,29 @@ class FakeTensorDispatchCache(TestCase): assert isinstance(fake_wrapped_a, DifferentDeviceTensor) self.assertFalse(fake_wrapped_a.inner_tensor.is_cpu) + def test__upsample_bilinear2d_aa_backward_dynamic_shapes(self): + def f(x): + return torch.nn.functional.interpolate( + x, + size=[256, 256], + mode='bilinear', + align_corners=False, + antialias=True, + ) + + shape_env = ShapeEnv() + fake_m = FakeTensorMode(shape_env=shape_env) + x = fake_m.from_tensor( + torch.randn(1, 3, 2005, 1920, requires_grad=True), + symbolic_context=StatelessSymbolicContext( + dynamic_sizes=[DimDynamic.STATIC, DimDynamic.STATIC, DimDynamic.DYNAMIC, DimDynamic.DYNAMIC], + constraint_sizes=[None, None, None, None] + ), + ) + with fake_m, enable_python_dispatcher(): + out = f(x) + out.sum().backward() + self.assertEqual(x.shape, x.grad.shape) def test_cache_tuple_outputs(self): """ diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 8108c427880..9a4022c88d6 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -6535,6 +6535,34 @@ def meta_upsample_bimode2d_aa( ) +@register_meta([aten._upsample_bilinear2d_aa_backward.default]) +def meta_upsample_bimode2d_aa_backward( + grad_output, + output_size, + input_size, + align_corners, + scales_h=None, + scales_w=None, +): + full_output_size = upsample_common_check( + input_size, output_size, num_spatial_dims=2 + ) + torch._check( + grad_output.ndim == 4, + lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}", + ) + for i in range(4): + torch._check( + grad_output.shape[i] == full_output_size[i], + lambda: f""" +Expected grad_output to have the same shape as output; output.size({i}) = {full_output_size[i]} +but got grad_output_size({i}) = {grad_output.size(i)}""", + ) + return grad_output.new_empty(input_size).to( + memory_format=utils.suggest_memory_format(grad_output) + ) + + # From aten/src/ATen/native/cuda/AmpKernels.cu @register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default) def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale): diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 0b6d5f9660b..c6597318737 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15660,7 +15660,6 @@ op_db: List[OpInfo] = [ skips=( DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), - DecorateInfo(unittest.expectedFailure, 'TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive'), DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), )),