mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
NCHWc: avoid buffer reordering around Add nodes (#7279)
Use Reshape to handle more NCHWc Add cases without ReorderInput/ReorderOutput.
This commit is contained in:
parent
e14b291ce7
commit
bc6ef809bb
2 changed files with 189 additions and 55 deletions
|
|
@ -91,7 +91,7 @@ class NchwcTransformerImpl {
|
|||
const int64_t channels_;
|
||||
|
||||
// Stores the proto shape for the NCHWc output.
|
||||
NchwcArgument::Shape shape_;
|
||||
const NchwcArgument::Shape shape_;
|
||||
|
||||
NchwcArgument(Node& output_node, NodeArg* output_nchwc_arg, size_t original_uses, size_t channels, const NchwcArgument::Shape& shape)
|
||||
: output_node_(output_node),
|
||||
|
|
@ -112,6 +112,7 @@ class NchwcTransformerImpl {
|
|||
const NchwcArgument::Shape& input_shape,
|
||||
NchwcArgument::Shape& output_shape,
|
||||
const ONNX_NAMESPACE::TensorProto* filter_shape);
|
||||
Node& InsertReshape(NodeArg* input_arg, NodeArg* output_arg, int64_t channels, bool split_channels);
|
||||
|
||||
void TransformConv(Node& node);
|
||||
void TransformPool(Node& node);
|
||||
|
|
@ -143,6 +144,11 @@ class NchwcTransformerImpl {
|
|||
// Stores a mapping of NodeArg biases that have already been aligned to the
|
||||
// NCHWc block size, so multiple nodes can share the NCHWc biases.
|
||||
std::unordered_map<NodeArg*, NodeArg*> aligned_biases_;
|
||||
|
||||
// Stores a mapping of shape initializers for use by Reshape when splitting
|
||||
// or unsplitting the channels dimension of a tensor.
|
||||
std::unordered_map<int64_t, NodeArg*> reshape_split_;
|
||||
std::unordered_map<int64_t, NodeArg*> reshape_unsplit_;
|
||||
};
|
||||
|
||||
size_t NchwcTransformerImpl::RemoveOutputEdges(Node& node) {
|
||||
|
|
@ -510,11 +516,55 @@ void NchwcTransformerImpl::TransformPool(Node& node) {
|
|||
removed_nodes_.push_front(node.Index());
|
||||
}
|
||||
|
||||
// The existing Add/Sum operator implementations can be used with tensors
|
||||
// in NCHWc format if the tensor shapes are exactly the same (elementwise
|
||||
// add).
|
||||
Node& NchwcTransformerImpl::InsertReshape(NodeArg* input_arg,
|
||||
NodeArg* output_arg,
|
||||
int64_t channels,
|
||||
bool split_channels) {
|
||||
const int64_t nchwc_block_size = static_cast<int64_t>(MlasNchwcGetBlockSize());
|
||||
const int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1);
|
||||
|
||||
// Reuse the shape initializer across reshapes for the same channel configuration.
|
||||
auto& shape_arg_map = split_channels ? reshape_split_ : reshape_unsplit_;
|
||||
NodeArg* shape_arg = shape_arg_map[nchwc_channels];
|
||||
if (shape_arg == nullptr) {
|
||||
ONNX_NAMESPACE::TensorProto shape_tensor_proto;
|
||||
shape_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
shape_tensor_proto.set_name(graph_.GenerateNodeArgName("Reshape"));
|
||||
// Passthrough the batch dimension.
|
||||
shape_tensor_proto.add_int64_data(0);
|
||||
if (split_channels) {
|
||||
shape_tensor_proto.add_int64_data(nchwc_channels / nchwc_block_size);
|
||||
} else {
|
||||
shape_tensor_proto.add_int64_data(nchwc_channels);
|
||||
}
|
||||
// Passthrough the spatial dimensions.
|
||||
for (int i = 0; i < kNchwcSpatialDims; i++) {
|
||||
shape_tensor_proto.add_int64_data(0);
|
||||
}
|
||||
if (split_channels) {
|
||||
shape_tensor_proto.add_int64_data(nchwc_block_size);
|
||||
shape_tensor_proto.add_dims(kNchwcDims + 1);
|
||||
} else {
|
||||
shape_tensor_proto.add_dims(kNchwcDims);
|
||||
}
|
||||
|
||||
shape_arg = &graph_utils::AddInitializer(graph_, shape_tensor_proto);
|
||||
shape_arg_map[nchwc_channels] = shape_arg;
|
||||
}
|
||||
|
||||
Node& reshape_node = graph_.AddNode(graph_.GenerateNodeName("Reshape"),
|
||||
"Reshape",
|
||||
"Reshape",
|
||||
{input_arg, shape_arg},
|
||||
{output_arg});
|
||||
reshape_node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
|
||||
return reshape_node;
|
||||
}
|
||||
|
||||
void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) {
|
||||
auto& input_defs = node.MutableInputDefs();
|
||||
auto& output_defs = node.MutableOutputDefs();
|
||||
|
||||
// Verify that all of the inputs to this operator are from NCHWc outputs.
|
||||
std::vector<NchwcArgument*> nchwc_inputs;
|
||||
|
|
@ -528,74 +578,124 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) {
|
|||
nchwc_inputs.push_back(it->second.get());
|
||||
}
|
||||
|
||||
// Test if all of the NCHWc inputs have a compatible shape.
|
||||
auto* nchwc_input_0 = nchwc_inputs[0];
|
||||
auto* nchwc_input_0_shape = input_defs[0]->Shape();
|
||||
const int64_t channels = nchwc_inputs[0]->channels_;
|
||||
|
||||
// Test if all of the NCHWc inputs have an equal shape.
|
||||
bool all_shapes_match = true;
|
||||
auto* input_0_shape = input_defs[0]->Shape();
|
||||
for (size_t n = 1; n < input_defs_count; n++) {
|
||||
auto* nchwc_input_n = nchwc_inputs[n];
|
||||
// Require that all inputs have the same logical number of channels.
|
||||
if (nchwc_input_n->channels_ != channels) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < kNchwcDims; i++) {
|
||||
// Test if this dimension is derived from the same NodeArg.
|
||||
if (!nchwc_input_0->shape_.IsDimEqual(nchwc_input_n->shape_, i)) {
|
||||
// Check if ONNX shape inferencing has computed a precise dimension value.
|
||||
auto* nchwc_input_n_shape = input_defs[n]->Shape();
|
||||
if ((nchwc_input_0_shape == nullptr) || (nchwc_input_n_shape == nullptr)) {
|
||||
return;
|
||||
}
|
||||
auto& nchwc_input_0_dim = nchwc_input_0_shape->dim(i);
|
||||
auto& nchwc_input_n_dim = nchwc_input_n_shape->dim(i);
|
||||
if (!utils::HasDimValue(nchwc_input_0_dim) ||
|
||||
!utils::HasDimValue(nchwc_input_n_dim) ||
|
||||
(nchwc_input_0_dim.dim_value() <= 0) ||
|
||||
(nchwc_input_0_dim.dim_value() != nchwc_input_n_dim.dim_value())) {
|
||||
return;
|
||||
auto* input_n_shape = input_defs[n]->Shape();
|
||||
if ((input_0_shape == nullptr) || (input_n_shape == nullptr)) {
|
||||
all_shapes_match = false;
|
||||
} else {
|
||||
auto& input_0_dim = input_0_shape->dim(i);
|
||||
auto& input_n_dim = input_n_shape->dim(i);
|
||||
if (!utils::HasDimValue(input_0_dim) ||
|
||||
!utils::HasDimValue(input_n_dim) ||
|
||||
(input_0_dim.dim_value() <= 0) ||
|
||||
(input_0_dim.dim_value() != input_n_dim.dim_value())) {
|
||||
if (!utils::HasDimParam(input_0_dim) ||
|
||||
!utils::HasDimParam(input_n_dim) ||
|
||||
(input_0_dim.dim_param() != input_n_dim.dim_param())) {
|
||||
all_shapes_match = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update the node to directly use the NCHWc inputs directly and decrement
|
||||
// the original use counts of the NCHWc inputs.
|
||||
for (size_t n = 0; n < input_defs_count; n++) {
|
||||
input_defs[n] = nchwc_inputs[n]->nchwc_arg_;
|
||||
nchwc_inputs[n]->remaining_original_uses_--;
|
||||
}
|
||||
if (all_shapes_match) {
|
||||
// Update the node to directly use the NCHWc inputs directly and decrement
|
||||
// the original use counts of the NCHWc inputs.
|
||||
for (size_t n = 0; n < input_defs_count; n++) {
|
||||
input_defs[n] = nchwc_inputs[n]->nchwc_arg_;
|
||||
nchwc_inputs[n]->remaining_original_uses_--;
|
||||
}
|
||||
|
||||
// If one of the inputs to the Add/Sum node is a NCHWc convolution, then
|
||||
// attempt to fuse the addition into the convolution itself.
|
||||
if (add_node && input_defs_count == 2) {
|
||||
for (size_t n = 0; n < 2; n++) {
|
||||
auto* nchwc_input_n = nchwc_inputs[n];
|
||||
auto& nchwc_node = nchwc_input_n->output_node_;
|
||||
auto& nchwc_input_defs = nchwc_node.MutableInputDefs();
|
||||
auto& nchwc_input_args_count = nchwc_node.MutableInputArgsCount();
|
||||
size_t nchwc_input_defs_count = nchwc_input_defs.size();
|
||||
// Check if this is a single use NCHWc convolution that hasn't already
|
||||
// been fused with another Add/Sum node. The Add/Sum can also only be
|
||||
// fused if the convolution isn't itself fused with an activation.
|
||||
if ((nchwc_node.OpType() == "Conv") && (nchwc_node.Domain() == kMSNchwcDomain) &&
|
||||
(nchwc_input_defs_count < 4) && (nchwc_input_args_count.size() < 4) &&
|
||||
(nchwc_input_n->starting_original_uses_ == 1) &&
|
||||
(graph_utils::GetNodeAttribute(nchwc_node, "activation") == nullptr)) {
|
||||
// Feed the output of the other NCHWc node into the selected convolution
|
||||
// node.
|
||||
nchwc_input_defs.resize(4);
|
||||
nchwc_input_args_count.resize(4);
|
||||
if (nchwc_input_defs_count < 3) {
|
||||
// The optional bias parameter is empty so set to an empty string.
|
||||
nchwc_input_defs[2] = &graph_.GetOrCreateNodeArg("", nullptr);
|
||||
nchwc_input_args_count[2] = 1;
|
||||
// If one of the inputs to the Add/Sum node is a NCHWc convolution, then
|
||||
// attempt to fuse the addition into the convolution itself.
|
||||
if (add_node && input_defs_count == 2) {
|
||||
for (size_t n = 0; n < 2; n++) {
|
||||
auto* nchwc_input_n = nchwc_inputs[n];
|
||||
auto& nchwc_node = nchwc_input_n->output_node_;
|
||||
auto& nchwc_input_defs = nchwc_node.MutableInputDefs();
|
||||
auto& nchwc_input_args_count = nchwc_node.MutableInputArgsCount();
|
||||
size_t nchwc_input_defs_count = nchwc_input_defs.size();
|
||||
// Check if this is a single use NCHWc convolution that hasn't already
|
||||
// been fused with another Add/Sum node. The Add/Sum can also only be
|
||||
// fused if the convolution isn't itself fused with an activation.
|
||||
if ((nchwc_node.OpType() == "Conv") && (nchwc_node.Domain() == kMSNchwcDomain) &&
|
||||
(nchwc_input_defs_count < 4) && (nchwc_input_args_count.size() < 4) &&
|
||||
(nchwc_input_n->starting_original_uses_ == 1) &&
|
||||
(graph_utils::GetNodeAttribute(nchwc_node, "activation") == nullptr)) {
|
||||
// Feed the output of the other NCHWc node into the selected convolution
|
||||
// node.
|
||||
nchwc_input_defs.resize(4);
|
||||
nchwc_input_args_count.resize(4);
|
||||
if (nchwc_input_defs_count < 3) {
|
||||
// The optional bias parameter is empty so set to an empty string.
|
||||
nchwc_input_defs[2] = &graph_.GetOrCreateNodeArg("", nullptr);
|
||||
nchwc_input_args_count[2] = 1;
|
||||
}
|
||||
nchwc_input_defs[3] = nchwc_inputs[n ^ 1]->output_node_.MutableOutputDefs()[0];
|
||||
nchwc_input_args_count[3] = 1;
|
||||
|
||||
FuseNchwcArgument(node, *nchwc_input_n);
|
||||
removed_nodes_.push_front(node.Index());
|
||||
return;
|
||||
}
|
||||
nchwc_input_defs[3] = nchwc_inputs[n ^ 1]->output_node_.MutableOutputDefs()[0];
|
||||
nchwc_input_args_count[3] = 1;
|
||||
|
||||
FuseNchwcArgument(node, *nchwc_input_n);
|
||||
removed_nodes_.push_front(node.Index());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
CreateNchwcArgument(node, node, nchwc_input_0->channels_, nchwc_input_0->shape_);
|
||||
return;
|
||||
}
|
||||
|
||||
CreateNchwcArgument(node, node, nchwc_input_0->channels_, nchwc_input_0->shape_);
|
||||
if (add_node) {
|
||||
// The input shapes cannot be shown to be identical, but the channel dimension
|
||||
// is the same. Reshape the tensors to explicitly use the true NCHWc shape in
|
||||
// order to perform the binary operation.
|
||||
//
|
||||
// Typically, both tensors are of the same shape at inferencing time, but this
|
||||
// could not be proven using symbolic dimensions. This reshaping avoids the
|
||||
// alternative of reordering the tensors back to NCHW.
|
||||
//
|
||||
// This optimization is restricted to Add/Sum nodes. Mul nodes would also work
|
||||
// using this code, however the common case here is multiplying a NxCxHxW
|
||||
// matrix by a NxCx1x1 vector. The implementation of Mul does not currently
|
||||
// vectorize well for the case of broadcasting a NCHWc sized channel block.
|
||||
// This case should be handled by a ScaleShift kernel that can broadcast a
|
||||
// multiply/add vector (BatchNormalization can also be reimplemented with this).
|
||||
for (size_t n = 0; n < input_defs_count; n++) {
|
||||
std::string reshape_input_def_name = graph_.GenerateNodeArgName("reshape");
|
||||
auto* reshape_input_arg = &graph_.GetOrCreateNodeArg(reshape_input_def_name, nullptr);
|
||||
InsertReshape(nchwc_inputs[n]->nchwc_arg_, reshape_input_arg, channels, true);
|
||||
|
||||
input_defs[n] = reshape_input_arg;
|
||||
nchwc_inputs[n]->remaining_original_uses_--;
|
||||
}
|
||||
|
||||
std::string output_reshaped_def_name = graph_.GenerateNodeArgName("reshape");
|
||||
auto* output_reshaped_arg = &graph_.GetOrCreateNodeArg(output_reshaped_def_name, nullptr);
|
||||
Node& nchwc_node = InsertReshape(output_reshaped_arg, output_defs[0], channels, false);
|
||||
|
||||
NchwcArgument::Shape output_shape(output_defs[0]);
|
||||
|
||||
CreateNchwcArgument(node, nchwc_node, channels, output_shape);
|
||||
output_defs[0] = output_reshaped_arg;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void NchwcTransformerImpl::TransformConcat(Node& node) {
|
||||
|
|
|
|||
|
|
@ -165,7 +165,7 @@ struct NchwcTestHelper {
|
|||
|
||||
void NchwcOptimizerTester(const std::function<void(NchwcTestHelper& helper)>& build_test_case,
|
||||
const std::function<void(InferenceSessionWrapper& session)>& check_nchwc_graph,
|
||||
int opset_version = 12) {
|
||||
int opset_version = 13) {
|
||||
// Ignore the test if NCHWc is not supported by the platform.
|
||||
if (MlasNchwcGetBlockSize() <= 1) {
|
||||
return;
|
||||
|
|
@ -673,6 +673,40 @@ TEST(NchwcOptimizerTests, ConvBinary) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(NchwcOptimizerTests, ConvBinaryBroadcast) {
|
||||
auto test_case = [&](const std::string& op_type) {
|
||||
auto build_test_case = [&](NchwcTestHelper& helper) {
|
||||
auto* input_arg = helper.MakeInput<float>({1, 32, 25, 21});
|
||||
auto* conv_output_arg = helper.MakeIntermediate();
|
||||
auto* pool_output_arg = helper.MakeIntermediate();
|
||||
auto* output_arg = helper.MakeOutput();
|
||||
|
||||
helper.AddConvNode(input_arg, conv_output_arg, {32, 32, 3, 3});
|
||||
helper.AddNode("GlobalAveragePool", {input_arg}, {pool_output_arg});
|
||||
helper.AddNode(op_type, {conv_output_arg, pool_output_arg}, {output_arg});
|
||||
};
|
||||
|
||||
auto check_nchwc_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 1);
|
||||
EXPECT_EQ(op_to_count["com.microsoft.nchwc.GlobalAveragePool"], 1);
|
||||
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1);
|
||||
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1);
|
||||
EXPECT_EQ(op_to_count["Reshape"], 3);
|
||||
EXPECT_EQ(op_to_count[op_type], 1);
|
||||
};
|
||||
|
||||
NchwcOptimizerTester(build_test_case, check_nchwc_graph);
|
||||
};
|
||||
|
||||
// Verify that the optimizer keeps the inputs to the binary operator as NCHWc
|
||||
// and only reorders the output of the binary operator.
|
||||
std::vector<std::string> op_types{"Add", "Sum"};
|
||||
for (auto& op_type : op_types) {
|
||||
test_case(op_type);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NchwcOptimizerTests, ConvConcat) {
|
||||
auto test_case = [&](int axis, int channel_count, int reorder_output_count) {
|
||||
auto build_test_case = [&](NchwcTestHelper& helper) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue