mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Fix backward for reshape() on jagged layout NT (#117137)
Provides symbolic C++-side `reshape_as()` / `reshape()` decomps for jagged layout NTs to make the backwards pass work. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117137 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
e10cfdd895
commit
f70aeb4ffd
5 changed files with 44 additions and 18 deletions
|
|
@ -4826,7 +4826,7 @@
|
|||
device_guard: False
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: reshape_symint
|
||||
CompositeImplicitAutogradNestedTensor: reshape_nested
|
||||
CompositeImplicitAutogradNestedTensor: reshape_nested_symint
|
||||
|
||||
- func: _reshape_copy(Tensor self, SymInt[] size) -> Tensor
|
||||
variants: function
|
||||
|
|
|
|||
|
|
@ -894,7 +894,26 @@ Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor reshape_nested_symint(const Tensor& self, SymIntArrayRef proposed_shape) {
|
||||
// Jagged layout NT decomp
|
||||
if (self.layout() == at::kJagged) {
|
||||
// TODO: Expand decomp to handle other viewable cases
|
||||
bool viewable = self.is_contiguous();
|
||||
return (
|
||||
viewable ? self.view_symint(proposed_shape) :
|
||||
self.clone(at::MemoryFormat::Contiguous).view_symint(proposed_shape)
|
||||
);
|
||||
}
|
||||
|
||||
return reshape_nested(self, C10_AS_INTARRAYREF_SLOW(proposed_shape));
|
||||
}
|
||||
|
||||
Tensor reshape_as_nested(const Tensor& self, const Tensor& other) {
|
||||
// Jagged layout NT decomp
|
||||
if (self.layout() == at::kJagged) {
|
||||
return self.reshape_symint(other.sym_sizes());
|
||||
}
|
||||
|
||||
auto other_ptr = get_nested_tensor_impl(other);
|
||||
// TODO: this is to reproduce other_ptr->opt_sizes_
|
||||
// if an accessor is provided in the future, can replace this
|
||||
|
|
|
|||
|
|
@ -75,5 +75,7 @@ C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_nested_layer_norm_inputs(
|
|||
return std::make_pair(M, N);
|
||||
}
|
||||
|
||||
Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -3138,22 +3138,41 @@ class TestNestedTensorSubclass(TestCase):
|
|||
self.assertEqual(nt.shape[:2], view.shape[:2])
|
||||
|
||||
@xfailIfTorchDynamo
|
||||
def test_reshape_decomp(self, device):
|
||||
@parametrize("requires_grad", [False, True])
|
||||
def test_reshape_decomp(self, device, requires_grad):
|
||||
# contiguous NT should result in view
|
||||
nt = random_nt_from_dims(
|
||||
[3, None, 10], device=device, dtype=torch.float32, layout=torch.jagged)
|
||||
[3, None, 10],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
layout=torch.jagged,
|
||||
requires_grad=requires_grad
|
||||
)
|
||||
view = nt.reshape(-1, -1, 5, 2)
|
||||
self.assertEqual(view.shape[:2], nt.shape[:2])
|
||||
self.assertTrue(view._is_view() and view._base is nt)
|
||||
# make sure gradients flow back
|
||||
if requires_grad:
|
||||
view.backward(torch.ones_like(view))
|
||||
self.assertEqual(nt.grad, torch.ones_like(nt))
|
||||
|
||||
# non-contiguous NT should result in contiguous copy
|
||||
nt = random_nt_from_dims(
|
||||
[3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged)
|
||||
[3, None, 5, 2],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
layout=torch.jagged,
|
||||
requires_grad=requires_grad
|
||||
)
|
||||
nt_noncontig = nt.transpose(-1, -2)
|
||||
self.assertFalse(nt_noncontig.is_contiguous())
|
||||
copy = nt_noncontig.reshape(-1, -1, 10)
|
||||
self.assertTrue(copy.is_contiguous())
|
||||
self.assertEqual(copy.shape[:2], nt.shape[:2])
|
||||
# make sure gradients flow back
|
||||
if requires_grad:
|
||||
copy.backward(torch.ones_like(copy))
|
||||
self.assertEqual(nt.grad, torch.ones_like(nt))
|
||||
|
||||
def test_flatten_decomp(self, device):
|
||||
nt = random_nt_from_dims(
|
||||
|
|
|
|||
|
|
@ -280,20 +280,6 @@ def jagged_torch_function(func, *args, **kwargs):
|
|||
if func is torch._C._nn.scaled_dot_product_attention:
|
||||
return jagged_scaled_dot_product_attention(*args, **kwargs)
|
||||
|
||||
# Handle reshape() / reshape_as() here because they're CompositeImplicit.
|
||||
# TODO: Do the full view determination logic based on computeStride()
|
||||
if func.__name__ == "reshape":
|
||||
inp = args[0]
|
||||
shape = args[1:]
|
||||
|
||||
return inp.view(shape) if inp.is_contiguous() else inp.contiguous().view(shape)
|
||||
|
||||
if func.__name__ == "reshape_as":
|
||||
inp = args[0]
|
||||
other = args[1]
|
||||
|
||||
return inp.reshape(*other.shape)
|
||||
|
||||
# Handle flatten() here because it's CompositeImplicit.
|
||||
if func.__name__ == "flatten":
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue