mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[export] allow bit shift builtin ops (#145802)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145802 Approved by: https://github.com/pianpwk
This commit is contained in:
parent
f4ca98950e
commit
50f834f134
2 changed files with 20 additions and 0 deletions
|
|
@ -11564,6 +11564,24 @@ class GraphModule(torch.nn.Module):
|
|||
ref_res = module(*dyn_inp)
|
||||
self.assertEqual(export_res, ref_res)
|
||||
|
||||
@testing.expectedFailureSerDer # T202237665
|
||||
@testing.expectedFailureSerDerNonStrict
|
||||
def test_dynamic_lr_shift(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
rshift = x.shape[0] >> 1
|
||||
lshift = x.shape[0] << 1
|
||||
return x[:rshift], x[:lshift]
|
||||
|
||||
dynamic_shapes = {"x": {0: Dim("N", min=5, max=10)}}
|
||||
inp = (torch.randn(8),)
|
||||
ep = export(Module(), inp, dynamic_shapes=dynamic_shapes)
|
||||
for op in (operator.lshift, operator.rshift):
|
||||
shift_op = [
|
||||
n for n in ep.graph.nodes if n.op == "call_function" and n.target == op
|
||||
]
|
||||
self.assertEqual(len(shift_op), 1)
|
||||
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
class TestOneOffModelExportResult(TestCase):
|
||||
|
|
|
|||
|
|
@ -134,6 +134,8 @@ class Verifier(metaclass=_VerifierMeta):
|
|||
operator.pow,
|
||||
operator.neg,
|
||||
operator.abs,
|
||||
operator.lshift,
|
||||
operator.rshift,
|
||||
math.ceil,
|
||||
math.floor,
|
||||
math.trunc,
|
||||
|
|
|
|||
Loading…
Reference in a new issue