NCHWc: avoid buffer reordering around Add nodes (#7279)

Use Reshape to handle more NCHWc Add cases without ReorderInput/ReorderOutput.
This commit is contained in:
Tracy Sharpe 2021-04-08 09:57:23 -07:00 committed by GitHub
parent e14b291ce7
commit bc6ef809bb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 189 additions and 55 deletions

View file

@ -91,7 +91,7 @@ class NchwcTransformerImpl {
const int64_t channels_;
// Stores the proto shape for the NCHWc output.
NchwcArgument::Shape shape_;
const NchwcArgument::Shape shape_;
NchwcArgument(Node& output_node, NodeArg* output_nchwc_arg, size_t original_uses, size_t channels, const NchwcArgument::Shape& shape)
: output_node_(output_node),
@ -112,6 +112,7 @@ class NchwcTransformerImpl {
const NchwcArgument::Shape& input_shape,
NchwcArgument::Shape& output_shape,
const ONNX_NAMESPACE::TensorProto* filter_shape);
Node& InsertReshape(NodeArg* input_arg, NodeArg* output_arg, int64_t channels, bool split_channels);
void TransformConv(Node& node);
void TransformPool(Node& node);
@ -143,6 +144,11 @@ class NchwcTransformerImpl {
// Stores a mapping of NodeArg biases that have already been aligned to the
// NCHWc block size, so multiple nodes can share the NCHWc biases.
std::unordered_map<NodeArg*, NodeArg*> aligned_biases_;
// Stores a mapping of shape initializers for use by Reshape when splitting
// or unsplitting the channels dimension of a tensor.
std::unordered_map<int64_t, NodeArg*> reshape_split_;
std::unordered_map<int64_t, NodeArg*> reshape_unsplit_;
};
size_t NchwcTransformerImpl::RemoveOutputEdges(Node& node) {
@ -510,11 +516,55 @@ void NchwcTransformerImpl::TransformPool(Node& node) {
removed_nodes_.push_front(node.Index());
}
// The existing Add/Sum operator implementations can be used with tensors
// in NCHWc format if the tensor shapes are exactly the same (elementwise
// add).
Node& NchwcTransformerImpl::InsertReshape(NodeArg* input_arg,
NodeArg* output_arg,
int64_t channels,
bool split_channels) {
const int64_t nchwc_block_size = static_cast<int64_t>(MlasNchwcGetBlockSize());
const int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);
// Reuse the shape initializer across reshapes for the same channel configuration.
auto& shape_arg_map = split_channels ? reshape_split_ : reshape_unsplit_;
NodeArg* shape_arg = shape_arg_map[nchwc_channels];
if (shape_arg == nullptr) {
ONNX_NAMESPACE::TensorProto shape_tensor_proto;
shape_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
shape_tensor_proto.set_name(graph_.GenerateNodeArgName("Reshape"));
// Passthrough the batch dimension.
shape_tensor_proto.add_int64_data(0);
if (split_channels) {
shape_tensor_proto.add_int64_data(nchwc_channels / nchwc_block_size);
} else {
shape_tensor_proto.add_int64_data(nchwc_channels);
}
// Passthrough the spatial dimensions.
for (int i = 0; i < kNchwcSpatialDims; i++) {
shape_tensor_proto.add_int64_data(0);
}
if (split_channels) {
shape_tensor_proto.add_int64_data(nchwc_block_size);
shape_tensor_proto.add_dims(kNchwcDims + 1);
} else {
shape_tensor_proto.add_dims(kNchwcDims);
}
shape_arg = &graph_utils::AddInitializer(graph_, shape_tensor_proto);
shape_arg_map[nchwc_channels] = shape_arg;
}
Node& reshape_node = graph_.AddNode(graph_.GenerateNodeName("Reshape"),
"Reshape",
"Reshape",
{input_arg, shape_arg},
{output_arg});
reshape_node.SetExecutionProviderType(kCpuExecutionProvider);
return reshape_node;
}
void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) {
auto& input_defs = node.MutableInputDefs();
auto& output_defs = node.MutableOutputDefs();
// Verify that all of the inputs to this operator are from NCHWc outputs.
std::vector<NchwcArgument*> nchwc_inputs;
@ -528,74 +578,124 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) {
nchwc_inputs.push_back(it->second.get());
}
// Test if all of the NCHWc inputs have a compatible shape.
auto* nchwc_input_0 = nchwc_inputs[0];
auto* nchwc_input_0_shape = input_defs[0]->Shape();
const int64_t channels = nchwc_inputs[0]->channels_;
// Test if all of the NCHWc inputs have an equal shape.
bool all_shapes_match = true;
auto* input_0_shape = input_defs[0]->Shape();
for (size_t n = 1; n < input_defs_count; n++) {
auto* nchwc_input_n = nchwc_inputs[n];
// Require that all inputs have the same logical number of channels.
if (nchwc_input_n->channels_ != channels) {
return;
}
for (int i = 0; i < kNchwcDims; i++) {
// Test if this dimension is derived from the same NodeArg.
if (!nchwc_input_0->shape_.IsDimEqual(nchwc_input_n->shape_, i)) {
// Check if ONNX shape inferencing has computed a precise dimension value.
auto* nchwc_input_n_shape = input_defs[n]->Shape();
if ((nchwc_input_0_shape == nullptr) || (nchwc_input_n_shape == nullptr)) {
return;
}
auto& nchwc_input_0_dim = nchwc_input_0_shape->dim(i);
auto& nchwc_input_n_dim = nchwc_input_n_shape->dim(i);
if (!utils::HasDimValue(nchwc_input_0_dim) ||
!utils::HasDimValue(nchwc_input_n_dim) ||
(nchwc_input_0_dim.dim_value() <= 0) ||
(nchwc_input_0_dim.dim_value() != nchwc_input_n_dim.dim_value())) {
return;
auto* input_n_shape = input_defs[n]->Shape();
if ((input_0_shape == nullptr) || (input_n_shape == nullptr)) {
all_shapes_match = false;
} else {
auto& input_0_dim = input_0_shape->dim(i);
auto& input_n_dim = input_n_shape->dim(i);
if (!utils::HasDimValue(input_0_dim) ||
!utils::HasDimValue(input_n_dim) ||
(input_0_dim.dim_value() <= 0) ||
(input_0_dim.dim_value() != input_n_dim.dim_value())) {
if (!utils::HasDimParam(input_0_dim) ||
!utils::HasDimParam(input_n_dim) ||
(input_0_dim.dim_param() != input_n_dim.dim_param())) {
all_shapes_match = false;
}
}
}
}
}
}
// Update the node to directly use the NCHWc inputs directly and decrement
// the original use counts of the NCHWc inputs.
for (size_t n = 0; n < input_defs_count; n++) {
input_defs[n] = nchwc_inputs[n]->nchwc_arg_;
nchwc_inputs[n]->remaining_original_uses_--;
}
if (all_shapes_match) {
// Update the node to directly use the NCHWc inputs directly and decrement
// the original use counts of the NCHWc inputs.
for (size_t n = 0; n < input_defs_count; n++) {
input_defs[n] = nchwc_inputs[n]->nchwc_arg_;
nchwc_inputs[n]->remaining_original_uses_--;
}
// If one of the inputs to the Add/Sum node is a NCHWc convolution, then
// attempt to fuse the addition into the convolution itself.
if (add_node && input_defs_count == 2) {
for (size_t n = 0; n < 2; n++) {
auto* nchwc_input_n = nchwc_inputs[n];
auto& nchwc_node = nchwc_input_n->output_node_;
auto& nchwc_input_defs = nchwc_node.MutableInputDefs();
auto& nchwc_input_args_count = nchwc_node.MutableInputArgsCount();
size_t nchwc_input_defs_count = nchwc_input_defs.size();
// Check if this is a single use NCHWc convolution that hasn't already
// been fused with another Add/Sum node. The Add/Sum can also only be
// fused if the convolution isn't itself fused with an activation.
if ((nchwc_node.OpType() == "Conv") && (nchwc_node.Domain() == kMSNchwcDomain) &&
(nchwc_input_defs_count < 4) && (nchwc_input_args_count.size() < 4) &&
(nchwc_input_n->starting_original_uses_ == 1) &&
(graph_utils::GetNodeAttribute(nchwc_node, "activation") == nullptr)) {
// Feed the output of the other NCHWc node into the selected convolution
// node.
nchwc_input_defs.resize(4);
nchwc_input_args_count.resize(4);
if (nchwc_input_defs_count < 3) {
// The optional bias parameter is empty so set to an empty string.
nchwc_input_defs[2] = &graph_.GetOrCreateNodeArg("", nullptr);
nchwc_input_args_count[2] = 1;
// If one of the inputs to the Add/Sum node is a NCHWc convolution, then
// attempt to fuse the addition into the convolution itself.
if (add_node && input_defs_count == 2) {
for (size_t n = 0; n < 2; n++) {
auto* nchwc_input_n = nchwc_inputs[n];
auto& nchwc_node = nchwc_input_n->output_node_;
auto& nchwc_input_defs = nchwc_node.MutableInputDefs();
auto& nchwc_input_args_count = nchwc_node.MutableInputArgsCount();
size_t nchwc_input_defs_count = nchwc_input_defs.size();
// Check if this is a single use NCHWc convolution that hasn't already
// been fused with another Add/Sum node. The Add/Sum can also only be
// fused if the convolution isn't itself fused with an activation.
if ((nchwc_node.OpType() == "Conv") && (nchwc_node.Domain() == kMSNchwcDomain) &&
(nchwc_input_defs_count < 4) && (nchwc_input_args_count.size() < 4) &&
(nchwc_input_n->starting_original_uses_ == 1) &&
(graph_utils::GetNodeAttribute(nchwc_node, "activation") == nullptr)) {
// Feed the output of the other NCHWc node into the selected convolution
// node.
nchwc_input_defs.resize(4);
nchwc_input_args_count.resize(4);
if (nchwc_input_defs_count < 3) {
// The optional bias parameter is empty so set to an empty string.
nchwc_input_defs[2] = &graph_.GetOrCreateNodeArg("", nullptr);
nchwc_input_args_count[2] = 1;
}
nchwc_input_defs[3] = nchwc_inputs[n ^ 1]->output_node_.MutableOutputDefs()[0];
nchwc_input_args_count[3] = 1;
FuseNchwcArgument(node, *nchwc_input_n);
removed_nodes_.push_front(node.Index());
return;
}
nchwc_input_defs[3] = nchwc_inputs[n ^ 1]->output_node_.MutableOutputDefs()[0];
nchwc_input_args_count[3] = 1;
FuseNchwcArgument(node, *nchwc_input_n);
removed_nodes_.push_front(node.Index());
return;
}
}
CreateNchwcArgument(node, node, nchwc_input_0->channels_, nchwc_input_0->shape_);
return;
}
CreateNchwcArgument(node, node, nchwc_input_0->channels_, nchwc_input_0->shape_);
if (add_node) {
// The input shapes cannot be shown to be identical, but the channel dimension
// is the same. Reshape the tensors to explicitly use the true NCHWc shape in
// order to perform the binary operation.
//
// Typically, both tensors are of the same shape at inferencing time, but this
// could not be proven using symbolic dimensions. This reshaping avoids the
// alternative of reordering the tensors back to NCHW.
//
// This optimization is restricted to Add/Sum nodes. Mul nodes would also work
// using this code, however the common case here is multiplying a NxCxHxW
// matrix by a NxCx1x1 vector. The implementation of Mul does not currently
// vectorize well for the case of broadcasting a NCHWc sized channel block.
// This case should be handled by a ScaleShift kernel that can broadcast a
// multiply/add vector (BatchNormalization can also be reimplemented with this).
for (size_t n = 0; n < input_defs_count; n++) {
std::string reshape_input_def_name = graph_.GenerateNodeArgName("reshape");
auto* reshape_input_arg = &graph_.GetOrCreateNodeArg(reshape_input_def_name, nullptr);
InsertReshape(nchwc_inputs[n]->nchwc_arg_, reshape_input_arg, channels, true);
input_defs[n] = reshape_input_arg;
nchwc_inputs[n]->remaining_original_uses_--;
}
std::string output_reshaped_def_name = graph_.GenerateNodeArgName("reshape");
auto* output_reshaped_arg = &graph_.GetOrCreateNodeArg(output_reshaped_def_name, nullptr);
Node& nchwc_node = InsertReshape(output_reshaped_arg, output_defs[0], channels, false);
NchwcArgument::Shape output_shape(output_defs[0]);
CreateNchwcArgument(node, nchwc_node, channels, output_shape);
output_defs[0] = output_reshaped_arg;
return;
}
}
void NchwcTransformerImpl::TransformConcat(Node& node) {

View file

@ -165,7 +165,7 @@ struct NchwcTestHelper {
void NchwcOptimizerTester(const std::function<void(NchwcTestHelper& helper)>& build_test_case,
const std::function<void(InferenceSessionWrapper& session)>& check_nchwc_graph,
int opset_version = 12) {
int opset_version = 13) {
// Ignore the test if NCHWc is not supported by the platform.
if (MlasNchwcGetBlockSize() <= 1) {
return;
@ -673,6 +673,40 @@ TEST(NchwcOptimizerTests, ConvBinary) {
}
}
TEST(NchwcOptimizerTests, ConvBinaryBroadcast) {
auto test_case = [&](const std::string& op_type) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput<float>({1, 32, 25, 21});
auto* conv_output_arg = helper.MakeIntermediate();
auto* pool_output_arg = helper.MakeIntermediate();
auto* output_arg = helper.MakeOutput();
helper.AddConvNode(input_arg, conv_output_arg, {32, 32, 3, 3});
helper.AddNode("GlobalAveragePool", {input_arg}, {pool_output_arg});
helper.AddNode(op_type, {conv_output_arg, pool_output_arg}, {output_arg});
};
auto check_nchwc_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 1);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.GlobalAveragePool"], 1);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1);
EXPECT_EQ(op_to_count["Reshape"], 3);
EXPECT_EQ(op_to_count[op_type], 1);
};
NchwcOptimizerTester(build_test_case, check_nchwc_graph);
};
// Verify that the optimizer keeps the inputs to the binary operator as NCHWc
// and only reorders the output of the binary operator.
std::vector<std::string> op_types{"Add", "Sum"};
for (auto& op_type : op_types) {
test_case(op_type);
}
}
TEST(NchwcOptimizerTests, ConvConcat) {
auto test_case = [&](int axis, int channel_count, int reorder_output_count) {
auto build_test_case = [&](NchwcTestHelper& helper) {