mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-15 01:23:42 +00:00
NCHWc ReorderOutput->Transpose(NHWC) fusion (#3035)
Add support to fuse ReorderOutput+Transpose(NHWC). Converting from NCHWc to NHWC tensors is a trivial copy of data and avoids the cost of a transpose node.
This commit is contained in:
parent
71ca43b345
commit
ecdcd682bb
11 changed files with 534 additions and 195 deletions
|
|
@ -16,7 +16,7 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
|||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
ReorderInput<float>);
|
||||
ReorderInput);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
ReorderOutput,
|
||||
|
|
@ -24,7 +24,7 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
|||
float,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
ReorderOutput<float>);
|
||||
ReorderOutput);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
||||
Conv,
|
||||
|
|
@ -67,27 +67,41 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
|
|||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
NchwcAveragePool);
|
||||
|
||||
template <typename T>
|
||||
Status ReorderInput<T>::Compute(OpKernelContext* context) const {
|
||||
Status ReorderInput::Compute(OpKernelContext* context) const {
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
const auto& X_shape = X->Shape();
|
||||
ORT_ENFORCE(X_shape.NumDimensions() == 4);
|
||||
ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
|
||||
auto* Y = context->Output(0, X_shape);
|
||||
MlasReorderInput(X_shape.GetDims().data(), X->template Data<T>(), Y->template MutableData<T>());
|
||||
MlasReorderInput(X_shape.GetDims().data(), X->template Data<float>(), Y->template MutableData<float>());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ReorderOutput<T>::Compute(OpKernelContext* context) const {
|
||||
Status ReorderOutput::Compute(OpKernelContext* context) const {
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
const auto& X_shape = X->Shape();
|
||||
ORT_ENFORCE(X_shape.NumDimensions() == 4);
|
||||
std::vector<int64_t> Y_shape(X_shape.GetDims());
|
||||
ORT_ENFORCE(channels_ <= Y_shape[1]);
|
||||
Y_shape[1] = channels_;
|
||||
const auto X_rank = X_shape.NumDimensions();
|
||||
ORT_ENFORCE(X_rank == 4);
|
||||
ORT_ENFORCE(channels_ <= X_shape[1]);
|
||||
|
||||
// Build the output shape in NCHW or NHWC order.
|
||||
std::vector<int64_t> Y_shape(X_rank);
|
||||
Y_shape[0] = X_shape[0];
|
||||
Y_shape[channels_last_ ? X_rank - 1 : 1] = channels_;
|
||||
auto* Y_spatial_dims = Y_shape.data() + (channels_last_ ? 1 : 2);
|
||||
for (size_t i = 0; i < X_rank - 2; i++) {
|
||||
Y_spatial_dims[i] = X_shape[2 + i];
|
||||
}
|
||||
auto* Y = context->Output(0, Y_shape);
|
||||
MlasReorderOutput(Y_shape.data(), X->template Data<T>(), Y->template MutableData<T>());
|
||||
|
||||
const auto* x_data = X->template Data<float>();
|
||||
auto* y_data = Y->template MutableData<float>();
|
||||
if (channels_last_) {
|
||||
MlasReorderOutputNhwc(Y_shape.data(), x_data, y_data);
|
||||
} else {
|
||||
MlasReorderOutputNchw(Y_shape.data(), x_data, y_data);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
template <typename T>
|
||||
class ReorderInput : public OpKernel {
|
||||
public:
|
||||
ReorderInput(const OpKernelInfo& info) : OpKernel(info) {
|
||||
|
|
@ -21,18 +20,19 @@ class ReorderInput : public OpKernel {
|
|||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ReorderOutput : public OpKernel {
|
||||
public:
|
||||
ReorderOutput(const OpKernelInfo& info) : OpKernel(info) {
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("channels", &channels_).IsOK());
|
||||
ORT_ENFORCE(channels_ > 0, "invalid channel count");
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("channels_last", &channels_last_).IsOK());
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
int64_t channels_;
|
||||
int64_t channels_last_;
|
||||
};
|
||||
|
||||
class NchwcConv : public OpKernel {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include "core/graph/constants.h"
|
||||
#include "core/graph/contrib_ops/attn_lstm_schema_defs.h"
|
||||
#include "core/graph/contrib_ops/contrib_defs.h"
|
||||
#include "core/graph/contrib_ops/nchwc_schema_defs.h"
|
||||
#include "core/graph/contrib_ops/range_schema_defs.h"
|
||||
#include "core/graph/op.h"
|
||||
#include "onnx/defs/schema.h"
|
||||
|
|
@ -18,7 +19,6 @@ void convPoolShapeInference(
|
|||
bool use_dilation, bool require_kernel_shape,
|
||||
int input1Idx,
|
||||
int input2Idx);
|
||||
void globalPoolTypeShapeInference(ONNX_NAMESPACE::InferenceContext& ctx);
|
||||
void matmulShapeInference(
|
||||
ONNX_NAMESPACE::InferenceContext& ctx,
|
||||
int input1Idx,
|
||||
|
|
@ -166,37 +166,6 @@ using ONNX_NAMESPACE::AttributeProto;
|
|||
using ONNX_NAMESPACE::OpSchema;
|
||||
using ONNX_NAMESPACE::OPTIONAL;
|
||||
|
||||
void NchwcPoolOpSchemaGenerator(OpSchema& schema) {
|
||||
schema.SetDomain(kMSNchwcDomain);
|
||||
schema.SinceVersion(1);
|
||||
schema.SetDoc(R"DOC(For internal use.)DOC");
|
||||
schema.Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET"));
|
||||
schema.Attr("kernel_shape", "", AttributeProto::INTS);
|
||||
schema.Attr("dilations", "", AttributeProto::INTS, OPTIONAL);
|
||||
schema.Attr("strides", "", AttributeProto::INTS, OPTIONAL);
|
||||
schema.Attr("pads", "", AttributeProto::INTS, OPTIONAL);
|
||||
schema.Attr("ceil_mode", "", AttributeProto::INT, static_cast<int64_t>(0));
|
||||
schema.Input(0, "X", "", "T");
|
||||
schema.Output(0, "Y", "", "T");
|
||||
schema.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors");
|
||||
schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
ONNX_NAMESPACE::convPoolShapeInference(ctx, true, true, 0, 1);
|
||||
});
|
||||
}
|
||||
|
||||
void NchwcGlobalPoolOpSchemaGenerator(OpSchema& schema) {
|
||||
schema.SetDomain(kMSNchwcDomain);
|
||||
schema.SinceVersion(1);
|
||||
schema.SetDoc(R"DOC(For internal use.)DOC");
|
||||
schema.Input(0, "X", "", "T");
|
||||
schema.Output(0, "Y", "", "T");
|
||||
schema.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors");
|
||||
schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
ONNX_NAMESPACE::globalPoolTypeShapeInference(ctx);
|
||||
});
|
||||
}
|
||||
|
||||
void ValidateTypeAndShapeForScaleAndZP(ONNX_NAMESPACE::InferenceContext& ctx, int index, ::google::protobuf::int32 expectedType, bool isScalar, int expectedTensorSize = 0) {
|
||||
if (ctx.getNumInputs() > static_cast<size_t>(index)) {
|
||||
auto data_type = ctx.getInputType(index);
|
||||
|
|
@ -320,132 +289,6 @@ const char* contrib_ops_auto_pad_doc =
|
|||
"In case of odd number add the extra padding at the end for SAME_UPPER and at the "
|
||||
"beginning for SAME_LOWER. VALID mean no padding.";
|
||||
|
||||
void RegisterNchwcSchemas() {
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderInput)
|
||||
.SetDomain(kMSNchwcDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(R"DOC(For internal use.)DOC")
|
||||
.Input(0, "X", "", "T")
|
||||
.Output(0, "Y", "", "T")
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(float)", "tensor(int8)", "tensor(uint8)"},
|
||||
"Constrain input and output types to float/quantized tensors")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderOutput)
|
||||
.SetDomain(kMSNchwcDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(R"DOC(For internal use.)DOC")
|
||||
.Attr(
|
||||
"channels",
|
||||
"",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Input(0, "X", "", "T")
|
||||
.Output(0, "Y", "", "T")
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(float)", "tensor(int8)", "tensor(uint8)"},
|
||||
"Constrain input and output types to float/quantized tensors")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
if (!hasNInputShapes(ctx, 1)) {
|
||||
return;
|
||||
}
|
||||
propagateShapeFromInputToOutput(ctx, 0, 0);
|
||||
|
||||
// Update the output shape with the actual number of channels.
|
||||
auto channels = getAttribute(ctx, "channels", 0);
|
||||
if (channels <= 0) {
|
||||
fail_shape_inference("invalid channel count");
|
||||
}
|
||||
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
|
||||
if (output_shape->dim_size() < 2) {
|
||||
fail_shape_inference("tensor rank too small");
|
||||
}
|
||||
auto* channels_dim = output_shape->mutable_dim(1);
|
||||
channels_dim->clear_dim_param();
|
||||
channels_dim->set_dim_value(channels);
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(Conv)
|
||||
.SetDomain(kMSNchwcDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(R"DOC(For internal use.)DOC")
|
||||
.Attr(
|
||||
"auto_pad",
|
||||
"",
|
||||
AttributeProto::STRING,
|
||||
std::string("NOTSET"))
|
||||
.Attr(
|
||||
"kernel_shape",
|
||||
"",
|
||||
AttributeProto::INTS,
|
||||
OPTIONAL)
|
||||
.Attr(
|
||||
"dilations",
|
||||
"",
|
||||
AttributeProto::INTS,
|
||||
OPTIONAL)
|
||||
.Attr(
|
||||
"strides",
|
||||
"",
|
||||
AttributeProto::INTS,
|
||||
OPTIONAL)
|
||||
.Attr(
|
||||
"pads",
|
||||
"",
|
||||
AttributeProto::INTS, OPTIONAL)
|
||||
.Attr(
|
||||
"group",
|
||||
"",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(1))
|
||||
.Attr(
|
||||
"activation",
|
||||
"",
|
||||
AttributeProto::STRING,
|
||||
OPTIONAL)
|
||||
.Attr(
|
||||
"activation_params",
|
||||
"",
|
||||
AttributeProto::FLOATS,
|
||||
OPTIONAL)
|
||||
.Input(0, "X", "", "T")
|
||||
.Input(1, "W", "", "T")
|
||||
.Input(2, "B", "", "T", OpSchema::Optional)
|
||||
.Input(3, "Sum", "", "T", OpSchema::Optional)
|
||||
.Output(0, "Y", "", "T")
|
||||
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
ONNX_NAMESPACE::convPoolShapeInference(ctx, true, false, 0, 1);
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(MaxPool)
|
||||
.FillUsing(NchwcPoolOpSchemaGenerator)
|
||||
.Attr(
|
||||
"storage_order",
|
||||
"",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0));
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(AveragePool)
|
||||
.FillUsing(NchwcPoolOpSchemaGenerator)
|
||||
.Attr(
|
||||
"count_include_pad",
|
||||
"",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0));
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(GlobalMaxPool)
|
||||
.FillUsing(NchwcGlobalPoolOpSchemaGenerator);
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(GlobalAveragePool)
|
||||
.FillUsing(NchwcGlobalPoolOpSchemaGenerator);
|
||||
}
|
||||
|
||||
void RegisterBertSchemas() {
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(Attention)
|
||||
.SetDomain(kMSDomain)
|
||||
|
|
@ -1383,8 +1226,8 @@ activation and leaky_relu_alpha.)DOC")
|
|||
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema);
|
||||
|
||||
static const char* QuantizeLinear_ver1_doc = R"DOC(
|
||||
The linear quantization operator. It consumes a full precision data, a scale, a zero point and computes the quantized data.
|
||||
The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale), it computes the nearest integer value to arg (in floating-point format),
|
||||
The linear quantization operator. It consumes a full precision data, a scale, a zero point and computes the quantized data.
|
||||
The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale), it computes the nearest integer value to arg (in floating-point format),
|
||||
rounding halfway cases away from zero. Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').)DOC";
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(QuantizeLinear)
|
||||
|
|
@ -1440,8 +1283,8 @@ The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale),
|
|||
});
|
||||
|
||||
static const char* DequantizeLinear_ver1_doc = R"DOC(
|
||||
The linear dequantization operator. It consumes a quantized data, a scale, a zero point and computes the full precision data.
|
||||
The dequantization formula is y = (x - x_zero_point) * x_scale.
|
||||
The linear dequantization operator. It consumes a quantized data, a scale, a zero point and computes the full precision data.
|
||||
The dequantization formula is y = (x - x_zero_point) * x_scale.
|
||||
Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').)DOC";
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(DequantizeLinear)
|
||||
|
|
@ -1682,7 +1525,7 @@ Computes the mean of the low-precision input tensor's element along the provided
|
|||
The resulting tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0,
|
||||
then the resulting tensor have the reduced dimension pruned. The above behavior is similar to numpy,
|
||||
with the exception that numpy default keepdims to False instead of True.
|
||||
Input and Output scales and zero points are used to requantize the output in a new range.
|
||||
Input and Output scales and zero points are used to requantize the output in a new range.
|
||||
This helps to improve accuracy as after ReduceMean operation the range of the output is expected to decrease.
|
||||
|
||||
```
|
||||
|
|
@ -1861,7 +1704,7 @@ C (int32) = (A - A_zero_point) * (B - B_zero_point)
|
|||
```
|
||||
pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + kernel_spatial_shape[i] - input_spatial_shape[i]
|
||||
```
|
||||
|
||||
|
||||
The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero).
|
||||
|
||||
Input and output scales and zero points are used to convert the output to a new quantization range.
|
||||
|
|
@ -2448,7 +2291,7 @@ Example 4:
|
|||
R"DOC(Gaussian Error Linear Unit.
|
||||
A high-performing neural network activation function.The GELU nonlinearity is
|
||||
the expected transformation of a stochastic regularizer which randomly applies
|
||||
the identity or zero map to a neuron's input. The GELU nonlinearity weights
|
||||
the identity or zero map to a neuron's input. The GELU nonlinearity weights
|
||||
inputs by their magnitude, rather than gates inputs by their sign as in ReLUs.)DOC";
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(Gelu)
|
||||
|
|
|
|||
153
onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc
Normal file
153
onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/graph/constants.h"
|
||||
#include "core/graph/contrib_ops/contrib_defs.h"
|
||||
#include "core/graph/contrib_ops/nchwc_schema_defs.h"
|
||||
|
||||
namespace ONNX_NAMESPACE {
|
||||
void convPoolShapeInference(
|
||||
ONNX_NAMESPACE::InferenceContext& ctx,
|
||||
bool use_dilation, bool require_kernel_shape,
|
||||
int input1Idx,
|
||||
int input2Idx);
|
||||
void globalPoolTypeShapeInference(ONNX_NAMESPACE::InferenceContext& ctx);
|
||||
} // namespace ONNX_NAMESPACE
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
using ONNX_NAMESPACE::AttributeProto;
|
||||
using ONNX_NAMESPACE::InferenceContext;
|
||||
using ONNX_NAMESPACE::OpSchema;
|
||||
using ONNX_NAMESPACE::OPTIONAL;
|
||||
|
||||
void NchwcPoolOpSchemaGenerator(OpSchema& schema) {
|
||||
schema.SetDomain(kMSNchwcDomain);
|
||||
schema.SinceVersion(1);
|
||||
schema.SetDoc(R"DOC(For internal use.)DOC");
|
||||
schema.Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET"));
|
||||
schema.Attr("kernel_shape", "", AttributeProto::INTS);
|
||||
schema.Attr("dilations", "", AttributeProto::INTS, OPTIONAL);
|
||||
schema.Attr("strides", "", AttributeProto::INTS, OPTIONAL);
|
||||
schema.Attr("pads", "", AttributeProto::INTS, OPTIONAL);
|
||||
schema.Attr("ceil_mode", "", AttributeProto::INT, static_cast<int64_t>(0));
|
||||
schema.Input(0, "X", "", "T");
|
||||
schema.Output(0, "Y", "", "T");
|
||||
schema.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors");
|
||||
schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
ONNX_NAMESPACE::convPoolShapeInference(ctx, true, true, 0, 1);
|
||||
});
|
||||
}
|
||||
|
||||
void NchwcGlobalPoolOpSchemaGenerator(OpSchema& schema) {
|
||||
schema.SetDomain(kMSNchwcDomain);
|
||||
schema.SinceVersion(1);
|
||||
schema.SetDoc(R"DOC(For internal use.)DOC");
|
||||
schema.Input(0, "X", "", "T");
|
||||
schema.Output(0, "Y", "", "T");
|
||||
schema.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors");
|
||||
schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
ONNX_NAMESPACE::globalPoolTypeShapeInference(ctx);
|
||||
});
|
||||
}
|
||||
|
||||
void RegisterNchwcSchemas() {
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderInput)
|
||||
.SetDomain(kMSNchwcDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(R"DOC(For internal use.)DOC")
|
||||
.Input(0, "X", "", "T")
|
||||
.Output(0, "Y", "", "T")
|
||||
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderOutput)
|
||||
.SetDomain(kMSNchwcDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(R"DOC(For internal use.)DOC")
|
||||
.Attr("channels", "", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("channels_last", "", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Input(0, "X", "", "T")
|
||||
.Output(0, "Y", "", "T")
|
||||
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
if (!hasNInputShapes(ctx, 1)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto input_shape = ctx.getInputType(0)->tensor_type().shape();
|
||||
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
|
||||
|
||||
auto input_rank = input_shape.dim_size();
|
||||
if (input_rank < 2) {
|
||||
fail_shape_inference("tensor rank too small");
|
||||
}
|
||||
|
||||
// Update the output shape with the actual number of channels.
|
||||
auto channels = getAttribute(ctx, "channels", 0);
|
||||
if (channels <= 0) {
|
||||
fail_shape_inference("invalid channel count");
|
||||
}
|
||||
|
||||
// Copy batch dimension.
|
||||
*output_shape->add_dim() = input_shape.dim(0);
|
||||
|
||||
auto channels_last = getAttribute(ctx, "channels_last", 0);
|
||||
if (channels_last == 0) {
|
||||
output_shape->add_dim()->set_dim_value(channels);
|
||||
}
|
||||
|
||||
// Copy spatial dimensions.
|
||||
for (int i = 0; i < input_rank - 2; i++) {
|
||||
*output_shape->add_dim() = input_shape.dim(2 + i);
|
||||
}
|
||||
|
||||
if (channels_last != 0) {
|
||||
output_shape->add_dim()->set_dim_value(channels);
|
||||
}
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(Conv)
|
||||
.SetDomain(kMSNchwcDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(R"DOC(For internal use.)DOC")
|
||||
.Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET"))
|
||||
.Attr("kernel_shape", "", AttributeProto::INTS, OPTIONAL)
|
||||
.Attr("dilations", "", AttributeProto::INTS, OPTIONAL)
|
||||
.Attr("strides", "", AttributeProto::INTS, OPTIONAL)
|
||||
.Attr("pads", "", AttributeProto::INTS, OPTIONAL)
|
||||
.Attr("group", "", AttributeProto::INT, static_cast<int64_t>(1))
|
||||
.Attr("activation", "", AttributeProto::STRING, OPTIONAL)
|
||||
.Attr("activation_params", "", AttributeProto::FLOATS, OPTIONAL)
|
||||
.Input(0, "X", "", "T")
|
||||
.Input(1, "W", "", "T")
|
||||
.Input(2, "B", "", "T", OpSchema::Optional)
|
||||
.Input(3, "Sum", "", "T", OpSchema::Optional)
|
||||
.Output(0, "Y", "", "T")
|
||||
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
ONNX_NAMESPACE::convPoolShapeInference(ctx, true, false, 0, 1);
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(MaxPool)
|
||||
.FillUsing(NchwcPoolOpSchemaGenerator)
|
||||
.Attr("storage_order", "", AttributeProto::INT, static_cast<int64_t>(0));
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(AveragePool)
|
||||
.FillUsing(NchwcPoolOpSchemaGenerator)
|
||||
.Attr("count_include_pad", "", AttributeProto::INT, static_cast<int64_t>(0));
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(GlobalMaxPool)
|
||||
.FillUsing(NchwcGlobalPoolOpSchemaGenerator);
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(GlobalAveragePool)
|
||||
.FillUsing(NchwcGlobalPoolOpSchemaGenerator);
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
12
onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.h
Normal file
12
onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.h
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
void RegisterNchwcSchemas();
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -333,7 +333,15 @@ MlasReorderInput(
|
|||
|
||||
void
|
||||
MLASCALL
|
||||
MlasReorderOutput(
|
||||
MlasReorderOutputNchw(
|
||||
const int64_t* OutputShape,
|
||||
const float* S,
|
||||
float* D
|
||||
);
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasReorderOutputNhwc(
|
||||
const int64_t* OutputShape,
|
||||
const float* S,
|
||||
float* D
|
||||
|
|
|
|||
|
|
@ -269,7 +269,7 @@ Return Value:
|
|||
|
||||
void
|
||||
MLASCALL
|
||||
MlasReorderOutput(
|
||||
MlasReorderOutputNchw(
|
||||
const int64_t* OutputShape,
|
||||
const float* S,
|
||||
float* D
|
||||
|
|
@ -365,6 +365,81 @@ Return Value:
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasReorderOutputNhwc(
|
||||
const int64_t* OutputShape,
|
||||
const float* S,
|
||||
float* D
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine reorders an output buffer from NCHWc to NHWC format.
|
||||
|
||||
Arguments:
|
||||
|
||||
OutputShape - Supplies the shape of the output tensor.
|
||||
|
||||
S - Supplies the address of the source tensor.
|
||||
|
||||
D - Supplies the address of the destination tensor.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
const size_t BlockSize = MlasNchwcGetBlockSize();
|
||||
|
||||
const size_t BatchCount = size_t(OutputShape[0]);
|
||||
const size_t OutputChannels = size_t(OutputShape[3]);
|
||||
const size_t OutputSize = size_t(OutputShape[1]) * size_t(OutputShape[2]);
|
||||
|
||||
const size_t AlignedOutputChannels = (OutputChannels + BlockSize - 1) & ~(BlockSize - 1);
|
||||
|
||||
//
|
||||
// Copy NCHWc blocks from the source buffer to the destination buffer.
|
||||
//
|
||||
|
||||
for (size_t batch = 0; batch < BatchCount; batch++) {
|
||||
|
||||
const float* s = S;
|
||||
size_t OutputSizeRemaining = OutputSize;
|
||||
|
||||
for (; OutputSizeRemaining > 0; OutputSizeRemaining--) {
|
||||
|
||||
const float* ss = s;
|
||||
|
||||
for (size_t o = OutputChannels; o > 0;) {
|
||||
|
||||
const size_t OutputChannelsThisIteration = (std::min)(o, BlockSize);
|
||||
const size_t AlignedOutputChannelsThisIteration = OutputChannelsThisIteration & (~3);
|
||||
o -= OutputChannelsThisIteration;
|
||||
|
||||
size_t bc = 0;
|
||||
|
||||
for (; bc < AlignedOutputChannelsThisIteration; bc += 4) {
|
||||
MlasStoreFloat32x4(&D[bc], MlasLoadFloat32x4(&ss[bc]));
|
||||
}
|
||||
|
||||
for (; bc < OutputChannelsThisIteration; bc += 1) {
|
||||
D[bc] = ss[bc];
|
||||
}
|
||||
|
||||
ss += BlockSize * OutputSize;
|
||||
D += OutputChannelsThisIteration;
|
||||
}
|
||||
|
||||
s += BlockSize;
|
||||
}
|
||||
|
||||
S += AlignedOutputChannels * OutputSize;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasReorderFilterOIHWBiBo(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/gemm_activation_fusion.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include <deque>
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::common;
|
||||
|
|
@ -23,7 +22,6 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
GraphViewer graph_viewer(graph);
|
||||
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
std::deque<onnxruntime::NodeIndex> removed_nodes;
|
||||
for (auto index : order) {
|
||||
auto* node_ptr = graph.GetNode(index);
|
||||
if (!node_ptr)
|
||||
|
|
|
|||
|
|
@ -119,6 +119,7 @@ class NchwcTransformerImpl {
|
|||
void TransformConcat(Node& node);
|
||||
void TransformActivation(Node& node);
|
||||
void TransformBatchNormalization(Node& node);
|
||||
void TransformTranspose(Node& node);
|
||||
|
||||
Graph& graph_;
|
||||
|
||||
|
|
@ -199,7 +200,7 @@ void NchwcTransformerImpl::InsertReorderInput(Node& node) {
|
|||
{input_nchwc_arg},
|
||||
nullptr,
|
||||
kMSNchwcDomain);
|
||||
reorder_input_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
reorder_input_node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
input_defs[0] = input_nchwc_arg;
|
||||
} else {
|
||||
input_defs[0] = it->second;
|
||||
|
|
@ -426,7 +427,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) {
|
|||
output_defs,
|
||||
&node.GetAttributes(),
|
||||
kMSNchwcDomain);
|
||||
nchwc_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
nchwc_node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
|
||||
nchwc_node.MutableInputDefs()[1] = nchwc_conv_W_arg;
|
||||
|
||||
|
|
@ -485,7 +486,7 @@ void NchwcTransformerImpl::TransformPool(Node& node) {
|
|||
output_defs,
|
||||
&node.GetAttributes(),
|
||||
kMSNchwcDomain);
|
||||
nchwc_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
nchwc_node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
|
||||
NchwcArgument::Shape output_shape(output_defs[0]);
|
||||
|
||||
|
|
@ -774,7 +775,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) {
|
|||
output_defs,
|
||||
nullptr,
|
||||
kMSNchwcDomain);
|
||||
nchwc_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
nchwc_node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
nchwc_node.AddAttribute("group", nchwc_channels);
|
||||
|
||||
nchwc_input->remaining_original_uses_--;
|
||||
|
|
@ -783,6 +784,47 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) {
|
|||
removed_nodes_.push_front(node.Index());
|
||||
}
|
||||
|
||||
void NchwcTransformerImpl::TransformTranspose(Node& node) {
|
||||
auto& input_defs = node.MutableInputDefs();
|
||||
auto& output_defs = node.MutableOutputDefs();
|
||||
|
||||
const ONNX_NAMESPACE::AttributeProto* perm_attr = graph_utils::GetNodeAttribute(node, "perm");
|
||||
if (perm_attr == nullptr || perm_attr->ints_size() != 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Test if this transposes from NCHW to NHWC layout order.
|
||||
const int64_t* perm_data = perm_attr->ints().data();
|
||||
if (perm_data[0] != 0 || perm_data[1] != 2 || perm_data[2] != 3 || perm_data[3] != 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Don't transform the node if the input is not already in NCHWc format.
|
||||
auto it = nchwc_args_.find(input_defs[0]);
|
||||
if (it == nchwc_args_.end()) {
|
||||
return;
|
||||
}
|
||||
auto* nchwc_input = it->second.get();
|
||||
|
||||
// Create the replacement node.
|
||||
Node& reorder_output_node = graph_.AddNode(graph_.GenerateNodeName("ReorderOutput"),
|
||||
"ReorderOutput",
|
||||
"ReorderOutput",
|
||||
{nchwc_input->nchwc_arg_},
|
||||
output_defs,
|
||||
nullptr,
|
||||
kMSNchwcDomain);
|
||||
reorder_output_node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
reorder_output_node.AddAttribute("channels", nchwc_input->channels_);
|
||||
reorder_output_node.AddAttribute("channels_last", static_cast<int64_t>(1));
|
||||
|
||||
nchwc_input->remaining_original_uses_--;
|
||||
|
||||
graph_utils::RemoveNodeOutputEdges(graph_, node);
|
||||
|
||||
removed_nodes_.push_front(node.Index());
|
||||
}
|
||||
|
||||
void NchwcTransformerImpl::Transform(Node& node) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(node, "FusedConv", {1}, kMSDomain)) {
|
||||
|
|
@ -807,6 +849,8 @@ void NchwcTransformerImpl::Transform(Node& node) {
|
|||
TransformActivation(node);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9})) {
|
||||
TransformBatchNormalization(node);
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1})) {
|
||||
TransformTranspose(node);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -183,7 +183,9 @@ public:
|
|||
void
|
||||
ExecuteShort(
|
||||
void
|
||||
) = 0;
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
//
|
||||
// Contains tests that can run slowly to more exhaustively test that
|
||||
|
|
@ -194,7 +196,9 @@ public:
|
|||
void
|
||||
ExecuteLong(
|
||||
void
|
||||
) = 0;
|
||||
)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -1101,7 +1105,7 @@ protected:
|
|||
// Reorder the output buffer.
|
||||
//
|
||||
|
||||
MlasReorderOutput(OutputShape, NchwcOutput, Output);
|
||||
MlasReorderOutputNchw(OutputShape, NchwcOutput, Output);
|
||||
}
|
||||
|
||||
const size_t BlockSize = MlasNchwcGetBlockSize();
|
||||
|
|
@ -1510,7 +1514,7 @@ protected:
|
|||
NchwcOutput,
|
||||
nullptr);
|
||||
|
||||
MlasReorderOutput(OutputShape, NchwcOutput, Output);
|
||||
MlasReorderOutputNchw(OutputShape, NchwcOutput, Output);
|
||||
}
|
||||
|
||||
MatrixGuardBuffer<float> BufferNchwcInput;
|
||||
|
|
@ -1963,12 +1967,107 @@ public:
|
|||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class MlasReorderOutputTest : public MlasTestBase
|
||||
{
|
||||
private:
|
||||
const size_t BlockSize = MlasNchwcGetBlockSize();
|
||||
|
||||
MatrixGuardBuffer<float> BufferInput;
|
||||
MatrixGuardBuffer<float> BufferOutput;
|
||||
MatrixGuardBuffer<float> BufferOutput2;
|
||||
MatrixGuardBuffer<float> BufferOutputReference;
|
||||
|
||||
void
|
||||
ExecuteLong(
|
||||
Test(
|
||||
size_t BatchCount,
|
||||
size_t Channels,
|
||||
size_t Height,
|
||||
size_t Width
|
||||
)
|
||||
{
|
||||
size_t NchwcChannels = (Channels + BlockSize - 1) & ~(BlockSize - 1);
|
||||
|
||||
size_t InputBufferElements = BatchCount * NchwcChannels * Height * Width;
|
||||
size_t OutputBufferElements = BatchCount * Channels * Height * Width;
|
||||
|
||||
const float* Input = BufferInput.GetBuffer(InputBufferElements);
|
||||
float* Output = BufferOutput.GetBuffer(OutputBufferElements);
|
||||
float* OutputReference = BufferOutputReference.GetBuffer(OutputBufferElements);
|
||||
|
||||
int64_t NchwOutputShape[] = { int64_t(BatchCount), int64_t(Channels), int64_t(Height), int64_t(Width) };
|
||||
|
||||
std::fill_n(Output, OutputBufferElements, -0.5f);
|
||||
std::fill_n(OutputReference, OutputBufferElements, -0.5f);
|
||||
|
||||
MlasReorderOutputNchw(NchwOutputShape, Input, Output);
|
||||
ReferenceReorderOutput(BatchCount, Channels, Height, Width, Input, OutputReference, false);
|
||||
|
||||
if (memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)) != 0) {
|
||||
printf("mismatch ReorderOutputNchw: batch=%zd channels=%zd height=%zd width=%zd\n",
|
||||
BatchCount, Channels, Height, Width);
|
||||
}
|
||||
|
||||
int64_t NhwcOutputShape[] = { int64_t(BatchCount), int64_t(Height), int64_t(Width), int64_t(Channels) };
|
||||
|
||||
std::fill_n(Output, OutputBufferElements, -0.5f);
|
||||
std::fill_n(OutputReference, OutputBufferElements, -0.5f);
|
||||
|
||||
MlasReorderOutputNhwc(NhwcOutputShape, Input, Output);
|
||||
ReferenceReorderOutput(BatchCount, Channels, Height, Width, Input, OutputReference, true);
|
||||
|
||||
if (memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)) != 0) {
|
||||
printf("mismatch ReorderOutputNhwc: batch=%zd channels=%zd height=%zd width=%zd\n",
|
||||
BatchCount, Channels, Height, Width);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ReferenceReorderOutput(
|
||||
size_t BatchCount,
|
||||
size_t Channels,
|
||||
size_t Height,
|
||||
size_t Width,
|
||||
const float* Input,
|
||||
float* Output,
|
||||
bool NhwcFormat
|
||||
)
|
||||
{
|
||||
size_t NchwcChannels = (Channels + (BlockSize - 1)) & ~(BlockSize - 1);
|
||||
size_t SpatialSize = Height * Width;
|
||||
|
||||
size_t ChannelStride = NhwcFormat ? 1 : SpatialSize;
|
||||
size_t SpatialStride = NhwcFormat ? Channels : 1;
|
||||
|
||||
for (size_t n = 0; n < BatchCount; n++) {
|
||||
|
||||
for (size_t c = 0; c < Channels; c++) {
|
||||
|
||||
const float* input = Input + ((c & ~(BlockSize - 1)) * SpatialSize) + (c & (BlockSize - 1));
|
||||
float* output = Output + (c * ChannelStride);
|
||||
|
||||
for (size_t hw = 0; hw < SpatialSize; hw++) {
|
||||
output[hw * SpatialStride] = input[hw * BlockSize];
|
||||
}
|
||||
}
|
||||
|
||||
Input += NchwcChannels * SpatialSize;
|
||||
Output += Channels * SpatialSize;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
void
|
||||
ExecuteShort(
|
||||
void
|
||||
) override
|
||||
{
|
||||
for (size_t c = 1; c < 48; c++) {
|
||||
Test(1, c, 112, 112);
|
||||
Test(4, c, 15, 21);
|
||||
Test(16, c, 11, 11);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -2010,9 +2109,6 @@ main(
|
|||
printf("Pool3D tests.\n");
|
||||
onnxruntime::make_unique<MlasPool3DTest>()->ExecuteShort();
|
||||
|
||||
printf("Activation tests.\n");
|
||||
onnxruntime::make_unique<MlasActivationTest>()->ExecuteShort();
|
||||
|
||||
printf("Done.\n");
|
||||
#if !defined(MLAS_NO_ONNXRUNTIME_THREADPOOL)
|
||||
if(threadpool != nullptr) threadpool = new onnxruntime::concurrency::ThreadPool("test", 2);
|
||||
|
|
@ -2023,5 +2119,14 @@ main(
|
|||
#if !defined(MLAS_NO_ONNXRUNTIME_THREADPOOL)
|
||||
delete threadpool;
|
||||
#endif
|
||||
|
||||
printf("Activation tests.\n");
|
||||
onnxruntime::make_unique<MlasActivationTest>()->ExecuteShort();
|
||||
|
||||
printf("ReorderOutput tests.\n");
|
||||
if (MlasNchwcGetBlockSize() > 1) {
|
||||
onnxruntime::make_unique<MlasReorderOutputTest>()->ExecuteShort();
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -138,6 +138,24 @@ struct NchwcTestHelper {
|
|||
return node;
|
||||
}
|
||||
|
||||
Node& AddTransposeNode(NodeArg* input_arg, NodeArg* output_arg, const std::vector<int64_t>& perm) {
|
||||
auto& node = AddNode("Transpose", {input_arg}, {output_arg});
|
||||
node.AddAttribute("perm", perm);
|
||||
return node;
|
||||
}
|
||||
|
||||
Node& AddTransposeToNchwNode(NodeArg* input_arg, NodeArg* output_arg) {
|
||||
return AddTransposeNode(input_arg, output_arg, {0, 3, 1, 2});
|
||||
}
|
||||
|
||||
Node& AddTransposeToNhwcNode(NodeArg* input_arg, NodeArg* output_arg) {
|
||||
return AddTransposeNode(input_arg, output_arg, {0, 2, 3, 1});
|
||||
}
|
||||
|
||||
Node& AddTransposeToCnhwNode(NodeArg* input_arg, NodeArg* output_arg) {
|
||||
return AddTransposeNode(input_arg, output_arg, {1, 0, 2, 3});
|
||||
}
|
||||
|
||||
std::vector<float> FillRandomData(size_t count) {
|
||||
constexpr int min_fill_value = -23;
|
||||
constexpr int max_fill_value = 23;
|
||||
|
|
@ -1041,6 +1059,75 @@ TEST(NchwcOptimizerTests, BatchNormalization) {
|
|||
test_case(true);
|
||||
}
|
||||
|
||||
TEST(NchwcOptimizerTests, ConvReorderOutputNhwc) {
|
||||
auto build_test_case = [&](NchwcTestHelper& helper) {
|
||||
auto* input_arg = helper.MakeInput({1, 64, 28, 32});
|
||||
auto* conv_output_arg = helper.MakeIntermediate();
|
||||
auto* nhwc_output_arg = helper.MakeOutput();
|
||||
|
||||
helper.AddConvNode(input_arg, conv_output_arg, {130, 64, 1, 1});
|
||||
helper.AddTransposeToNhwcNode(conv_output_arg, nhwc_output_arg);
|
||||
};
|
||||
|
||||
auto check_nchwc_graph = [&](NchwcInferenceSession& session) {
|
||||
auto op_to_count = session.CountOpsInGraph();
|
||||
EXPECT_EQ(op_to_count["nchwc.Conv"], 1);
|
||||
EXPECT_EQ(op_to_count["nchwc.ReorderInput"], 1);
|
||||
EXPECT_EQ(op_to_count["nchwc.ReorderOutput"], 1);
|
||||
EXPECT_EQ(op_to_count["Transpose"], 0);
|
||||
};
|
||||
|
||||
// Verify that a NHWC transpose is fused into ReorderOutput.
|
||||
NchwcOptimizerTester(build_test_case, check_nchwc_graph);
|
||||
}
|
||||
|
||||
TEST(NchwcOptimizerTests, ConvReorderOutputBoth) {
|
||||
auto build_test_case = [&](NchwcTestHelper& helper) {
|
||||
auto* input_arg = helper.MakeInput({5, 64, 33, 37});
|
||||
auto* conv_output_arg = helper.MakeIntermediate();
|
||||
auto* nchw_output_arg = helper.MakeOutput();
|
||||
auto* nhwc_output_arg = helper.MakeOutput();
|
||||
|
||||
helper.AddConvNode(input_arg, conv_output_arg, {7, 64, 1, 1});
|
||||
helper.AddTransposeToNhwcNode(conv_output_arg, nhwc_output_arg);
|
||||
helper.AddNode("Neg", {conv_output_arg}, {nchw_output_arg});
|
||||
};
|
||||
|
||||
auto check_nchwc_graph = [&](NchwcInferenceSession& session) {
|
||||
auto op_to_count = session.CountOpsInGraph();
|
||||
EXPECT_EQ(op_to_count["nchwc.Conv"], 1);
|
||||
EXPECT_EQ(op_to_count["nchwc.ReorderInput"], 1);
|
||||
EXPECT_EQ(op_to_count["nchwc.ReorderOutput"], 2);
|
||||
EXPECT_EQ(op_to_count["Transpose"], 0);
|
||||
};
|
||||
|
||||
// Verify that if an output argument is used as both NCHW and NHWC, then
|
||||
// two ReorderOutput nodes are inserted.
|
||||
NchwcOptimizerTester(build_test_case, check_nchwc_graph);
|
||||
}
|
||||
|
||||
TEST(NchwcOptimizerTests, ConvReorderOutputCnhw) {
|
||||
auto build_test_case = [&](NchwcTestHelper& helper) {
|
||||
auto* input_arg = helper.MakeInput({1, 64, 28, 32});
|
||||
auto* conv_output_arg = helper.MakeIntermediate();
|
||||
auto* nhwc_output_arg = helper.MakeOutput();
|
||||
|
||||
helper.AddConvNode(input_arg, conv_output_arg, {130, 64, 1, 1});
|
||||
helper.AddTransposeToCnhwNode(conv_output_arg, nhwc_output_arg);
|
||||
};
|
||||
|
||||
auto check_nchwc_graph = [&](NchwcInferenceSession& session) {
|
||||
auto op_to_count = session.CountOpsInGraph();
|
||||
EXPECT_EQ(op_to_count["nchwc.Conv"], 1);
|
||||
EXPECT_EQ(op_to_count["nchwc.ReorderInput"], 1);
|
||||
EXPECT_EQ(op_to_count["nchwc.ReorderOutput"], 1);
|
||||
EXPECT_EQ(op_to_count["Transpose"], 1);
|
||||
};
|
||||
|
||||
// Verify that a CNHW transpose is not fused into ReorderOutput.
|
||||
NchwcOptimizerTester(build_test_case, check_nchwc_graph);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
Loading…
Reference in a new issue