mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
bitwise_and ONNX support (#12189)
* bitwise_and ONNX support * whitespace lint
This commit is contained in:
parent
89bf6c9b5d
commit
9bca8405aa
2 changed files with 22 additions and 1 deletions
|
|
@ -158,7 +158,7 @@ hand_implemented = {
|
|||
"aten::ne.Tensor_out": Cast(Not(Equal("self", "other")), to="GetONNXTensorProtoDataType(out.scalar_type())"),
|
||||
"aten::eq.Tensor_out": Cast(Equal("self", "other"), to="GetONNXTensorProtoDataType(out.scalar_type())"),
|
||||
"aten::eq.Scalar_out": Cast(Equal("self", "other"), to="GetONNXTensorProtoDataType(out.scalar_type())"),
|
||||
"aten::bitwise_and.Tensor_out": MakeTorchFallback(),
|
||||
"aten::bitwise_and.Tensor_out": And("self", "other"), # This generates a fallback for all but Bool, as expected.
|
||||
"aten::masked_select": GatherND("self", Transpose(NonZero(Expand("mask", Shape("self"))))),
|
||||
"aten::_local_scalar_dense": MakeTorchFallback(), # This function extracts a scalar value from
|
||||
# a tensor with exactly one value; there's no need to try to do this on an ORT device.
|
||||
|
|
|
|||
|
|
@ -272,6 +272,27 @@ class OrtOpTests(unittest.TestCase):
|
|||
assert torch.allclose(cpu_result, ort_result.cpu())
|
||||
assert cpu_result.dim() == ort_result.dim()
|
||||
|
||||
def test_bitwise_and(self):
|
||||
device = self.get_device()
|
||||
cpu_a = torch.tensor([[0], [1], [1]], dtype=bool)
|
||||
cpu_b = torch.tensor([[1], [0], [1]], dtype=bool)
|
||||
ort_a = cpu_a.to(device)
|
||||
ort_b = cpu_b.to(device)
|
||||
cpu_result = torch.bitwise_and(cpu_a, cpu_b)
|
||||
ort_result = torch.bitwise_and(ort_a, ort_b)
|
||||
assert torch.equal(cpu_result, ort_result.cpu())
|
||||
|
||||
def test_bitwise_and_fallback(self):
|
||||
device = self.get_device()
|
||||
# use randint because bitwise_and is not supported on floats
|
||||
cpu_a = torch.randint(200, (3, 4))
|
||||
cpu_b = torch.randint(200, (3, 4))
|
||||
ort_a = cpu_a.to(device)
|
||||
ort_b = cpu_b.to(device)
|
||||
cpu_result = torch.bitwise_and(cpu_a, cpu_b)
|
||||
ort_result = torch.bitwise_and(ort_a, ort_b)
|
||||
assert torch.equal(cpu_result, ort_result.cpu())
|
||||
|
||||
# @parameterized.expand generate test methods for ops and using name_func we renaming the test to be test_{ops}
|
||||
@parameterized.expand(ops, name_func=rename_func_to_op)
|
||||
def test_op(self, test_name, tensor_test=torch.rand(6)):
|
||||
|
|
|
|||
Loading…
Reference in a new issue