mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Dynamo] Unspec int list if enabling dynamic_shapes (#97557)
Fixes #97348 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97557 Approved by: https://github.com/ezyang, https://github.com/jansel
This commit is contained in:
parent
542fb0b1fa
commit
e3df6a7c8a
2 changed files with 23 additions and 3 deletions
|
|
@ -32,7 +32,7 @@ from torch._dynamo.testing import (
|
|||
unsupported,
|
||||
)
|
||||
|
||||
from torch._dynamo.utils import CompileProfiler, ifunspec
|
||||
from torch._dynamo.utils import CompileProfiler, ifdyn, ifunspec
|
||||
from torch.ao.quantization import MinMaxObserver
|
||||
from torch.ao.quantization.fake_quantize import FakeQuantize
|
||||
from torch.ao.quantization.qconfig import QConfig
|
||||
|
|
@ -4536,6 +4536,22 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
torch.randn([4, 4]), torch.randn([4, 4]), (4, 4)
|
||||
)
|
||||
|
||||
def test_int_list(self):
|
||||
# if dynamic_shapes == True: unspec int list
|
||||
# if dynamic_shapes == False: spec int list
|
||||
def fn(x, y):
|
||||
return torch.sin(x + y[1] % 2)
|
||||
|
||||
x = torch.randn(6)
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
||||
for i in range(10, 25, 3):
|
||||
y = [i, i + 1, i + 2]
|
||||
ref = fn(x, y)
|
||||
res = opt_fn(x, y)
|
||||
self.assertTrue(same(ref, res))
|
||||
self.assertEqual(cnt.frame_count, ifunspec(ifdyn(1, 5), 5))
|
||||
|
||||
# specifically test for tensor.attribute -> torch.something()
|
||||
def test_real_imag_tensor_attribute(self):
|
||||
def fn(x, y):
|
||||
|
|
|
|||
|
|
@ -582,8 +582,12 @@ class VariableBuilder:
|
|||
def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
|
||||
# One can index a tensor with a list/tuple. Therefore, we need to
|
||||
# have a stricter match.
|
||||
if istype(value, (tuple, list)) and all(
|
||||
[isinstance(x, int) or is_numpy_int_type(x) or x is None for x in value]
|
||||
if (
|
||||
istype(value, (tuple, list))
|
||||
and all(
|
||||
[isinstance(x, int) or is_numpy_int_type(x) or x is None for x in value]
|
||||
)
|
||||
and not config.dynamic_shapes
|
||||
):
|
||||
guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in a new issue