diff --git a/test/export/test_export.py b/test/export/test_export.py index 9b070141ac3..cac1ccf148a 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -11564,6 +11564,24 @@ class GraphModule(torch.nn.Module): ref_res = module(*dyn_inp) self.assertEqual(export_res, ref_res) + @testing.expectedFailureSerDer # T202237665 + @testing.expectedFailureSerDerNonStrict + def test_dynamic_lr_shift(self): + class Module(torch.nn.Module): + def forward(self, x): + rshift = x.shape[0] >> 1 + lshift = x.shape[0] << 1 + return x[:rshift], x[:lshift] + + dynamic_shapes = {"x": {0: Dim("N", min=5, max=10)}} + inp = (torch.randn(8),) + ep = export(Module(), inp, dynamic_shapes=dynamic_shapes) + for op in (operator.lshift, operator.rshift): + shift_op = [ + n for n in ep.graph.nodes if n.op == "call_function" and n.target == op + ] + self.assertEqual(len(shift_op), 1) + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestOneOffModelExportResult(TestCase): diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index ad5380b04c9..06cd2b657a4 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -134,6 +134,8 @@ class Verifier(metaclass=_VerifierMeta): operator.pow, operator.neg, operator.abs, + operator.lshift, + operator.rshift, math.ceil, math.floor, math.trunc,