mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
fix functionalization regression introduced by ProxyTorchDispatchMode, migrate testing to make_fx (#80416)
`ProxyTorchDispatchMode` was added recently as part of `make_fx`, which was secretly causing the meta tensor calls used inside of functionalization to get baked into the graph. It also wasn't caught because the functionalization tests in core don't use `make_fx`, and the tests in functorch aren't as comprehensive. Now that `make_fx` is in core, I also ported the functionalization test infra over to use it, which would have caught the regression. This also makes the tests cleaner, since mode-based tracing lets us pick up factory functions in the trace output. Pull Request resolved: https://github.com/pytorch/pytorch/pull/80416 Approved by: https://github.com/ezyang, https://github.com/albanD
This commit is contained in:
parent
5f8c2076df
commit
f84b30f790
3 changed files with 374 additions and 177 deletions
|
|
@ -23,6 +23,19 @@ $ops_headers
|
|||
namespace at {
|
||||
namespace functionalization {
|
||||
|
||||
// This keyset is used by functionalization when it calls into meta kernels
|
||||
// to accurately propagate stride metadata.
|
||||
// Exclude any modes: the purpose of calling into meta kernels is only as an implementation
|
||||
// detail to perform shape inference, and we don't want any modal keys to run.
|
||||
// Specifically, we want to prevent functionalization and Python modes from running.
|
||||
constexpr auto exclude_keys_for_meta_dispatch =
|
||||
c10::functorch_transforms_ks |
|
||||
c10::DispatchKeySet({
|
||||
c10::DispatchKey::FuncTorchDynamicLayerBackMode,
|
||||
c10::DispatchKey::FuncTorchDynamicLayerFrontMode,
|
||||
c10::DispatchKey::Python
|
||||
});
|
||||
|
||||
|
||||
inline Tensor to_meta(const Tensor& t) {
|
||||
return at::native::empty_strided_meta(t.sizes(), t.strides(),
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@
|
|||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
|
||||
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, capture_logs, log_input
|
||||
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, capture_logs
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
import logging
|
||||
|
||||
|
|
@ -58,19 +59,26 @@ class InplaceLoggingTensor(LoggingTensorReentrant):
|
|||
|
||||
|
||||
class TestFunctionalization(TestCase):
|
||||
|
||||
def get_logs(self, func, inpt, *, reapply_views=False):
|
||||
input_clone_logging = LoggingTensor(inpt.clone())
|
||||
input_functional_logging = torch._to_functional_tensor(input_clone_logging)
|
||||
|
||||
with capture_logs() as logs:
|
||||
log_input("input", input_clone_logging)
|
||||
# We can unify testing and use functionalize() here instead
|
||||
# if/when functorch moves into core.
|
||||
def _functionalize(self, f, *, reapply_views: bool):
|
||||
def wrapped(a):
|
||||
input_functional = torch._to_functional_tensor(a)
|
||||
torch._enable_functionalization(reapply_views=reapply_views)
|
||||
try:
|
||||
func(input_functional_logging)
|
||||
out = f(input_functional)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
return logs
|
||||
torch._sync(input_functional)
|
||||
tree_map(torch._sync, out)
|
||||
out_unwrapped = tree_map(torch._from_functional_tensor, out)
|
||||
return out_unwrapped
|
||||
|
||||
return wrapped
|
||||
|
||||
def get_logs(self, func, inpt, *, reapply_views=False):
|
||||
traced_f = make_fx(self._functionalize(func, reapply_views=reapply_views))(inpt)
|
||||
return traced_f.code
|
||||
|
||||
def assert_functionalization(self, func, inpt, *, reapply_views=False):
|
||||
input_clone = inpt.clone()
|
||||
|
|
@ -124,15 +132,19 @@ class TestFunctionalization(TestCase):
|
|||
return y
|
||||
self.assert_functionalization(f, torch.ones(4, 2))
|
||||
logs = self.get_logs(f, torch.ones(4, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.view_copy.default($0, [4, 2])
|
||||
$2 = torch._ops.aten.add.Tensor($1, tensor([[1., 1.],
|
||||
[1., 1.],
|
||||
[1., 1.],
|
||||
[1., 1.]]))
|
||||
$3 = torch._ops.aten.view_copy.default($2, [4, 2])
|
||||
$4 = torch._ops.aten.mul.Tensor($3, $3)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2])
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(view_copy_default_1, view_copy_default_1); view_copy_default_1 = None
|
||||
return add_tensor
|
||||
""")
|
||||
|
||||
def test_simple_out(self):
|
||||
def f(x):
|
||||
|
|
@ -145,14 +157,19 @@ $4 = torch._ops.aten.mul.Tensor($3, $3)""")
|
|||
return w
|
||||
self.assert_functionalization(f, torch.ones(4, 2))
|
||||
logs = self.get_logs(f, torch.ones(4, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.view_copy.default($0, [4, 2])
|
||||
$2 = torch._ops.aten.add.Tensor($1, tensor([[1., 1.],
|
||||
[1., 1.],
|
||||
[1., 1.],
|
||||
[1., 1.]]))
|
||||
$3 = torch._ops.aten.mul.Tensor($2, $2)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
|
||||
empty_1 = torch.ops.aten.empty.SymInt([], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None
|
||||
return mul_tensor
|
||||
""")
|
||||
|
||||
def test_multi_out(self):
|
||||
def f(x):
|
||||
|
|
@ -164,9 +181,18 @@ $3 = torch._ops.aten.mul.Tensor($2, $2)""")
|
|||
return out_max
|
||||
self.assert_functionalization(f, torch.arange(8, dtype=torch.float32))
|
||||
logs = self.get_logs(f, torch.arange(8, dtype=torch.float32))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1, $2 = torch._ops.aten.aminmax.default($0, dim=0)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.SymInt([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
empty_1 = torch.ops.aten.empty.SymInt([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
aminmax_default = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
|
||||
getitem = aminmax_default[0]
|
||||
getitem_1 = aminmax_default[1]; aminmax_default = None
|
||||
return getitem
|
||||
""")
|
||||
|
||||
def test_tensor_ctr(self):
|
||||
def f(x):
|
||||
|
|
@ -186,13 +212,18 @@ $1, $2 = torch._ops.aten.aminmax.default($0, dim=0)""")
|
|||
return y
|
||||
self.assert_functionalization(f, torch.ones(4, 2))
|
||||
logs = self.get_logs(f, torch.ones(4, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.view_copy.default($0, [4, 2])
|
||||
$2 = torch._ops.aten.add.Tensor($0, tensor([[1., 1.],
|
||||
[1., 1.],
|
||||
[1., 1.],
|
||||
[1., 1.]]))""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, fill_scalar); a_1 = fill_scalar = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
|
||||
return view_copy_default_1
|
||||
""")
|
||||
|
||||
# Some ops that are mutable are neither inplace nor out= ops.
|
||||
# They also need special handling.
|
||||
|
|
@ -201,9 +232,20 @@ $2 = torch._ops.aten.add.Tensor($0, tensor([[1., 1.],
|
|||
return torch._fused_moving_avg_obs_fq_helper(x, x, x, x, x, x, x, 1.0, 0, 1, 0)
|
||||
|
||||
logs = self.get_logs(f, torch.ones(1))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1, $2, $3, $4, $5, $6 = torch._ops.aten._fused_moving_avg_obs_fq_helper_functional.default($0, $0, $0, $0, $0, $0, $0, 1.0, 0, 1, 0)""") # noqa: B950
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
_fused_moving_avg_obs_fq_helper_functional_default = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(a_1, a_1, a_1, a_1, a_1, a_1, a_1, 1.0, 0, 1, 0); a_1 = None
|
||||
getitem = _fused_moving_avg_obs_fq_helper_functional_default[0]
|
||||
getitem_1 = _fused_moving_avg_obs_fq_helper_functional_default[1]
|
||||
getitem_2 = _fused_moving_avg_obs_fq_helper_functional_default[2]
|
||||
getitem_3 = _fused_moving_avg_obs_fq_helper_functional_default[3]
|
||||
getitem_4 = _fused_moving_avg_obs_fq_helper_functional_default[4]
|
||||
getitem_5 = _fused_moving_avg_obs_fq_helper_functional_default[5]; _fused_moving_avg_obs_fq_helper_functional_default = None
|
||||
return (getitem, getitem_1)
|
||||
""") # noqa: B950
|
||||
|
||||
def test_as_strided(self):
|
||||
def f(x):
|
||||
|
|
@ -212,10 +254,16 @@ $1, $2, $3, $4, $5, $6 = torch._ops.aten._fused_moving_avg_obs_fq_helper_functio
|
|||
return x
|
||||
self.assert_functionalization(f, torch.ones(9))
|
||||
logs = self.get_logs(f, torch.ones(9))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.as_strided_copy.default($0, [2], [2], 1)
|
||||
$2 = torch._ops.aten.add.Tensor($1, 1)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
as_strided_copy_default = torch.ops.aten.as_strided_copy.default(a_1, [2], [2], 1)
|
||||
add_tensor = torch.ops.aten.add.Tensor(as_strided_copy_default, 1); as_strided_copy_default = None
|
||||
as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(a_1, add_tensor, [2], [2], 1); a_1 = add_tensor = None
|
||||
return as_strided_scatter_default
|
||||
""")
|
||||
|
||||
def test_tensor_list_composite(self):
|
||||
def f(x):
|
||||
|
|
@ -224,11 +272,14 @@ $2 = torch._ops.aten.add.Tensor($1, 1)""")
|
|||
return y
|
||||
self.assert_functionalization(f, torch.ones(2, 2))
|
||||
logs = self.get_logs(f, torch.ones(2, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.block_diag.default([LoggingTensor(tensor([[1., 1.],
|
||||
[1., 1.]])), LoggingTensor(tensor([[1., 1.],
|
||||
[1., 1.]]))])""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
block_diag_default = torch.ops.aten.block_diag.default([a_1, a_1]); a_1 = None
|
||||
return block_diag_default
|
||||
""")
|
||||
|
||||
def test_cat(self):
|
||||
def f(x):
|
||||
|
|
@ -237,10 +288,15 @@ $1 = torch._ops.aten.block_diag.default([LoggingTensor(tensor([[1., 1.],
|
|||
return out
|
||||
self.assert_functionalization(f, torch.ones(2, 2))
|
||||
logs = self.get_logs(f, torch.ones(2, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.cat.default([LoggingTensor(tensor([[1., 1.],
|
||||
[1., 1.]]))])""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.SymInt([0], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
cat_default = torch.ops.aten.cat.default([a_1]); a_1 = None
|
||||
return cat_default
|
||||
""")
|
||||
|
||||
def test_diagonal(self):
|
||||
def f(x):
|
||||
|
|
@ -252,12 +308,19 @@ $1 = torch._ops.aten.cat.default([LoggingTensor(tensor([[1., 1.],
|
|||
return z
|
||||
self.assert_functionalization(f, torch.ones(2, 2))
|
||||
logs = self.get_logs(f, torch.ones(2, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.diagonal_copy.default($0)
|
||||
$2 = torch._ops.aten.add.Tensor($1, tensor([1., 1.]))
|
||||
$3 = torch._ops.aten.diagonal_scatter.default($0, $2)
|
||||
$4 = torch._ops.aten.mul.Tensor($3, $3)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(a_1)
|
||||
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, fill_scalar); diagonal_copy_default = fill_scalar = None
|
||||
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(a_1, add_tensor); a_1 = add_tensor = None
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(diagonal_scatter_default, diagonal_scatter_default); diagonal_scatter_default = None
|
||||
return mul_tensor
|
||||
""")
|
||||
|
||||
def test_diagonal_mutated_input(self):
|
||||
def f(x):
|
||||
|
|
@ -280,15 +343,26 @@ $4 = torch._ops.aten.mul.Tensor($3, $3)""")
|
|||
return y3
|
||||
self.assert_functionalization(f, torch.ones(4, 2))
|
||||
logs = self.get_logs(f, torch.ones(4, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1, $2 = torch._ops.aten.split_copy.Tensor($0, 2)
|
||||
$3 = torch._ops.aten.diagonal_copy.default($2)
|
||||
$4 = torch._ops.aten.add.Tensor($3, tensor([1., 1.]))
|
||||
$5, $6 = torch._ops.aten.split_copy.Tensor($0, 2)
|
||||
$7 = torch._ops.aten.diagonal_scatter.default($6, $4)
|
||||
$8 = torch._ops.aten.slice_scatter.default($0, $7, 0, 2, 4)
|
||||
$9 = torch._ops.aten.mul.Tensor($8, $8)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||
split_copy_tensor = torch.ops.aten.split_copy.Tensor(a_1, 2)
|
||||
getitem = split_copy_tensor[0]
|
||||
getitem_1 = split_copy_tensor[1]; split_copy_tensor = None
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, fill_scalar); diagonal_copy_default = fill_scalar = None
|
||||
split_copy_tensor_1 = torch.ops.aten.split_copy.Tensor(a_1, 2)
|
||||
getitem_2 = split_copy_tensor_1[0]
|
||||
getitem_3 = split_copy_tensor_1[1]; split_copy_tensor_1 = None
|
||||
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(getitem_3, add_tensor); getitem_3 = None
|
||||
slice_scatter_default = torch.ops.aten.slice_scatter.default(a_1, diagonal_scatter_default, 0, 2, 4); a_1 = diagonal_scatter_default = None
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(slice_scatter_default, slice_scatter_default); slice_scatter_default = None
|
||||
return add_tensor
|
||||
""") # noqa: B950
|
||||
|
||||
def test_view_inplace(self):
|
||||
def f(x):
|
||||
|
|
@ -300,11 +374,22 @@ $9 = torch._ops.aten.mul.Tensor($8, $8)""")
|
|||
return x
|
||||
self.assert_functionalization(f, torch.ones(4, 2))
|
||||
logs = self.get_logs(f, torch.ones(4, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.transpose_copy.int($0, 1, 0)
|
||||
$2 = torch._ops.aten.select_copy.int($1, 0, 0)
|
||||
$3 = torch._ops.aten.add.Tensor($2, tensor([1., 1., 1., 1.]))""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||
transpose_copy_int = torch.ops.aten.transpose_copy.int(a_1, 1, 0)
|
||||
select_copy_int = torch.ops.aten.select_copy.int(transpose_copy_int, 0, 0); transpose_copy_int = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(select_copy_int, fill_scalar); select_copy_int = fill_scalar = None
|
||||
transpose_copy_int_1 = torch.ops.aten.transpose_copy.int(a_1, 1, 0); a_1 = None
|
||||
select_scatter_default = torch.ops.aten.select_scatter.default(transpose_copy_int_1, add_tensor, 0, 0); transpose_copy_int_1 = add_tensor = None
|
||||
transpose_copy_int_2 = torch.ops.aten.transpose_copy.int(select_scatter_default, 1, 0); select_scatter_default = None
|
||||
transpose_copy_int_3 = torch.ops.aten.transpose_copy.int(transpose_copy_int_2, 1, 0); transpose_copy_int_2 = None
|
||||
return transpose_copy_int_3
|
||||
""") # noqa: B950
|
||||
|
||||
def test_optional_tensor_list(self):
|
||||
def f(x):
|
||||
|
|
@ -317,10 +402,20 @@ $3 = torch._ops.aten.add.Tensor($2, tensor([1., 1., 1., 1.]))""")
|
|||
return y
|
||||
self.assert_functionalization(f, torch.ones(4, 2))
|
||||
logs = self.get_logs(f, torch.ones(4, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.view_copy.default($0, [8])
|
||||
$2 = torch._ops.aten.index_put.default($1, [tensor([0, 1, 2, 3])], tensor([0., 1., 2., 3.]))""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [8]); a_1 = None
|
||||
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
arange = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'))
|
||||
empty_1 = torch.ops.aten.empty.memory_format([0], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
arange_1 = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
|
||||
index_put_default = torch.ops.aten.index_put.default(view_copy_default, [arange], arange_1); view_copy_default = arange = arange_1 = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(index_put_default, [4, 2])
|
||||
return index_put_default
|
||||
""") # noqa: B950
|
||||
|
||||
def test_scalars(self):
|
||||
def f(x):
|
||||
|
|
@ -333,12 +428,20 @@ $2 = torch._ops.aten.index_put.default($1, [tensor([0, 1, 2, 3])], tensor([0., 1
|
|||
return z
|
||||
self.assert_functionalization(f, torch.ones(4, 2))
|
||||
logs = self.get_logs(f, torch.ones(4, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.view_copy.default($0, [4, 2])
|
||||
$2 = torch._ops.aten.add.Tensor($1, 1)
|
||||
$3 = torch._ops.aten.mul.Tensor($2, 2)
|
||||
$4 = torch._ops.aten.div.Tensor($3, 1)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, 2)
|
||||
div_tensor = torch.ops.aten.div.Tensor(mul_tensor, 1); mul_tensor = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
|
||||
return div_tensor
|
||||
""")
|
||||
|
||||
@skipIfTorchDynamo("Test does not work with TorchDynamo")
|
||||
def test_metadata_change(self):
|
||||
|
|
@ -348,10 +451,16 @@ $4 = torch._ops.aten.div.Tensor($3, 1)""")
|
|||
return x.ge_(0)
|
||||
self.assert_functionalization(f, torch.ones(4, 2))
|
||||
logs = self.get_logs(f, torch.ones(4, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.ge.Scalar($0, 0)
|
||||
$2 = torch._ops.aten._to_copy.default($1, dtype=torch.float32, layout=torch.strided)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
ge_scalar = torch.ops.aten.ge.Scalar(a_1, 0); a_1 = None
|
||||
_to_copy_default = torch.ops.aten._to_copy.default(ge_scalar, dtype = torch.float32, layout = torch.strided); ge_scalar = None
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
return _tensor_constant0
|
||||
""")
|
||||
|
||||
def test_only_one_view(self):
|
||||
def f(x):
|
||||
|
|
@ -360,9 +469,14 @@ $2 = torch._ops.aten._to_copy.default($1, dtype=torch.float32, layout=torch.stri
|
|||
# so there should be a total of 1 op in the output trace.
|
||||
return x.view(4, 2)
|
||||
logs = self.get_logs(f, torch.ones(4, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.view_copy.default($0, [4, 2])""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
|
||||
return view_copy_default
|
||||
""")
|
||||
|
||||
def test_everything(self):
|
||||
def f(x):
|
||||
|
|
@ -380,35 +494,43 @@ $1 = torch._ops.aten.view_copy.default($0, [4, 2])""")
|
|||
return z2
|
||||
self.assert_functionalization(f, torch.ones(4, 2))
|
||||
logs = self.get_logs(f, torch.ones(4, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.add.Tensor($0, $0)
|
||||
$2 = torch._ops.aten.view_copy.default($1, [8])
|
||||
$3 = torch._ops.aten._reshape_alias_copy.default($2, [2, 4], [4, 1])
|
||||
$4 = torch._ops.aten.transpose_copy.int($3, 1, 0)
|
||||
$5 = torch._ops.aten.unsqueeze_copy.default($4, 0)
|
||||
$6 = torch._ops.aten.squeeze_copy.default($5)
|
||||
$7, $8 = torch._ops.aten.split_copy.Tensor($6, 2)
|
||||
$9 = torch._ops.aten.add.Tensor($7, tensor([[1., 1.],
|
||||
[1., 1.]]))
|
||||
$10 = torch._ops.aten.select_copy.int($3, 0, 0)
|
||||
$11 = torch._ops.aten.clone.default($9, memory_format=torch.contiguous_format)
|
||||
$12 = torch._ops.aten._unsafe_view.default($11, [4])
|
||||
$13 = torch._ops.aten.view_copy.default($1, [8])
|
||||
$14 = torch._ops.aten._reshape_alias_copy.default($13, [2, 4], [4, 1])
|
||||
$15 = torch._ops.aten.transpose_copy.int($14, 1, 0)
|
||||
$16 = torch._ops.aten.unsqueeze_copy.default($15, 0)
|
||||
$17 = torch._ops.aten.squeeze_copy.default($16)
|
||||
$18 = torch._ops.aten.slice_scatter.default($17, $9, 0, 0, 2)
|
||||
$19 = torch._ops.aten.unsqueeze_copy.default($18, 0)
|
||||
$20 = torch._ops.aten.squeeze_copy.dim($19, 0)
|
||||
$21 = torch._ops.aten.transpose_copy.int($20, 1, 0)
|
||||
$22 = torch._ops.aten._reshape_alias_copy.default($21, [8], [1])
|
||||
$23 = torch._ops.aten.view_copy.default($22, [4, 2])
|
||||
$24 = torch._ops.aten.view_copy.default($23, [8])
|
||||
$25 = torch._ops.aten._reshape_alias_copy.default($24, [2, 4], [4, 1])
|
||||
$26 = torch._ops.aten.select_copy.int($25, 0, 0)
|
||||
$27 = torch._ops.aten.add.Tensor($26, $12)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [8])
|
||||
_reshape_alias_copy_default = torch.ops.aten._reshape_alias_copy.default(view_copy_default, [2, 4], [4, 1]); view_copy_default = None
|
||||
transpose_copy_int = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_default, 1, 0)
|
||||
unsqueeze_copy_default = torch.ops.aten.unsqueeze_copy.default(transpose_copy_int, 0); transpose_copy_int = None
|
||||
squeeze_copy_default = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_default); unsqueeze_copy_default = None
|
||||
split_copy_tensor = torch.ops.aten.split_copy.Tensor(squeeze_copy_default, 2); squeeze_copy_default = None
|
||||
getitem = split_copy_tensor[0]
|
||||
getitem_1 = split_copy_tensor[1]; split_copy_tensor = None
|
||||
add_tensor_1 = torch.ops.aten.add.Tensor(getitem, fill_scalar); getitem = fill_scalar = None
|
||||
select_copy_int = torch.ops.aten.select_copy.int(_reshape_alias_copy_default, 0, 0); _reshape_alias_copy_default = None
|
||||
clone_default = torch.ops.aten.clone.default(add_tensor_1, memory_format = torch.contiguous_format)
|
||||
_unsafe_view_default = torch.ops.aten._unsafe_view.default(clone_default, [4]); clone_default = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [8]); add_tensor = None
|
||||
_reshape_alias_copy_default_1 = torch.ops.aten._reshape_alias_copy.default(view_copy_default_1, [2, 4], [4, 1]); view_copy_default_1 = None
|
||||
transpose_copy_int_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_default_1, 1, 0); _reshape_alias_copy_default_1 = None
|
||||
unsqueeze_copy_default_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_int_1, 0); transpose_copy_int_1 = None
|
||||
squeeze_copy_default_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_default_1); unsqueeze_copy_default_1 = None
|
||||
slice_scatter_default = torch.ops.aten.slice_scatter.default(squeeze_copy_default_1, add_tensor_1, 0, 0, 2); squeeze_copy_default_1 = None
|
||||
unsqueeze_copy_default_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter_default, 0); slice_scatter_default = None
|
||||
squeeze_copy_dim = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_default_2, 0); unsqueeze_copy_default_2 = None
|
||||
transpose_copy_int_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_dim, 1, 0); squeeze_copy_dim = None
|
||||
_reshape_alias_copy_default_2 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_int_2, [8], [1]); transpose_copy_int_2 = None
|
||||
view_copy_default_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_default_2, [4, 2]); _reshape_alias_copy_default_2 = None
|
||||
view_copy_default_3 = torch.ops.aten.view_copy.default(view_copy_default_2, [8]); view_copy_default_2 = None
|
||||
_reshape_alias_copy_default_3 = torch.ops.aten._reshape_alias_copy.default(view_copy_default_3, [2, 4], [4, 1]); view_copy_default_3 = None
|
||||
select_copy_int_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_default_3, 0, 0); _reshape_alias_copy_default_3 = None
|
||||
add_tensor_2 = torch.ops.aten.add.Tensor(select_copy_int_1, _unsafe_view_default); select_copy_int_1 = _unsafe_view_default = None
|
||||
return add_tensor_1
|
||||
""") # noqa: B950
|
||||
|
||||
def test_reapply_views_simple(self):
|
||||
def f(x):
|
||||
|
|
@ -419,15 +541,19 @@ $27 = torch._ops.aten.add.Tensor($26, $12)""")
|
|||
return y
|
||||
self.assert_functionalization(f, torch.ones(4, 2), reapply_views=True)
|
||||
logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True)
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.view.default($0, [4, 2])
|
||||
$2 = torch._ops.aten.add.Tensor($1, tensor([[1., 1.],
|
||||
[1., 1.],
|
||||
[1., 1.],
|
||||
[1., 1.]]))
|
||||
$3 = torch._ops.aten.view.default($2, [4, 2])
|
||||
$4 = torch._ops.aten.mul.Tensor($3, $3)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
|
||||
view_default = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(view_default, fill_scalar); view_default = fill_scalar = None
|
||||
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2])
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1); view_default_1 = None
|
||||
return add_tensor
|
||||
""")
|
||||
|
||||
def test_aliases_maintained_after_pass_when_reapplying_views(self):
|
||||
def f(x):
|
||||
|
|
@ -467,34 +593,70 @@ $4 = torch._ops.aten.mul.Tensor($3, $3)""")
|
|||
# to() is a composite op that noops when the dtype/shape match, so nothing gets logged.
|
||||
# self.assert_functionalization(f, torch.ones(2))
|
||||
logs = self.get_logs(f, torch.ones(2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.copy.default(tensor([0., 0.]), $0)
|
||||
$2 = torch._ops.aten.add.Tensor($1, $0)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
zero_default = torch.ops.aten.zero.default(empty); empty = None
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default)
|
||||
diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None
|
||||
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
||||
return add_tensor
|
||||
""")
|
||||
|
||||
# Test 2: copy_() with same dtype, different shape
|
||||
self.assert_functionalization(f, torch.ones(1))
|
||||
logs = self.get_logs(f, torch.ones(1))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.copy.default(tensor([0., 0.]), $0)
|
||||
$2 = torch._ops.aten.add.Tensor($1, $0)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
zero_default = torch.ops.aten.zero.default(empty); empty = None
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default)
|
||||
diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None
|
||||
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
||||
return add_tensor
|
||||
""")
|
||||
|
||||
# Test 3: copy_() with different dtype, same shape
|
||||
self.assert_functionalization(f, torch.ones(2, dtype=torch.long))
|
||||
logs = self.get_logs(f, torch.ones(2, dtype=torch.long))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.copy.default(tensor([0., 0.]), $0)
|
||||
$2 = torch._ops.aten.add.Tensor($1, $0)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
zero_default = torch.ops.aten.zero.default(empty); empty = None
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default)
|
||||
diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None
|
||||
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
||||
return add_tensor
|
||||
""")
|
||||
|
||||
# Test 4: copy_() with different dtype, different shape
|
||||
self.assert_functionalization(f, torch.ones(1, dtype=torch.long))
|
||||
logs = self.get_logs(f, torch.ones(1, dtype=torch.long))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.copy.default(tensor([0., 0.]), $0)
|
||||
$2 = torch._ops.aten.add.Tensor($1, $0)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
zero_default = torch.ops.aten.zero.default(empty); empty = None
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default)
|
||||
diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None
|
||||
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
|
||||
return add_tensor
|
||||
""")
|
||||
|
||||
def test_fill_(self):
|
||||
def f(x):
|
||||
|
|
@ -505,11 +667,17 @@ $2 = torch._ops.aten.add.Tensor($1, $0)""")
|
|||
|
||||
self.assert_functionalization(f, torch.ones(2, 2))
|
||||
logs = self.get_logs(f, torch.ones(2, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.add.Tensor($0, $0)
|
||||
$2 = torch._ops.aten.diagonal_copy.default($1)
|
||||
$3 = torch._ops.aten.fill.Scalar($2, 0)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(add_tensor)
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(diagonal_copy_default, 0); diagonal_copy_default = None
|
||||
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(add_tensor, fill_scalar); add_tensor = fill_scalar = None
|
||||
return diagonal_scatter_default
|
||||
""")
|
||||
|
||||
def test_resize_smaller(self):
|
||||
def f(w):
|
||||
|
|
@ -524,22 +692,27 @@ $3 = torch._ops.aten.fill.Scalar($2, 0)""")
|
|||
|
||||
self.assert_functionalization(f, torch.ones(8, 2))
|
||||
logs = self.get_logs(f, torch.ones(8, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.add.Tensor($0, 1)
|
||||
$2 = torch._ops.aten.view_copy.default($1, [4, 4])
|
||||
$3 = torch._ops.aten.resize.default($2, [3, 3])
|
||||
$4 = torch._ops.aten.as_strided_copy.default($2, [3, 3], [3, 1])
|
||||
$5 = torch._ops.aten.view_copy.default($4, [-1])
|
||||
$6 = torch._ops.aten.add.Tensor($5, 1)
|
||||
$7 = torch._ops.aten.view_copy.default($1, [4, 4])
|
||||
$8 = torch._ops.aten.as_strided_copy.default($7, [3, 3], [3, 1])
|
||||
$9 = torch._ops.aten.view_copy.default($6, [3, 3])
|
||||
$10 = torch._ops.aten.as_strided_scatter.default($7, $9, [3, 3], [3, 1])
|
||||
$11 = torch._ops.aten.view_copy.default($10, [8, 2])
|
||||
$12 = torch._ops.aten.view_copy.default($11, [4, 4])
|
||||
$13 = torch._ops.aten.as_strided_copy.default($12, [3, 3], [3, 1])
|
||||
$14 = torch._ops.aten.add.Tensor($13, 1)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
|
||||
view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [4, 4])
|
||||
resize_default = torch.ops.aten.resize.default(view_copy_default, [3, 3])
|
||||
as_strided_copy_default = torch.ops.aten.as_strided_copy.default(view_copy_default, [3, 3], [3, 1]); view_copy_default = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(as_strided_copy_default, [-1]); as_strided_copy_default = None
|
||||
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1); view_copy_default_1 = None
|
||||
view_copy_default_2 = torch.ops.aten.view_copy.default(add_tensor, [4, 4]); add_tensor = None
|
||||
as_strided_copy_default_1 = torch.ops.aten.as_strided_copy.default(view_copy_default_2, [3, 3], [3, 1])
|
||||
view_copy_default_3 = torch.ops.aten.view_copy.default(add_tensor_1, [3, 3]); add_tensor_1 = None
|
||||
as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(view_copy_default_2, view_copy_default_3, [3, 3], [3, 1]); view_copy_default_2 = view_copy_default_3 = None
|
||||
view_copy_default_4 = torch.ops.aten.view_copy.default(as_strided_scatter_default, [8, 2]); as_strided_scatter_default = None
|
||||
view_copy_default_5 = torch.ops.aten.view_copy.default(view_copy_default_4, [4, 4]); view_copy_default_4 = None
|
||||
as_strided_copy_default_2 = torch.ops.aten.as_strided_copy.default(view_copy_default_5, [3, 3], [3, 1]); view_copy_default_5 = None
|
||||
add_tensor_2 = torch.ops.aten.add.Tensor(as_strided_copy_default_2, 1); as_strided_copy_default_2 = None
|
||||
return add_tensor_2
|
||||
""") # noqa: B950
|
||||
|
||||
def test_resize_larger_valid(self):
|
||||
def f(x):
|
||||
|
|
@ -560,14 +733,19 @@ $14 = torch._ops.aten.add.Tensor($13, 1)""")
|
|||
|
||||
self.assert_functionalization(f, torch.ones(8, 2))
|
||||
logs = self.get_logs(f, torch.ones(8, 2))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.add.Tensor($0, 1)
|
||||
$2 = torch._ops.aten.resize.default($1, [5, 5])
|
||||
$3 = torch._ops.aten.view_copy.default($2, [25])
|
||||
$4 = torch._ops.aten.fill.Scalar($3, 1)
|
||||
$5 = torch._ops.aten.view_copy.default($4, [5, 5])
|
||||
$6 = torch._ops.aten.add.Tensor($5, 1)""")
|
||||
self.assertExpectedInline(logs, """\
|
||||
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
|
||||
resize_default = torch.ops.aten.resize.default(add_tensor, [5, 5]); add_tensor = None
|
||||
view_copy_default = torch.ops.aten.view_copy.default(resize_default, [25]); resize_default = None
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(view_copy_default, 1); view_copy_default = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(fill_scalar, [5, 5]); fill_scalar = None
|
||||
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1)
|
||||
return (view_copy_default_1, add_tensor_1)
|
||||
""")
|
||||
|
||||
def test_resize_larger_invalid(self):
|
||||
def f(x):
|
||||
|
|
|
|||
|
|
@ -363,7 +363,8 @@ def emit_view_functionalization_body(
|
|||
);
|
||||
{return_type} reference_tensor_output;
|
||||
{{
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
at::AutoDispatchSkipFunctionalize func_guard;
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
|
||||
{meta_conversion_str}
|
||||
reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
|
||||
}}
|
||||
|
|
@ -391,12 +392,16 @@ def emit_view_functionalization_body(
|
|||
return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
|
||||
}}
|
||||
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
|
||||
{return_type} tmp_output;
|
||||
{return_type} reference_tensor_output;
|
||||
{{
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
at::AutoDispatchSkipFunctionalize func_guard;
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
|
||||
{meta_conversion_str}
|
||||
reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
|
||||
}}
|
||||
{return_type} tmp_output;
|
||||
{{
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
if (reapply_views) {{
|
||||
tmp_output = at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
|
||||
}} else {{
|
||||
|
|
@ -603,8 +608,9 @@ def emit_inplace_functionalization_body(
|
|||
// Before converting the mutable op to its functional variant, run meta tensors through the original op.
|
||||
// This will help us catch shape errors that apply to inplace ops that wouldn't apply to their functional variants.
|
||||
// (We can only do this for inplace ops today though, because they technicaly all support meta tensors).
|
||||
at::AutoDispatchSkipFunctionalize func_guard;
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
|
||||
{meta_conversion_str}
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(a.name for a in meta_call_ctx)});
|
||||
}}
|
||||
{unwrap_tensor_args_str}
|
||||
|
|
|
|||
Loading…
Reference in a new issue