[export] allow bit shift builtin ops (#145802)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145802
Approved by: https://github.com/pianpwk
This commit is contained in:
Colin Peppler 2025-01-28 09:27:44 -08:00 committed by PyTorch MergeBot
parent f4ca98950e
commit 50f834f134
2 changed files with 20 additions and 0 deletions

View file

@ -11564,6 +11564,24 @@ class GraphModule(torch.nn.Module):
ref_res = module(*dyn_inp)
self.assertEqual(export_res, ref_res)
@testing.expectedFailureSerDer # T202237665
@testing.expectedFailureSerDerNonStrict
def test_dynamic_lr_shift(self):
class Module(torch.nn.Module):
def forward(self, x):
rshift = x.shape[0] >> 1
lshift = x.shape[0] << 1
return x[:rshift], x[:lshift]
dynamic_shapes = {"x": {0: Dim("N", min=5, max=10)}}
inp = (torch.randn(8),)
ep = export(Module(), inp, dynamic_shapes=dynamic_shapes)
for op in (operator.lshift, operator.rshift):
shift_op = [
n for n in ep.graph.nodes if n.op == "call_function" and n.target == op
]
self.assertEqual(len(shift_op), 1)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOneOffModelExportResult(TestCase):

View file

@ -134,6 +134,8 @@ class Verifier(metaclass=_VerifierMeta):
operator.pow,
operator.neg,
operator.abs,
operator.lshift,
operator.rshift,
math.ceil,
math.floor,
math.trunc,