Fix to keep stride in return_and_correct_aliasing() (#117860)

Fixes #117794

Fix tripped the assert here: 86dedebeaf/torch/utils/_python_dispatch.py (L216)

From investigation: I found that functionalization of an in-place op (`mul_` in this test case) results in the strides of `TwoTensor`'s `a` / `b` components being mutated to be contiguous. This is not reflected in the outer tensor, causing the assert to be tripped.

After discussion with Brian, I address this in this PR by disallowing input mutations on non-contiguous tensor subclass inputs for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117860
Approved by: https://github.com/bdhirsh
This commit is contained in:
Joel Schlosser 2024-02-21 11:59:06 -05:00 committed by PyTorch MergeBot
parent fa77829126
commit e7eab2f07e
4 changed files with 35 additions and 6 deletions

View file

@ -1473,8 +1473,15 @@ def forward(self, primals_1, primals_2):
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True)
self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True)
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True)
with self.assertRaisesRegex(AssertionError, "attempted to compile the backward with incorrect subclass metadata"):
with self.assertRaisesRegex(
RuntimeError,
"Mutations on non-contiguous inputs are currently not allowed on tensor subclasses"
):
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True)
with self.assertRaisesRegex(
RuntimeError,
"Mutations on non-contiguous inputs are currently not allowed on tensor subclasses"
):
self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True)
# Mutations in the backward are allowed as long as the mutated object does not require grad
@ -1573,7 +1580,7 @@ def forward(self, primals_1, primals_2):
def inp_callable1(req_grad):
base = torch.ones(4, 4, requires_grad=req_grad)
x = base.add(1)
# create two non-contiguous views that share storage, but are actually non-overlapping
# create two views that share storage, but are actually non-overlapping
a = x[0:2]
b = x[2:4]
return [base], [a, b]

View file

@ -917,6 +917,13 @@ def forward(self, x_a_1, x_b_1, y_1):
return (mul, mul_1, add)
""")
# See https://github.com/pytorch/pytorch/issues/117794
def test_return_and_correct_aliasing_gives_correct_stride(self):
t = TwoTensor(torch.randn(2, 2), torch.randn(2, 2))
x = torch.randn(2, 2)
# slicing should result in the same stride for TwoTensor as a dense tensor would give
self.assertEqual(t[:, 0].stride(), x[:, 0].stride())
def test_make_wrapper_subclass_propagates_metadata(self) -> None:
class WrapperTensor(torch.Tensor):
elem: torch.Tensor

View file

@ -121,6 +121,21 @@ def run_functionalized_fw_and_collect_metadata(
# Inspect the state of the input tensor functional wrapper to detect input mutation info
# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)):
# NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in
# strides between the functionalized arg inner tensors and non-functionalized arg inner
# tensors. This is a problem as the inner tensor stride change may not be reflected
# correctly in the outer tensor, so disallow this for now.
mutates_data = has_data_mutation(f_arg)
if (
mutates_data
and not arg.is_contiguous()
and is_traceable_wrapper_subclass(arg)
):
raise RuntimeError(
"Mutations on non-contiguous inputs are currently not allowed on "
"tensor subclasses"
)
if not isinstance(arg, Tensor):
new_arg = arg
else:
@ -135,7 +150,6 @@ def run_functionalized_fw_and_collect_metadata(
mutates_storage_metadata = has_metadata_mutation(
f_arg, arg, check_only_storage_mutation=True
)
mutates_data = has_data_mutation(f_arg)
mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd(
f_arg
)

View file

@ -268,11 +268,12 @@ and output of type {type(ret)}. But expected types to match."""
if isinstance(ret, list):
for r in ret:
torch.ops.aten.set_.source_Storage_storage_offset(r, arg.untyped_storage(), r.storage_offset(), r.shape)
torch.ops.aten.set_.source_Storage_storage_offset(
r, arg.untyped_storage(), r.storage_offset(), r.shape, r.stride())
else:
assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
torch.ops.aten.set_.source_Storage_storage_offset(
ret, arg.untyped_storage(), ret.storage_offset(), ret.shape
ret, arg.untyped_storage(), ret.storage_offset(), ret.shape, ret.stride()
)
finally:
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)