mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "Use OpOverload instead of OpOverloadPacket for size/stride/etc slots (#112119)"
This reverts commit 2337d8d062.
Reverted https://github.com/pytorch/pytorch/pull/112119 on behalf of https://github.com/PaliC due to still breaking trt tests :( refer to diff ([comment](https://github.com/pytorch/pytorch/pull/112119#issuecomment-1795496395))
This commit is contained in:
parent
679ca510b0
commit
a1d1b73a7c
9 changed files with 67 additions and 66 deletions
|
|
@ -1790,7 +1790,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
has_sym_size = False
|
||||
for node in gm.graph.nodes:
|
||||
if node.target is torch.ops.aten.sym_size.int:
|
||||
if node.target is torch.ops.aten.sym_size:
|
||||
has_sym_size = True
|
||||
|
||||
self.assertTrue(has_sym_size)
|
||||
|
|
@ -3192,19 +3192,19 @@ def forward(self, x):
|
|||
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
arg0_1 = arg0
|
||||
slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3)
|
||||
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
|
||||
sub = sym_size_int - 1
|
||||
sym_size = torch.ops.aten.sym_size(arg0_1, 0)
|
||||
sub = sym_size - 1
|
||||
slice_2 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub); sub = None
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 2)
|
||||
slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int_1); slice_2 = None
|
||||
sym_size_1 = torch.ops.aten.sym_size(arg0_1, 2)
|
||||
slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_1); slice_2 = None
|
||||
slice_4 = torch.ops.aten.slice.Tensor(slice_3, 2, 1, 3); slice_3 = None
|
||||
sub_1 = sym_size_int - 2
|
||||
sub_1 = sym_size - 2
|
||||
slice_5 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_1); sub_1 = None
|
||||
slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 2, sym_size_int_1); slice_5 = None
|
||||
slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 2, sym_size_1); slice_5 = None
|
||||
slice_7 = torch.ops.aten.slice.Tensor(slice_6, 2, 2, 3); slice_6 = None
|
||||
sub_2 = sym_size_int - 3; sym_size_int = None
|
||||
sub_2 = sym_size - 3; sym_size = None
|
||||
slice_8 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_2); arg0_1 = sub_2 = None
|
||||
slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 3, sym_size_int_1); slice_8 = sym_size_int_1 = None
|
||||
slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 3, sym_size_1); slice_8 = sym_size_1 = None
|
||||
slice_10 = torch.ops.aten.slice.Tensor(slice_9, 2, 3, 3); slice_9 = None
|
||||
return pytree.tree_unflatten([slice_1, slice_4, slice_7, slice_10], self._out_spec)""",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2628,8 +2628,8 @@ class TestPartitioning(AOTTestCase):
|
|||
self.assertEqual(str(fw_output[0]), "sum_1")
|
||||
# make sure we don't do the suboptimal thing of saving the bigger primals input to sum,
|
||||
# rather than saving the sizes of the primals input for use in backward expand
|
||||
self.assertEqual(str(fw_output[1]), "sym_size_int")
|
||||
self.assertEqual(str(fw_output[2]), "sym_size_int_1")
|
||||
self.assertEqual(str(fw_output[1]), "sym_size")
|
||||
self.assertEqual(str(fw_output[2]), "sym_size_1")
|
||||
|
||||
inp = [
|
||||
torch.randn(10, requires_grad=True),
|
||||
|
|
|
|||
|
|
@ -1255,8 +1255,8 @@ def forward(self, arg0_1):
|
|||
|
||||
self.assertExpectedInline(gm.code.strip(), """\
|
||||
def forward(self, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
eq = sym_size_int == 4; sym_size_int = None
|
||||
sym_size = torch.ops.aten.sym_size(x_1, 0)
|
||||
eq = sym_size == 4; sym_size = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
||||
|
|
@ -1515,8 +1515,8 @@ def forward(self, arg0_1, arg1_1):
|
|||
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(3, 4))
|
||||
self.assertExpectedInline(gm.code.strip(), """\
|
||||
def forward(self, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
eq = sym_size_int == 4; sym_size_int = None
|
||||
sym_size = torch.ops.aten.sym_size(x_1, 0)
|
||||
eq = sym_size == 4; sym_size = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
||||
|
|
|
|||
|
|
@ -1468,10 +1468,10 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(x_1, 2)
|
||||
numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); x_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = None
|
||||
sym_size = torch.ops.aten.sym_size(x_1, 0)
|
||||
sym_size_1 = torch.ops.aten.sym_size(x_1, 1)
|
||||
sym_size_2 = torch.ops.aten.sym_size(x_1, 2)
|
||||
numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size, sym_size_1, sym_size_2]); x_1 = sym_size = sym_size_1 = sym_size_2 = None
|
||||
return numpy_view_copy""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -453,10 +453,10 @@ class TestPySymInt(TestCase):
|
|||
self.assertEqual(len(gm.shape_env.guards), 0)
|
||||
self.assertExpectedInline(gm.code.strip(), """\
|
||||
def forward(self, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
eq = sym_size_int == 5
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
|
||||
sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1); eq = sym_size_int = sym_size_int_1 = None
|
||||
sym_size = torch.ops.aten.sym_size(x_1, 0)
|
||||
eq = sym_size == 5
|
||||
sym_size_1 = torch.ops.aten.sym_size(x_1, 1); x_1 = None
|
||||
sym_ite = torch.sym_ite(eq, sym_size, sym_size_1); eq = sym_size = sym_size_1 = None
|
||||
return sym_ite""")
|
||||
r1 = gm(torch.ones(4, 5))
|
||||
self.assertIsInstance(r1, int)
|
||||
|
|
@ -606,12 +606,12 @@ def forward(self, x_1):
|
|||
class f(torch.nn.Module):
|
||||
def forward(self, a_1: f32[s0, s1], b_1: f32[s2, s1]):
|
||||
# No stacktrace found for following nodes
|
||||
sym_size_int: Sym(s0) = torch.ops.aten.sym_size.int(a_1, 0)
|
||||
sym_size_int_1: Sym(s2) = torch.ops.aten.sym_size.int(b_1, 0)
|
||||
add: Sym(s0 + s2) = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
|
||||
sym_size_int_2: Sym(s1) = torch.ops.aten.sym_size.int(a_1, 1)
|
||||
sym_size_int_3: Sym(s1) = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
|
||||
add_1: Sym(2*s1) = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
|
||||
sym_size: Sym(s0) = torch.ops.aten.sym_size(a_1, 0)
|
||||
sym_size_1: Sym(s2) = torch.ops.aten.sym_size(b_1, 0)
|
||||
add: Sym(s0 + s2) = sym_size + sym_size_1; sym_size = sym_size_1 = None
|
||||
sym_size_2: Sym(s1) = torch.ops.aten.sym_size(a_1, 1)
|
||||
sym_size_3: Sym(s1) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
|
||||
add_1: Sym(2*s1) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
|
||||
new_empty: f32[s0 + s2, 2*s1] = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
|
||||
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
|
||||
getitem: f32[s0 + s2, 2*s1] = native_dropout[0]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Owner(s): ["module: ProxyTensor"]
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, xfail_inherited_tests
|
||||
import torch
|
||||
import unittest
|
||||
import warnings
|
||||
|
|
@ -747,6 +747,9 @@ class TestGenericProxyTensorFake(TestGenericProxyTensor):
|
|||
tracing_mode = "fake"
|
||||
|
||||
|
||||
@xfail_inherited_tests([
|
||||
"test_make_fx_overloads",
|
||||
])
|
||||
class TestGenericProxyTensorSymbolic(TestGenericProxyTensor):
|
||||
tracing_mode = "symbolic"
|
||||
|
||||
|
|
@ -930,8 +933,8 @@ class TestSymbolicTracing(TestCase):
|
|||
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, x_1, y_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
|
||||
resize_ = torch.ops.aten.resize_.default(x_1, [sym_size_int]); x_1 = sym_size_int = None
|
||||
sym_size = torch.ops.aten.sym_size(y_1, 0); y_1 = None
|
||||
resize_ = torch.ops.aten.resize_.default(x_1, [sym_size]); x_1 = sym_size = None
|
||||
return None""")
|
||||
|
||||
|
||||
|
|
@ -1065,8 +1068,8 @@ def forward(self, a_1, b_1):
|
|||
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, a_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None
|
||||
mul = sym_size_int * 2; sym_size_int = None
|
||||
sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None
|
||||
mul = sym_size * 2; sym_size = None
|
||||
empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
|
||||
return empty""")
|
||||
|
||||
|
|
@ -1125,8 +1128,8 @@ def forward(self, a_1):
|
|||
self.assertExpectedInline(
|
||||
r, """\
|
||||
def forward(self, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); sym_size_int = None
|
||||
sym_size = torch.ops.aten.sym_size(x_1, 0)
|
||||
scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); sym_size = None
|
||||
select = torch.ops.aten.select.int(x_1, 0, 0)
|
||||
copy_ = torch.ops.aten.copy_.default(select, scalar_tensor); select = scalar_tensor = None
|
||||
return x_1""" # noqa: B950
|
||||
|
|
@ -1171,21 +1174,21 @@ def forward(self, crop_camera_1, mask_1):
|
|||
select = torch.ops.aten.select.int(eye, 0, 0)
|
||||
select_1 = torch.ops.aten.select.int(select, 0, 0); select = None
|
||||
copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(index, 0)
|
||||
expand = torch.ops.aten.expand.default(eye, [sym_size_int, 3, 3])
|
||||
view = torch.ops.aten.view.default(expand, [sym_size_int, 3, 3]); expand = None
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(crop_camera_1, 1)
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(crop_camera_1, 2)
|
||||
expand_1 = torch.ops.aten.expand.default(index, [sym_size_int, sym_size_int_1, sym_size_int_2]); index = None
|
||||
view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None
|
||||
sym_size = torch.ops.aten.sym_size(index, 0)
|
||||
expand = torch.ops.aten.expand.default(eye, [sym_size, 3, 3])
|
||||
view = torch.ops.aten.view.default(expand, [sym_size, 3, 3]); expand = None
|
||||
sym_size_1 = torch.ops.aten.sym_size(crop_camera_1, 1)
|
||||
sym_size_2 = torch.ops.aten.sym_size(crop_camera_1, 2)
|
||||
expand_1 = torch.ops.aten.expand.default(index, [sym_size, sym_size_1, sym_size_2]); index = None
|
||||
view_1 = torch.ops.aten.view.default(expand_1, [sym_size, sym_size_1, sym_size_2]); expand_1 = sym_size_1 = sym_size_2 = None
|
||||
bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
|
||||
view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None
|
||||
mul = sym_size_int * 3
|
||||
view_2 = torch.ops.aten.view.default(bmm, [sym_size, 3, 3]); bmm = None
|
||||
mul = sym_size * 3
|
||||
view_3 = torch.ops.aten.view.default(view_2, [mul, 3]); view_2 = mul = None
|
||||
mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None
|
||||
view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None
|
||||
view_4 = torch.ops.aten.view.default(mm, [sym_size, 3, 3]); mm = sym_size = None
|
||||
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4); crop_camera_1 = mask_1 = view_4 = None
|
||||
return None""") # noqa: B950
|
||||
return None""")
|
||||
|
||||
def test_unbacked_slice(self):
|
||||
def f(x, m):
|
||||
|
|
@ -1247,8 +1250,8 @@ def forward(self, images_1, handedness_1, valid_1):
|
|||
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, a_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None
|
||||
neg = -sym_size_int; sym_size_int = None
|
||||
sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None
|
||||
neg = -sym_size; sym_size = None
|
||||
add = neg + 10; neg = None
|
||||
empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None
|
||||
return empty""")
|
||||
|
|
@ -1361,8 +1364,8 @@ def forward(self, lengths_1, values_1):
|
|||
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, a_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
|
||||
pow_1 = sym_size_int ** 0.5; sym_size_int = None
|
||||
sym_size = torch.ops.aten.sym_size(a_1, 0)
|
||||
pow_1 = sym_size ** 0.5; sym_size = None
|
||||
div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
|
||||
return div""")
|
||||
|
||||
|
|
@ -1374,15 +1377,15 @@ def forward(self, a_1):
|
|||
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, a_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
|
||||
div = torch.ops.aten.div.Tensor(a_1, sym_size_int); a_1 = sym_size_int = None
|
||||
sym_size = torch.ops.aten.sym_size(a_1, 0)
|
||||
div = torch.ops.aten.div.Tensor(a_1, sym_size); a_1 = sym_size = None
|
||||
return div""")
|
||||
|
||||
r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, a_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
|
||||
sym_float = torch.sym_float(sym_size_int); sym_size_int = None
|
||||
sym_size = torch.ops.aten.sym_size(a_1, 0)
|
||||
sym_float = torch.sym_float(sym_size); sym_size = None
|
||||
div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None
|
||||
return div""")
|
||||
|
||||
|
|
|
|||
|
|
@ -2096,9 +2096,9 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
|||
fx_g = make_fx(trace_fn, tracing_mode="symbolic")(x)
|
||||
self.assertExpectedInline(fx_g.code.strip(), """\
|
||||
def forward(self, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
|
||||
return ((sym_size_int, sym_size_int_1), (sym_size_int, sym_size_int_1))""")
|
||||
sym_size = torch.ops.aten.sym_size(x_1, 0)
|
||||
sym_size_1 = torch.ops.aten.sym_size(x_1, 1); x_1 = None
|
||||
return ((sym_size, sym_size_1), (sym_size, sym_size_1))""")
|
||||
|
||||
def test_data_ptr_respects_numel_slow_path(self):
|
||||
data = torch.randn(6, 2)
|
||||
|
|
|
|||
|
|
@ -545,10 +545,8 @@ class TorchVariable(VariableTracker):
|
|||
),
|
||||
**options,
|
||||
)
|
||||
# TODO: These special cases shouldn't be necessary; we should
|
||||
# generically support torch.ops that return int
|
||||
elif (
|
||||
self.value in [torch.ops.aten.sym_size, torch.ops.aten.sym_size.int]
|
||||
self.value is torch.ops.aten.sym_size
|
||||
and len(args) == 2
|
||||
and len(kwargs) == 0
|
||||
and isinstance(args[0], TensorVariable)
|
||||
|
|
@ -556,7 +554,7 @@ class TorchVariable(VariableTracker):
|
|||
# we see this when retracing already traced code
|
||||
return args[0].call_method(tx, "size", [args[1]], {})
|
||||
elif (
|
||||
self.value is [torch.ops.aten.sym_stride, torch.ops.aten.sym_stride.int]
|
||||
self.value is torch.ops.aten.sym_stride
|
||||
and len(args) == 2
|
||||
and len(kwargs) == 0
|
||||
and isinstance(args[0], TensorVariable)
|
||||
|
|
|
|||
|
|
@ -170,13 +170,13 @@ def track_tensor(tensor, proxy, *, constant, tracer):
|
|||
# (so that if we have multiple tracers at the same time, they
|
||||
# don't clobber each other.)
|
||||
for i, s in enumerate(tensor.shape):
|
||||
try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_size.int(proxy, i), x), i)
|
||||
try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_size(proxy, i), x), i)
|
||||
|
||||
for i, s in enumerate(tensor.stride()):
|
||||
try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_stride.int(proxy, i), x), i)
|
||||
try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_stride(proxy, i), x), i)
|
||||
|
||||
try_set_proxy_slot(tensor.numel(), lambda x: set_meta(torch.ops.aten.sym_numel.default(proxy), x))
|
||||
try_set_proxy_slot(tensor.storage_offset(), lambda x: set_meta(torch.ops.aten.sym_storage_offset.default(proxy), x))
|
||||
try_set_proxy_slot(tensor.numel(), lambda x: set_meta(torch.ops.aten.sym_numel(proxy), x))
|
||||
try_set_proxy_slot(tensor.storage_offset(), lambda x: set_meta(torch.ops.aten.sym_storage_offset(proxy), x))
|
||||
set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant))
|
||||
|
||||
def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
|
||||
|
|
|
|||
Loading…
Reference in a new issue