diff --git a/caffe2/onnx/onnx_exporter.cc b/caffe2/onnx/onnx_exporter.cc index 8857898ac44..b64a91254f4 100644 --- a/caffe2/onnx/onnx_exporter.cc +++ b/caffe2/onnx/onnx_exporter.cc @@ -298,6 +298,7 @@ OnnxExporter::get_special_operators() const { {"AveragePool", &OnnxExporter::CreateConvPoolNodes}, {"FC", &OnnxExporter::CreateGemmNodes}, {"Concat", &OnnxExporter::CreateConcatNodes}, + {"MergeDim", &OnnxExporter::CreateMergeDimNodes}, {"LRN", &OnnxExporter::CreateLrnNodes}, {"Reshape", &OnnxExporter::CreateReshapeNodes}, {"Slice", &OnnxExporter::CreateSliceNodes}, @@ -746,6 +747,40 @@ ConvertedResult OnnxExporter::CreateConcatNodes( return result; } +ConvertedResult OnnxExporter::CreateMergeDimNodes( + const caffe2::OperatorDef& def, + const std::unordered_map& shapes) { + const auto& x = def.input(0); + const auto& y = def.output(0); + + ConvertedResult result; + auto& nodes = result.first; + auto& const_tensors = result.second; + + { + const auto ndim = shapes.at(x).dims().size(); + CAFFE_ENFORCE_GE(ndim, 2, "No enough dims to merge."); + std::vector dims(ndim); + dims[0] = 1; + dims[1] = -1; + const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, dims)); + } + + const auto reshaped = dummy_->NewDummyName(); + nodes.emplace_back(MakeNode("Reshape", + { x, const_tensors.back().name() }, + { reshaped })); + + nodes.emplace_back(MakeNode("Squeeze", + { reshaped }, + { y }, + std::vector{ + MakeAttribute("axes", std::vector{ 0 }), + })); + + return result; +} + ConvertedResult OnnxExporter::CreateChannelShuffleNodes( const caffe2::OperatorDef& def, const std::unordered_map& shapes) { diff --git a/caffe2/onnx/onnx_exporter.h b/caffe2/onnx/onnx_exporter.h index 30a52337518..f7f1643e87a 100644 --- a/caffe2/onnx/onnx_exporter.h +++ b/caffe2/onnx/onnx_exporter.h @@ -97,6 +97,10 @@ class CAFFE2_API OnnxExporter { const caffe2::OperatorDef& def, const std::unordered_map& shapes); + ConvertedResult CreateMergeDimNodes( + const caffe2::OperatorDef& def, + const std::unordered_map& shapes); + ConvertedResult CreateLrnNodes( const caffe2::OperatorDef& def, const std::unordered_map& shapes); diff --git a/caffe2/operators/prepend_dim_op.cc b/caffe2/operators/prepend_dim_op.cc index 2796fec3672..9cdc228c2b8 100644 --- a/caffe2/operators/prepend_dim_op.cc +++ b/caffe2/operators/prepend_dim_op.cc @@ -25,7 +25,8 @@ OPERATOR_SCHEMA(MergeDim) Merge first two dimensions in a single dimension with size dim(0) * dim(1). )DOC") .Input(0, "data", "An input tensor.") - .Output(0, "reshaped", "Reshaped tensor."); + .Output(0, "reshaped", "Reshaped tensor.") + .InheritOnnxSchema("Reshape"); class GetPrependDimGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; diff --git a/caffe2/python/onnx/tests/c2_ref_test.py b/caffe2/python/onnx/tests/c2_ref_test.py index 629cc762caf..df4df723f8a 100644 --- a/caffe2/python/onnx/tests/c2_ref_test.py +++ b/caffe2/python/onnx/tests/c2_ref_test.py @@ -474,6 +474,33 @@ class TestCaffe2Basic(DownloadingTestCase): op_names.append(op.type) self.assertEqual(op_names, ['Scale', 'Scale', 'MatMul', 'Add']) + def test_mergedim(self): + X = np.random.randn(2, 3, 1, 5).astype(np.float32) + + predict_net = caffe2_pb2.NetDef() + predict_net.name = 'test-mergedim-net' + predict_net.external_input[:] = ['X'] + predict_net.external_output[:] = ['Y'] + predict_net.op.extend([ + core.CreateOperator( + 'MergeDim', + inputs=['X'], + outputs=['Y'], + ), + ]) + ws, c2_outputs = c2_native_run_net( + init_net=None, + predict_net=predict_net, + inputs=[X]) + + onnx_model = c2_onnx.caffe2_net_to_onnx_model( + predict_net=predict_net, + value_info={ + 'X': (onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[X.dtype], X.shape), + }) + onnx_outputs = c2.run_model(onnx_model, inputs=[X]) + self.assertSameOutputs(c2_outputs, onnx_outputs) + def test_tensor_filling_ops(self): for dtype in [ onnx.TensorProto.FLOAT,