diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc index 54ade99b5e5..eb692000eca 100644 --- a/caffe2/onnx/backend.cc +++ b/caffe2/onnx/backend.cc @@ -365,7 +365,8 @@ Caffe2Backend::get_special_operators() const { {"DynamicSlice", &Caffe2Backend::CreateDynamicSlice}, {"RandomNormal", &Caffe2Backend::CreateRandomNormal}, {"RandomNormalLike", &Caffe2Backend::CreateRandomNormal}, - {"Where", &Caffe2Backend::CreateWhereOp}}; + {"Where", &Caffe2Backend::CreateWhereOp}, + {"NonZero", &Caffe2Backend::CreateNonZeroOp}}; return kSpecialOperators; } @@ -598,6 +599,30 @@ Caffe2Ops Caffe2Backend::CreateWhereOp( return CommonOnnxNodeToCaffe2Ops(&new_node, ctx); } +Caffe2Ops Caffe2Backend::CreateNonZeroOp( + OnnxNode* onnx_node, + const ConversionContext& ctx) { + // Native Caffe2 doesn't support NonZero, fallback to ATen. + // ATen nonzero is equivalent to Transpose(ONNX::NonZero). + const auto& node = onnx_node->node; + + onnx::NodeProto converted; + converted.CopyFrom(onnx_node->node); + + auto nonzero_output = dummy_->NewDummyName(); + converted.set_output(0, nonzero_output); + converted.set_op_type("ATen"); + onnx::AttributeProto* attr = converted.add_attribute(); + attr->set_name("operator"); + attr->set_s("nonzero"); + OnnxNode new_node(converted); + auto ret = CommonOnnxNodeToCaffe2Ops(&new_node, ctx); + + auto* c2_transpose = ret.ops.Add(); + BuildOperator(c2_transpose, "Transpose", {nonzero_output}, {onnx_node->node.output(0)}); + return ret; +} + Caffe2Ops Caffe2Backend::CreateReciprocal( OnnxNode* onnx_node, const ConversionContext& ctx) { diff --git a/caffe2/onnx/backend.h b/caffe2/onnx/backend.h index 8ee33ef2cab..dd6124b76e2 100644 --- a/caffe2/onnx/backend.h +++ b/caffe2/onnx/backend.h @@ -238,6 +238,8 @@ class CAFFE2_API Caffe2Backend { Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx); + Caffe2Ops CreateNonZeroOp(OnnxNode* onnx_node, const ConversionContext& ctx); + Caffe2Ops CreateBatchNormalization( OnnxNode* onnx_node, const ConversionContext& ctx); diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py index a0f665e48ab..8caac7c3a67 100644 --- a/test/onnx/test_onnx_opset.py +++ b/test/onnx/test_onnx_opset.py @@ -157,18 +157,15 @@ class TestONNXOpset(TestCase): super(MyModule, self).__init__() def forward(self, x): - return torch._dim_arange(x, 1) + return x - 1 module = MyModule() - ops_8 = [{"op_name" : "Shape"}, {"op_name" : "Constant"}, + ops_8 = [{"op_name" : "Constant"}, {"op_name" : "Cast", "attributes": [{"name": "to", "i": 7, "type": 2}]}, - {"op_name" : "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]}, - {"op_name" : "Range"}] - ops_9 = [{"op_name" : "Shape"}, {"op_name" : "Constant"}, - {"op_name" : "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]}, - {"op_name" : "Range"}] + {"op_name" : "Sub"}] + ops_9 = [{"op_name" : "Constant"}, {"op_name" : "Sub"}] ops = {8 : ops_8, 9 : ops_9} - x = torch.ones(5, 6) + x = torch.ones(5, 6, dtype=torch.long) check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9]) def test_slice(self): diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index 49445cbf864..39d9dbb62e9 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -1782,6 +1782,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase): x = torch.randn(1, 2, 3, 4, requires_grad=True) self.run_model_test(CeilModel(), train=False, input=x, batch_size=BATCH_SIZE) + @skipIfUnsupportedMinOpsetVersion(9) def test__dim_arange(self): class DimArange(torch.nn.Module): def forward(self, input): @@ -1790,6 +1791,60 @@ class TestCaffe2Backend_opset9(unittest.TestCase): x = torch.ones(5, 6) self.run_model_test(DimArange(), train=False, input=x, batch_size=BATCH_SIZE) + @skipIfUnsupportedMinOpsetVersion(9) + def test_arange_end(self): + class ArangeScript(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, a): + return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a + + x = torch.randn(3, 4, requires_grad=True) + outputs = ArangeScript()(x) + self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE, + example_outputs=(outputs,)) + + class ArangeModel(torch.nn.Module): + def forward(self, a): + return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a + + self.run_model_test(ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE) + + @skipIfUnsupportedMinOpsetVersion(9) + def test_arange_start_end(self): + class ArangeScript(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, a): + return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a + + x = torch.randn(3, 4, requires_grad=True) + outputs = ArangeScript()(x) + self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE, + example_outputs=(outputs,)) + + class ArangeModel(torch.nn.Module): + def forward(self, a): + return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a + + self.run_model_test(ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE) + + @skipIfUnsupportedMinOpsetVersion(9) + def test_arange_start_end_step(self): + class ArangeScript(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, a): + return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a + + x = torch.randn(3, 4, requires_grad=True) + outputs = ArangeScript()(x) + self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE, + example_outputs=(outputs,)) + + class ArangeModel(torch.nn.Module): + def forward(self, a): + return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a + + self.run_model_test(ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE) + def test_log2(self): class Log2Model(torch.nn.Module): def forward(self, input): diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 81203e0e466..d6542dd0a2d 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -246,6 +246,66 @@ class TestONNXRuntime(unittest.TestCase): y = torch.randn(4, 1, requires_grad=True) self.run_test(model, (x, y)) + @skipIfUnsupportedMinOpsetVersion(9) + def test_arange_end(self): + class ArangeScript(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, a): + return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a + + x = torch.randn(3, 4, requires_grad=True) + outputs = ArangeScript()(x) + self.run_test(ArangeScript(), x) + + class ArangeModel(torch.nn.Module): + def forward(self, a): + return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a + + self.run_test(ArangeModel(), x) + + @skipIfUnsupportedMinOpsetVersion(9) + def test_arange_start_end(self): + class ArangeScript(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, a): + return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a + + x = torch.randn(3, 4, requires_grad=True) + outputs = ArangeScript()(x) + self.run_test(ArangeScript(), x) + + class ArangeModel(torch.nn.Module): + def forward(self, a): + return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a + + self.run_test(ArangeModel(), x) + + @skipIfUnsupportedMinOpsetVersion(9) + def test_arange_start_end_step(self): + class ArangeScript(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, a): + return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a + + x = torch.randn(3, 4, requires_grad=True) + outputs = ArangeScript()(x) + self.run_test(ArangeScript(), x) + + class ArangeModel(torch.nn.Module): + def forward(self, a): + return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a + + self.run_test(ArangeModel(), x) + + @skipIfUnsupportedMinOpsetVersion(9) + def test__dim_arange(self): + class DimArange(torch.nn.Module): + def forward(self, input): + return torch._dim_arange(input, 1) + + x = torch.ones(5, 6) + self.run_test(DimArange(), x) + def test_gt(self): class GreaterModel(torch.nn.Module): def forward(self, input, other): diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index 3784506ff23..351272b879b 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -39,7 +39,7 @@ import warnings black_listed_operators = [ "nonzero", "where", "scatter", "scatter_add", "erf", "sign", "isnan", "gather", - "masked_fill" + "arange", "masked_fill" ] for black_listed_op in black_listed_operators: diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index a7e68fe510e..84deae28831 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -1486,7 +1486,11 @@ rnn_relu = _one_hidden_rnn('RNN_RELU') def _dim_arange(g, like, dim): like_shape = g.op('Shape', like) stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0) - return g.op("_caffe2::Range", stop) + if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: + return g.op("_caffe2::Range", stop) + else: + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + return arange(g, stop, 4, None, None, None) def detach(g, input): @@ -1668,6 +1672,44 @@ def logsumexp(g, input, dim, keepdim): return g.op('ReduceLogSumExp', input, axes_i=dim, keepdims_i=keepdim) +def arange(g, *args): + if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: + return g.op("ATen", *args, operator_s="arange") + + def _get_arange_dtype(dtype): + dtype = sym_help._maybe_get_const(dtype, 'i') + if sym_help._is_value(dtype): + dtype = 4 # default to int64 + return dtype + + if len(args) == 5: + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[1]) + end = g.op("Unsqueeze", args[0], axes_i=[0]) + arange_tensor = g.op("Squeeze", nonzero(g, ones(g, end, dtype, *(args[2:]))), axes_i=[1]) + return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype]) + elif len(args) == 6: + # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[2]) + end = g.op("Unsqueeze", args[1], axes_i=[0]) + start = g.op("Unsqueeze", args[0], axes_i=[0]) + range_tensor = g.op("Sub", end, start) + arange_tensor = g.op("Add", g.op("Squeeze", nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), axes_i=[1]), start) + return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype]) + elif len(args) == 7: + # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[3]) + step = g.op("Unsqueeze", args[2], axes_i=[0]) + end = g.op("Unsqueeze", args[1], axes_i=[0]) + start = g.op("Unsqueeze", args[0], axes_i=[0]) + range_tensor = g.op("Div", g.op("Sub", end, start), step) + arange_tensor = g.op("Squeeze", nonzero(g, ones(g, range_tensor, dtype, *(args[4:]))), axes_i=[1]) + arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start) + return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype]) + else: + raise NotImplementedError("Unknown aten::arange signature taking " + str(len(args)) + " arguments.") + + def masked_fill(g, self, mask, value): mask = _cast_Bool(g, mask, False) value = sym_help._maybe_get_scalar(value)