fix functionalization regression introduced by ProxyTorchDispatchMode, migrate testing to make_fx (#80416)

`ProxyTorchDispatchMode` was added recently as part of `make_fx`, which was secretly causing the meta tensor calls used inside of functionalization to get baked into the graph. It also wasn't caught because the functionalization tests in core don't use `make_fx`, and the tests in functorch aren't as comprehensive.

Now that `make_fx` is in core, I also ported the functionalization test infra over to use it, which would have caught the regression. This also makes the tests cleaner, since mode-based tracing lets us pick up factory functions in the trace output.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80416
Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
Brian Hirsh 2022-07-11 15:11:57 -07:00 committed by PyTorch MergeBot
parent 5f8c2076df
commit f84b30f790
3 changed files with 374 additions and 177 deletions

View file

@ -23,6 +23,19 @@ $ops_headers
namespace at {
namespace functionalization {
// This keyset is used by functionalization when it calls into meta kernels
// to accurately propagate stride metadata.
// Exclude any modes: the purpose of calling into meta kernels is only as an implementation
// detail to perform shape inference, and we don't want any modal keys to run.
// Specifically, we want to prevent functionalization and Python modes from running.
constexpr auto exclude_keys_for_meta_dispatch =
c10::functorch_transforms_ks |
c10::DispatchKeySet({
c10::DispatchKey::FuncTorchDynamicLayerBackMode,
c10::DispatchKey::FuncTorchDynamicLayerFrontMode,
c10::DispatchKey::Python
});
inline Tensor to_meta(const Tensor& t) {
return at::native::empty_strided_meta(t.sizes(), t.strides(),

View file

@ -2,8 +2,9 @@
import torch
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, capture_logs, log_input
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, capture_logs
from torch.utils._pytree import tree_map
from torch.fx.experimental.proxy_tensor import make_fx
import logging
@ -58,19 +59,26 @@ class InplaceLoggingTensor(LoggingTensorReentrant):
class TestFunctionalization(TestCase):
def get_logs(self, func, inpt, *, reapply_views=False):
input_clone_logging = LoggingTensor(inpt.clone())
input_functional_logging = torch._to_functional_tensor(input_clone_logging)
with capture_logs() as logs:
log_input("input", input_clone_logging)
# We can unify testing and use functionalize() here instead
# if/when functorch moves into core.
def _functionalize(self, f, *, reapply_views: bool):
def wrapped(a):
input_functional = torch._to_functional_tensor(a)
torch._enable_functionalization(reapply_views=reapply_views)
try:
func(input_functional_logging)
out = f(input_functional)
finally:
torch._disable_functionalization()
return logs
torch._sync(input_functional)
tree_map(torch._sync, out)
out_unwrapped = tree_map(torch._from_functional_tensor, out)
return out_unwrapped
return wrapped
def get_logs(self, func, inpt, *, reapply_views=False):
traced_f = make_fx(self._functionalize(func, reapply_views=reapply_views))(inpt)
return traced_f.code
def assert_functionalization(self, func, inpt, *, reapply_views=False):
input_clone = inpt.clone()
@ -124,15 +132,19 @@ class TestFunctionalization(TestCase):
return y
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.view_copy.default($0, [4, 2])
$2 = torch._ops.aten.add.Tensor($1, tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]]))
$3 = torch._ops.aten.view_copy.default($2, [4, 2])
$4 = torch._ops.aten.mul.Tensor($3, $3)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2])
mul_tensor = torch.ops.aten.mul.Tensor(view_copy_default_1, view_copy_default_1); view_copy_default_1 = None
return add_tensor
""")
def test_simple_out(self):
def f(x):
@ -145,14 +157,19 @@ $4 = torch._ops.aten.mul.Tensor($3, $3)""")
return w
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.view_copy.default($0, [4, 2])
$2 = torch._ops.aten.add.Tensor($1, tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]]))
$3 = torch._ops.aten.mul.Tensor($2, $2)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
empty_1 = torch.ops.aten.empty.SymInt([], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None
return mul_tensor
""")
def test_multi_out(self):
def f(x):
@ -164,9 +181,18 @@ $3 = torch._ops.aten.mul.Tensor($2, $2)""")
return out_max
self.assert_functionalization(f, torch.arange(8, dtype=torch.float32))
logs = self.get_logs(f, torch.arange(8, dtype=torch.float32))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1, $2 = torch._ops.aten.aminmax.default($0, dim=0)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.SymInt([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
empty_1 = torch.ops.aten.empty.SymInt([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
aminmax_default = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
getitem = aminmax_default[0]
getitem_1 = aminmax_default[1]; aminmax_default = None
return getitem
""")
def test_tensor_ctr(self):
def f(x):
@ -186,13 +212,18 @@ $1, $2 = torch._ops.aten.aminmax.default($0, dim=0)""")
return y
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.view_copy.default($0, [4, 2])
$2 = torch._ops.aten.add.Tensor($0, tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]]))""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(a_1, fill_scalar); a_1 = fill_scalar = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
return view_copy_default_1
""")
# Some ops that are mutable are neither inplace nor out= ops.
# They also need special handling.
@ -201,9 +232,20 @@ $2 = torch._ops.aten.add.Tensor($0, tensor([[1., 1.],
return torch._fused_moving_avg_obs_fq_helper(x, x, x, x, x, x, x, 1.0, 0, 1, 0)
logs = self.get_logs(f, torch.ones(1))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1, $2, $3, $4, $5, $6 = torch._ops.aten._fused_moving_avg_obs_fq_helper_functional.default($0, $0, $0, $0, $0, $0, $0, 1.0, 0, 1, 0)""") # noqa: B950
self.assertExpectedInline(logs, """\
def forward(self, a_1):
_fused_moving_avg_obs_fq_helper_functional_default = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(a_1, a_1, a_1, a_1, a_1, a_1, a_1, 1.0, 0, 1, 0); a_1 = None
getitem = _fused_moving_avg_obs_fq_helper_functional_default[0]
getitem_1 = _fused_moving_avg_obs_fq_helper_functional_default[1]
getitem_2 = _fused_moving_avg_obs_fq_helper_functional_default[2]
getitem_3 = _fused_moving_avg_obs_fq_helper_functional_default[3]
getitem_4 = _fused_moving_avg_obs_fq_helper_functional_default[4]
getitem_5 = _fused_moving_avg_obs_fq_helper_functional_default[5]; _fused_moving_avg_obs_fq_helper_functional_default = None
return (getitem, getitem_1)
""") # noqa: B950
def test_as_strided(self):
def f(x):
@ -212,10 +254,16 @@ $1, $2, $3, $4, $5, $6 = torch._ops.aten._fused_moving_avg_obs_fq_helper_functio
return x
self.assert_functionalization(f, torch.ones(9))
logs = self.get_logs(f, torch.ones(9))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.as_strided_copy.default($0, [2], [2], 1)
$2 = torch._ops.aten.add.Tensor($1, 1)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
as_strided_copy_default = torch.ops.aten.as_strided_copy.default(a_1, [2], [2], 1)
add_tensor = torch.ops.aten.add.Tensor(as_strided_copy_default, 1); as_strided_copy_default = None
as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(a_1, add_tensor, [2], [2], 1); a_1 = add_tensor = None
return as_strided_scatter_default
""")
def test_tensor_list_composite(self):
def f(x):
@ -224,11 +272,14 @@ $2 = torch._ops.aten.add.Tensor($1, 1)""")
return y
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.block_diag.default([LoggingTensor(tensor([[1., 1.],
[1., 1.]])), LoggingTensor(tensor([[1., 1.],
[1., 1.]]))])""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
block_diag_default = torch.ops.aten.block_diag.default([a_1, a_1]); a_1 = None
return block_diag_default
""")
def test_cat(self):
def f(x):
@ -237,10 +288,15 @@ $1 = torch._ops.aten.block_diag.default([LoggingTensor(tensor([[1., 1.],
return out
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.cat.default([LoggingTensor(tensor([[1., 1.],
[1., 1.]]))])""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.SymInt([0], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
cat_default = torch.ops.aten.cat.default([a_1]); a_1 = None
return cat_default
""")
def test_diagonal(self):
def f(x):
@ -252,12 +308,19 @@ $1 = torch._ops.aten.cat.default([LoggingTensor(tensor([[1., 1.],
return z
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.diagonal_copy.default($0)
$2 = torch._ops.aten.add.Tensor($1, tensor([1., 1.]))
$3 = torch._ops.aten.diagonal_scatter.default($0, $2)
$4 = torch._ops.aten.mul.Tensor($3, $3)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(a_1)
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, fill_scalar); diagonal_copy_default = fill_scalar = None
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(a_1, add_tensor); a_1 = add_tensor = None
mul_tensor = torch.ops.aten.mul.Tensor(diagonal_scatter_default, diagonal_scatter_default); diagonal_scatter_default = None
return mul_tensor
""")
def test_diagonal_mutated_input(self):
def f(x):
@ -280,15 +343,26 @@ $4 = torch._ops.aten.mul.Tensor($3, $3)""")
return y3
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1, $2 = torch._ops.aten.split_copy.Tensor($0, 2)
$3 = torch._ops.aten.diagonal_copy.default($2)
$4 = torch._ops.aten.add.Tensor($3, tensor([1., 1.]))
$5, $6 = torch._ops.aten.split_copy.Tensor($0, 2)
$7 = torch._ops.aten.diagonal_scatter.default($6, $4)
$8 = torch._ops.aten.slice_scatter.default($0, $7, 0, 2, 4)
$9 = torch._ops.aten.mul.Tensor($8, $8)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
split_copy_tensor = torch.ops.aten.split_copy.Tensor(a_1, 2)
getitem = split_copy_tensor[0]
getitem_1 = split_copy_tensor[1]; split_copy_tensor = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, fill_scalar); diagonal_copy_default = fill_scalar = None
split_copy_tensor_1 = torch.ops.aten.split_copy.Tensor(a_1, 2)
getitem_2 = split_copy_tensor_1[0]
getitem_3 = split_copy_tensor_1[1]; split_copy_tensor_1 = None
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(getitem_3, add_tensor); getitem_3 = None
slice_scatter_default = torch.ops.aten.slice_scatter.default(a_1, diagonal_scatter_default, 0, 2, 4); a_1 = diagonal_scatter_default = None
mul_tensor = torch.ops.aten.mul.Tensor(slice_scatter_default, slice_scatter_default); slice_scatter_default = None
return add_tensor
""") # noqa: B950
def test_view_inplace(self):
def f(x):
@ -300,11 +374,22 @@ $9 = torch._ops.aten.mul.Tensor($8, $8)""")
return x
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.transpose_copy.int($0, 1, 0)
$2 = torch._ops.aten.select_copy.int($1, 0, 0)
$3 = torch._ops.aten.add.Tensor($2, tensor([1., 1., 1., 1.]))""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
transpose_copy_int = torch.ops.aten.transpose_copy.int(a_1, 1, 0)
select_copy_int = torch.ops.aten.select_copy.int(transpose_copy_int, 0, 0); transpose_copy_int = None
add_tensor = torch.ops.aten.add.Tensor(select_copy_int, fill_scalar); select_copy_int = fill_scalar = None
transpose_copy_int_1 = torch.ops.aten.transpose_copy.int(a_1, 1, 0); a_1 = None
select_scatter_default = torch.ops.aten.select_scatter.default(transpose_copy_int_1, add_tensor, 0, 0); transpose_copy_int_1 = add_tensor = None
transpose_copy_int_2 = torch.ops.aten.transpose_copy.int(select_scatter_default, 1, 0); select_scatter_default = None
transpose_copy_int_3 = torch.ops.aten.transpose_copy.int(transpose_copy_int_2, 1, 0); transpose_copy_int_2 = None
return transpose_copy_int_3
""") # noqa: B950
def test_optional_tensor_list(self):
def f(x):
@ -317,10 +402,20 @@ $3 = torch._ops.aten.add.Tensor($2, tensor([1., 1., 1., 1.]))""")
return y
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.view_copy.default($0, [8])
$2 = torch._ops.aten.index_put.default($1, [tensor([0, 1, 2, 3])], tensor([0., 1., 2., 3.]))""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
view_copy_default = torch.ops.aten.view_copy.default(a_1, [8]); a_1 = None
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
arange = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'))
empty_1 = torch.ops.aten.empty.memory_format([0], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
arange_1 = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
index_put_default = torch.ops.aten.index_put.default(view_copy_default, [arange], arange_1); view_copy_default = arange = arange_1 = None
view_copy_default_1 = torch.ops.aten.view_copy.default(index_put_default, [4, 2])
return index_put_default
""") # noqa: B950
def test_scalars(self):
def f(x):
@ -333,12 +428,20 @@ $2 = torch._ops.aten.index_put.default($1, [tensor([0, 1, 2, 3])], tensor([0., 1
return z
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.view_copy.default($0, [4, 2])
$2 = torch._ops.aten.add.Tensor($1, 1)
$3 = torch._ops.aten.mul.Tensor($2, 2)
$4 = torch._ops.aten.div.Tensor($3, 1)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, 2)
div_tensor = torch.ops.aten.div.Tensor(mul_tensor, 1); mul_tensor = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
return div_tensor
""")
@skipIfTorchDynamo("Test does not work with TorchDynamo")
def test_metadata_change(self):
@ -348,10 +451,16 @@ $4 = torch._ops.aten.div.Tensor($3, 1)""")
return x.ge_(0)
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.ge.Scalar($0, 0)
$2 = torch._ops.aten._to_copy.default($1, dtype=torch.float32, layout=torch.strided)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
ge_scalar = torch.ops.aten.ge.Scalar(a_1, 0); a_1 = None
_to_copy_default = torch.ops.aten._to_copy.default(ge_scalar, dtype = torch.float32, layout = torch.strided); ge_scalar = None
_tensor_constant0 = self._tensor_constant0
return _tensor_constant0
""")
def test_only_one_view(self):
def f(x):
@ -360,9 +469,14 @@ $2 = torch._ops.aten._to_copy.default($1, dtype=torch.float32, layout=torch.stri
# so there should be a total of 1 op in the output trace.
return x.view(4, 2)
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.view_copy.default($0, [4, 2])""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
return view_copy_default
""")
def test_everything(self):
def f(x):
@ -380,35 +494,43 @@ $1 = torch._ops.aten.view_copy.default($0, [4, 2])""")
return z2
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.add.Tensor($0, $0)
$2 = torch._ops.aten.view_copy.default($1, [8])
$3 = torch._ops.aten._reshape_alias_copy.default($2, [2, 4], [4, 1])
$4 = torch._ops.aten.transpose_copy.int($3, 1, 0)
$5 = torch._ops.aten.unsqueeze_copy.default($4, 0)
$6 = torch._ops.aten.squeeze_copy.default($5)
$7, $8 = torch._ops.aten.split_copy.Tensor($6, 2)
$9 = torch._ops.aten.add.Tensor($7, tensor([[1., 1.],
[1., 1.]]))
$10 = torch._ops.aten.select_copy.int($3, 0, 0)
$11 = torch._ops.aten.clone.default($9, memory_format=torch.contiguous_format)
$12 = torch._ops.aten._unsafe_view.default($11, [4])
$13 = torch._ops.aten.view_copy.default($1, [8])
$14 = torch._ops.aten._reshape_alias_copy.default($13, [2, 4], [4, 1])
$15 = torch._ops.aten.transpose_copy.int($14, 1, 0)
$16 = torch._ops.aten.unsqueeze_copy.default($15, 0)
$17 = torch._ops.aten.squeeze_copy.default($16)
$18 = torch._ops.aten.slice_scatter.default($17, $9, 0, 0, 2)
$19 = torch._ops.aten.unsqueeze_copy.default($18, 0)
$20 = torch._ops.aten.squeeze_copy.dim($19, 0)
$21 = torch._ops.aten.transpose_copy.int($20, 1, 0)
$22 = torch._ops.aten._reshape_alias_copy.default($21, [8], [1])
$23 = torch._ops.aten.view_copy.default($22, [4, 2])
$24 = torch._ops.aten.view_copy.default($23, [8])
$25 = torch._ops.aten._reshape_alias_copy.default($24, [2, 4], [4, 1])
$26 = torch._ops.aten.select_copy.int($25, 0, 0)
$27 = torch._ops.aten.add.Tensor($26, $12)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [8])
_reshape_alias_copy_default = torch.ops.aten._reshape_alias_copy.default(view_copy_default, [2, 4], [4, 1]); view_copy_default = None
transpose_copy_int = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_default, 1, 0)
unsqueeze_copy_default = torch.ops.aten.unsqueeze_copy.default(transpose_copy_int, 0); transpose_copy_int = None
squeeze_copy_default = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_default); unsqueeze_copy_default = None
split_copy_tensor = torch.ops.aten.split_copy.Tensor(squeeze_copy_default, 2); squeeze_copy_default = None
getitem = split_copy_tensor[0]
getitem_1 = split_copy_tensor[1]; split_copy_tensor = None
add_tensor_1 = torch.ops.aten.add.Tensor(getitem, fill_scalar); getitem = fill_scalar = None
select_copy_int = torch.ops.aten.select_copy.int(_reshape_alias_copy_default, 0, 0); _reshape_alias_copy_default = None
clone_default = torch.ops.aten.clone.default(add_tensor_1, memory_format = torch.contiguous_format)
_unsafe_view_default = torch.ops.aten._unsafe_view.default(clone_default, [4]); clone_default = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [8]); add_tensor = None
_reshape_alias_copy_default_1 = torch.ops.aten._reshape_alias_copy.default(view_copy_default_1, [2, 4], [4, 1]); view_copy_default_1 = None
transpose_copy_int_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_default_1, 1, 0); _reshape_alias_copy_default_1 = None
unsqueeze_copy_default_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_int_1, 0); transpose_copy_int_1 = None
squeeze_copy_default_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_default_1); unsqueeze_copy_default_1 = None
slice_scatter_default = torch.ops.aten.slice_scatter.default(squeeze_copy_default_1, add_tensor_1, 0, 0, 2); squeeze_copy_default_1 = None
unsqueeze_copy_default_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter_default, 0); slice_scatter_default = None
squeeze_copy_dim = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_default_2, 0); unsqueeze_copy_default_2 = None
transpose_copy_int_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_dim, 1, 0); squeeze_copy_dim = None
_reshape_alias_copy_default_2 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_int_2, [8], [1]); transpose_copy_int_2 = None
view_copy_default_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_default_2, [4, 2]); _reshape_alias_copy_default_2 = None
view_copy_default_3 = torch.ops.aten.view_copy.default(view_copy_default_2, [8]); view_copy_default_2 = None
_reshape_alias_copy_default_3 = torch.ops.aten._reshape_alias_copy.default(view_copy_default_3, [2, 4], [4, 1]); view_copy_default_3 = None
select_copy_int_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_default_3, 0, 0); _reshape_alias_copy_default_3 = None
add_tensor_2 = torch.ops.aten.add.Tensor(select_copy_int_1, _unsafe_view_default); select_copy_int_1 = _unsafe_view_default = None
return add_tensor_1
""") # noqa: B950
def test_reapply_views_simple(self):
def f(x):
@ -419,15 +541,19 @@ $27 = torch._ops.aten.add.Tensor($26, $12)""")
return y
self.assert_functionalization(f, torch.ones(4, 2), reapply_views=True)
logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True)
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.view.default($0, [4, 2])
$2 = torch._ops.aten.add.Tensor($1, tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]]))
$3 = torch._ops.aten.view.default($2, [4, 2])
$4 = torch._ops.aten.mul.Tensor($3, $3)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_default = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None
add_tensor = torch.ops.aten.add.Tensor(view_default, fill_scalar); view_default = fill_scalar = None
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2])
mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1); view_default_1 = None
return add_tensor
""")
def test_aliases_maintained_after_pass_when_reapplying_views(self):
def f(x):
@ -467,34 +593,70 @@ $4 = torch._ops.aten.mul.Tensor($3, $3)""")
# to() is a composite op that noops when the dtype/shape match, so nothing gets logged.
# self.assert_functionalization(f, torch.ones(2))
logs = self.get_logs(f, torch.ones(2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.copy.default(tensor([0., 0.]), $0)
$2 = torch._ops.aten.add.Tensor($1, $0)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
zero_default = torch.ops.aten.zero.default(empty); empty = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default)
diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
return add_tensor
""")
# Test 2: copy_() with same dtype, different shape
self.assert_functionalization(f, torch.ones(1))
logs = self.get_logs(f, torch.ones(1))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.copy.default(tensor([0., 0.]), $0)
$2 = torch._ops.aten.add.Tensor($1, $0)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
zero_default = torch.ops.aten.zero.default(empty); empty = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default)
diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
return add_tensor
""")
# Test 3: copy_() with different dtype, same shape
self.assert_functionalization(f, torch.ones(2, dtype=torch.long))
logs = self.get_logs(f, torch.ones(2, dtype=torch.long))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.copy.default(tensor([0., 0.]), $0)
$2 = torch._ops.aten.add.Tensor($1, $0)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
zero_default = torch.ops.aten.zero.default(empty); empty = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default)
diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
return add_tensor
""")
# Test 4: copy_() with different dtype, different shape
self.assert_functionalization(f, torch.ones(1, dtype=torch.long))
logs = self.get_logs(f, torch.ones(1, dtype=torch.long))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.copy.default(tensor([0., 0.]), $0)
$2 = torch._ops.aten.add.Tensor($1, $0)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
zero_default = torch.ops.aten.zero.default(empty); empty = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default)
diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
return add_tensor
""")
def test_fill_(self):
def f(x):
@ -505,11 +667,17 @@ $2 = torch._ops.aten.add.Tensor($1, $0)""")
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.add.Tensor($0, $0)
$2 = torch._ops.aten.diagonal_copy.default($1)
$3 = torch._ops.aten.fill.Scalar($2, 0)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(add_tensor)
fill_scalar = torch.ops.aten.fill.Scalar(diagonal_copy_default, 0); diagonal_copy_default = None
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(add_tensor, fill_scalar); add_tensor = fill_scalar = None
return diagonal_scatter_default
""")
def test_resize_smaller(self):
def f(w):
@ -524,22 +692,27 @@ $3 = torch._ops.aten.fill.Scalar($2, 0)""")
self.assert_functionalization(f, torch.ones(8, 2))
logs = self.get_logs(f, torch.ones(8, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.add.Tensor($0, 1)
$2 = torch._ops.aten.view_copy.default($1, [4, 4])
$3 = torch._ops.aten.resize.default($2, [3, 3])
$4 = torch._ops.aten.as_strided_copy.default($2, [3, 3], [3, 1])
$5 = torch._ops.aten.view_copy.default($4, [-1])
$6 = torch._ops.aten.add.Tensor($5, 1)
$7 = torch._ops.aten.view_copy.default($1, [4, 4])
$8 = torch._ops.aten.as_strided_copy.default($7, [3, 3], [3, 1])
$9 = torch._ops.aten.view_copy.default($6, [3, 3])
$10 = torch._ops.aten.as_strided_scatter.default($7, $9, [3, 3], [3, 1])
$11 = torch._ops.aten.view_copy.default($10, [8, 2])
$12 = torch._ops.aten.view_copy.default($11, [4, 4])
$13 = torch._ops.aten.as_strided_copy.default($12, [3, 3], [3, 1])
$14 = torch._ops.aten.add.Tensor($13, 1)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [4, 4])
resize_default = torch.ops.aten.resize.default(view_copy_default, [3, 3])
as_strided_copy_default = torch.ops.aten.as_strided_copy.default(view_copy_default, [3, 3], [3, 1]); view_copy_default = None
view_copy_default_1 = torch.ops.aten.view_copy.default(as_strided_copy_default, [-1]); as_strided_copy_default = None
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1); view_copy_default_1 = None
view_copy_default_2 = torch.ops.aten.view_copy.default(add_tensor, [4, 4]); add_tensor = None
as_strided_copy_default_1 = torch.ops.aten.as_strided_copy.default(view_copy_default_2, [3, 3], [3, 1])
view_copy_default_3 = torch.ops.aten.view_copy.default(add_tensor_1, [3, 3]); add_tensor_1 = None
as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(view_copy_default_2, view_copy_default_3, [3, 3], [3, 1]); view_copy_default_2 = view_copy_default_3 = None
view_copy_default_4 = torch.ops.aten.view_copy.default(as_strided_scatter_default, [8, 2]); as_strided_scatter_default = None
view_copy_default_5 = torch.ops.aten.view_copy.default(view_copy_default_4, [4, 4]); view_copy_default_4 = None
as_strided_copy_default_2 = torch.ops.aten.as_strided_copy.default(view_copy_default_5, [3, 3], [3, 1]); view_copy_default_5 = None
add_tensor_2 = torch.ops.aten.add.Tensor(as_strided_copy_default_2, 1); as_strided_copy_default_2 = None
return add_tensor_2
""") # noqa: B950
def test_resize_larger_valid(self):
def f(x):
@ -560,14 +733,19 @@ $14 = torch._ops.aten.add.Tensor($13, 1)""")
self.assert_functionalization(f, torch.ones(8, 2))
logs = self.get_logs(f, torch.ones(8, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.add.Tensor($0, 1)
$2 = torch._ops.aten.resize.default($1, [5, 5])
$3 = torch._ops.aten.view_copy.default($2, [25])
$4 = torch._ops.aten.fill.Scalar($3, 1)
$5 = torch._ops.aten.view_copy.default($4, [5, 5])
$6 = torch._ops.aten.add.Tensor($5, 1)""")
self.assertExpectedInline(logs, """\
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
resize_default = torch.ops.aten.resize.default(add_tensor, [5, 5]); add_tensor = None
view_copy_default = torch.ops.aten.view_copy.default(resize_default, [25]); resize_default = None
fill_scalar = torch.ops.aten.fill.Scalar(view_copy_default, 1); view_copy_default = None
view_copy_default_1 = torch.ops.aten.view_copy.default(fill_scalar, [5, 5]); fill_scalar = None
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1)
return (view_copy_default_1, add_tensor_1)
""")
def test_resize_larger_invalid(self):
def f(x):

View file

@ -363,7 +363,8 @@ def emit_view_functionalization_body(
);
{return_type} reference_tensor_output;
{{
at::AutoDispatchSkipFunctionalize guard;
at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
{meta_conversion_str}
reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
}}
@ -391,12 +392,16 @@ def emit_view_functionalization_body(
return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
}}
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
{return_type} tmp_output;
{return_type} reference_tensor_output;
{{
at::AutoDispatchSkipFunctionalize guard;
at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
{meta_conversion_str}
reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
}}
{return_type} tmp_output;
{{
at::AutoDispatchSkipFunctionalize guard;
if (reapply_views) {{
tmp_output = at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
}} else {{
@ -603,8 +608,9 @@ def emit_inplace_functionalization_body(
// Before converting the mutable op to its functional variant, run meta tensors through the original op.
// This will help us catch shape errors that apply to inplace ops that wouldn't apply to their functional variants.
// (We can only do this for inplace ops today though, because they technicaly all support meta tensors).
at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
{meta_conversion_str}
at::AutoDispatchSkipFunctionalize guard;
at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(a.name for a in meta_call_ctx)});
}}
{unwrap_tensor_args_str}