mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Add AtenOp at:bitwise_or (#9662)
* Add AtenOp at:bitwise_or * Specify overload name for bitwise_or * undo unnecessary import * set output element type to BOOL * Add broadcasting support * Fix test Co-authored-by: Gani Nazirov <ganaziro@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> Co-authored-by: Gani Nazirov <ganaziro@microsoft.com>
This commit is contained in:
parent
ad99dff298
commit
c82160bbd0
3 changed files with 47 additions and 3 deletions
|
|
@ -189,6 +189,7 @@ class SymbolicShapeInference:
|
|||
}
|
||||
self.aten_op_dispatcher_ = {
|
||||
'aten::embedding': self._infer_Gather,
|
||||
'aten::bitwise_or': self._infer_aten_bitwise_or,
|
||||
'aten::diagonal': self._infer_aten_diagonal,
|
||||
'aten::max_pool2d_with_indices': self._infer_aten_pool2d,
|
||||
'aten::multinomial': self._infer_aten_multinomial,
|
||||
|
|
@ -1080,6 +1081,16 @@ class SymbolicShapeInference:
|
|||
helper.make_tensor_value_info(o, vi.type.tensor_type.elem_type,
|
||||
get_shape_from_sympy_shape(sympy_shape)))
|
||||
|
||||
def _infer_aten_bitwise_or(self, node):
|
||||
shape0 = self._get_shape(node, 0)
|
||||
shape1 = self._get_shape(node, 1)
|
||||
new_shape = self._broadcast_shapes(shape0, shape1)
|
||||
t0 = self.known_vi_[node.input[0]]
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(
|
||||
helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type,
|
||||
new_shape))
|
||||
|
||||
def _infer_aten_diagonal(self, node):
|
||||
sympy_shape = self._get_sympy_shape(node, 0)
|
||||
rank = len(sympy_shape)
|
||||
|
|
|
|||
|
|
@ -73,6 +73,10 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
|
|||
output.setType(output_type)
|
||||
return output
|
||||
|
||||
@register_symbolic('bitwise_or')
|
||||
def bitwise_or(g, self, other):
|
||||
return g.op("com.microsoft::ATenOp", self, other,
|
||||
name_s='aten::bitwise_or', overload_name_s='Tensor')
|
||||
|
||||
@register_symbolic('diagonal')
|
||||
def diagonal(g, self, offset, dim1, dim2):
|
||||
|
|
|
|||
|
|
@ -664,7 +664,7 @@ def test_gradient_correctness():
|
|||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
@pytest.mark.parametrize("device", ['cpu', 'cuda'])
|
||||
@pytest.mark.parametrize("indices", ([[ 2, 3, -1, -1],[0, 1, -1, -1]],
|
||||
@pytest.mark.parametrize("indices", ([[ 2, 3, -1, -1],[0, 1, -1, -1]],
|
||||
[[ 2, 3, 4, 4],[ 0, 1, 4, 4]]))
|
||||
def test_scatternd_correctness(device, indices):
|
||||
class NeuralNetScatterND(torch.nn.Module):
|
||||
|
|
@ -685,7 +685,7 @@ def test_scatternd_correctness(device, indices):
|
|||
rerouted_output = torch.tensor([[0.],[0.],[0.],[0.],[0.]], device=device)
|
||||
dispatch_mask = torch.tensor(indices, device=device)
|
||||
expert_output = torch.tensor([[[0.3817],[0.9625],[0.9625],[0.9625]],[[0.3817],[0.9625],[0.9625],[0.9625]]], device=device)
|
||||
|
||||
|
||||
pt_prediction = run_step(pt_model, rerouted_output, dispatch_mask, expert_output)
|
||||
ort_prediction = run_step(ort_model, rerouted_output, dispatch_mask, expert_output)
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-5)
|
||||
|
|
@ -1035,6 +1035,35 @@ def test_gradient_correctness_argmax_unfold():
|
|||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
@pytest.mark.parametrize("high", [1, 2, 10])
|
||||
def test_correctness_argmax_bitwise_or(high):
|
||||
N, D, H, M = 16, 256, 128, 4
|
||||
device = 'cuda'
|
||||
|
||||
class NeuralNetBitwiseOr(torch.nn.Module):
|
||||
def __init__(self, high):
|
||||
super(NeuralNetBitwiseOr, self).__init__()
|
||||
self.other = torch.randint(0, high, (N, D, H), device=device)
|
||||
|
||||
def forward(self, input):
|
||||
return torch.bitwise_or(self.other, input)
|
||||
|
||||
pt_model = NeuralNetBitwiseOr(high).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
def run_step(model, input):
|
||||
prediction = model(input)
|
||||
return prediction
|
||||
|
||||
for _ in range(10):
|
||||
# this also tests broadcasting
|
||||
pt_input = torch.randint(-10, 10, (M, N, D, H), device=device)
|
||||
ort_input = copy.deepcopy(pt_input)
|
||||
pt_prediction = run_step(pt_model, pt_input)
|
||||
ort_prediction = run_step(ort_model, ort_input)
|
||||
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
|
||||
@pytest.mark.parametrize("offset", [-1, 0, 1])
|
||||
@pytest.mark.parametrize("dim1, dim2", ([0, 1], [0, 2], [1, 2], [2, 0]))
|
||||
def test_gradient_correctness_argmax_diagonal(offset, dim1, dim2):
|
||||
|
|
@ -4534,7 +4563,7 @@ def test_sigmoid_grad_opset13():
|
|||
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = old_opset
|
||||
assert ortmodule.ONNX_OPSET_VERSION == 13
|
||||
ortmodule.ONNX_OPSET_VERSION = old_opst_cst
|
||||
|
||||
|
||||
@pytest.mark.parametrize("opset_version", [12, 13, 14])
|
||||
def test_opset_version_change(opset_version):
|
||||
device = 'cuda'
|
||||
|
|
|
|||
Loading…
Reference in a new issue