From f804b65d4e822aedfe7814b42dc9b2e1bc6d18d7 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 20 Apr 2021 22:58:06 -0700 Subject: [PATCH] [ONNX] Update repeat_interleave symbolic (#54312) (#56165) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56165 Add implementation for cases when - interleaving happens along dim which consist of dynamic axes Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D27866137 Pulled By: SplitInfinity fbshipit-source-id: 7fef1b2c614f2e24a677b7ca0886bb37bd0ab479 --- test/onnx/test_pytorch_onnx_onnxruntime.py | 42 +++++++++ torch/onnx/symbolic_opset11.py | 103 +++++++++++++++++++++ torch/onnx/symbolic_opset9.py | 6 +- 3 files changed, 150 insertions(+), 1 deletion(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 5992d7e075a..653e7b3d25e 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -3835,6 +3835,48 @@ class TestONNXRuntime(unittest.TestCase): x = torch.tensor([[1, 2], [3, 4]]) self.run_test(RepeatsDimsModel2(), (x,)) + @skipIfUnsupportedMinOpsetVersion(11) + def test_dynamic_repeat_interleave(self): + class SingleDynamicModel(torch.nn.Module): + def forward(self, x): + repeats = torch.tensor(4) + return torch.repeat_interleave(x, repeats, dim=1) + + x = torch.tensor([[1, 2, 4], [3, 4, 7]]) + another_x = torch.tensor([[7, 8], [5, 6]]) + self.run_test(SingleDynamicModel(), x, test_with_inputs=[another_x], + input_names=['input_1'], dynamic_axes={'input_1' : {1 : 'w'}}) + + class NegDynamicModel(torch.nn.Module): + def forward(self, x): + repeats = torch.tensor(4) + return torch.repeat_interleave(x, repeats, dim=-1) + + x = torch.tensor([[1, 2, 4], [3, 4, 7]]) + another_x = torch.tensor([[7, 8], [5, 6]]) + self.run_test(NegDynamicModel(), x, test_with_inputs=[another_x], + input_names=['input_1'], dynamic_axes={'input_1' : {1 : 'w'}}) + + class SingleDynamicModel2(torch.nn.Module): + def forward(self, x): + repeats = torch.tensor([4]) + return torch.repeat_interleave(x, repeats, dim=0) + + x = torch.tensor([[1, 2], [3, 4]]) + another_x = torch.tensor([[7, 8], [5, 6]]) + self.run_test(SingleDynamicModel2(), x, test_with_inputs=[another_x], + input_names=['input_1'], dynamic_axes={'input_1' : {0 : 'h'}}) + + class AllDynamicModel(torch.nn.Module): + def forward(self, x): + repeats = torch.tensor([4]) + return torch.repeat_interleave(x, repeats, dim=0) + + x = torch.tensor([[1, 2, 4, 16], [3, 9, 27, 81], [2, 3, 5, 7]]) + another_x = torch.tensor([[7, 8], [5, 6]]) + self.run_test(AllDynamicModel(), x, test_with_inputs=[another_x], + input_names=['input_1'], dynamic_axes={'input_1' : {0 : 'h', 1 : 'w'}}) + def test_view(self): class ViewModel(torch.nn.Module): def forward(self, input): diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 7596a26b8a0..050b9526f1c 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -840,3 +840,106 @@ def prim_ConstantChunk(g, self, chunks, dim): res.append(g.op("Slice", self, start, end, axis)) start = end return res + +def repeat_interleave(g, self, repeats, dim=None): + from torch.onnx.symbolic_opset9 import reshape + input = self + final_dim = dim + # if dim is None flatten + # By default, use the flattened input array, and return a flat output array + if sym_help._is_none(dim): + input = reshape(g, self, g.op("Constant", value_t=torch.tensor([-1]))) + dim = 0 + else: + dim = sym_help._maybe_get_scalar(dim) + + repeats_dim = sym_help._get_tensor_rank(repeats) + repeats_sizes = sym_help._get_tensor_sizes(repeats) + input_sizes = sym_help._get_tensor_sizes(input) + if repeats_dim is None: + raise RuntimeError('Unsupported: ONNX export of repeat_interleave for unknown ' + 'repeats rank.') + if repeats_sizes is None: + raise RuntimeError('Unsupported: ONNX export of repeat_interleave for unknown ' + 'repeats size.') + if input_sizes is None: + raise RuntimeError('Unsupported: ONNX export of repeat_interleave for unknown ' + 'input size.') + # Handle cases where dim is negative + if dim < 0: + dim += len(input_sizes) + + output_sizes = input_sizes.copy() + perm_i = [0] + for idx, input_size in enumerate(input_sizes): + perm_i.append(idx + 1) + if input_size is None: + output_sizes[idx], input_sizes[idx] = 0, -1 + perm_i[0], perm_i[dim] = perm_i[dim], perm_i[0] + + # Cases when repeats is a single value tensor and dim has unknown input size + if (repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1)) and output_sizes[dim] == 0: + if not sym_help._is_tensor(repeats): + repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) + reps = sym_help._size_helper(g, input, dim) + reps = unsqueeze(g, reps, 0) + repeats = g.op("Expand", repeats, reps) + # There are cases when the repeats are 1-d tensor with multiple repeats, but dim + # provided along one of the dynamic axes provided. A simple example would be + # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2 + # Now, repeat interleaving can be performed in pytorch when the value of * matches + # with the number of elements in repeat, for example if * -> 2, number of repeats + # should be 2 as well. + else: + return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim) + + reps_like = g.op("ConstantOfShape", g.op("Shape", repeats), + value_t=torch.tensor([1], dtype=torch.long)) + r_splits = split(g, repeats, reps_like, 0) + i_splits = split(g, input, reps_like, dim) + + output_sizes[dim], input_sizes[dim] = -1, 1 + + # Create a loop to iterate over each value along the dimension + # and perform individual interleaving using the repeats tensor + # Loop is of the following pattern + # input (trip_count, cond) + # int trip_count = ...; + # bool cond = ...; + # for (int i=0; i < trip_count && cond; ++i) { + # cond = ...; + # } + + # Loop conditions + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op("Cast", loop_condition, to_i=9) + loop_len = reps + loop = g.op("Loop", loop_len, loop_condition) + + # Loop inputs + loop_block = _add_block(loop.node()) + block_input_iter = _add_input_to_block(loop_block) + cond = _add_input_to_block(loop_block) + + r_split = loop_block.op("SequenceAt", r_splits, block_input_iter) + i_split = loop_block.op("SequenceAt", i_splits, block_input_iter) + + i_split = unsqueeze(loop_block, i_split, dim + 1) + r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])), + r_split, + loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))] + r_concat = loop_block.op("Concat", *r_concat, axis_i=0) + i_split = expand(loop_block, i_split, r_concat, None) + i_split = reshape(loop_block, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes))) + + # Loop outputs + cond_out = loop_block.op("Cast", loop_condition, to_i=9) + _add_output_to_block(loop_block, cond_out) + _add_output_to_block(loop_block, i_split) + loop_out = loop.node().output() + + # In this loop, the outputs are scan outputs and are concatenated along + # the zero'th dimension (by default). In order to avoid this and concatenate + # along the dimension provided, some post-processing is required + loop_out = g.op("Transpose", loop_out, perm_i=perm_i) + return reshape(g, loop_out, g.op("Constant", value_t=torch.LongTensor(output_sizes))) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 2145cf065b1..9bd9412cdc5 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -2004,13 +2004,17 @@ def repeat_interleave(g, self, repeats, dim=None): if not sym_help._is_tensor(repeats): repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) if input_sizes[dim] == 0: - raise NotImplementedError("Unsupported repeat_interleave along dimension with unknown input size") + return sym_help._onnx_opset_unsupported_detailed('repeat_interleave', 9, 11, + 'Unsupported along dimension with unknown input size') else: reps = input_sizes[dim] repeats = expand(g, repeats, g.op("Constant", value_t=torch.tensor([reps])), None) # Cases where repeats is a 1 dim Tensor elif repeats_dim == 1: + if input_sizes[dim] == 0: + return sym_help._onnx_opset_unsupported_detailed('repeat_interleave', 9, 11, + 'Unsupported along dimension with unknown input size') assert repeats_sizes[0] == input_sizes[dim], "repeats must have the same size as input along dim" reps = repeats_sizes[0] else: