Add AtenOp at:bitwise_or (#9662)

* Add AtenOp at:bitwise_or

* Specify overload name for bitwise_or

* undo unnecessary import

* set output element type to BOOL

* Add broadcasting support

* Fix test

Co-authored-by: Gani Nazirov <ganaziro@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
Co-authored-by: Gani Nazirov <ganaziro@microsoft.com>
This commit is contained in:
Gani Nazirov 2021-12-13 14:36:15 -08:00 committed by GitHub
parent ad99dff298
commit c82160bbd0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 3 deletions

View file

@ -189,6 +189,7 @@ class SymbolicShapeInference:
}
self.aten_op_dispatcher_ = {
'aten::embedding': self._infer_Gather,
'aten::bitwise_or': self._infer_aten_bitwise_or,
'aten::diagonal': self._infer_aten_diagonal,
'aten::max_pool2d_with_indices': self._infer_aten_pool2d,
'aten::multinomial': self._infer_aten_multinomial,
@ -1080,6 +1081,16 @@ class SymbolicShapeInference:
helper.make_tensor_value_info(o, vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(sympy_shape)))
def _infer_aten_bitwise_or(self, node):
shape0 = self._get_shape(node, 0)
shape1 = self._get_shape(node, 1)
new_shape = self._broadcast_shapes(shape0, shape1)
t0 = self.known_vi_[node.input[0]]
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type,
new_shape))
def _infer_aten_diagonal(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
rank = len(sympy_shape)

View file

@ -73,6 +73,10 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
output.setType(output_type)
return output
@register_symbolic('bitwise_or')
def bitwise_or(g, self, other):
return g.op("com.microsoft::ATenOp", self, other,
name_s='aten::bitwise_or', overload_name_s='Tensor')
@register_symbolic('diagonal')
def diagonal(g, self, offset, dim1, dim2):

View file

@ -664,7 +664,7 @@ def test_gradient_correctness():
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
@pytest.mark.parametrize("device", ['cpu', 'cuda'])
@pytest.mark.parametrize("indices", ([[ 2, 3, -1, -1],[0, 1, -1, -1]],
@pytest.mark.parametrize("indices", ([[ 2, 3, -1, -1],[0, 1, -1, -1]],
[[ 2, 3, 4, 4],[ 0, 1, 4, 4]]))
def test_scatternd_correctness(device, indices):
class NeuralNetScatterND(torch.nn.Module):
@ -685,7 +685,7 @@ def test_scatternd_correctness(device, indices):
rerouted_output = torch.tensor([[0.],[0.],[0.],[0.],[0.]], device=device)
dispatch_mask = torch.tensor(indices, device=device)
expert_output = torch.tensor([[[0.3817],[0.9625],[0.9625],[0.9625]],[[0.3817],[0.9625],[0.9625],[0.9625]]], device=device)
pt_prediction = run_step(pt_model, rerouted_output, dispatch_mask, expert_output)
ort_prediction = run_step(ort_model, rerouted_output, dispatch_mask, expert_output)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-5)
@ -1035,6 +1035,35 @@ def test_gradient_correctness_argmax_unfold():
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
@pytest.mark.parametrize("high", [1, 2, 10])
def test_correctness_argmax_bitwise_or(high):
N, D, H, M = 16, 256, 128, 4
device = 'cuda'
class NeuralNetBitwiseOr(torch.nn.Module):
def __init__(self, high):
super(NeuralNetBitwiseOr, self).__init__()
self.other = torch.randint(0, high, (N, D, H), device=device)
def forward(self, input):
return torch.bitwise_or(self.other, input)
pt_model = NeuralNetBitwiseOr(high).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
def run_step(model, input):
prediction = model(input)
return prediction
for _ in range(10):
# this also tests broadcasting
pt_input = torch.randint(-10, 10, (M, N, D, H), device=device)
ort_input = copy.deepcopy(pt_input)
pt_prediction = run_step(pt_model, pt_input)
ort_prediction = run_step(ort_model, ort_input)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
@pytest.mark.parametrize("offset", [-1, 0, 1])
@pytest.mark.parametrize("dim1, dim2", ([0, 1], [0, 2], [1, 2], [2, 0]))
def test_gradient_correctness_argmax_diagonal(offset, dim1, dim2):
@ -4534,7 +4563,7 @@ def test_sigmoid_grad_opset13():
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = old_opset
assert ortmodule.ONNX_OPSET_VERSION == 13
ortmodule.ONNX_OPSET_VERSION = old_opst_cst
@pytest.mark.parametrize("opset_version", [12, 13, 14])
def test_opset_version_change(opset_version):
device = 'cuda'