bitwise_and ONNX support (#12189)

* bitwise_and ONNX support

* whitespace lint
This commit is contained in:
msftlincoln 2022-07-15 12:59:56 -04:00 committed by GitHub
parent 89bf6c9b5d
commit 9bca8405aa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 1 deletions

View file

@ -158,7 +158,7 @@ hand_implemented = {
"aten::ne.Tensor_out": Cast(Not(Equal("self", "other")), to="GetONNXTensorProtoDataType(out.scalar_type())"),
"aten::eq.Tensor_out": Cast(Equal("self", "other"), to="GetONNXTensorProtoDataType(out.scalar_type())"),
"aten::eq.Scalar_out": Cast(Equal("self", "other"), to="GetONNXTensorProtoDataType(out.scalar_type())"),
"aten::bitwise_and.Tensor_out": MakeTorchFallback(),
"aten::bitwise_and.Tensor_out": And("self", "other"), # This generates a fallback for all but Bool, as expected.
"aten::masked_select": GatherND("self", Transpose(NonZero(Expand("mask", Shape("self"))))),
"aten::_local_scalar_dense": MakeTorchFallback(), # This function extracts a scalar value from
# a tensor with exactly one value; there's no need to try to do this on an ORT device.

View file

@ -272,6 +272,27 @@ class OrtOpTests(unittest.TestCase):
assert torch.allclose(cpu_result, ort_result.cpu())
assert cpu_result.dim() == ort_result.dim()
def test_bitwise_and(self):
device = self.get_device()
cpu_a = torch.tensor([[0], [1], [1]], dtype=bool)
cpu_b = torch.tensor([[1], [0], [1]], dtype=bool)
ort_a = cpu_a.to(device)
ort_b = cpu_b.to(device)
cpu_result = torch.bitwise_and(cpu_a, cpu_b)
ort_result = torch.bitwise_and(ort_a, ort_b)
assert torch.equal(cpu_result, ort_result.cpu())
def test_bitwise_and_fallback(self):
device = self.get_device()
# use randint because bitwise_and is not supported on floats
cpu_a = torch.randint(200, (3, 4))
cpu_b = torch.randint(200, (3, 4))
ort_a = cpu_a.to(device)
ort_b = cpu_b.to(device)
cpu_result = torch.bitwise_and(cpu_a, cpu_b)
ort_result = torch.bitwise_and(ort_a, ort_b)
assert torch.equal(cpu_result, ort_result.cpu())
# @parameterized.expand generate test methods for ops and using name_func we renaming the test to be test_{ops}
@parameterized.expand(ops, name_func=rename_func_to_op)
def test_op(self, test_name, tensor_test=torch.rand(6)):