From 7d47175f76a0dc71268ba41f2f4c37a05fbda4a8 Mon Sep 17 00:00:00 2001 From: Tracy Sharpe <42477615+tracysh@users.noreply.github.com> Date: Tue, 27 Jul 2021 15:39:10 -0700 Subject: [PATCH] cleanup NCHWc transformer (#8479) --- .../core/optimizer/nchwc_transformer.cc | 112 ++++++++---------- 1 file changed, 49 insertions(+), 63 deletions(-) diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 609e80c4be..f264121104 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -93,7 +93,7 @@ class NchwcTransformerImpl { // Stores the proto shape for the NCHWc output. const NchwcArgument::Shape shape_; - NchwcArgument(Node& output_node, NodeArg* output_nchwc_arg, size_t original_uses, size_t channels, const NchwcArgument::Shape& shape) + NchwcArgument(Node& output_node, NodeArg* output_nchwc_arg, size_t original_uses, int64_t channels, const NchwcArgument::Shape& shape) : output_node_(output_node), nchwc_arg_(output_nchwc_arg), starting_original_uses_(original_uses), @@ -103,8 +103,13 @@ class NchwcTransformerImpl { } }; + NchwcArgument* LookupNchwcArgument(NodeArg* arg) { + auto it = nchwc_args_.find(arg); + return (it != nchwc_args_.end()) ? it->second.get() : nullptr; + } + size_t RemoveOutputEdges(Node& node); - void CreateNchwcArgument(Node& node, Node& nchwc_node, size_t channels, const NchwcArgument::Shape& shape); + void CreateNchwcArgument(Node& node, Node& nchwc_node, int64_t channels, const NchwcArgument::Shape& shape); void FuseNchwcArgument(Node& node, const NchwcArgument& nchwc_arg); void InsertReorderInput(Node& node); @@ -112,7 +117,7 @@ class NchwcTransformerImpl { const NchwcArgument::Shape& input_shape, NchwcArgument::Shape& output_shape, const ONNX_NAMESPACE::TensorProto* filter_shape); - Node& InsertReshape(NodeArg* input_arg, NodeArg* output_arg, int64_t channels, bool split_channels); + Node& InsertReshape(NodeArg* input_arg, NodeArg* output_arg, bool split_channels); void TransformConv(Node& node); void TransformPool(Node& node); @@ -146,10 +151,10 @@ class NchwcTransformerImpl { // NCHWc block size, so multiple nodes can share the NCHWc biases. std::unordered_map aligned_biases_; - // Stores a mapping of shape initializers for use by Reshape when splitting - // or unsplitting the channels dimension of a tensor. - std::unordered_map reshape_split_; - std::unordered_map reshape_unsplit_; + // Stores the shape initializers for Reshape to split or unsplit the channels + // dimension of a tensor. + NodeArg* reshape_split_arg_{nullptr}; + NodeArg* reshape_unsplit_arg_{nullptr}; // Tracks the last Transpose node and output NodeArg that transposed from // NHWC to NCHW format. @@ -172,7 +177,7 @@ size_t NchwcTransformerImpl::RemoveOutputEdges(Node& node) { void NchwcTransformerImpl::CreateNchwcArgument(Node& node, Node& nchwc_node, - size_t channels, + int64_t channels, const NchwcArgument::Shape& shape) { size_t original_uses = RemoveOutputEdges(node); @@ -194,7 +199,7 @@ void NchwcTransformerImpl::FuseNchwcArgument(Node& node, const NchwcArgument& nc auto& nchwc_node = nchwc_arg.output_node_; auto* output_nchwc_arg = nchwc_node.MutableOutputDefs()[0]; nchwc_args_[output_original_arg] = - std::make_unique(nchwc_node, output_nchwc_arg, original_uses, static_cast(nchwc_arg.channels_), nchwc_arg.shape_); + std::make_unique(nchwc_node, output_nchwc_arg, original_uses, nchwc_arg.channels_, nchwc_arg.shape_); } void NchwcTransformerImpl::InsertReorderInput(Node& node) { @@ -484,18 +489,17 @@ void NchwcTransformerImpl::TransformConv(Node& node) { NchwcArgument::Shape output_shape(output_defs[0]); if (do_reorder_input) { - auto it = nchwc_args_.find(input_defs[0]); - if (it == nchwc_args_.end()) { + auto* nchwc_input = LookupNchwcArgument(input_defs[0]); + if (nchwc_input == nullptr) { InsertReorderInput(nchwc_node); } else { - auto* nchwc_input = it->second.get(); nchwc_node.MutableInputDefs()[0] = nchwc_input->nchwc_arg_; nchwc_input->remaining_original_uses_--; ConvPoolShapeInference(node, nchwc_input->shape_, output_shape, conv_W_tensor_proto); } } - CreateNchwcArgument(node, nchwc_node, static_cast(output_channels), output_shape); + CreateNchwcArgument(node, nchwc_node, output_channels, output_shape); removed_nodes_.push_front(node.Index()); } @@ -540,54 +544,40 @@ void NchwcTransformerImpl::TransformPool(Node& node) { NchwcArgument::Shape output_shape(output_defs[0]); - auto it = nchwc_args_.find(input_defs[0]); - if (it == nchwc_args_.end()) { + auto* nchwc_input = LookupNchwcArgument(input_defs[0]); + if (nchwc_input == nullptr) { InsertReorderInput(nchwc_node); } else { - auto* nchwc_input = it->second.get(); nchwc_node.MutableInputDefs()[0] = nchwc_input->nchwc_arg_; nchwc_input->remaining_original_uses_--; ConvPoolShapeInference(node, nchwc_input->shape_, output_shape, nullptr); } - CreateNchwcArgument(node, nchwc_node, static_cast(channels), output_shape); + CreateNchwcArgument(node, nchwc_node, channels, output_shape); removed_nodes_.push_front(node.Index()); } Node& NchwcTransformerImpl::InsertReshape(NodeArg* input_arg, NodeArg* output_arg, - int64_t channels, bool split_channels) { - const int64_t nchwc_block_size = static_cast(MlasNchwcGetBlockSize()); - const int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1); - - // Reuse the shape initializer across reshapes for the same channel configuration. - auto& shape_arg_map = split_channels ? reshape_split_ : reshape_unsplit_; - NodeArg* shape_arg = shape_arg_map[nchwc_channels]; + auto& shape_arg = split_channels ? reshape_split_arg_ : reshape_unsplit_arg_; if (shape_arg == nullptr) { ONNX_NAMESPACE::TensorProto shape_tensor_proto; shape_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); shape_tensor_proto.set_name(graph_.GenerateNodeArgName("Reshape")); - // Passthrough the batch dimension. + // Passthrough the batch and spatial dimensions. Compute the channel count based + // on the remaining tensor size and whether this is a split or unsplit. shape_tensor_proto.add_int64_data(0); - if (split_channels) { - shape_tensor_proto.add_int64_data(nchwc_channels / nchwc_block_size); - } else { - shape_tensor_proto.add_int64_data(nchwc_channels); - } - // Passthrough the spatial dimensions. + shape_tensor_proto.add_int64_data(-1); for (int i = 0; i < kNchwcSpatialDims; i++) { shape_tensor_proto.add_int64_data(0); } if (split_channels) { - shape_tensor_proto.add_int64_data(nchwc_block_size); - shape_tensor_proto.add_dims(kNchwcDims + 1); - } else { - shape_tensor_proto.add_dims(kNchwcDims); + shape_tensor_proto.add_int64_data(static_cast(MlasNchwcGetBlockSize())); } + shape_tensor_proto.add_dims(split_channels ? kNchwcDims + 1 : kNchwcDims); shape_arg = &graph_utils::AddInitializer(graph_, shape_tensor_proto); - shape_arg_map[nchwc_channels] = shape_arg; } Node& reshape_node = graph_.AddNode(graph_.GenerateNodeName("Reshape"), @@ -609,11 +599,11 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) { size_t input_defs_count = input_defs.size(); nchwc_inputs.reserve(input_defs_count); for (size_t i = 0; i < input_defs_count; i++) { - auto it = nchwc_args_.find(input_defs[i]); - if (it == nchwc_args_.end()) { + auto* nchwc_input = LookupNchwcArgument(input_defs[i]); + if (nchwc_input == nullptr) { return; } - nchwc_inputs.push_back(it->second.get()); + nchwc_inputs.push_back(nchwc_input); } auto* nchwc_input_0 = nchwc_inputs[0]; @@ -697,7 +687,7 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) { } } - CreateNchwcArgument(node, node, static_cast(nchwc_input_0->channels_), nchwc_input_0->shape_); + CreateNchwcArgument(node, node, nchwc_input_0->channels_, nchwc_input_0->shape_); return; } @@ -719,7 +709,7 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) { for (size_t n = 0; n < input_defs_count; n++) { std::string reshape_input_def_name = graph_.GenerateNodeArgName("reshape"); auto* reshape_input_arg = &graph_.GetOrCreateNodeArg(reshape_input_def_name, nullptr); - InsertReshape(nchwc_inputs[n]->nchwc_arg_, reshape_input_arg, channels, true); + InsertReshape(nchwc_inputs[n]->nchwc_arg_, reshape_input_arg, true); input_defs[n] = reshape_input_arg; nchwc_inputs[n]->remaining_original_uses_--; @@ -727,11 +717,11 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) { std::string output_reshaped_def_name = graph_.GenerateNodeArgName("reshape"); auto* output_reshaped_arg = &graph_.GetOrCreateNodeArg(output_reshaped_def_name, nullptr); - Node& nchwc_node = InsertReshape(output_reshaped_arg, output_defs[0], channels, false); + Node& nchwc_node = InsertReshape(output_reshaped_arg, output_defs[0], false); NchwcArgument::Shape output_shape(output_defs[0]); - CreateNchwcArgument(node, nchwc_node, static_cast(channels), output_shape); + CreateNchwcArgument(node, nchwc_node, channels, output_shape); output_defs[0] = output_reshaped_arg; return; } @@ -755,17 +745,17 @@ void NchwcTransformerImpl::TransformConcat(Node& node) { nchwc_inputs.reserve(input_defs_count); int64_t total_channels = 0; for (size_t i = 0; i < input_defs_count; i++) { - auto it = nchwc_args_.find(input_defs[i]); - if (it == nchwc_args_.end()) { + auto* nchwc_input = LookupNchwcArgument(input_defs[i]); + if (nchwc_input == nullptr) { return; } // Verify that the logical number of channels is block aligned. - int64_t input_channels = it->second->channels_; + int64_t input_channels = nchwc_input->channels_; if ((input_channels % nchwc_block_size) != 0) { return; } total_channels += input_channels; - nchwc_inputs.push_back(it->second.get()); + nchwc_inputs.push_back(nchwc_input); } // Update the node to directly use the NCHWc inputs directly and decrement @@ -780,7 +770,7 @@ void NchwcTransformerImpl::TransformConcat(Node& node) { NchwcArgument::Shape output_shape = nchwc_inputs[0]->shape_; output_shape.dims_[1] = output_defs[0]; - CreateNchwcArgument(node, node, static_cast(total_channels), output_shape); + CreateNchwcArgument(node, node, total_channels, output_shape); } // After doing a Conv/Add fusion, there may be an activation node that could now @@ -789,9 +779,8 @@ void NchwcTransformerImpl::TransformConcat(Node& node) { void NchwcTransformerImpl::TransformActivation(Node& node) { auto& input_defs = node.MutableInputDefs(); - auto it = nchwc_args_.find(input_defs[0]); - if (it != nchwc_args_.end()) { - auto& nchwc_input = it->second; + auto* nchwc_input = LookupNchwcArgument(input_defs[0]); + if (nchwc_input != nullptr) { input_defs[0] = nchwc_input->nchwc_arg_; nchwc_input->remaining_original_uses_--; @@ -805,7 +794,7 @@ void NchwcTransformerImpl::TransformActivation(Node& node) { FuseNchwcArgument(node, *nchwc_input); removed_nodes_.push_front(node.Index()); } else { - CreateNchwcArgument(node, node, static_cast(nchwc_input->channels_), nchwc_input->shape_); + CreateNchwcArgument(node, node, nchwc_input->channels_, nchwc_input->shape_); } } } @@ -823,11 +812,10 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { } // Don't transform the node if the input is not already in NCHWc format. - auto it = nchwc_args_.find(input_defs[0]); - if (it == nchwc_args_.end()) { + auto* nchwc_input = LookupNchwcArgument(input_defs[0]); + if (nchwc_input == nullptr) { return; } - auto* nchwc_input = it->second.get(); // Require that BatchNormalization-7 uses spatial normalization. const auto* spatial_attr = graph_utils::GetNodeAttribute(node, "spatial"); @@ -926,7 +914,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { nchwc_input->remaining_original_uses_--; - CreateNchwcArgument(node, nchwc_node, static_cast(channels), nchwc_input->shape_); + CreateNchwcArgument(node, nchwc_node, channels, nchwc_input->shape_); removed_nodes_.push_front(node.Index()); } @@ -935,11 +923,10 @@ void NchwcTransformerImpl::TransformTransposeToNhwc(Node& node) { auto& output_defs = node.MutableOutputDefs(); // Don't transform the node if the input is not already in NCHWc format. - auto it = nchwc_args_.find(input_defs[0]); - if (it == nchwc_args_.end()) { + auto* nchwc_input = LookupNchwcArgument(input_defs[0]); + if (nchwc_input == nullptr) { return; } - auto* nchwc_input = it->second.get(); const auto* perm_attr = graph_utils::GetNodeAttribute(node, "perm"); if (perm_attr == nullptr || perm_attr->ints_size() != 4) { @@ -976,11 +963,10 @@ void NchwcTransformerImpl::TransformResize(Node& node) { auto& output_defs = node.MutableOutputDefs(); // Don't transform the node if the input is not already in NCHWc format. - auto it = nchwc_args_.find(input_defs[0]); - if (it == nchwc_args_.end()) { + auto* nchwc_input = LookupNchwcArgument(input_defs[0]); + if (nchwc_input == nullptr) { return; } - auto* nchwc_input = it->second.get(); // Support nearest (default) and linear modes. const auto* mode_attr = graph_utils::GetNodeAttribute(node, "mode"); @@ -1127,7 +1113,7 @@ void NchwcTransformerImpl::TransformResize(Node& node) { NchwcArgument::Shape output_shape(output_defs[0]); - CreateNchwcArgument(node, nchwc_node, static_cast(nchwc_input->channels_), output_shape); + CreateNchwcArgument(node, nchwc_node, nchwc_input->channels_, output_shape); removed_nodes_.push_front(node.Index()); }