NCHWc ReorderOutput->Transpose(NHWC) fusion (#3035)

Add support to fuse ReorderOutput+Transpose(NHWC). Converting from NCHWc to NHWC tensors is a trivial copy of data and avoids the cost of a transpose node.
This commit is contained in:
Tracy Sharpe 2020-02-18 10:23:48 -08:00 committed by GitHub
parent 71ca43b345
commit ecdcd682bb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 534 additions and 195 deletions

View file

@ -16,7 +16,7 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
float,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
ReorderInput<float>);
ReorderInput);
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
ReorderOutput,
@ -24,7 +24,7 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
float,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
ReorderOutput<float>);
ReorderOutput);
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
Conv,
@ -67,27 +67,41 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
NchwcAveragePool);
template <typename T>
Status ReorderInput<T>::Compute(OpKernelContext* context) const {
Status ReorderInput::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
const auto& X_shape = X->Shape();
ORT_ENFORCE(X_shape.NumDimensions() == 4);
ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
auto* Y = context->Output(0, X_shape);
MlasReorderInput(X_shape.GetDims().data(), X->template Data<T>(), Y->template MutableData<T>());
MlasReorderInput(X_shape.GetDims().data(), X->template Data<float>(), Y->template MutableData<float>());
return Status::OK();
}
template <typename T>
Status ReorderOutput<T>::Compute(OpKernelContext* context) const {
Status ReorderOutput::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
const auto& X_shape = X->Shape();
ORT_ENFORCE(X_shape.NumDimensions() == 4);
std::vector<int64_t> Y_shape(X_shape.GetDims());
ORT_ENFORCE(channels_ <= Y_shape[1]);
Y_shape[1] = channels_;
const auto X_rank = X_shape.NumDimensions();
ORT_ENFORCE(X_rank == 4);
ORT_ENFORCE(channels_ <= X_shape[1]);
// Build the output shape in NCHW or NHWC order.
std::vector<int64_t> Y_shape(X_rank);
Y_shape[0] = X_shape[0];
Y_shape[channels_last_ ? X_rank - 1 : 1] = channels_;
auto* Y_spatial_dims = Y_shape.data() + (channels_last_ ? 1 : 2);
for (size_t i = 0; i < X_rank - 2; i++) {
Y_spatial_dims[i] = X_shape[2 + i];
}
auto* Y = context->Output(0, Y_shape);
MlasReorderOutput(Y_shape.data(), X->template Data<T>(), Y->template MutableData<T>());
const auto* x_data = X->template Data<float>();
auto* y_data = Y->template MutableData<float>();
if (channels_last_) {
MlasReorderOutputNhwc(Y_shape.data(), x_data, y_data);
} else {
MlasReorderOutputNchw(Y_shape.data(), x_data, y_data);
}
return Status::OK();
}

View file

@ -12,7 +12,6 @@
namespace onnxruntime {
namespace contrib {
template <typename T>
class ReorderInput : public OpKernel {
public:
ReorderInput(const OpKernelInfo& info) : OpKernel(info) {
@ -21,18 +20,19 @@ class ReorderInput : public OpKernel {
Status Compute(OpKernelContext* context) const override;
};
template <typename T>
class ReorderOutput : public OpKernel {
public:
ReorderOutput(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(info.GetAttr<int64_t>("channels", &channels_).IsOK());
ORT_ENFORCE(channels_ > 0, "invalid channel count");
ORT_ENFORCE(info.GetAttr<int64_t>("channels_last", &channels_last_).IsOK());
}
Status Compute(OpKernelContext* context) const override;
private:
int64_t channels_;
int64_t channels_last_;
};
class NchwcConv : public OpKernel {

View file

@ -5,6 +5,7 @@
#include "core/graph/constants.h"
#include "core/graph/contrib_ops/attn_lstm_schema_defs.h"
#include "core/graph/contrib_ops/contrib_defs.h"
#include "core/graph/contrib_ops/nchwc_schema_defs.h"
#include "core/graph/contrib_ops/range_schema_defs.h"
#include "core/graph/op.h"
#include "onnx/defs/schema.h"
@ -18,7 +19,6 @@ void convPoolShapeInference(
bool use_dilation, bool require_kernel_shape,
int input1Idx,
int input2Idx);
void globalPoolTypeShapeInference(ONNX_NAMESPACE::InferenceContext& ctx);
void matmulShapeInference(
ONNX_NAMESPACE::InferenceContext& ctx,
int input1Idx,
@ -166,37 +166,6 @@ using ONNX_NAMESPACE::AttributeProto;
using ONNX_NAMESPACE::OpSchema;
using ONNX_NAMESPACE::OPTIONAL;
void NchwcPoolOpSchemaGenerator(OpSchema& schema) {
schema.SetDomain(kMSNchwcDomain);
schema.SinceVersion(1);
schema.SetDoc(R"DOC(For internal use.)DOC");
schema.Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET"));
schema.Attr("kernel_shape", "", AttributeProto::INTS);
schema.Attr("dilations", "", AttributeProto::INTS, OPTIONAL);
schema.Attr("strides", "", AttributeProto::INTS, OPTIONAL);
schema.Attr("pads", "", AttributeProto::INTS, OPTIONAL);
schema.Attr("ceil_mode", "", AttributeProto::INT, static_cast<int64_t>(0));
schema.Input(0, "X", "", "T");
schema.Output(0, "Y", "", "T");
schema.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors");
schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
ONNX_NAMESPACE::convPoolShapeInference(ctx, true, true, 0, 1);
});
}
void NchwcGlobalPoolOpSchemaGenerator(OpSchema& schema) {
schema.SetDomain(kMSNchwcDomain);
schema.SinceVersion(1);
schema.SetDoc(R"DOC(For internal use.)DOC");
schema.Input(0, "X", "", "T");
schema.Output(0, "Y", "", "T");
schema.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors");
schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::globalPoolTypeShapeInference(ctx);
});
}
void ValidateTypeAndShapeForScaleAndZP(ONNX_NAMESPACE::InferenceContext& ctx, int index, ::google::protobuf::int32 expectedType, bool isScalar, int expectedTensorSize = 0) {
if (ctx.getNumInputs() > static_cast<size_t>(index)) {
auto data_type = ctx.getInputType(index);
@ -320,132 +289,6 @@ const char* contrib_ops_auto_pad_doc =
"In case of odd number add the extra padding at the end for SAME_UPPER and at the "
"beginning for SAME_LOWER. VALID mean no padding.";
void RegisterNchwcSchemas() {
ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderInput)
.SetDomain(kMSNchwcDomain)
.SinceVersion(1)
.SetDoc(R"DOC(For internal use.)DOC")
.Input(0, "X", "", "T")
.Output(0, "Y", "", "T")
.TypeConstraint(
"T",
{"tensor(float)", "tensor(int8)", "tensor(uint8)"},
"Constrain input and output types to float/quantized tensors")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderOutput)
.SetDomain(kMSNchwcDomain)
.SinceVersion(1)
.SetDoc(R"DOC(For internal use.)DOC")
.Attr(
"channels",
"",
AttributeProto::INT,
static_cast<int64_t>(0))
.Input(0, "X", "", "T")
.Output(0, "Y", "", "T")
.TypeConstraint(
"T",
{"tensor(float)", "tensor(int8)", "tensor(uint8)"},
"Constrain input and output types to float/quantized tensors")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
propagateShapeFromInputToOutput(ctx, 0, 0);
// Update the output shape with the actual number of channels.
auto channels = getAttribute(ctx, "channels", 0);
if (channels <= 0) {
fail_shape_inference("invalid channel count");
}
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
if (output_shape->dim_size() < 2) {
fail_shape_inference("tensor rank too small");
}
auto* channels_dim = output_shape->mutable_dim(1);
channels_dim->clear_dim_param();
channels_dim->set_dim_value(channels);
});
ONNX_CONTRIB_OPERATOR_SCHEMA(Conv)
.SetDomain(kMSNchwcDomain)
.SinceVersion(1)
.SetDoc(R"DOC(For internal use.)DOC")
.Attr(
"auto_pad",
"",
AttributeProto::STRING,
std::string("NOTSET"))
.Attr(
"kernel_shape",
"",
AttributeProto::INTS,
OPTIONAL)
.Attr(
"dilations",
"",
AttributeProto::INTS,
OPTIONAL)
.Attr(
"strides",
"",
AttributeProto::INTS,
OPTIONAL)
.Attr(
"pads",
"",
AttributeProto::INTS, OPTIONAL)
.Attr(
"group",
"",
AttributeProto::INT,
static_cast<int64_t>(1))
.Attr(
"activation",
"",
AttributeProto::STRING,
OPTIONAL)
.Attr(
"activation_params",
"",
AttributeProto::FLOATS,
OPTIONAL)
.Input(0, "X", "", "T")
.Input(1, "W", "", "T")
.Input(2, "B", "", "T", OpSchema::Optional)
.Input(3, "Sum", "", "T", OpSchema::Optional)
.Output(0, "Y", "", "T")
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
ONNX_NAMESPACE::convPoolShapeInference(ctx, true, false, 0, 1);
});
ONNX_CONTRIB_OPERATOR_SCHEMA(MaxPool)
.FillUsing(NchwcPoolOpSchemaGenerator)
.Attr(
"storage_order",
"",
AttributeProto::INT,
static_cast<int64_t>(0));
ONNX_CONTRIB_OPERATOR_SCHEMA(AveragePool)
.FillUsing(NchwcPoolOpSchemaGenerator)
.Attr(
"count_include_pad",
"",
AttributeProto::INT,
static_cast<int64_t>(0));
ONNX_CONTRIB_OPERATOR_SCHEMA(GlobalMaxPool)
.FillUsing(NchwcGlobalPoolOpSchemaGenerator);
ONNX_CONTRIB_OPERATOR_SCHEMA(GlobalAveragePool)
.FillUsing(NchwcGlobalPoolOpSchemaGenerator);
}
void RegisterBertSchemas() {
ONNX_CONTRIB_OPERATOR_SCHEMA(Attention)
.SetDomain(kMSDomain)
@ -1383,8 +1226,8 @@ activation and leaky_relu_alpha.)DOC")
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema);
static const char* QuantizeLinear_ver1_doc = R"DOC(
The linear quantization operator. It consumes a full precision data, a scale, a zero point and computes the quantized data.
The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale), it computes the nearest integer value to arg (in floating-point format),
The linear quantization operator. It consumes a full precision data, a scale, a zero point and computes the quantized data.
The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale), it computes the nearest integer value to arg (in floating-point format),
rounding halfway cases away from zero. Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').)DOC";
ONNX_CONTRIB_OPERATOR_SCHEMA(QuantizeLinear)
@ -1440,8 +1283,8 @@ The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale),
});
static const char* DequantizeLinear_ver1_doc = R"DOC(
The linear dequantization operator. It consumes a quantized data, a scale, a zero point and computes the full precision data.
The dequantization formula is y = (x - x_zero_point) * x_scale.
The linear dequantization operator. It consumes a quantized data, a scale, a zero point and computes the full precision data.
The dequantization formula is y = (x - x_zero_point) * x_scale.
Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').)DOC";
ONNX_CONTRIB_OPERATOR_SCHEMA(DequantizeLinear)
@ -1682,7 +1525,7 @@ Computes the mean of the low-precision input tensor's element along the provided
The resulting tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0,
then the resulting tensor have the reduced dimension pruned. The above behavior is similar to numpy,
with the exception that numpy default keepdims to False instead of True.
Input and Output scales and zero points are used to requantize the output in a new range.
Input and Output scales and zero points are used to requantize the output in a new range.
This helps to improve accuracy as after ReduceMean operation the range of the output is expected to decrease.
```
@ -1861,7 +1704,7 @@ C (int32) = (A - A_zero_point) * (B - B_zero_point)
```
pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + kernel_spatial_shape[i] - input_spatial_shape[i]
```
The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero).
Input and output scales and zero points are used to convert the output to a new quantization range.
@ -2448,7 +2291,7 @@ Example 4:
R"DOC(Gaussian Error Linear Unit.
A high-performing neural network activation function.The GELU nonlinearity is
the expected transformation of a stochastic regularizer which randomly applies
the identity or zero map to a neuron's input. The GELU nonlinearity weights
the identity or zero map to a neuron's input. The GELU nonlinearity weights
inputs by their magnitude, rather than gates inputs by their sign as in ReLUs.)DOC";
ONNX_CONTRIB_OPERATOR_SCHEMA(Gelu)

View file

@ -0,0 +1,153 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/tensorprotoutils.h"
#include "core/graph/constants.h"
#include "core/graph/contrib_ops/contrib_defs.h"
#include "core/graph/contrib_ops/nchwc_schema_defs.h"
namespace ONNX_NAMESPACE {
void convPoolShapeInference(
ONNX_NAMESPACE::InferenceContext& ctx,
bool use_dilation, bool require_kernel_shape,
int input1Idx,
int input2Idx);
void globalPoolTypeShapeInference(ONNX_NAMESPACE::InferenceContext& ctx);
} // namespace ONNX_NAMESPACE
namespace onnxruntime {
namespace contrib {
using ONNX_NAMESPACE::AttributeProto;
using ONNX_NAMESPACE::InferenceContext;
using ONNX_NAMESPACE::OpSchema;
using ONNX_NAMESPACE::OPTIONAL;
void NchwcPoolOpSchemaGenerator(OpSchema& schema) {
schema.SetDomain(kMSNchwcDomain);
schema.SinceVersion(1);
schema.SetDoc(R"DOC(For internal use.)DOC");
schema.Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET"));
schema.Attr("kernel_shape", "", AttributeProto::INTS);
schema.Attr("dilations", "", AttributeProto::INTS, OPTIONAL);
schema.Attr("strides", "", AttributeProto::INTS, OPTIONAL);
schema.Attr("pads", "", AttributeProto::INTS, OPTIONAL);
schema.Attr("ceil_mode", "", AttributeProto::INT, static_cast<int64_t>(0));
schema.Input(0, "X", "", "T");
schema.Output(0, "Y", "", "T");
schema.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors");
schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
ONNX_NAMESPACE::convPoolShapeInference(ctx, true, true, 0, 1);
});
}
void NchwcGlobalPoolOpSchemaGenerator(OpSchema& schema) {
schema.SetDomain(kMSNchwcDomain);
schema.SinceVersion(1);
schema.SetDoc(R"DOC(For internal use.)DOC");
schema.Input(0, "X", "", "T");
schema.Output(0, "Y", "", "T");
schema.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors");
schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::globalPoolTypeShapeInference(ctx);
});
}
void RegisterNchwcSchemas() {
ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderInput)
.SetDomain(kMSNchwcDomain)
.SinceVersion(1)
.SetDoc(R"DOC(For internal use.)DOC")
.Input(0, "X", "", "T")
.Output(0, "Y", "", "T")
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderOutput)
.SetDomain(kMSNchwcDomain)
.SinceVersion(1)
.SetDoc(R"DOC(For internal use.)DOC")
.Attr("channels", "", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("channels_last", "", AttributeProto::INT, static_cast<int64_t>(0))
.Input(0, "X", "", "T")
.Output(0, "Y", "", "T")
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
auto input_shape = ctx.getInputType(0)->tensor_type().shape();
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
auto input_rank = input_shape.dim_size();
if (input_rank < 2) {
fail_shape_inference("tensor rank too small");
}
// Update the output shape with the actual number of channels.
auto channels = getAttribute(ctx, "channels", 0);
if (channels <= 0) {
fail_shape_inference("invalid channel count");
}
// Copy batch dimension.
*output_shape->add_dim() = input_shape.dim(0);
auto channels_last = getAttribute(ctx, "channels_last", 0);
if (channels_last == 0) {
output_shape->add_dim()->set_dim_value(channels);
}
// Copy spatial dimensions.
for (int i = 0; i < input_rank - 2; i++) {
*output_shape->add_dim() = input_shape.dim(2 + i);
}
if (channels_last != 0) {
output_shape->add_dim()->set_dim_value(channels);
}
});
ONNX_CONTRIB_OPERATOR_SCHEMA(Conv)
.SetDomain(kMSNchwcDomain)
.SinceVersion(1)
.SetDoc(R"DOC(For internal use.)DOC")
.Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET"))
.Attr("kernel_shape", "", AttributeProto::INTS, OPTIONAL)
.Attr("dilations", "", AttributeProto::INTS, OPTIONAL)
.Attr("strides", "", AttributeProto::INTS, OPTIONAL)
.Attr("pads", "", AttributeProto::INTS, OPTIONAL)
.Attr("group", "", AttributeProto::INT, static_cast<int64_t>(1))
.Attr("activation", "", AttributeProto::STRING, OPTIONAL)
.Attr("activation_params", "", AttributeProto::FLOATS, OPTIONAL)
.Input(0, "X", "", "T")
.Input(1, "W", "", "T")
.Input(2, "B", "", "T", OpSchema::Optional)
.Input(3, "Sum", "", "T", OpSchema::Optional)
.Output(0, "Y", "", "T")
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
ONNX_NAMESPACE::convPoolShapeInference(ctx, true, false, 0, 1);
});
ONNX_CONTRIB_OPERATOR_SCHEMA(MaxPool)
.FillUsing(NchwcPoolOpSchemaGenerator)
.Attr("storage_order", "", AttributeProto::INT, static_cast<int64_t>(0));
ONNX_CONTRIB_OPERATOR_SCHEMA(AveragePool)
.FillUsing(NchwcPoolOpSchemaGenerator)
.Attr("count_include_pad", "", AttributeProto::INT, static_cast<int64_t>(0));
ONNX_CONTRIB_OPERATOR_SCHEMA(GlobalMaxPool)
.FillUsing(NchwcGlobalPoolOpSchemaGenerator);
ONNX_CONTRIB_OPERATOR_SCHEMA(GlobalAveragePool)
.FillUsing(NchwcGlobalPoolOpSchemaGenerator);
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
namespace contrib {
void RegisterNchwcSchemas();
} // namespace contrib
} // namespace onnxruntime

View file

@ -333,7 +333,15 @@ MlasReorderInput(
void
MLASCALL
MlasReorderOutput(
MlasReorderOutputNchw(
const int64_t* OutputShape,
const float* S,
float* D
);
void
MLASCALL
MlasReorderOutputNhwc(
const int64_t* OutputShape,
const float* S,
float* D

View file

@ -269,7 +269,7 @@ Return Value:
void
MLASCALL
MlasReorderOutput(
MlasReorderOutputNchw(
const int64_t* OutputShape,
const float* S,
float* D
@ -365,6 +365,81 @@ Return Value:
}
}
void
MLASCALL
MlasReorderOutputNhwc(
const int64_t* OutputShape,
const float* S,
float* D
)
/*++
Routine Description:
This routine reorders an output buffer from NCHWc to NHWC format.
Arguments:
OutputShape - Supplies the shape of the output tensor.
S - Supplies the address of the source tensor.
D - Supplies the address of the destination tensor.
Return Value:
None.
--*/
{
const size_t BlockSize = MlasNchwcGetBlockSize();
const size_t BatchCount = size_t(OutputShape[0]);
const size_t OutputChannels = size_t(OutputShape[3]);
const size_t OutputSize = size_t(OutputShape[1]) * size_t(OutputShape[2]);
const size_t AlignedOutputChannels = (OutputChannels + BlockSize - 1) & ~(BlockSize - 1);
//
// Copy NCHWc blocks from the source buffer to the destination buffer.
//
for (size_t batch = 0; batch < BatchCount; batch++) {
const float* s = S;
size_t OutputSizeRemaining = OutputSize;
for (; OutputSizeRemaining > 0; OutputSizeRemaining--) {
const float* ss = s;
for (size_t o = OutputChannels; o > 0;) {
const size_t OutputChannelsThisIteration = (std::min)(o, BlockSize);
const size_t AlignedOutputChannelsThisIteration = OutputChannelsThisIteration & (~3);
o -= OutputChannelsThisIteration;
size_t bc = 0;
for (; bc < AlignedOutputChannelsThisIteration; bc += 4) {
MlasStoreFloat32x4(&D[bc], MlasLoadFloat32x4(&ss[bc]));
}
for (; bc < OutputChannelsThisIteration; bc += 1) {
D[bc] = ss[bc];
}
ss += BlockSize * OutputSize;
D += OutputChannelsThisIteration;
}
s += BlockSize;
}
S += AlignedOutputChannels * OutputSize;
}
}
void
MLASCALL
MlasReorderFilterOIHWBiBo(

View file

@ -4,7 +4,6 @@
#include "core/optimizer/initializer.h"
#include "core/optimizer/gemm_activation_fusion.h"
#include "core/graph/graph_utils.h"
#include <deque>
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
@ -23,7 +22,6 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
std::deque<onnxruntime::NodeIndex> removed_nodes;
for (auto index : order) {
auto* node_ptr = graph.GetNode(index);
if (!node_ptr)

View file

@ -119,6 +119,7 @@ class NchwcTransformerImpl {
void TransformConcat(Node& node);
void TransformActivation(Node& node);
void TransformBatchNormalization(Node& node);
void TransformTranspose(Node& node);
Graph& graph_;
@ -199,7 +200,7 @@ void NchwcTransformerImpl::InsertReorderInput(Node& node) {
{input_nchwc_arg},
nullptr,
kMSNchwcDomain);
reorder_input_node.SetExecutionProviderType(node.GetExecutionProviderType());
reorder_input_node.SetExecutionProviderType(kCpuExecutionProvider);
input_defs[0] = input_nchwc_arg;
} else {
input_defs[0] = it->second;
@ -426,7 +427,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
output_defs,
&node.GetAttributes(),
kMSNchwcDomain);
nchwc_node.SetExecutionProviderType(node.GetExecutionProviderType());
nchwc_node.SetExecutionProviderType(kCpuExecutionProvider);
nchwc_node.MutableInputDefs()[1] = nchwc_conv_W_arg;
@ -485,7 +486,7 @@ void NchwcTransformerImpl::TransformPool(Node& node) {
output_defs,
&node.GetAttributes(),
kMSNchwcDomain);
nchwc_node.SetExecutionProviderType(node.GetExecutionProviderType());
nchwc_node.SetExecutionProviderType(kCpuExecutionProvider);
NchwcArgument::Shape output_shape(output_defs[0]);
@ -774,7 +775,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) {
output_defs,
nullptr,
kMSNchwcDomain);
nchwc_node.SetExecutionProviderType(node.GetExecutionProviderType());
nchwc_node.SetExecutionProviderType(kCpuExecutionProvider);
nchwc_node.AddAttribute("group", nchwc_channels);
nchwc_input->remaining_original_uses_--;
@ -783,6 +784,47 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) {
removed_nodes_.push_front(node.Index());
}
void NchwcTransformerImpl::TransformTranspose(Node& node) {
auto& input_defs = node.MutableInputDefs();
auto& output_defs = node.MutableOutputDefs();
const ONNX_NAMESPACE::AttributeProto* perm_attr = graph_utils::GetNodeAttribute(node, "perm");
if (perm_attr == nullptr || perm_attr->ints_size() != 4) {
return;
}
// Test if this transposes from NCHW to NHWC layout order.
const int64_t* perm_data = perm_attr->ints().data();
if (perm_data[0] != 0 || perm_data[1] != 2 || perm_data[2] != 3 || perm_data[3] != 1) {
return;
}
// Don't transform the node if the input is not already in NCHWc format.
auto it = nchwc_args_.find(input_defs[0]);
if (it == nchwc_args_.end()) {
return;
}
auto* nchwc_input = it->second.get();
// Create the replacement node.
Node& reorder_output_node = graph_.AddNode(graph_.GenerateNodeName("ReorderOutput"),
"ReorderOutput",
"ReorderOutput",
{nchwc_input->nchwc_arg_},
output_defs,
nullptr,
kMSNchwcDomain);
reorder_output_node.SetExecutionProviderType(kCpuExecutionProvider);
reorder_output_node.AddAttribute("channels", nchwc_input->channels_);
reorder_output_node.AddAttribute("channels_last", static_cast<int64_t>(1));
nchwc_input->remaining_original_uses_--;
graph_utils::RemoveNodeOutputEdges(graph_, node);
removed_nodes_.push_front(node.Index());
}
void NchwcTransformerImpl::Transform(Node& node) {
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "FusedConv", {1}, kMSDomain)) {
@ -807,6 +849,8 @@ void NchwcTransformerImpl::Transform(Node& node) {
TransformActivation(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9})) {
TransformBatchNormalization(node);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1})) {
TransformTranspose(node);
}
}

View file

@ -183,7 +183,9 @@ public:
void
ExecuteShort(
void
) = 0;
)
{
}
//
// Contains tests that can run slowly to more exhaustively test that
@ -194,7 +196,9 @@ public:
void
ExecuteLong(
void
) = 0;
)
{
}
};
template <typename T>
@ -1101,7 +1105,7 @@ protected:
// Reorder the output buffer.
//
MlasReorderOutput(OutputShape, NchwcOutput, Output);
MlasReorderOutputNchw(OutputShape, NchwcOutput, Output);
}
const size_t BlockSize = MlasNchwcGetBlockSize();
@ -1510,7 +1514,7 @@ protected:
NchwcOutput,
nullptr);
MlasReorderOutput(OutputShape, NchwcOutput, Output);
MlasReorderOutputNchw(OutputShape, NchwcOutput, Output);
}
MatrixGuardBuffer<float> BufferNchwcInput;
@ -1963,12 +1967,107 @@ public:
}
}
}
};
class MlasReorderOutputTest : public MlasTestBase
{
private:
const size_t BlockSize = MlasNchwcGetBlockSize();
MatrixGuardBuffer<float> BufferInput;
MatrixGuardBuffer<float> BufferOutput;
MatrixGuardBuffer<float> BufferOutput2;
MatrixGuardBuffer<float> BufferOutputReference;
void
ExecuteLong(
Test(
size_t BatchCount,
size_t Channels,
size_t Height,
size_t Width
)
{
size_t NchwcChannels = (Channels + BlockSize - 1) & ~(BlockSize - 1);
size_t InputBufferElements = BatchCount * NchwcChannels * Height * Width;
size_t OutputBufferElements = BatchCount * Channels * Height * Width;
const float* Input = BufferInput.GetBuffer(InputBufferElements);
float* Output = BufferOutput.GetBuffer(OutputBufferElements);
float* OutputReference = BufferOutputReference.GetBuffer(OutputBufferElements);
int64_t NchwOutputShape[] = { int64_t(BatchCount), int64_t(Channels), int64_t(Height), int64_t(Width) };
std::fill_n(Output, OutputBufferElements, -0.5f);
std::fill_n(OutputReference, OutputBufferElements, -0.5f);
MlasReorderOutputNchw(NchwOutputShape, Input, Output);
ReferenceReorderOutput(BatchCount, Channels, Height, Width, Input, OutputReference, false);
if (memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)) != 0) {
printf("mismatch ReorderOutputNchw: batch=%zd channels=%zd height=%zd width=%zd\n",
BatchCount, Channels, Height, Width);
}
int64_t NhwcOutputShape[] = { int64_t(BatchCount), int64_t(Height), int64_t(Width), int64_t(Channels) };
std::fill_n(Output, OutputBufferElements, -0.5f);
std::fill_n(OutputReference, OutputBufferElements, -0.5f);
MlasReorderOutputNhwc(NhwcOutputShape, Input, Output);
ReferenceReorderOutput(BatchCount, Channels, Height, Width, Input, OutputReference, true);
if (memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)) != 0) {
printf("mismatch ReorderOutputNhwc: batch=%zd channels=%zd height=%zd width=%zd\n",
BatchCount, Channels, Height, Width);
}
}
void
ReferenceReorderOutput(
size_t BatchCount,
size_t Channels,
size_t Height,
size_t Width,
const float* Input,
float* Output,
bool NhwcFormat
)
{
size_t NchwcChannels = (Channels + (BlockSize - 1)) & ~(BlockSize - 1);
size_t SpatialSize = Height * Width;
size_t ChannelStride = NhwcFormat ? 1 : SpatialSize;
size_t SpatialStride = NhwcFormat ? Channels : 1;
for (size_t n = 0; n < BatchCount; n++) {
for (size_t c = 0; c < Channels; c++) {
const float* input = Input + ((c & ~(BlockSize - 1)) * SpatialSize) + (c & (BlockSize - 1));
float* output = Output + (c * ChannelStride);
for (size_t hw = 0; hw < SpatialSize; hw++) {
output[hw * SpatialStride] = input[hw * BlockSize];
}
}
Input += NchwcChannels * SpatialSize;
Output += Channels * SpatialSize;
}
}
public:
void
ExecuteShort(
void
) override
{
for (size_t c = 1; c < 48; c++) {
Test(1, c, 112, 112);
Test(4, c, 15, 21);
Test(16, c, 11, 11);
}
}
};
@ -2010,9 +2109,6 @@ main(
printf("Pool3D tests.\n");
onnxruntime::make_unique<MlasPool3DTest>()->ExecuteShort();
printf("Activation tests.\n");
onnxruntime::make_unique<MlasActivationTest>()->ExecuteShort();
printf("Done.\n");
#if !defined(MLAS_NO_ONNXRUNTIME_THREADPOOL)
if(threadpool != nullptr) threadpool = new onnxruntime::concurrency::ThreadPool("test", 2);
@ -2023,5 +2119,14 @@ main(
#if !defined(MLAS_NO_ONNXRUNTIME_THREADPOOL)
delete threadpool;
#endif
printf("Activation tests.\n");
onnxruntime::make_unique<MlasActivationTest>()->ExecuteShort();
printf("ReorderOutput tests.\n");
if (MlasNchwcGetBlockSize() > 1) {
onnxruntime::make_unique<MlasReorderOutputTest>()->ExecuteShort();
}
return 0;
}

View file

@ -138,6 +138,24 @@ struct NchwcTestHelper {
return node;
}
Node& AddTransposeNode(NodeArg* input_arg, NodeArg* output_arg, const std::vector<int64_t>& perm) {
auto& node = AddNode("Transpose", {input_arg}, {output_arg});
node.AddAttribute("perm", perm);
return node;
}
Node& AddTransposeToNchwNode(NodeArg* input_arg, NodeArg* output_arg) {
return AddTransposeNode(input_arg, output_arg, {0, 3, 1, 2});
}
Node& AddTransposeToNhwcNode(NodeArg* input_arg, NodeArg* output_arg) {
return AddTransposeNode(input_arg, output_arg, {0, 2, 3, 1});
}
Node& AddTransposeToCnhwNode(NodeArg* input_arg, NodeArg* output_arg) {
return AddTransposeNode(input_arg, output_arg, {1, 0, 2, 3});
}
std::vector<float> FillRandomData(size_t count) {
constexpr int min_fill_value = -23;
constexpr int max_fill_value = 23;
@ -1041,6 +1059,75 @@ TEST(NchwcOptimizerTests, BatchNormalization) {
test_case(true);
}
TEST(NchwcOptimizerTests, ConvReorderOutputNhwc) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput({1, 64, 28, 32});
auto* conv_output_arg = helper.MakeIntermediate();
auto* nhwc_output_arg = helper.MakeOutput();
helper.AddConvNode(input_arg, conv_output_arg, {130, 64, 1, 1});
helper.AddTransposeToNhwcNode(conv_output_arg, nhwc_output_arg);
};
auto check_nchwc_graph = [&](NchwcInferenceSession& session) {
auto op_to_count = session.CountOpsInGraph();
EXPECT_EQ(op_to_count["nchwc.Conv"], 1);
EXPECT_EQ(op_to_count["nchwc.ReorderInput"], 1);
EXPECT_EQ(op_to_count["nchwc.ReorderOutput"], 1);
EXPECT_EQ(op_to_count["Transpose"], 0);
};
// Verify that a NHWC transpose is fused into ReorderOutput.
NchwcOptimizerTester(build_test_case, check_nchwc_graph);
}
TEST(NchwcOptimizerTests, ConvReorderOutputBoth) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput({5, 64, 33, 37});
auto* conv_output_arg = helper.MakeIntermediate();
auto* nchw_output_arg = helper.MakeOutput();
auto* nhwc_output_arg = helper.MakeOutput();
helper.AddConvNode(input_arg, conv_output_arg, {7, 64, 1, 1});
helper.AddTransposeToNhwcNode(conv_output_arg, nhwc_output_arg);
helper.AddNode("Neg", {conv_output_arg}, {nchw_output_arg});
};
auto check_nchwc_graph = [&](NchwcInferenceSession& session) {
auto op_to_count = session.CountOpsInGraph();
EXPECT_EQ(op_to_count["nchwc.Conv"], 1);
EXPECT_EQ(op_to_count["nchwc.ReorderInput"], 1);
EXPECT_EQ(op_to_count["nchwc.ReorderOutput"], 2);
EXPECT_EQ(op_to_count["Transpose"], 0);
};
// Verify that if an output argument is used as both NCHW and NHWC, then
// two ReorderOutput nodes are inserted.
NchwcOptimizerTester(build_test_case, check_nchwc_graph);
}
TEST(NchwcOptimizerTests, ConvReorderOutputCnhw) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput({1, 64, 28, 32});
auto* conv_output_arg = helper.MakeIntermediate();
auto* nhwc_output_arg = helper.MakeOutput();
helper.AddConvNode(input_arg, conv_output_arg, {130, 64, 1, 1});
helper.AddTransposeToCnhwNode(conv_output_arg, nhwc_output_arg);
};
auto check_nchwc_graph = [&](NchwcInferenceSession& session) {
auto op_to_count = session.CountOpsInGraph();
EXPECT_EQ(op_to_count["nchwc.Conv"], 1);
EXPECT_EQ(op_to_count["nchwc.ReorderInput"], 1);
EXPECT_EQ(op_to_count["nchwc.ReorderOutput"], 1);
EXPECT_EQ(op_to_count["Transpose"], 1);
};
// Verify that a CNHW transpose is not fused into ReorderOutput.
NchwcOptimizerTester(build_test_case, check_nchwc_graph);
}
#endif
} // namespace test