Export torch.arange to ONNX (#22601)

Summary:
Some overlap with https://github.com/pytorch/pytorch/pull/21716 regarding caffe2 nonzero. Will rebase the other one accordingly whichever gets merged first.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22601

Reviewed By: zrphercule

Differential Revision: D16224660

Pulled By: houseroad

fbshipit-source-id: dbfd1b8776cb626601e0bf83b3fcca291806e653
This commit is contained in:
BowenBao 2019-07-22 20:20:03 -07:00 committed by Facebook Github Bot
parent 06d11f0434
commit eb5137a5d1
7 changed files with 192 additions and 11 deletions

View file

@ -365,7 +365,8 @@ Caffe2Backend::get_special_operators() const {
{"DynamicSlice", &Caffe2Backend::CreateDynamicSlice},
{"RandomNormal", &Caffe2Backend::CreateRandomNormal},
{"RandomNormalLike", &Caffe2Backend::CreateRandomNormal},
{"Where", &Caffe2Backend::CreateWhereOp}};
{"Where", &Caffe2Backend::CreateWhereOp},
{"NonZero", &Caffe2Backend::CreateNonZeroOp}};
return kSpecialOperators;
}
@ -598,6 +599,30 @@ Caffe2Ops Caffe2Backend::CreateWhereOp(
return CommonOnnxNodeToCaffe2Ops(&new_node, ctx);
}
Caffe2Ops Caffe2Backend::CreateNonZeroOp(
OnnxNode* onnx_node,
const ConversionContext& ctx) {
// Native Caffe2 doesn't support NonZero, fallback to ATen.
// ATen nonzero is equivalent to Transpose(ONNX::NonZero).
const auto& node = onnx_node->node;
onnx::NodeProto converted;
converted.CopyFrom(onnx_node->node);
auto nonzero_output = dummy_->NewDummyName();
converted.set_output(0, nonzero_output);
converted.set_op_type("ATen");
onnx::AttributeProto* attr = converted.add_attribute();
attr->set_name("operator");
attr->set_s("nonzero");
OnnxNode new_node(converted);
auto ret = CommonOnnxNodeToCaffe2Ops(&new_node, ctx);
auto* c2_transpose = ret.ops.Add();
BuildOperator(c2_transpose, "Transpose", {nonzero_output}, {onnx_node->node.output(0)});
return ret;
}
Caffe2Ops Caffe2Backend::CreateReciprocal(
OnnxNode* onnx_node,
const ConversionContext& ctx) {

View file

@ -238,6 +238,8 @@ class CAFFE2_API Caffe2Backend {
Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx);
Caffe2Ops CreateNonZeroOp(OnnxNode* onnx_node, const ConversionContext& ctx);
Caffe2Ops CreateBatchNormalization(
OnnxNode* onnx_node,
const ConversionContext& ctx);

View file

@ -157,18 +157,15 @@ class TestONNXOpset(TestCase):
super(MyModule, self).__init__()
def forward(self, x):
return torch._dim_arange(x, 1)
return x - 1
module = MyModule()
ops_8 = [{"op_name" : "Shape"}, {"op_name" : "Constant"},
ops_8 = [{"op_name" : "Constant"},
{"op_name" : "Cast", "attributes": [{"name": "to", "i": 7, "type": 2}]},
{"op_name" : "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
{"op_name" : "Range"}]
ops_9 = [{"op_name" : "Shape"}, {"op_name" : "Constant"},
{"op_name" : "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
{"op_name" : "Range"}]
{"op_name" : "Sub"}]
ops_9 = [{"op_name" : "Constant"}, {"op_name" : "Sub"}]
ops = {8 : ops_8, 9 : ops_9}
x = torch.ones(5, 6)
x = torch.ones(5, 6, dtype=torch.long)
check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
def test_slice(self):

View file

@ -1782,6 +1782,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_model_test(CeilModel(), train=False, input=x, batch_size=BATCH_SIZE)
@skipIfUnsupportedMinOpsetVersion(9)
def test__dim_arange(self):
class DimArange(torch.nn.Module):
def forward(self, input):
@ -1790,6 +1791,60 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
x = torch.ones(5, 6)
self.run_model_test(DimArange(), train=False, input=x, batch_size=BATCH_SIZE)
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_end(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
self.run_model_test(ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE)
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_start_end(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
self.run_model_test(ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE)
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_start_end_step(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
self.run_model_test(ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE)
def test_log2(self):
class Log2Model(torch.nn.Module):
def forward(self, input):

View file

@ -246,6 +246,66 @@ class TestONNXRuntime(unittest.TestCase):
y = torch.randn(4, 1, requires_grad=True)
self.run_test(model, (x, y))
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_end(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_test(ArangeScript(), x)
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
self.run_test(ArangeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_start_end(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_test(ArangeScript(), x)
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
self.run_test(ArangeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_arange_start_end_step(self):
class ArangeScript(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, a):
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_test(ArangeScript(), x)
class ArangeModel(torch.nn.Module):
def forward(self, a):
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
self.run_test(ArangeModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test__dim_arange(self):
class DimArange(torch.nn.Module):
def forward(self, input):
return torch._dim_arange(input, 1)
x = torch.ones(5, 6)
self.run_test(DimArange(), x)
def test_gt(self):
class GreaterModel(torch.nn.Module):
def forward(self, input, other):

View file

@ -39,7 +39,7 @@ import warnings
black_listed_operators = [
"nonzero", "where", "scatter", "scatter_add", "erf", "sign", "isnan", "gather",
"masked_fill"
"arange", "masked_fill"
]
for black_listed_op in black_listed_operators:

View file

@ -1486,7 +1486,11 @@ rnn_relu = _one_hidden_rnn('RNN_RELU')
def _dim_arange(g, like, dim):
like_shape = g.op('Shape', like)
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
return g.op("_caffe2::Range", stop)
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("_caffe2::Range", stop)
else:
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
return arange(g, stop, 4, None, None, None)
def detach(g, input):
@ -1668,6 +1672,44 @@ def logsumexp(g, input, dim, keepdim):
return g.op('ReduceLogSumExp', input, axes_i=dim, keepdims_i=keepdim)
def arange(g, *args):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", *args, operator_s="arange")
def _get_arange_dtype(dtype):
dtype = sym_help._maybe_get_const(dtype, 'i')
if sym_help._is_value(dtype):
dtype = 4 # default to int64
return dtype
if len(args) == 5:
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[1])
end = g.op("Unsqueeze", args[0], axes_i=[0])
arange_tensor = g.op("Squeeze", nonzero(g, ones(g, end, dtype, *(args[2:]))), axes_i=[1])
return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype])
elif len(args) == 6:
# aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[2])
end = g.op("Unsqueeze", args[1], axes_i=[0])
start = g.op("Unsqueeze", args[0], axes_i=[0])
range_tensor = g.op("Sub", end, start)
arange_tensor = g.op("Add", g.op("Squeeze", nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), axes_i=[1]), start)
return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype])
elif len(args) == 7:
# aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[3])
step = g.op("Unsqueeze", args[2], axes_i=[0])
end = g.op("Unsqueeze", args[1], axes_i=[0])
start = g.op("Unsqueeze", args[0], axes_i=[0])
range_tensor = g.op("Div", g.op("Sub", end, start), step)
arange_tensor = g.op("Squeeze", nonzero(g, ones(g, range_tensor, dtype, *(args[4:]))), axes_i=[1])
arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start)
return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype])
else:
raise NotImplementedError("Unknown aten::arange signature taking " + str(len(args)) + " arguments.")
def masked_fill(g, self, mask, value):
mask = _cast_Bool(g, mask, False)
value = sym_help._maybe_get_scalar(value)