cleanup NCHWc transformer (#8479)

This commit is contained in:
Tracy Sharpe 2021-07-27 15:39:10 -07:00 committed by GitHub
parent 3850755feb
commit 7d47175f76
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -93,7 +93,7 @@ class NchwcTransformerImpl {
// Stores the proto shape for the NCHWc output.
const NchwcArgument::Shape shape_;
NchwcArgument(Node& output_node, NodeArg* output_nchwc_arg, size_t original_uses, size_t channels, const NchwcArgument::Shape& shape)
NchwcArgument(Node& output_node, NodeArg* output_nchwc_arg, size_t original_uses, int64_t channels, const NchwcArgument::Shape& shape)
: output_node_(output_node),
nchwc_arg_(output_nchwc_arg),
starting_original_uses_(original_uses),
@ -103,8 +103,13 @@ class NchwcTransformerImpl {
}
};
NchwcArgument* LookupNchwcArgument(NodeArg* arg) {
auto it = nchwc_args_.find(arg);
return (it != nchwc_args_.end()) ? it->second.get() : nullptr;
}
size_t RemoveOutputEdges(Node& node);
void CreateNchwcArgument(Node& node, Node& nchwc_node, size_t channels, const NchwcArgument::Shape& shape);
void CreateNchwcArgument(Node& node, Node& nchwc_node, int64_t channels, const NchwcArgument::Shape& shape);
void FuseNchwcArgument(Node& node, const NchwcArgument& nchwc_arg);
void InsertReorderInput(Node& node);
@ -112,7 +117,7 @@ class NchwcTransformerImpl {
const NchwcArgument::Shape& input_shape,
NchwcArgument::Shape& output_shape,
const ONNX_NAMESPACE::TensorProto* filter_shape);
Node& InsertReshape(NodeArg* input_arg, NodeArg* output_arg, int64_t channels, bool split_channels);
Node& InsertReshape(NodeArg* input_arg, NodeArg* output_arg, bool split_channels);
void TransformConv(Node& node);
void TransformPool(Node& node);
@ -146,10 +151,10 @@ class NchwcTransformerImpl {
// NCHWc block size, so multiple nodes can share the NCHWc biases.
std::unordered_map<NodeArg*, NodeArg*> aligned_biases_;
// Stores a mapping of shape initializers for use by Reshape when splitting
// or unsplitting the channels dimension of a tensor.
std::unordered_map<int64_t, NodeArg*> reshape_split_;
std::unordered_map<int64_t, NodeArg*> reshape_unsplit_;
// Stores the shape initializers for Reshape to split or unsplit the channels
// dimension of a tensor.
NodeArg* reshape_split_arg_{nullptr};
NodeArg* reshape_unsplit_arg_{nullptr};
// Tracks the last Transpose node and output NodeArg that transposed from
// NHWC to NCHW format.
@ -172,7 +177,7 @@ size_t NchwcTransformerImpl::RemoveOutputEdges(Node& node) {
void NchwcTransformerImpl::CreateNchwcArgument(Node& node,
Node& nchwc_node,
size_t channels,
int64_t channels,
const NchwcArgument::Shape& shape) {
size_t original_uses = RemoveOutputEdges(node);
@ -194,7 +199,7 @@ void NchwcTransformerImpl::FuseNchwcArgument(Node& node, const NchwcArgument& nc
auto& nchwc_node = nchwc_arg.output_node_;
auto* output_nchwc_arg = nchwc_node.MutableOutputDefs()[0];
nchwc_args_[output_original_arg] =
std::make_unique<NchwcArgument>(nchwc_node, output_nchwc_arg, original_uses, static_cast<size_t>(nchwc_arg.channels_), nchwc_arg.shape_);
std::make_unique<NchwcArgument>(nchwc_node, output_nchwc_arg, original_uses, nchwc_arg.channels_, nchwc_arg.shape_);
}
void NchwcTransformerImpl::InsertReorderInput(Node& node) {
@ -484,18 +489,17 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
NchwcArgument::Shape output_shape(output_defs[0]);
if (do_reorder_input) {
auto it = nchwc_args_.find(input_defs[0]);
if (it == nchwc_args_.end()) {
auto* nchwc_input = LookupNchwcArgument(input_defs[0]);
if (nchwc_input == nullptr) {
InsertReorderInput(nchwc_node);
} else {
auto* nchwc_input = it->second.get();
nchwc_node.MutableInputDefs()[0] = nchwc_input->nchwc_arg_;
nchwc_input->remaining_original_uses_--;
ConvPoolShapeInference(node, nchwc_input->shape_, output_shape, conv_W_tensor_proto);
}
}
CreateNchwcArgument(node, nchwc_node, static_cast<size_t>(output_channels), output_shape);
CreateNchwcArgument(node, nchwc_node, output_channels, output_shape);
removed_nodes_.push_front(node.Index());
}
@ -540,54 +544,40 @@ void NchwcTransformerImpl::TransformPool(Node& node) {
NchwcArgument::Shape output_shape(output_defs[0]);
auto it = nchwc_args_.find(input_defs[0]);
if (it == nchwc_args_.end()) {
auto* nchwc_input = LookupNchwcArgument(input_defs[0]);
if (nchwc_input == nullptr) {
InsertReorderInput(nchwc_node);
} else {
auto* nchwc_input = it->second.get();
nchwc_node.MutableInputDefs()[0] = nchwc_input->nchwc_arg_;
nchwc_input->remaining_original_uses_--;
ConvPoolShapeInference(node, nchwc_input->shape_, output_shape, nullptr);
}
CreateNchwcArgument(node, nchwc_node, static_cast<size_t>(channels), output_shape);
CreateNchwcArgument(node, nchwc_node, channels, output_shape);
removed_nodes_.push_front(node.Index());
}
Node& NchwcTransformerImpl::InsertReshape(NodeArg* input_arg,
NodeArg* output_arg,
int64_t channels,
bool split_channels) {
const int64_t nchwc_block_size = static_cast<int64_t>(MlasNchwcGetBlockSize());
const int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);
// Reuse the shape initializer across reshapes for the same channel configuration.
auto& shape_arg_map = split_channels ? reshape_split_ : reshape_unsplit_;
NodeArg* shape_arg = shape_arg_map[nchwc_channels];
auto& shape_arg = split_channels ? reshape_split_arg_ : reshape_unsplit_arg_;
if (shape_arg == nullptr) {
ONNX_NAMESPACE::TensorProto shape_tensor_proto;
shape_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
shape_tensor_proto.set_name(graph_.GenerateNodeArgName("Reshape"));
// Passthrough the batch dimension.
// Passthrough the batch and spatial dimensions. Compute the channel count based
// on the remaining tensor size and whether this is a split or unsplit.
shape_tensor_proto.add_int64_data(0);
if (split_channels) {
shape_tensor_proto.add_int64_data(nchwc_channels / nchwc_block_size);
} else {
shape_tensor_proto.add_int64_data(nchwc_channels);
}
// Passthrough the spatial dimensions.
shape_tensor_proto.add_int64_data(-1);
for (int i = 0; i < kNchwcSpatialDims; i++) {
shape_tensor_proto.add_int64_data(0);
}
if (split_channels) {
shape_tensor_proto.add_int64_data(nchwc_block_size);
shape_tensor_proto.add_dims(kNchwcDims + 1);
} else {
shape_tensor_proto.add_dims(kNchwcDims);
shape_tensor_proto.add_int64_data(static_cast<int64_t>(MlasNchwcGetBlockSize()));
}
shape_tensor_proto.add_dims(split_channels ? kNchwcDims + 1 : kNchwcDims);
shape_arg = &graph_utils::AddInitializer(graph_, shape_tensor_proto);
shape_arg_map[nchwc_channels] = shape_arg;
}
Node& reshape_node = graph_.AddNode(graph_.GenerateNodeName("Reshape"),
@ -609,11 +599,11 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) {
size_t input_defs_count = input_defs.size();
nchwc_inputs.reserve(input_defs_count);
for (size_t i = 0; i < input_defs_count; i++) {
auto it = nchwc_args_.find(input_defs[i]);
if (it == nchwc_args_.end()) {
auto* nchwc_input = LookupNchwcArgument(input_defs[i]);
if (nchwc_input == nullptr) {
return;
}
nchwc_inputs.push_back(it->second.get());
nchwc_inputs.push_back(nchwc_input);
}
auto* nchwc_input_0 = nchwc_inputs[0];
@ -697,7 +687,7 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) {
}
}
CreateNchwcArgument(node, node, static_cast<size_t>(nchwc_input_0->channels_), nchwc_input_0->shape_);
CreateNchwcArgument(node, node, nchwc_input_0->channels_, nchwc_input_0->shape_);
return;
}
@ -719,7 +709,7 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) {
for (size_t n = 0; n < input_defs_count; n++) {
std::string reshape_input_def_name = graph_.GenerateNodeArgName("reshape");
auto* reshape_input_arg = &graph_.GetOrCreateNodeArg(reshape_input_def_name, nullptr);
InsertReshape(nchwc_inputs[n]->nchwc_arg_, reshape_input_arg, channels, true);
InsertReshape(nchwc_inputs[n]->nchwc_arg_, reshape_input_arg, true);
input_defs[n] = reshape_input_arg;
nchwc_inputs[n]->remaining_original_uses_--;
@ -727,11 +717,11 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) {
std::string output_reshaped_def_name = graph_.GenerateNodeArgName("reshape");
auto* output_reshaped_arg = &graph_.GetOrCreateNodeArg(output_reshaped_def_name, nullptr);
Node& nchwc_node = InsertReshape(output_reshaped_arg, output_defs[0], channels, false);
Node& nchwc_node = InsertReshape(output_reshaped_arg, output_defs[0], false);
NchwcArgument::Shape output_shape(output_defs[0]);
CreateNchwcArgument(node, nchwc_node, static_cast<size_t>(channels), output_shape);
CreateNchwcArgument(node, nchwc_node, channels, output_shape);
output_defs[0] = output_reshaped_arg;
return;
}
@ -755,17 +745,17 @@ void NchwcTransformerImpl::TransformConcat(Node& node) {
nchwc_inputs.reserve(input_defs_count);
int64_t total_channels = 0;
for (size_t i = 0; i < input_defs_count; i++) {
auto it = nchwc_args_.find(input_defs[i]);
if (it == nchwc_args_.end()) {
auto* nchwc_input = LookupNchwcArgument(input_defs[i]);
if (nchwc_input == nullptr) {
return;
}
// Verify that the logical number of channels is block aligned.
int64_t input_channels = it->second->channels_;
int64_t input_channels = nchwc_input->channels_;
if ((input_channels % nchwc_block_size) != 0) {
return;
}
total_channels += input_channels;
nchwc_inputs.push_back(it->second.get());
nchwc_inputs.push_back(nchwc_input);
}
// Update the node to directly use the NCHWc inputs directly and decrement
@ -780,7 +770,7 @@ void NchwcTransformerImpl::TransformConcat(Node& node) {
NchwcArgument::Shape output_shape = nchwc_inputs[0]->shape_;
output_shape.dims_[1] = output_defs[0];
CreateNchwcArgument(node, node, static_cast<size_t>(total_channels), output_shape);
CreateNchwcArgument(node, node, total_channels, output_shape);
}
// After doing a Conv/Add fusion, there may be an activation node that could now
@ -789,9 +779,8 @@ void NchwcTransformerImpl::TransformConcat(Node& node) {
void NchwcTransformerImpl::TransformActivation(Node& node) {
auto& input_defs = node.MutableInputDefs();
auto it = nchwc_args_.find(input_defs[0]);
if (it != nchwc_args_.end()) {
auto& nchwc_input = it->second;
auto* nchwc_input = LookupNchwcArgument(input_defs[0]);
if (nchwc_input != nullptr) {
input_defs[0] = nchwc_input->nchwc_arg_;
nchwc_input->remaining_original_uses_--;
@ -805,7 +794,7 @@ void NchwcTransformerImpl::TransformActivation(Node& node) {
FuseNchwcArgument(node, *nchwc_input);
removed_nodes_.push_front(node.Index());
} else {
CreateNchwcArgument(node, node, static_cast<size_t>(nchwc_input->channels_), nchwc_input->shape_);
CreateNchwcArgument(node, node, nchwc_input->channels_, nchwc_input->shape_);
}
}
}
@ -823,11 +812,10 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) {
}
// Don't transform the node if the input is not already in NCHWc format.
auto it = nchwc_args_.find(input_defs[0]);
if (it == nchwc_args_.end()) {
auto* nchwc_input = LookupNchwcArgument(input_defs[0]);
if (nchwc_input == nullptr) {
return;
}
auto* nchwc_input = it->second.get();
// Require that BatchNormalization-7 uses spatial normalization.
const auto* spatial_attr = graph_utils::GetNodeAttribute(node, "spatial");
@ -926,7 +914,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) {
nchwc_input->remaining_original_uses_--;
CreateNchwcArgument(node, nchwc_node, static_cast<size_t>(channels), nchwc_input->shape_);
CreateNchwcArgument(node, nchwc_node, channels, nchwc_input->shape_);
removed_nodes_.push_front(node.Index());
}
@ -935,11 +923,10 @@ void NchwcTransformerImpl::TransformTransposeToNhwc(Node& node) {
auto& output_defs = node.MutableOutputDefs();
// Don't transform the node if the input is not already in NCHWc format.
auto it = nchwc_args_.find(input_defs[0]);
if (it == nchwc_args_.end()) {
auto* nchwc_input = LookupNchwcArgument(input_defs[0]);
if (nchwc_input == nullptr) {
return;
}
auto* nchwc_input = it->second.get();
const auto* perm_attr = graph_utils::GetNodeAttribute(node, "perm");
if (perm_attr == nullptr || perm_attr->ints_size() != 4) {
@ -976,11 +963,10 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
auto& output_defs = node.MutableOutputDefs();
// Don't transform the node if the input is not already in NCHWc format.
auto it = nchwc_args_.find(input_defs[0]);
if (it == nchwc_args_.end()) {
auto* nchwc_input = LookupNchwcArgument(input_defs[0]);
if (nchwc_input == nullptr) {
return;
}
auto* nchwc_input = it->second.get();
// Support nearest (default) and linear modes.
const auto* mode_attr = graph_utils::GetNodeAttribute(node, "mode");
@ -1127,7 +1113,7 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
NchwcArgument::Shape output_shape(output_defs[0]);
CreateNchwcArgument(node, nchwc_node, static_cast<size_t>(nchwc_input->channels_), output_shape);
CreateNchwcArgument(node, nchwc_node, nchwc_input->channels_, output_shape);
removed_nodes_.push_front(node.Index());
}