From c82160bbd0e74a446dab17440dec87cdaaab5d6f Mon Sep 17 00:00:00 2001 From: Gani Nazirov Date: Mon, 13 Dec 2021 14:36:15 -0800 Subject: [PATCH] Add AtenOp at:bitwise_or (#9662) * Add AtenOp at:bitwise_or * Specify overload name for bitwise_or * undo unnecessary import * set output element type to BOOL * Add broadcasting support * Fix test Co-authored-by: Gani Nazirov Co-authored-by: Gani Nazirov --- .../python/tools/symbolic_shape_infer.py | 11 ++++++ .../ortmodule/_custom_op_symbolic_registry.py | 4 +++ .../python/orttraining_test_ortmodule_api.py | 35 +++++++++++++++++-- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index cd5d30638e..ad878b4ed8 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -189,6 +189,7 @@ class SymbolicShapeInference: } self.aten_op_dispatcher_ = { 'aten::embedding': self._infer_Gather, + 'aten::bitwise_or': self._infer_aten_bitwise_or, 'aten::diagonal': self._infer_aten_diagonal, 'aten::max_pool2d_with_indices': self._infer_aten_pool2d, 'aten::multinomial': self._infer_aten_multinomial, @@ -1080,6 +1081,16 @@ class SymbolicShapeInference: helper.make_tensor_value_info(o, vi.type.tensor_type.elem_type, get_shape_from_sympy_shape(sympy_shape))) + def _infer_aten_bitwise_or(self, node): + shape0 = self._get_shape(node, 0) + shape1 = self._get_shape(node, 1) + new_shape = self._broadcast_shapes(shape0, shape1) + t0 = self.known_vi_[node.input[0]] + vi = self.known_vi_[node.output[0]] + vi.CopyFrom( + helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, + new_shape)) + def _infer_aten_diagonal(self, node): sympy_shape = self._get_sympy_shape(node, 0) rank = len(sympy_shape) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 3527e643ea..f193129a0a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -73,6 +73,10 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): output.setType(output_type) return output +@register_symbolic('bitwise_or') +def bitwise_or(g, self, other): + return g.op("com.microsoft::ATenOp", self, other, + name_s='aten::bitwise_or', overload_name_s='Tensor') @register_symbolic('diagonal') def diagonal(g, self, offset, dim1, dim2): diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 595540a78a..0964e47f21 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -664,7 +664,7 @@ def test_gradient_correctness(): _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) @pytest.mark.parametrize("device", ['cpu', 'cuda']) -@pytest.mark.parametrize("indices", ([[ 2, 3, -1, -1],[0, 1, -1, -1]], +@pytest.mark.parametrize("indices", ([[ 2, 3, -1, -1],[0, 1, -1, -1]], [[ 2, 3, 4, 4],[ 0, 1, 4, 4]])) def test_scatternd_correctness(device, indices): class NeuralNetScatterND(torch.nn.Module): @@ -685,7 +685,7 @@ def test_scatternd_correctness(device, indices): rerouted_output = torch.tensor([[0.],[0.],[0.],[0.],[0.]], device=device) dispatch_mask = torch.tensor(indices, device=device) expert_output = torch.tensor([[[0.3817],[0.9625],[0.9625],[0.9625]],[[0.3817],[0.9625],[0.9625],[0.9625]]], device=device) - + pt_prediction = run_step(pt_model, rerouted_output, dispatch_mask, expert_output) ort_prediction = run_step(ort_model, rerouted_output, dispatch_mask, expert_output) _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-5) @@ -1035,6 +1035,35 @@ def test_gradient_correctness_argmax_unfold(): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) +@pytest.mark.parametrize("high", [1, 2, 10]) +def test_correctness_argmax_bitwise_or(high): + N, D, H, M = 16, 256, 128, 4 + device = 'cuda' + + class NeuralNetBitwiseOr(torch.nn.Module): + def __init__(self, high): + super(NeuralNetBitwiseOr, self).__init__() + self.other = torch.randint(0, high, (N, D, H), device=device) + + def forward(self, input): + return torch.bitwise_or(self.other, input) + + pt_model = NeuralNetBitwiseOr(high).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, input): + prediction = model(input) + return prediction + + for _ in range(10): + # this also tests broadcasting + pt_input = torch.randint(-10, 10, (M, N, D, H), device=device) + ort_input = copy.deepcopy(pt_input) + pt_prediction = run_step(pt_model, pt_input) + ort_prediction = run_step(ort_model, ort_input) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + @pytest.mark.parametrize("offset", [-1, 0, 1]) @pytest.mark.parametrize("dim1, dim2", ([0, 1], [0, 2], [1, 2], [2, 0])) def test_gradient_correctness_argmax_diagonal(offset, dim1, dim2): @@ -4534,7 +4563,7 @@ def test_sigmoid_grad_opset13(): os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = old_opset assert ortmodule.ONNX_OPSET_VERSION == 13 ortmodule.ONNX_OPSET_VERSION = old_opst_cst - + @pytest.mark.parametrize("opset_version", [12, 13, 14]) def test_opset_version_change(opset_version): device = 'cuda'