Fix backward for reshape() on jagged layout NT (#117137)

Provides symbolic C++-side `reshape_as()` / `reshape()` decomps for jagged layout NTs to make the backwards pass work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117137
Approved by: https://github.com/soulitzer
This commit is contained in:
Joel Schlosser 2024-01-10 15:00:38 -05:00 committed by PyTorch MergeBot
parent e10cfdd895
commit f70aeb4ffd
5 changed files with 44 additions and 18 deletions

View file

@ -4826,7 +4826,7 @@
device_guard: False
dispatch:
CompositeImplicitAutograd: reshape_symint
CompositeImplicitAutogradNestedTensor: reshape_nested
CompositeImplicitAutogradNestedTensor: reshape_nested_symint
- func: _reshape_copy(Tensor self, SymInt[] size) -> Tensor
variants: function

View file

@ -894,7 +894,26 @@ Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
}
}
Tensor reshape_nested_symint(const Tensor& self, SymIntArrayRef proposed_shape) {
// Jagged layout NT decomp
if (self.layout() == at::kJagged) {
// TODO: Expand decomp to handle other viewable cases
bool viewable = self.is_contiguous();
return (
viewable ? self.view_symint(proposed_shape) :
self.clone(at::MemoryFormat::Contiguous).view_symint(proposed_shape)
);
}
return reshape_nested(self, C10_AS_INTARRAYREF_SLOW(proposed_shape));
}
Tensor reshape_as_nested(const Tensor& self, const Tensor& other) {
// Jagged layout NT decomp
if (self.layout() == at::kJagged) {
return self.reshape_symint(other.sym_sizes());
}
auto other_ptr = get_nested_tensor_impl(other);
// TODO: this is to reproduce other_ptr->opt_sizes_
// if an accessor is provided in the future, can replace this

View file

@ -75,5 +75,7 @@ C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_nested_layer_norm_inputs(
return std::make_pair(M, N);
}
Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape);
} // namespace native
} // namespace at

View file

@ -3138,22 +3138,41 @@ class TestNestedTensorSubclass(TestCase):
self.assertEqual(nt.shape[:2], view.shape[:2])
@xfailIfTorchDynamo
def test_reshape_decomp(self, device):
@parametrize("requires_grad", [False, True])
def test_reshape_decomp(self, device, requires_grad):
# contiguous NT should result in view
nt = random_nt_from_dims(
[3, None, 10], device=device, dtype=torch.float32, layout=torch.jagged)
[3, None, 10],
device=device,
dtype=torch.float32,
layout=torch.jagged,
requires_grad=requires_grad
)
view = nt.reshape(-1, -1, 5, 2)
self.assertEqual(view.shape[:2], nt.shape[:2])
self.assertTrue(view._is_view() and view._base is nt)
# make sure gradients flow back
if requires_grad:
view.backward(torch.ones_like(view))
self.assertEqual(nt.grad, torch.ones_like(nt))
# non-contiguous NT should result in contiguous copy
nt = random_nt_from_dims(
[3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged)
[3, None, 5, 2],
device=device,
dtype=torch.float32,
layout=torch.jagged,
requires_grad=requires_grad
)
nt_noncontig = nt.transpose(-1, -2)
self.assertFalse(nt_noncontig.is_contiguous())
copy = nt_noncontig.reshape(-1, -1, 10)
self.assertTrue(copy.is_contiguous())
self.assertEqual(copy.shape[:2], nt.shape[:2])
# make sure gradients flow back
if requires_grad:
copy.backward(torch.ones_like(copy))
self.assertEqual(nt.grad, torch.ones_like(nt))
def test_flatten_decomp(self, device):
nt = random_nt_from_dims(

View file

@ -280,20 +280,6 @@ def jagged_torch_function(func, *args, **kwargs):
if func is torch._C._nn.scaled_dot_product_attention:
return jagged_scaled_dot_product_attention(*args, **kwargs)
# Handle reshape() / reshape_as() here because they're CompositeImplicit.
# TODO: Do the full view determination logic based on computeStride()
if func.__name__ == "reshape":
inp = args[0]
shape = args[1:]
return inp.view(shape) if inp.is_contiguous() else inp.contiguous().view(shape)
if func.__name__ == "reshape_as":
inp = args[0]
other = args[1]
return inp.reshape(*other.shape)
# Handle flatten() here because it's CompositeImplicit.
if func.__name__ == "flatten":