diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index eef93413912..51e4df1412a 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1473,8 +1473,15 @@ def forward(self, primals_1, primals_2): self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) - self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True) - with self.assertRaisesRegex(AssertionError, "attempted to compile the backward with incorrect subclass metadata"): + with self.assertRaisesRegex( + RuntimeError, + "Mutations on non-contiguous inputs are currently not allowed on tensor subclasses" + ): + self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True) + with self.assertRaisesRegex( + RuntimeError, + "Mutations on non-contiguous inputs are currently not allowed on tensor subclasses" + ): self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True) # Mutations in the backward are allowed as long as the mutated object does not require grad @@ -1573,7 +1580,7 @@ def forward(self, primals_1, primals_2): def inp_callable1(req_grad): base = torch.ones(4, 4, requires_grad=req_grad) x = base.add(1) - # create two non-contiguous views that share storage, but are actually non-overlapping + # create two views that share storage, but are actually non-overlapping a = x[0:2] b = x[2:4] return [base], [a, b] diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index d35e76e21c2..c04f724227b 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -917,6 +917,13 @@ def forward(self, x_a_1, x_b_1, y_1): return (mul, mul_1, add) """) + # See https://github.com/pytorch/pytorch/issues/117794 + def test_return_and_correct_aliasing_gives_correct_stride(self): + t = TwoTensor(torch.randn(2, 2), torch.randn(2, 2)) + x = torch.randn(2, 2) + # slicing should result in the same stride for TwoTensor as a dense tensor would give + self.assertEqual(t[:, 0].stride(), x[:, 0].stride()) + def test_make_wrapper_subclass_propagates_metadata(self) -> None: class WrapperTensor(torch.Tensor): elem: torch.Tensor diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 4ff5c1e89ba..a7097591390 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -121,6 +121,21 @@ def run_functionalized_fw_and_collect_metadata( # Inspect the state of the input tensor functional wrapper to detect input mutation info # If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)): + # NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in + # strides between the functionalized arg inner tensors and non-functionalized arg inner + # tensors. This is a problem as the inner tensor stride change may not be reflected + # correctly in the outer tensor, so disallow this for now. + mutates_data = has_data_mutation(f_arg) + if ( + mutates_data + and not arg.is_contiguous() + and is_traceable_wrapper_subclass(arg) + ): + raise RuntimeError( + "Mutations on non-contiguous inputs are currently not allowed on " + "tensor subclasses" + ) + if not isinstance(arg, Tensor): new_arg = arg else: @@ -135,7 +150,6 @@ def run_functionalized_fw_and_collect_metadata( mutates_storage_metadata = has_metadata_mutation( f_arg, arg, check_only_storage_mutation=True ) - mutates_data = has_data_mutation(f_arg) mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd( f_arg ) diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 434a88ab292..8f2413be6b5 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -268,11 +268,12 @@ and output of type {type(ret)}. But expected types to match.""" if isinstance(ret, list): for r in ret: - torch.ops.aten.set_.source_Storage_storage_offset(r, arg.untyped_storage(), r.storage_offset(), r.shape) + torch.ops.aten.set_.source_Storage_storage_offset( + r, arg.untyped_storage(), r.storage_offset(), r.shape, r.stride()) else: assert isinstance(ret, torch.Tensor), f"type: {type(ret)}" torch.ops.aten.set_.source_Storage_storage_offset( - ret, arg.untyped_storage(), ret.storage_offset(), ret.shape + ret, arg.untyped_storage(), ret.storage_offset(), ret.shape, ret.stride() ) finally: torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)