mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Export torch.arange to ONNX (#22601)
Summary: Some overlap with https://github.com/pytorch/pytorch/pull/21716 regarding caffe2 nonzero. Will rebase the other one accordingly whichever gets merged first. Pull Request resolved: https://github.com/pytorch/pytorch/pull/22601 Reviewed By: zrphercule Differential Revision: D16224660 Pulled By: houseroad fbshipit-source-id: dbfd1b8776cb626601e0bf83b3fcca291806e653
This commit is contained in:
parent
06d11f0434
commit
eb5137a5d1
7 changed files with 192 additions and 11 deletions
|
|
@ -365,7 +365,8 @@ Caffe2Backend::get_special_operators() const {
|
|||
{"DynamicSlice", &Caffe2Backend::CreateDynamicSlice},
|
||||
{"RandomNormal", &Caffe2Backend::CreateRandomNormal},
|
||||
{"RandomNormalLike", &Caffe2Backend::CreateRandomNormal},
|
||||
{"Where", &Caffe2Backend::CreateWhereOp}};
|
||||
{"Where", &Caffe2Backend::CreateWhereOp},
|
||||
{"NonZero", &Caffe2Backend::CreateNonZeroOp}};
|
||||
return kSpecialOperators;
|
||||
}
|
||||
|
||||
|
|
@ -598,6 +599,30 @@ Caffe2Ops Caffe2Backend::CreateWhereOp(
|
|||
return CommonOnnxNodeToCaffe2Ops(&new_node, ctx);
|
||||
}
|
||||
|
||||
Caffe2Ops Caffe2Backend::CreateNonZeroOp(
|
||||
OnnxNode* onnx_node,
|
||||
const ConversionContext& ctx) {
|
||||
// Native Caffe2 doesn't support NonZero, fallback to ATen.
|
||||
// ATen nonzero is equivalent to Transpose(ONNX::NonZero).
|
||||
const auto& node = onnx_node->node;
|
||||
|
||||
onnx::NodeProto converted;
|
||||
converted.CopyFrom(onnx_node->node);
|
||||
|
||||
auto nonzero_output = dummy_->NewDummyName();
|
||||
converted.set_output(0, nonzero_output);
|
||||
converted.set_op_type("ATen");
|
||||
onnx::AttributeProto* attr = converted.add_attribute();
|
||||
attr->set_name("operator");
|
||||
attr->set_s("nonzero");
|
||||
OnnxNode new_node(converted);
|
||||
auto ret = CommonOnnxNodeToCaffe2Ops(&new_node, ctx);
|
||||
|
||||
auto* c2_transpose = ret.ops.Add();
|
||||
BuildOperator(c2_transpose, "Transpose", {nonzero_output}, {onnx_node->node.output(0)});
|
||||
return ret;
|
||||
}
|
||||
|
||||
Caffe2Ops Caffe2Backend::CreateReciprocal(
|
||||
OnnxNode* onnx_node,
|
||||
const ConversionContext& ctx) {
|
||||
|
|
|
|||
|
|
@ -238,6 +238,8 @@ class CAFFE2_API Caffe2Backend {
|
|||
|
||||
Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx);
|
||||
|
||||
Caffe2Ops CreateNonZeroOp(OnnxNode* onnx_node, const ConversionContext& ctx);
|
||||
|
||||
Caffe2Ops CreateBatchNormalization(
|
||||
OnnxNode* onnx_node,
|
||||
const ConversionContext& ctx);
|
||||
|
|
|
|||
|
|
@ -157,18 +157,15 @@ class TestONNXOpset(TestCase):
|
|||
super(MyModule, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch._dim_arange(x, 1)
|
||||
return x - 1
|
||||
|
||||
module = MyModule()
|
||||
ops_8 = [{"op_name" : "Shape"}, {"op_name" : "Constant"},
|
||||
ops_8 = [{"op_name" : "Constant"},
|
||||
{"op_name" : "Cast", "attributes": [{"name": "to", "i": 7, "type": 2}]},
|
||||
{"op_name" : "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
|
||||
{"op_name" : "Range"}]
|
||||
ops_9 = [{"op_name" : "Shape"}, {"op_name" : "Constant"},
|
||||
{"op_name" : "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]},
|
||||
{"op_name" : "Range"}]
|
||||
{"op_name" : "Sub"}]
|
||||
ops_9 = [{"op_name" : "Constant"}, {"op_name" : "Sub"}]
|
||||
ops = {8 : ops_8, 9 : ops_9}
|
||||
x = torch.ones(5, 6)
|
||||
x = torch.ones(5, 6, dtype=torch.long)
|
||||
check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9])
|
||||
|
||||
def test_slice(self):
|
||||
|
|
|
|||
|
|
@ -1782,6 +1782,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
|||
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
self.run_model_test(CeilModel(), train=False, input=x, batch_size=BATCH_SIZE)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test__dim_arange(self):
|
||||
class DimArange(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
|
|
@ -1790,6 +1791,60 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
|||
x = torch.ones(5, 6)
|
||||
self.run_model_test(DimArange(), train=False, input=x, batch_size=BATCH_SIZE)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_arange_end(self):
|
||||
class ArangeScript(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, a):
|
||||
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
outputs = ArangeScript()(x)
|
||||
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
|
||||
class ArangeModel(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
self.run_model_test(ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_arange_start_end(self):
|
||||
class ArangeScript(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, a):
|
||||
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
outputs = ArangeScript()(x)
|
||||
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
|
||||
class ArangeModel(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
self.run_model_test(ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_arange_start_end_step(self):
|
||||
class ArangeScript(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, a):
|
||||
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
outputs = ArangeScript()(x)
|
||||
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
|
||||
class ArangeModel(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
self.run_model_test(ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE)
|
||||
|
||||
def test_log2(self):
|
||||
class Log2Model(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
|
|
|
|||
|
|
@ -246,6 +246,66 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
y = torch.randn(4, 1, requires_grad=True)
|
||||
self.run_test(model, (x, y))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_arange_end(self):
|
||||
class ArangeScript(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, a):
|
||||
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
outputs = ArangeScript()(x)
|
||||
self.run_test(ArangeScript(), x)
|
||||
|
||||
class ArangeModel(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
self.run_test(ArangeModel(), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_arange_start_end(self):
|
||||
class ArangeScript(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, a):
|
||||
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
outputs = ArangeScript()(x)
|
||||
self.run_test(ArangeScript(), x)
|
||||
|
||||
class ArangeModel(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
self.run_test(ArangeModel(), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_arange_start_end_step(self):
|
||||
class ArangeScript(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, a):
|
||||
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
outputs = ArangeScript()(x)
|
||||
self.run_test(ArangeScript(), x)
|
||||
|
||||
class ArangeModel(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
self.run_test(ArangeModel(), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test__dim_arange(self):
|
||||
class DimArange(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return torch._dim_arange(input, 1)
|
||||
|
||||
x = torch.ones(5, 6)
|
||||
self.run_test(DimArange(), x)
|
||||
|
||||
def test_gt(self):
|
||||
class GreaterModel(torch.nn.Module):
|
||||
def forward(self, input, other):
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ import warnings
|
|||
|
||||
black_listed_operators = [
|
||||
"nonzero", "where", "scatter", "scatter_add", "erf", "sign", "isnan", "gather",
|
||||
"masked_fill"
|
||||
"arange", "masked_fill"
|
||||
]
|
||||
|
||||
for black_listed_op in black_listed_operators:
|
||||
|
|
|
|||
|
|
@ -1486,7 +1486,11 @@ rnn_relu = _one_hidden_rnn('RNN_RELU')
|
|||
def _dim_arange(g, like, dim):
|
||||
like_shape = g.op('Shape', like)
|
||||
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
|
||||
return g.op("_caffe2::Range", stop)
|
||||
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
||||
return g.op("_caffe2::Range", stop)
|
||||
else:
|
||||
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
||||
return arange(g, stop, 4, None, None, None)
|
||||
|
||||
|
||||
def detach(g, input):
|
||||
|
|
@ -1668,6 +1672,44 @@ def logsumexp(g, input, dim, keepdim):
|
|||
return g.op('ReduceLogSumExp', input, axes_i=dim, keepdims_i=keepdim)
|
||||
|
||||
|
||||
def arange(g, *args):
|
||||
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
|
||||
return g.op("ATen", *args, operator_s="arange")
|
||||
|
||||
def _get_arange_dtype(dtype):
|
||||
dtype = sym_help._maybe_get_const(dtype, 'i')
|
||||
if sym_help._is_value(dtype):
|
||||
dtype = 4 # default to int64
|
||||
return dtype
|
||||
|
||||
if len(args) == 5:
|
||||
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
||||
dtype = _get_arange_dtype(args[1])
|
||||
end = g.op("Unsqueeze", args[0], axes_i=[0])
|
||||
arange_tensor = g.op("Squeeze", nonzero(g, ones(g, end, dtype, *(args[2:]))), axes_i=[1])
|
||||
return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype])
|
||||
elif len(args) == 6:
|
||||
# aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
||||
dtype = _get_arange_dtype(args[2])
|
||||
end = g.op("Unsqueeze", args[1], axes_i=[0])
|
||||
start = g.op("Unsqueeze", args[0], axes_i=[0])
|
||||
range_tensor = g.op("Sub", end, start)
|
||||
arange_tensor = g.op("Add", g.op("Squeeze", nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), axes_i=[1]), start)
|
||||
return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype])
|
||||
elif len(args) == 7:
|
||||
# aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
|
||||
dtype = _get_arange_dtype(args[3])
|
||||
step = g.op("Unsqueeze", args[2], axes_i=[0])
|
||||
end = g.op("Unsqueeze", args[1], axes_i=[0])
|
||||
start = g.op("Unsqueeze", args[0], axes_i=[0])
|
||||
range_tensor = g.op("Div", g.op("Sub", end, start), step)
|
||||
arange_tensor = g.op("Squeeze", nonzero(g, ones(g, range_tensor, dtype, *(args[4:]))), axes_i=[1])
|
||||
arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start)
|
||||
return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype])
|
||||
else:
|
||||
raise NotImplementedError("Unknown aten::arange signature taking " + str(len(args)) + " arguments.")
|
||||
|
||||
|
||||
def masked_fill(g, self, mask, value):
|
||||
mask = _cast_Bool(g, mask, False)
|
||||
value = sym_help._maybe_get_scalar(value)
|
||||
|
|
|
|||
Loading…
Reference in a new issue