mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix to keep stride in return_and_correct_aliasing() (#117860)
Fixes #117794
Fix tripped the assert here: 86dedebeaf/torch/utils/_python_dispatch.py (L216)
From investigation: I found that functionalization of an in-place op (`mul_` in this test case) results in the strides of `TwoTensor`'s `a` / `b` components being mutated to be contiguous. This is not reflected in the outer tensor, causing the assert to be tripped.
After discussion with Brian, I address this in this PR by disallowing input mutations on non-contiguous tensor subclass inputs for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117860
Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
fa77829126
commit
e7eab2f07e
4 changed files with 35 additions and 6 deletions
|
|
@ -1473,8 +1473,15 @@ def forward(self, primals_1, primals_2):
|
|||
|
||||
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True)
|
||||
self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True)
|
||||
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True)
|
||||
with self.assertRaisesRegex(AssertionError, "attempted to compile the backward with incorrect subclass metadata"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Mutations on non-contiguous inputs are currently not allowed on tensor subclasses"
|
||||
):
|
||||
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Mutations on non-contiguous inputs are currently not allowed on tensor subclasses"
|
||||
):
|
||||
self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True)
|
||||
|
||||
# Mutations in the backward are allowed as long as the mutated object does not require grad
|
||||
|
|
@ -1573,7 +1580,7 @@ def forward(self, primals_1, primals_2):
|
|||
def inp_callable1(req_grad):
|
||||
base = torch.ones(4, 4, requires_grad=req_grad)
|
||||
x = base.add(1)
|
||||
# create two non-contiguous views that share storage, but are actually non-overlapping
|
||||
# create two views that share storage, but are actually non-overlapping
|
||||
a = x[0:2]
|
||||
b = x[2:4]
|
||||
return [base], [a, b]
|
||||
|
|
|
|||
|
|
@ -917,6 +917,13 @@ def forward(self, x_a_1, x_b_1, y_1):
|
|||
return (mul, mul_1, add)
|
||||
""")
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/117794
|
||||
def test_return_and_correct_aliasing_gives_correct_stride(self):
|
||||
t = TwoTensor(torch.randn(2, 2), torch.randn(2, 2))
|
||||
x = torch.randn(2, 2)
|
||||
# slicing should result in the same stride for TwoTensor as a dense tensor would give
|
||||
self.assertEqual(t[:, 0].stride(), x[:, 0].stride())
|
||||
|
||||
def test_make_wrapper_subclass_propagates_metadata(self) -> None:
|
||||
class WrapperTensor(torch.Tensor):
|
||||
elem: torch.Tensor
|
||||
|
|
|
|||
|
|
@ -121,6 +121,21 @@ def run_functionalized_fw_and_collect_metadata(
|
|||
# Inspect the state of the input tensor functional wrapper to detect input mutation info
|
||||
# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
|
||||
for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)):
|
||||
# NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in
|
||||
# strides between the functionalized arg inner tensors and non-functionalized arg inner
|
||||
# tensors. This is a problem as the inner tensor stride change may not be reflected
|
||||
# correctly in the outer tensor, so disallow this for now.
|
||||
mutates_data = has_data_mutation(f_arg)
|
||||
if (
|
||||
mutates_data
|
||||
and not arg.is_contiguous()
|
||||
and is_traceable_wrapper_subclass(arg)
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Mutations on non-contiguous inputs are currently not allowed on "
|
||||
"tensor subclasses"
|
||||
)
|
||||
|
||||
if not isinstance(arg, Tensor):
|
||||
new_arg = arg
|
||||
else:
|
||||
|
|
@ -135,7 +150,6 @@ def run_functionalized_fw_and_collect_metadata(
|
|||
mutates_storage_metadata = has_metadata_mutation(
|
||||
f_arg, arg, check_only_storage_mutation=True
|
||||
)
|
||||
mutates_data = has_data_mutation(f_arg)
|
||||
mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd(
|
||||
f_arg
|
||||
)
|
||||
|
|
|
|||
|
|
@ -268,11 +268,12 @@ and output of type {type(ret)}. But expected types to match."""
|
|||
|
||||
if isinstance(ret, list):
|
||||
for r in ret:
|
||||
torch.ops.aten.set_.source_Storage_storage_offset(r, arg.untyped_storage(), r.storage_offset(), r.shape)
|
||||
torch.ops.aten.set_.source_Storage_storage_offset(
|
||||
r, arg.untyped_storage(), r.storage_offset(), r.shape, r.stride())
|
||||
else:
|
||||
assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
|
||||
torch.ops.aten.set_.source_Storage_storage_offset(
|
||||
ret, arg.untyped_storage(), ret.storage_offset(), ret.shape
|
||||
ret, arg.untyped_storage(), ret.storage_offset(), ret.shape, ret.stride()
|
||||
)
|
||||
finally:
|
||||
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
|
||||
|
|
|
|||
Loading…
Reference in a new issue