mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56165 Add implementation for cases when - interleaving happens along dim which consist of dynamic axes Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D27866137 Pulled By: SplitInfinity fbshipit-source-id: 7fef1b2c614f2e24a677b7ca0886bb37bd0ab479
This commit is contained in:
parent
9986b109d2
commit
f804b65d4e
3 changed files with 150 additions and 1 deletions
|
|
@ -3835,6 +3835,48 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
x = torch.tensor([[1, 2], [3, 4]])
|
||||
self.run_test(RepeatsDimsModel2(), (x,))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_dynamic_repeat_interleave(self):
|
||||
class SingleDynamicModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
repeats = torch.tensor(4)
|
||||
return torch.repeat_interleave(x, repeats, dim=1)
|
||||
|
||||
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
|
||||
another_x = torch.tensor([[7, 8], [5, 6]])
|
||||
self.run_test(SingleDynamicModel(), x, test_with_inputs=[another_x],
|
||||
input_names=['input_1'], dynamic_axes={'input_1' : {1 : 'w'}})
|
||||
|
||||
class NegDynamicModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
repeats = torch.tensor(4)
|
||||
return torch.repeat_interleave(x, repeats, dim=-1)
|
||||
|
||||
x = torch.tensor([[1, 2, 4], [3, 4, 7]])
|
||||
another_x = torch.tensor([[7, 8], [5, 6]])
|
||||
self.run_test(NegDynamicModel(), x, test_with_inputs=[another_x],
|
||||
input_names=['input_1'], dynamic_axes={'input_1' : {1 : 'w'}})
|
||||
|
||||
class SingleDynamicModel2(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
repeats = torch.tensor([4])
|
||||
return torch.repeat_interleave(x, repeats, dim=0)
|
||||
|
||||
x = torch.tensor([[1, 2], [3, 4]])
|
||||
another_x = torch.tensor([[7, 8], [5, 6]])
|
||||
self.run_test(SingleDynamicModel2(), x, test_with_inputs=[another_x],
|
||||
input_names=['input_1'], dynamic_axes={'input_1' : {0 : 'h'}})
|
||||
|
||||
class AllDynamicModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
repeats = torch.tensor([4])
|
||||
return torch.repeat_interleave(x, repeats, dim=0)
|
||||
|
||||
x = torch.tensor([[1, 2, 4, 16], [3, 9, 27, 81], [2, 3, 5, 7]])
|
||||
another_x = torch.tensor([[7, 8], [5, 6]])
|
||||
self.run_test(AllDynamicModel(), x, test_with_inputs=[another_x],
|
||||
input_names=['input_1'], dynamic_axes={'input_1' : {0 : 'h', 1 : 'w'}})
|
||||
|
||||
def test_view(self):
|
||||
class ViewModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
|
|
|
|||
|
|
@ -840,3 +840,106 @@ def prim_ConstantChunk(g, self, chunks, dim):
|
|||
res.append(g.op("Slice", self, start, end, axis))
|
||||
start = end
|
||||
return res
|
||||
|
||||
def repeat_interleave(g, self, repeats, dim=None):
|
||||
from torch.onnx.symbolic_opset9 import reshape
|
||||
input = self
|
||||
final_dim = dim
|
||||
# if dim is None flatten
|
||||
# By default, use the flattened input array, and return a flat output array
|
||||
if sym_help._is_none(dim):
|
||||
input = reshape(g, self, g.op("Constant", value_t=torch.tensor([-1])))
|
||||
dim = 0
|
||||
else:
|
||||
dim = sym_help._maybe_get_scalar(dim)
|
||||
|
||||
repeats_dim = sym_help._get_tensor_rank(repeats)
|
||||
repeats_sizes = sym_help._get_tensor_sizes(repeats)
|
||||
input_sizes = sym_help._get_tensor_sizes(input)
|
||||
if repeats_dim is None:
|
||||
raise RuntimeError('Unsupported: ONNX export of repeat_interleave for unknown '
|
||||
'repeats rank.')
|
||||
if repeats_sizes is None:
|
||||
raise RuntimeError('Unsupported: ONNX export of repeat_interleave for unknown '
|
||||
'repeats size.')
|
||||
if input_sizes is None:
|
||||
raise RuntimeError('Unsupported: ONNX export of repeat_interleave for unknown '
|
||||
'input size.')
|
||||
# Handle cases where dim is negative
|
||||
if dim < 0:
|
||||
dim += len(input_sizes)
|
||||
|
||||
output_sizes = input_sizes.copy()
|
||||
perm_i = [0]
|
||||
for idx, input_size in enumerate(input_sizes):
|
||||
perm_i.append(idx + 1)
|
||||
if input_size is None:
|
||||
output_sizes[idx], input_sizes[idx] = 0, -1
|
||||
perm_i[0], perm_i[dim] = perm_i[dim], perm_i[0]
|
||||
|
||||
# Cases when repeats is a single value tensor and dim has unknown input size
|
||||
if (repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1)) and output_sizes[dim] == 0:
|
||||
if not sym_help._is_tensor(repeats):
|
||||
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
|
||||
reps = sym_help._size_helper(g, input, dim)
|
||||
reps = unsqueeze(g, reps, 0)
|
||||
repeats = g.op("Expand", repeats, reps)
|
||||
# There are cases when the repeats are 1-d tensor with multiple repeats, but dim
|
||||
# provided along one of the dynamic axes provided. A simple example would be
|
||||
# input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
|
||||
# Now, repeat interleaving can be performed in pytorch when the value of * matches
|
||||
# with the number of elements in repeat, for example if * -> 2, number of repeats
|
||||
# should be 2 as well.
|
||||
else:
|
||||
return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim)
|
||||
|
||||
reps_like = g.op("ConstantOfShape", g.op("Shape", repeats),
|
||||
value_t=torch.tensor([1], dtype=torch.long))
|
||||
r_splits = split(g, repeats, reps_like, 0)
|
||||
i_splits = split(g, input, reps_like, dim)
|
||||
|
||||
output_sizes[dim], input_sizes[dim] = -1, 1
|
||||
|
||||
# Create a loop to iterate over each value along the dimension
|
||||
# and perform individual interleaving using the repeats tensor
|
||||
# Loop is of the following pattern
|
||||
# input (trip_count, cond)
|
||||
# int trip_count = ...;
|
||||
# bool cond = ...;
|
||||
# for (int i=0; i < trip_count && cond; ++i) {
|
||||
# cond = ...;
|
||||
# }
|
||||
|
||||
# Loop conditions
|
||||
loop_condition = g.op("Constant", value_t=torch.tensor(1))
|
||||
loop_condition = g.op("Cast", loop_condition, to_i=9)
|
||||
loop_len = reps
|
||||
loop = g.op("Loop", loop_len, loop_condition)
|
||||
|
||||
# Loop inputs
|
||||
loop_block = _add_block(loop.node())
|
||||
block_input_iter = _add_input_to_block(loop_block)
|
||||
cond = _add_input_to_block(loop_block)
|
||||
|
||||
r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
|
||||
i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)
|
||||
|
||||
i_split = unsqueeze(loop_block, i_split, dim + 1)
|
||||
r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])),
|
||||
r_split,
|
||||
loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))]
|
||||
r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
|
||||
i_split = expand(loop_block, i_split, r_concat, None)
|
||||
i_split = reshape(loop_block, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes)))
|
||||
|
||||
# Loop outputs
|
||||
cond_out = loop_block.op("Cast", loop_condition, to_i=9)
|
||||
_add_output_to_block(loop_block, cond_out)
|
||||
_add_output_to_block(loop_block, i_split)
|
||||
loop_out = loop.node().output()
|
||||
|
||||
# In this loop, the outputs are scan outputs and are concatenated along
|
||||
# the zero'th dimension (by default). In order to avoid this and concatenate
|
||||
# along the dimension provided, some post-processing is required
|
||||
loop_out = g.op("Transpose", loop_out, perm_i=perm_i)
|
||||
return reshape(g, loop_out, g.op("Constant", value_t=torch.LongTensor(output_sizes)))
|
||||
|
|
|
|||
|
|
@ -2004,13 +2004,17 @@ def repeat_interleave(g, self, repeats, dim=None):
|
|||
if not sym_help._is_tensor(repeats):
|
||||
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
|
||||
if input_sizes[dim] == 0:
|
||||
raise NotImplementedError("Unsupported repeat_interleave along dimension with unknown input size")
|
||||
return sym_help._onnx_opset_unsupported_detailed('repeat_interleave', 9, 11,
|
||||
'Unsupported along dimension with unknown input size')
|
||||
else:
|
||||
reps = input_sizes[dim]
|
||||
repeats = expand(g, repeats, g.op("Constant", value_t=torch.tensor([reps])), None)
|
||||
|
||||
# Cases where repeats is a 1 dim Tensor
|
||||
elif repeats_dim == 1:
|
||||
if input_sizes[dim] == 0:
|
||||
return sym_help._onnx_opset_unsupported_detailed('repeat_interleave', 9, 11,
|
||||
'Unsupported along dimension with unknown input size')
|
||||
assert repeats_sizes[0] == input_sizes[dim], "repeats must have the same size as input along dim"
|
||||
reps = repeats_sizes[0]
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in a new issue