diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4960417abdb..bbdd4f8e1d4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4826,7 +4826,7 @@ device_guard: False dispatch: CompositeImplicitAutograd: reshape_symint - CompositeImplicitAutogradNestedTensor: reshape_nested + CompositeImplicitAutogradNestedTensor: reshape_nested_symint - func: _reshape_copy(Tensor self, SymInt[] size) -> Tensor variants: function diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index c4bc824fdb3..442660ef8d7 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -894,7 +894,26 @@ Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape) { } } +Tensor reshape_nested_symint(const Tensor& self, SymIntArrayRef proposed_shape) { + // Jagged layout NT decomp + if (self.layout() == at::kJagged) { + // TODO: Expand decomp to handle other viewable cases + bool viewable = self.is_contiguous(); + return ( + viewable ? self.view_symint(proposed_shape) : + self.clone(at::MemoryFormat::Contiguous).view_symint(proposed_shape) + ); + } + + return reshape_nested(self, C10_AS_INTARRAYREF_SLOW(proposed_shape)); +} + Tensor reshape_as_nested(const Tensor& self, const Tensor& other) { + // Jagged layout NT decomp + if (self.layout() == at::kJagged) { + return self.reshape_symint(other.sym_sizes()); + } + auto other_ptr = get_nested_tensor_impl(other); // TODO: this is to reproduce other_ptr->opt_sizes_ // if an accessor is provided in the future, can replace this diff --git a/aten/src/ATen/native/nested/NestedTensorMath.h b/aten/src/ATen/native/nested/NestedTensorMath.h index 82699859909..068cc6b51ee 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.h +++ b/aten/src/ATen/native/nested/NestedTensorMath.h @@ -75,5 +75,7 @@ C10_ALWAYS_INLINE std::pair _check_nested_layer_norm_inputs( return std::make_pair(M, N); } +Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape); + } // namespace native } // namespace at diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index e9deadbb251..65cdc0791e6 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -3138,22 +3138,41 @@ class TestNestedTensorSubclass(TestCase): self.assertEqual(nt.shape[:2], view.shape[:2]) @xfailIfTorchDynamo - def test_reshape_decomp(self, device): + @parametrize("requires_grad", [False, True]) + def test_reshape_decomp(self, device, requires_grad): # contiguous NT should result in view nt = random_nt_from_dims( - [3, None, 10], device=device, dtype=torch.float32, layout=torch.jagged) + [3, None, 10], + device=device, + dtype=torch.float32, + layout=torch.jagged, + requires_grad=requires_grad + ) view = nt.reshape(-1, -1, 5, 2) self.assertEqual(view.shape[:2], nt.shape[:2]) self.assertTrue(view._is_view() and view._base is nt) + # make sure gradients flow back + if requires_grad: + view.backward(torch.ones_like(view)) + self.assertEqual(nt.grad, torch.ones_like(nt)) # non-contiguous NT should result in contiguous copy nt = random_nt_from_dims( - [3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged) + [3, None, 5, 2], + device=device, + dtype=torch.float32, + layout=torch.jagged, + requires_grad=requires_grad + ) nt_noncontig = nt.transpose(-1, -2) self.assertFalse(nt_noncontig.is_contiguous()) copy = nt_noncontig.reshape(-1, -1, 10) self.assertTrue(copy.is_contiguous()) self.assertEqual(copy.shape[:2], nt.shape[:2]) + # make sure gradients flow back + if requires_grad: + copy.backward(torch.ones_like(copy)) + self.assertEqual(nt.grad, torch.ones_like(nt)) def test_flatten_decomp(self, device): nt = random_nt_from_dims( diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index b91812a3c2d..3aacd3eb2f4 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -280,20 +280,6 @@ def jagged_torch_function(func, *args, **kwargs): if func is torch._C._nn.scaled_dot_product_attention: return jagged_scaled_dot_product_attention(*args, **kwargs) - # Handle reshape() / reshape_as() here because they're CompositeImplicit. - # TODO: Do the full view determination logic based on computeStride() - if func.__name__ == "reshape": - inp = args[0] - shape = args[1:] - - return inp.view(shape) if inp.is_contiguous() else inp.contiguous().view(shape) - - if func.__name__ == "reshape_as": - inp = args[0] - other = args[1] - - return inp.reshape(*other.shape) - # Handle flatten() here because it's CompositeImplicit. if func.__name__ == "flatten":