From 9bca8405aa8cd1be3aab5fb1bc7a374a37d80835 Mon Sep 17 00:00:00 2001 From: msftlincoln <107071614+msftlincoln@users.noreply.github.com> Date: Fri, 15 Jul 2022 12:59:56 -0400 Subject: [PATCH] bitwise_and ONNX support (#12189) * bitwise_and ONNX support * whitespace lint --- .../orttraining/eager/opgen/opgen/atenops.py | 2 +- orttraining/orttraining/eager/test/ort_ops.py | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py index 0dd125030d..62b8678b33 100644 --- a/orttraining/orttraining/eager/opgen/opgen/atenops.py +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -158,7 +158,7 @@ hand_implemented = { "aten::ne.Tensor_out": Cast(Not(Equal("self", "other")), to="GetONNXTensorProtoDataType(out.scalar_type())"), "aten::eq.Tensor_out": Cast(Equal("self", "other"), to="GetONNXTensorProtoDataType(out.scalar_type())"), "aten::eq.Scalar_out": Cast(Equal("self", "other"), to="GetONNXTensorProtoDataType(out.scalar_type())"), - "aten::bitwise_and.Tensor_out": MakeTorchFallback(), + "aten::bitwise_and.Tensor_out": And("self", "other"), # This generates a fallback for all but Bool, as expected. "aten::masked_select": GatherND("self", Transpose(NonZero(Expand("mask", Shape("self"))))), "aten::_local_scalar_dense": MakeTorchFallback(), # This function extracts a scalar value from # a tensor with exactly one value; there's no need to try to do this on an ORT device. diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index 597911efa9..d278b865d6 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -272,6 +272,27 @@ class OrtOpTests(unittest.TestCase): assert torch.allclose(cpu_result, ort_result.cpu()) assert cpu_result.dim() == ort_result.dim() + def test_bitwise_and(self): + device = self.get_device() + cpu_a = torch.tensor([[0], [1], [1]], dtype=bool) + cpu_b = torch.tensor([[1], [0], [1]], dtype=bool) + ort_a = cpu_a.to(device) + ort_b = cpu_b.to(device) + cpu_result = torch.bitwise_and(cpu_a, cpu_b) + ort_result = torch.bitwise_and(ort_a, ort_b) + assert torch.equal(cpu_result, ort_result.cpu()) + + def test_bitwise_and_fallback(self): + device = self.get_device() + # use randint because bitwise_and is not supported on floats + cpu_a = torch.randint(200, (3, 4)) + cpu_b = torch.randint(200, (3, 4)) + ort_a = cpu_a.to(device) + ort_b = cpu_b.to(device) + cpu_result = torch.bitwise_and(cpu_a, cpu_b) + ort_result = torch.bitwise_and(ort_a, ort_b) + assert torch.equal(cpu_result, ort_result.cpu()) + # @parameterized.expand generate test methods for ops and using name_func we renaming the test to be test_{ops} @parameterized.expand(ops, name_func=rename_func_to_op) def test_op(self, test_name, tensor_test=torch.rand(6)):