From e3df6a7c8adb8a3c596e021b48890ea9d6324020 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 27 Mar 2023 06:12:43 +0000 Subject: [PATCH] [Dynamo] Unspec int list if enabling dynamic_shapes (#97557) Fixes #97348 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97557 Approved by: https://github.com/ezyang, https://github.com/jansel --- test/dynamo/test_misc.py | 18 +++++++++++++++++- torch/_dynamo/variables/builder.py | 8 ++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 61a5bbab56a..7af5bcd880c 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -32,7 +32,7 @@ from torch._dynamo.testing import ( unsupported, ) -from torch._dynamo.utils import CompileProfiler, ifunspec +from torch._dynamo.utils import CompileProfiler, ifdyn, ifunspec from torch.ao.quantization import MinMaxObserver from torch.ao.quantization.fake_quantize import FakeQuantize from torch.ao.quantization.qconfig import QConfig @@ -4536,6 +4536,22 @@ class MiscTests(torch._dynamo.test_case.TestCase): torch.randn([4, 4]), torch.randn([4, 4]), (4, 4) ) + def test_int_list(self): + # if dynamic_shapes == True: unspec int list + # if dynamic_shapes == False: spec int list + def fn(x, y): + return torch.sin(x + y[1] % 2) + + x = torch.randn(6) + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt)(fn) + for i in range(10, 25, 3): + y = [i, i + 1, i + 2] + ref = fn(x, y) + res = opt_fn(x, y) + self.assertTrue(same(ref, res)) + self.assertEqual(cnt.frame_count, ifunspec(ifdyn(1, 5), 5)) + # specifically test for tensor.attribute -> torch.something() def test_real_imag_tensor_attribute(self): def fn(x, y): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index bc6721545a7..bb033301d63 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -582,8 +582,12 @@ class VariableBuilder: def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): # One can index a tensor with a list/tuple. Therefore, we need to # have a stricter match. - if istype(value, (tuple, list)) and all( - [isinstance(x, int) or is_numpy_int_type(x) or x is None for x in value] + if ( + istype(value, (tuple, list)) + and all( + [isinstance(x, int) or is_numpy_int_type(x) or x is None for x in value] + ) + and not config.dynamic_shapes ): guards = self.make_guards(GuardBuilder.EQUALS_MATCH) else: