From bc6ef809bb81046e6922e6eef496fce567dc4ecd Mon Sep 17 00:00:00 2001 From: Tracy Sharpe <42477615+tracysh@users.noreply.github.com> Date: Thu, 8 Apr 2021 09:57:23 -0700 Subject: [PATCH] NCHWc: avoid buffer reordering around Add nodes (#7279) Use Reshape to handle more NCHWc Add cases without ReorderInput/ReorderOutput. --- .../core/optimizer/nchwc_transformer.cc | 208 +++++++++++++----- .../test/optimizer/nchwc_optimizer_test.cc | 36 ++- 2 files changed, 189 insertions(+), 55 deletions(-) diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 8017d39864..fe37e6579f 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -91,7 +91,7 @@ class NchwcTransformerImpl { const int64_t channels_; // Stores the proto shape for the NCHWc output. - NchwcArgument::Shape shape_; + const NchwcArgument::Shape shape_; NchwcArgument(Node& output_node, NodeArg* output_nchwc_arg, size_t original_uses, size_t channels, const NchwcArgument::Shape& shape) : output_node_(output_node), @@ -112,6 +112,7 @@ class NchwcTransformerImpl { const NchwcArgument::Shape& input_shape, NchwcArgument::Shape& output_shape, const ONNX_NAMESPACE::TensorProto* filter_shape); + Node& InsertReshape(NodeArg* input_arg, NodeArg* output_arg, int64_t channels, bool split_channels); void TransformConv(Node& node); void TransformPool(Node& node); @@ -143,6 +144,11 @@ class NchwcTransformerImpl { // Stores a mapping of NodeArg biases that have already been aligned to the // NCHWc block size, so multiple nodes can share the NCHWc biases. std::unordered_map aligned_biases_; + + // Stores a mapping of shape initializers for use by Reshape when splitting + // or unsplitting the channels dimension of a tensor. + std::unordered_map reshape_split_; + std::unordered_map reshape_unsplit_; }; size_t NchwcTransformerImpl::RemoveOutputEdges(Node& node) { @@ -510,11 +516,55 @@ void NchwcTransformerImpl::TransformPool(Node& node) { removed_nodes_.push_front(node.Index()); } -// The existing Add/Sum operator implementations can be used with tensors -// in NCHWc format if the tensor shapes are exactly the same (elementwise -// add). +Node& NchwcTransformerImpl::InsertReshape(NodeArg* input_arg, + NodeArg* output_arg, + int64_t channels, + bool split_channels) { + const int64_t nchwc_block_size = static_cast(MlasNchwcGetBlockSize()); + const int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1); + + // Reuse the shape initializer across reshapes for the same channel configuration. + auto& shape_arg_map = split_channels ? reshape_split_ : reshape_unsplit_; + NodeArg* shape_arg = shape_arg_map[nchwc_channels]; + if (shape_arg == nullptr) { + ONNX_NAMESPACE::TensorProto shape_tensor_proto; + shape_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + shape_tensor_proto.set_name(graph_.GenerateNodeArgName("Reshape")); + // Passthrough the batch dimension. + shape_tensor_proto.add_int64_data(0); + if (split_channels) { + shape_tensor_proto.add_int64_data(nchwc_channels / nchwc_block_size); + } else { + shape_tensor_proto.add_int64_data(nchwc_channels); + } + // Passthrough the spatial dimensions. + for (int i = 0; i < kNchwcSpatialDims; i++) { + shape_tensor_proto.add_int64_data(0); + } + if (split_channels) { + shape_tensor_proto.add_int64_data(nchwc_block_size); + shape_tensor_proto.add_dims(kNchwcDims + 1); + } else { + shape_tensor_proto.add_dims(kNchwcDims); + } + + shape_arg = &graph_utils::AddInitializer(graph_, shape_tensor_proto); + shape_arg_map[nchwc_channels] = shape_arg; + } + + Node& reshape_node = graph_.AddNode(graph_.GenerateNodeName("Reshape"), + "Reshape", + "Reshape", + {input_arg, shape_arg}, + {output_arg}); + reshape_node.SetExecutionProviderType(kCpuExecutionProvider); + + return reshape_node; +} + void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) { auto& input_defs = node.MutableInputDefs(); + auto& output_defs = node.MutableOutputDefs(); // Verify that all of the inputs to this operator are from NCHWc outputs. std::vector nchwc_inputs; @@ -528,74 +578,124 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) { nchwc_inputs.push_back(it->second.get()); } - // Test if all of the NCHWc inputs have a compatible shape. auto* nchwc_input_0 = nchwc_inputs[0]; - auto* nchwc_input_0_shape = input_defs[0]->Shape(); + const int64_t channels = nchwc_inputs[0]->channels_; + + // Test if all of the NCHWc inputs have an equal shape. + bool all_shapes_match = true; + auto* input_0_shape = input_defs[0]->Shape(); for (size_t n = 1; n < input_defs_count; n++) { auto* nchwc_input_n = nchwc_inputs[n]; + // Require that all inputs have the same logical number of channels. + if (nchwc_input_n->channels_ != channels) { + return; + } for (int i = 0; i < kNchwcDims; i++) { // Test if this dimension is derived from the same NodeArg. if (!nchwc_input_0->shape_.IsDimEqual(nchwc_input_n->shape_, i)) { // Check if ONNX shape inferencing has computed a precise dimension value. - auto* nchwc_input_n_shape = input_defs[n]->Shape(); - if ((nchwc_input_0_shape == nullptr) || (nchwc_input_n_shape == nullptr)) { - return; - } - auto& nchwc_input_0_dim = nchwc_input_0_shape->dim(i); - auto& nchwc_input_n_dim = nchwc_input_n_shape->dim(i); - if (!utils::HasDimValue(nchwc_input_0_dim) || - !utils::HasDimValue(nchwc_input_n_dim) || - (nchwc_input_0_dim.dim_value() <= 0) || - (nchwc_input_0_dim.dim_value() != nchwc_input_n_dim.dim_value())) { - return; + auto* input_n_shape = input_defs[n]->Shape(); + if ((input_0_shape == nullptr) || (input_n_shape == nullptr)) { + all_shapes_match = false; + } else { + auto& input_0_dim = input_0_shape->dim(i); + auto& input_n_dim = input_n_shape->dim(i); + if (!utils::HasDimValue(input_0_dim) || + !utils::HasDimValue(input_n_dim) || + (input_0_dim.dim_value() <= 0) || + (input_0_dim.dim_value() != input_n_dim.dim_value())) { + if (!utils::HasDimParam(input_0_dim) || + !utils::HasDimParam(input_n_dim) || + (input_0_dim.dim_param() != input_n_dim.dim_param())) { + all_shapes_match = false; + } + } } } } } - // Update the node to directly use the NCHWc inputs directly and decrement - // the original use counts of the NCHWc inputs. - for (size_t n = 0; n < input_defs_count; n++) { - input_defs[n] = nchwc_inputs[n]->nchwc_arg_; - nchwc_inputs[n]->remaining_original_uses_--; - } + if (all_shapes_match) { + // Update the node to directly use the NCHWc inputs directly and decrement + // the original use counts of the NCHWc inputs. + for (size_t n = 0; n < input_defs_count; n++) { + input_defs[n] = nchwc_inputs[n]->nchwc_arg_; + nchwc_inputs[n]->remaining_original_uses_--; + } - // If one of the inputs to the Add/Sum node is a NCHWc convolution, then - // attempt to fuse the addition into the convolution itself. - if (add_node && input_defs_count == 2) { - for (size_t n = 0; n < 2; n++) { - auto* nchwc_input_n = nchwc_inputs[n]; - auto& nchwc_node = nchwc_input_n->output_node_; - auto& nchwc_input_defs = nchwc_node.MutableInputDefs(); - auto& nchwc_input_args_count = nchwc_node.MutableInputArgsCount(); - size_t nchwc_input_defs_count = nchwc_input_defs.size(); - // Check if this is a single use NCHWc convolution that hasn't already - // been fused with another Add/Sum node. The Add/Sum can also only be - // fused if the convolution isn't itself fused with an activation. - if ((nchwc_node.OpType() == "Conv") && (nchwc_node.Domain() == kMSNchwcDomain) && - (nchwc_input_defs_count < 4) && (nchwc_input_args_count.size() < 4) && - (nchwc_input_n->starting_original_uses_ == 1) && - (graph_utils::GetNodeAttribute(nchwc_node, "activation") == nullptr)) { - // Feed the output of the other NCHWc node into the selected convolution - // node. - nchwc_input_defs.resize(4); - nchwc_input_args_count.resize(4); - if (nchwc_input_defs_count < 3) { - // The optional bias parameter is empty so set to an empty string. - nchwc_input_defs[2] = &graph_.GetOrCreateNodeArg("", nullptr); - nchwc_input_args_count[2] = 1; + // If one of the inputs to the Add/Sum node is a NCHWc convolution, then + // attempt to fuse the addition into the convolution itself. + if (add_node && input_defs_count == 2) { + for (size_t n = 0; n < 2; n++) { + auto* nchwc_input_n = nchwc_inputs[n]; + auto& nchwc_node = nchwc_input_n->output_node_; + auto& nchwc_input_defs = nchwc_node.MutableInputDefs(); + auto& nchwc_input_args_count = nchwc_node.MutableInputArgsCount(); + size_t nchwc_input_defs_count = nchwc_input_defs.size(); + // Check if this is a single use NCHWc convolution that hasn't already + // been fused with another Add/Sum node. The Add/Sum can also only be + // fused if the convolution isn't itself fused with an activation. + if ((nchwc_node.OpType() == "Conv") && (nchwc_node.Domain() == kMSNchwcDomain) && + (nchwc_input_defs_count < 4) && (nchwc_input_args_count.size() < 4) && + (nchwc_input_n->starting_original_uses_ == 1) && + (graph_utils::GetNodeAttribute(nchwc_node, "activation") == nullptr)) { + // Feed the output of the other NCHWc node into the selected convolution + // node. + nchwc_input_defs.resize(4); + nchwc_input_args_count.resize(4); + if (nchwc_input_defs_count < 3) { + // The optional bias parameter is empty so set to an empty string. + nchwc_input_defs[2] = &graph_.GetOrCreateNodeArg("", nullptr); + nchwc_input_args_count[2] = 1; + } + nchwc_input_defs[3] = nchwc_inputs[n ^ 1]->output_node_.MutableOutputDefs()[0]; + nchwc_input_args_count[3] = 1; + + FuseNchwcArgument(node, *nchwc_input_n); + removed_nodes_.push_front(node.Index()); + return; } - nchwc_input_defs[3] = nchwc_inputs[n ^ 1]->output_node_.MutableOutputDefs()[0]; - nchwc_input_args_count[3] = 1; - - FuseNchwcArgument(node, *nchwc_input_n); - removed_nodes_.push_front(node.Index()); - return; } } + + CreateNchwcArgument(node, node, nchwc_input_0->channels_, nchwc_input_0->shape_); + return; } - CreateNchwcArgument(node, node, nchwc_input_0->channels_, nchwc_input_0->shape_); + if (add_node) { + // The input shapes cannot be shown to be identical, but the channel dimension + // is the same. Reshape the tensors to explicitly use the true NCHWc shape in + // order to perform the binary operation. + // + // Typically, both tensors are of the same shape at inferencing time, but this + // could not be proven using symbolic dimensions. This reshaping avoids the + // alternative of reordering the tensors back to NCHW. + // + // This optimization is restricted to Add/Sum nodes. Mul nodes would also work + // using this code, however the common case here is multiplying a NxCxHxW + // matrix by a NxCx1x1 vector. The implementation of Mul does not currently + // vectorize well for the case of broadcasting a NCHWc sized channel block. + // This case should be handled by a ScaleShift kernel that can broadcast a + // multiply/add vector (BatchNormalization can also be reimplemented with this). + for (size_t n = 0; n < input_defs_count; n++) { + std::string reshape_input_def_name = graph_.GenerateNodeArgName("reshape"); + auto* reshape_input_arg = &graph_.GetOrCreateNodeArg(reshape_input_def_name, nullptr); + InsertReshape(nchwc_inputs[n]->nchwc_arg_, reshape_input_arg, channels, true); + + input_defs[n] = reshape_input_arg; + nchwc_inputs[n]->remaining_original_uses_--; + } + + std::string output_reshaped_def_name = graph_.GenerateNodeArgName("reshape"); + auto* output_reshaped_arg = &graph_.GetOrCreateNodeArg(output_reshaped_def_name, nullptr); + Node& nchwc_node = InsertReshape(output_reshaped_arg, output_defs[0], channels, false); + + NchwcArgument::Shape output_shape(output_defs[0]); + + CreateNchwcArgument(node, nchwc_node, channels, output_shape); + output_defs[0] = output_reshaped_arg; + return; + } } void NchwcTransformerImpl::TransformConcat(Node& node) { diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 83b624f695..6958735a15 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -165,7 +165,7 @@ struct NchwcTestHelper { void NchwcOptimizerTester(const std::function& build_test_case, const std::function& check_nchwc_graph, - int opset_version = 12) { + int opset_version = 13) { // Ignore the test if NCHWc is not supported by the platform. if (MlasNchwcGetBlockSize() <= 1) { return; @@ -673,6 +673,40 @@ TEST(NchwcOptimizerTests, ConvBinary) { } } +TEST(NchwcOptimizerTests, ConvBinaryBroadcast) { + auto test_case = [&](const std::string& op_type) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({1, 32, 25, 21}); + auto* conv_output_arg = helper.MakeIntermediate(); + auto* pool_output_arg = helper.MakeIntermediate(); + auto* output_arg = helper.MakeOutput(); + + helper.AddConvNode(input_arg, conv_output_arg, {32, 32, 3, 3}); + helper.AddNode("GlobalAveragePool", {input_arg}, {pool_output_arg}); + helper.AddNode(op_type, {conv_output_arg, pool_output_arg}, {output_arg}); + }; + + auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.GlobalAveragePool"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count["Reshape"], 3); + EXPECT_EQ(op_to_count[op_type], 1); + }; + + NchwcOptimizerTester(build_test_case, check_nchwc_graph); + }; + + // Verify that the optimizer keeps the inputs to the binary operator as NCHWc + // and only reorders the output of the binary operator. + std::vector op_types{"Add", "Sum"}; + for (auto& op_type : op_types) { + test_case(op_type); + } +} + TEST(NchwcOptimizerTests, ConvConcat) { auto test_case = [&](int axis, int channel_count, int reorder_output_count) { auto build_test_case = [&](NchwcTestHelper& helper) {