Revert "Use OpOverload instead of OpOverloadPacket for size/stride/etc slots (#112119)"

This reverts commit 2337d8d062.

Reverted https://github.com/pytorch/pytorch/pull/112119 on behalf of https://github.com/PaliC due to still breaking trt tests :( refer to diff ([comment](https://github.com/pytorch/pytorch/pull/112119#issuecomment-1795496395))
This commit is contained in:
PyTorch MergeBot 2023-11-06 17:01:50 +00:00
parent 679ca510b0
commit a1d1b73a7c
9 changed files with 67 additions and 66 deletions

View file

@ -1790,7 +1790,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
has_sym_size = False
for node in gm.graph.nodes:
if node.target is torch.ops.aten.sym_size.int:
if node.target is torch.ops.aten.sym_size:
has_sym_size = True
self.assertTrue(has_sym_size)
@ -3192,19 +3192,19 @@ def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0_1 = arg0
slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3)
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sub = sym_size_int - 1
sym_size = torch.ops.aten.sym_size(arg0_1, 0)
sub = sym_size - 1
slice_2 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub); sub = None
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 2)
slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int_1); slice_2 = None
sym_size_1 = torch.ops.aten.sym_size(arg0_1, 2)
slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_1); slice_2 = None
slice_4 = torch.ops.aten.slice.Tensor(slice_3, 2, 1, 3); slice_3 = None
sub_1 = sym_size_int - 2
sub_1 = sym_size - 2
slice_5 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_1); sub_1 = None
slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 2, sym_size_int_1); slice_5 = None
slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 2, sym_size_1); slice_5 = None
slice_7 = torch.ops.aten.slice.Tensor(slice_6, 2, 2, 3); slice_6 = None
sub_2 = sym_size_int - 3; sym_size_int = None
sub_2 = sym_size - 3; sym_size = None
slice_8 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_2); arg0_1 = sub_2 = None
slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 3, sym_size_int_1); slice_8 = sym_size_int_1 = None
slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 3, sym_size_1); slice_8 = sym_size_1 = None
slice_10 = torch.ops.aten.slice.Tensor(slice_9, 2, 3, 3); slice_9 = None
return pytree.tree_unflatten([slice_1, slice_4, slice_7, slice_10], self._out_spec)""",
)

View file

@ -2628,8 +2628,8 @@ class TestPartitioning(AOTTestCase):
self.assertEqual(str(fw_output[0]), "sum_1")
# make sure we don't do the suboptimal thing of saving the bigger primals input to sum,
# rather than saving the sizes of the primals input for use in backward expand
self.assertEqual(str(fw_output[1]), "sym_size_int")
self.assertEqual(str(fw_output[2]), "sym_size_int_1")
self.assertEqual(str(fw_output[1]), "sym_size")
self.assertEqual(str(fw_output[2]), "sym_size_1")
inp = [
torch.randn(10, requires_grad=True),

View file

@ -1255,8 +1255,8 @@ def forward(self, arg0_1):
self.assertExpectedInline(gm.code.strip(), """\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 4; sym_size_int = None
sym_size = torch.ops.aten.sym_size(x_1, 0)
eq = sym_size == 4; sym_size = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
@ -1515,8 +1515,8 @@ def forward(self, arg0_1, arg1_1):
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(3, 4))
self.assertExpectedInline(gm.code.strip(), """\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 4; sym_size_int = None
sym_size = torch.ops.aten.sym_size(x_1, 0)
eq = sym_size == 4; sym_size = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None

View file

@ -1468,10 +1468,10 @@ class TestCustomOp(CustomOpTestCaseBase):
gm.code.strip(),
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(x_1, 2)
numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); x_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = None
sym_size = torch.ops.aten.sym_size(x_1, 0)
sym_size_1 = torch.ops.aten.sym_size(x_1, 1)
sym_size_2 = torch.ops.aten.sym_size(x_1, 2)
numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size, sym_size_1, sym_size_2]); x_1 = sym_size = sym_size_1 = sym_size_2 = None
return numpy_view_copy""", # noqa: B950
)

View file

@ -453,10 +453,10 @@ class TestPySymInt(TestCase):
self.assertEqual(len(gm.shape_env.guards), 0)
self.assertExpectedInline(gm.code.strip(), """\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 5
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1); eq = sym_size_int = sym_size_int_1 = None
sym_size = torch.ops.aten.sym_size(x_1, 0)
eq = sym_size == 5
sym_size_1 = torch.ops.aten.sym_size(x_1, 1); x_1 = None
sym_ite = torch.sym_ite(eq, sym_size, sym_size_1); eq = sym_size = sym_size_1 = None
return sym_ite""")
r1 = gm(torch.ones(4, 5))
self.assertIsInstance(r1, int)
@ -606,12 +606,12 @@ def forward(self, x_1):
class f(torch.nn.Module):
def forward(self, a_1: f32[s0, s1], b_1: f32[s2, s1]):
# No stacktrace found for following nodes
sym_size_int: Sym(s0) = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int_1: Sym(s2) = torch.ops.aten.sym_size.int(b_1, 0)
add: Sym(s0 + s2) = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
sym_size_int_2: Sym(s1) = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_3: Sym(s1) = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
add_1: Sym(2*s1) = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
sym_size: Sym(s0) = torch.ops.aten.sym_size(a_1, 0)
sym_size_1: Sym(s2) = torch.ops.aten.sym_size(b_1, 0)
add: Sym(s0 + s2) = sym_size + sym_size_1; sym_size = sym_size_1 = None
sym_size_2: Sym(s1) = torch.ops.aten.sym_size(a_1, 1)
sym_size_3: Sym(s1) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
add_1: Sym(2*s1) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
new_empty: f32[s0 + s2, 2*s1] = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: f32[s0 + s2, 2*s1] = native_dropout[0]

View file

@ -1,6 +1,6 @@
# Owner(s): ["module: ProxyTensor"]
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_utils import TestCase, run_tests, xfail_inherited_tests
import torch
import unittest
import warnings
@ -747,6 +747,9 @@ class TestGenericProxyTensorFake(TestGenericProxyTensor):
tracing_mode = "fake"
@xfail_inherited_tests([
"test_make_fx_overloads",
])
class TestGenericProxyTensorSymbolic(TestGenericProxyTensor):
tracing_mode = "symbolic"
@ -930,8 +933,8 @@ class TestSymbolicTracing(TestCase):
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, x_1, y_1):
sym_size_int = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
resize_ = torch.ops.aten.resize_.default(x_1, [sym_size_int]); x_1 = sym_size_int = None
sym_size = torch.ops.aten.sym_size(y_1, 0); y_1 = None
resize_ = torch.ops.aten.resize_.default(x_1, [sym_size]); x_1 = sym_size = None
return None""")
@ -1065,8 +1068,8 @@ def forward(self, a_1, b_1):
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None
mul = sym_size_int * 2; sym_size_int = None
sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None
mul = sym_size * 2; sym_size = None
empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
return empty""")
@ -1125,8 +1128,8 @@ def forward(self, a_1):
self.assertExpectedInline(
r, """\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); sym_size_int = None
sym_size = torch.ops.aten.sym_size(x_1, 0)
scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); sym_size = None
select = torch.ops.aten.select.int(x_1, 0, 0)
copy_ = torch.ops.aten.copy_.default(select, scalar_tensor); select = scalar_tensor = None
return x_1""" # noqa: B950
@ -1171,21 +1174,21 @@ def forward(self, crop_camera_1, mask_1):
select = torch.ops.aten.select.int(eye, 0, 0)
select_1 = torch.ops.aten.select.int(select, 0, 0); select = None
copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = None
sym_size_int = torch.ops.aten.sym_size.int(index, 0)
expand = torch.ops.aten.expand.default(eye, [sym_size_int, 3, 3])
view = torch.ops.aten.view.default(expand, [sym_size_int, 3, 3]); expand = None
sym_size_int_1 = torch.ops.aten.sym_size.int(crop_camera_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(crop_camera_1, 2)
expand_1 = torch.ops.aten.expand.default(index, [sym_size_int, sym_size_int_1, sym_size_int_2]); index = None
view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None
sym_size = torch.ops.aten.sym_size(index, 0)
expand = torch.ops.aten.expand.default(eye, [sym_size, 3, 3])
view = torch.ops.aten.view.default(expand, [sym_size, 3, 3]); expand = None
sym_size_1 = torch.ops.aten.sym_size(crop_camera_1, 1)
sym_size_2 = torch.ops.aten.sym_size(crop_camera_1, 2)
expand_1 = torch.ops.aten.expand.default(index, [sym_size, sym_size_1, sym_size_2]); index = None
view_1 = torch.ops.aten.view.default(expand_1, [sym_size, sym_size_1, sym_size_2]); expand_1 = sym_size_1 = sym_size_2 = None
bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None
mul = sym_size_int * 3
view_2 = torch.ops.aten.view.default(bmm, [sym_size, 3, 3]); bmm = None
mul = sym_size * 3
view_3 = torch.ops.aten.view.default(view_2, [mul, 3]); view_2 = mul = None
mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None
view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None
view_4 = torch.ops.aten.view.default(mm, [sym_size, 3, 3]); mm = sym_size = None
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4); crop_camera_1 = mask_1 = view_4 = None
return None""") # noqa: B950
return None""")
def test_unbacked_slice(self):
def f(x, m):
@ -1247,8 +1250,8 @@ def forward(self, images_1, handedness_1, valid_1):
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None
neg = -sym_size_int; sym_size_int = None
sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None
neg = -sym_size; sym_size = None
add = neg + 10; neg = None
empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None
return empty""")
@ -1361,8 +1364,8 @@ def forward(self, lengths_1, values_1):
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
pow_1 = sym_size_int ** 0.5; sym_size_int = None
sym_size = torch.ops.aten.sym_size(a_1, 0)
pow_1 = sym_size ** 0.5; sym_size = None
div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
return div""")
@ -1374,15 +1377,15 @@ def forward(self, a_1):
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
div = torch.ops.aten.div.Tensor(a_1, sym_size_int); a_1 = sym_size_int = None
sym_size = torch.ops.aten.sym_size(a_1, 0)
div = torch.ops.aten.div.Tensor(a_1, sym_size); a_1 = sym_size = None
return div""")
r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
sym_float = torch.sym_float(sym_size_int); sym_size_int = None
sym_size = torch.ops.aten.sym_size(a_1, 0)
sym_float = torch.sym_float(sym_size); sym_size = None
div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None
return div""")

View file

@ -2096,9 +2096,9 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
fx_g = make_fx(trace_fn, tracing_mode="symbolic")(x)
self.assertExpectedInline(fx_g.code.strip(), """\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
return ((sym_size_int, sym_size_int_1), (sym_size_int, sym_size_int_1))""")
sym_size = torch.ops.aten.sym_size(x_1, 0)
sym_size_1 = torch.ops.aten.sym_size(x_1, 1); x_1 = None
return ((sym_size, sym_size_1), (sym_size, sym_size_1))""")
def test_data_ptr_respects_numel_slow_path(self):
data = torch.randn(6, 2)

View file

@ -545,10 +545,8 @@ class TorchVariable(VariableTracker):
),
**options,
)
# TODO: These special cases shouldn't be necessary; we should
# generically support torch.ops that return int
elif (
self.value in [torch.ops.aten.sym_size, torch.ops.aten.sym_size.int]
self.value is torch.ops.aten.sym_size
and len(args) == 2
and len(kwargs) == 0
and isinstance(args[0], TensorVariable)
@ -556,7 +554,7 @@ class TorchVariable(VariableTracker):
# we see this when retracing already traced code
return args[0].call_method(tx, "size", [args[1]], {})
elif (
self.value is [torch.ops.aten.sym_stride, torch.ops.aten.sym_stride.int]
self.value is torch.ops.aten.sym_stride
and len(args) == 2
and len(kwargs) == 0
and isinstance(args[0], TensorVariable)

View file

@ -170,13 +170,13 @@ def track_tensor(tensor, proxy, *, constant, tracer):
# (so that if we have multiple tracers at the same time, they
# don't clobber each other.)
for i, s in enumerate(tensor.shape):
try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_size.int(proxy, i), x), i)
try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_size(proxy, i), x), i)
for i, s in enumerate(tensor.stride()):
try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_stride.int(proxy, i), x), i)
try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_stride(proxy, i), x), i)
try_set_proxy_slot(tensor.numel(), lambda x: set_meta(torch.ops.aten.sym_numel.default(proxy), x))
try_set_proxy_slot(tensor.storage_offset(), lambda x: set_meta(torch.ops.aten.sym_storage_offset.default(proxy), x))
try_set_proxy_slot(tensor.numel(), lambda x: set_meta(torch.ops.aten.sym_numel(proxy), x))
try_set_proxy_slot(tensor.storage_offset(), lambda x: set_meta(torch.ops.aten.sym_storage_offset(proxy), x))
set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant))
def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):