diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc index 14e0a83384..ee8bb17396 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc @@ -16,7 +16,7 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( float, KernelDefBuilder() .TypeConstraint("T", DataTypeImpl::GetTensorType()), - ReorderInput); + ReorderInput); ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( ReorderOutput, @@ -24,7 +24,7 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( float, KernelDefBuilder() .TypeConstraint("T", DataTypeImpl::GetTensorType()), - ReorderOutput); + ReorderOutput); ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( Conv, @@ -67,27 +67,41 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( .TypeConstraint("T", DataTypeImpl::GetTensorType()), NchwcAveragePool); -template -Status ReorderInput::Compute(OpKernelContext* context) const { +Status ReorderInput::Compute(OpKernelContext* context) const { const auto* X = context->Input(0); const auto& X_shape = X->Shape(); ORT_ENFORCE(X_shape.NumDimensions() == 4); ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0); auto* Y = context->Output(0, X_shape); - MlasReorderInput(X_shape.GetDims().data(), X->template Data(), Y->template MutableData()); + MlasReorderInput(X_shape.GetDims().data(), X->template Data(), Y->template MutableData()); return Status::OK(); } -template -Status ReorderOutput::Compute(OpKernelContext* context) const { +Status ReorderOutput::Compute(OpKernelContext* context) const { const auto* X = context->Input(0); const auto& X_shape = X->Shape(); - ORT_ENFORCE(X_shape.NumDimensions() == 4); - std::vector Y_shape(X_shape.GetDims()); - ORT_ENFORCE(channels_ <= Y_shape[1]); - Y_shape[1] = channels_; + const auto X_rank = X_shape.NumDimensions(); + ORT_ENFORCE(X_rank == 4); + ORT_ENFORCE(channels_ <= X_shape[1]); + + // Build the output shape in NCHW or NHWC order. + std::vector Y_shape(X_rank); + Y_shape[0] = X_shape[0]; + Y_shape[channels_last_ ? X_rank - 1 : 1] = channels_; + auto* Y_spatial_dims = Y_shape.data() + (channels_last_ ? 1 : 2); + for (size_t i = 0; i < X_rank - 2; i++) { + Y_spatial_dims[i] = X_shape[2 + i]; + } auto* Y = context->Output(0, Y_shape); - MlasReorderOutput(Y_shape.data(), X->template Data(), Y->template MutableData()); + + const auto* x_data = X->template Data(); + auto* y_data = Y->template MutableData(); + if (channels_last_) { + MlasReorderOutputNhwc(Y_shape.data(), x_data, y_data); + } else { + MlasReorderOutputNchw(Y_shape.data(), x_data, y_data); + } + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.h b/onnxruntime/contrib_ops/cpu/nchwc_ops.h index 5a6c606231..8edcadfc43 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.h +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.h @@ -12,7 +12,6 @@ namespace onnxruntime { namespace contrib { -template class ReorderInput : public OpKernel { public: ReorderInput(const OpKernelInfo& info) : OpKernel(info) { @@ -21,18 +20,19 @@ class ReorderInput : public OpKernel { Status Compute(OpKernelContext* context) const override; }; -template class ReorderOutput : public OpKernel { public: ReorderOutput(const OpKernelInfo& info) : OpKernel(info) { ORT_ENFORCE(info.GetAttr("channels", &channels_).IsOK()); ORT_ENFORCE(channels_ > 0, "invalid channel count"); + ORT_ENFORCE(info.GetAttr("channels_last", &channels_last_).IsOK()); } Status Compute(OpKernelContext* context) const override; private: int64_t channels_; + int64_t channels_last_; }; class NchwcConv : public OpKernel { diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 0bcbe7aa32..4f0ac9a092 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -5,6 +5,7 @@ #include "core/graph/constants.h" #include "core/graph/contrib_ops/attn_lstm_schema_defs.h" #include "core/graph/contrib_ops/contrib_defs.h" +#include "core/graph/contrib_ops/nchwc_schema_defs.h" #include "core/graph/contrib_ops/range_schema_defs.h" #include "core/graph/op.h" #include "onnx/defs/schema.h" @@ -18,7 +19,6 @@ void convPoolShapeInference( bool use_dilation, bool require_kernel_shape, int input1Idx, int input2Idx); -void globalPoolTypeShapeInference(ONNX_NAMESPACE::InferenceContext& ctx); void matmulShapeInference( ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, @@ -166,37 +166,6 @@ using ONNX_NAMESPACE::AttributeProto; using ONNX_NAMESPACE::OpSchema; using ONNX_NAMESPACE::OPTIONAL; -void NchwcPoolOpSchemaGenerator(OpSchema& schema) { - schema.SetDomain(kMSNchwcDomain); - schema.SinceVersion(1); - schema.SetDoc(R"DOC(For internal use.)DOC"); - schema.Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET")); - schema.Attr("kernel_shape", "", AttributeProto::INTS); - schema.Attr("dilations", "", AttributeProto::INTS, OPTIONAL); - schema.Attr("strides", "", AttributeProto::INTS, OPTIONAL); - schema.Attr("pads", "", AttributeProto::INTS, OPTIONAL); - schema.Attr("ceil_mode", "", AttributeProto::INT, static_cast(0)); - schema.Input(0, "X", "", "T"); - schema.Output(0, "Y", "", "T"); - schema.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors"); - schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); - ONNX_NAMESPACE::convPoolShapeInference(ctx, true, true, 0, 1); - }); -} - -void NchwcGlobalPoolOpSchemaGenerator(OpSchema& schema) { - schema.SetDomain(kMSNchwcDomain); - schema.SinceVersion(1); - schema.SetDoc(R"DOC(For internal use.)DOC"); - schema.Input(0, "X", "", "T"); - schema.Output(0, "Y", "", "T"); - schema.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors"); - schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - ONNX_NAMESPACE::globalPoolTypeShapeInference(ctx); - }); -} - void ValidateTypeAndShapeForScaleAndZP(ONNX_NAMESPACE::InferenceContext& ctx, int index, ::google::protobuf::int32 expectedType, bool isScalar, int expectedTensorSize = 0) { if (ctx.getNumInputs() > static_cast(index)) { auto data_type = ctx.getInputType(index); @@ -320,132 +289,6 @@ const char* contrib_ops_auto_pad_doc = "In case of odd number add the extra padding at the end for SAME_UPPER and at the " "beginning for SAME_LOWER. VALID mean no padding."; -void RegisterNchwcSchemas() { - ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderInput) - .SetDomain(kMSNchwcDomain) - .SinceVersion(1) - .SetDoc(R"DOC(For internal use.)DOC") - .Input(0, "X", "", "T") - .Output(0, "Y", "", "T") - .TypeConstraint( - "T", - {"tensor(float)", "tensor(int8)", "tensor(uint8)"}, - "Constrain input and output types to float/quantized tensors") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); - - ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderOutput) - .SetDomain(kMSNchwcDomain) - .SinceVersion(1) - .SetDoc(R"DOC(For internal use.)DOC") - .Attr( - "channels", - "", - AttributeProto::INT, - static_cast(0)) - .Input(0, "X", "", "T") - .Output(0, "Y", "", "T") - .TypeConstraint( - "T", - {"tensor(float)", "tensor(int8)", "tensor(uint8)"}, - "Constrain input and output types to float/quantized tensors") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (!hasNInputShapes(ctx, 1)) { - return; - } - propagateShapeFromInputToOutput(ctx, 0, 0); - - // Update the output shape with the actual number of channels. - auto channels = getAttribute(ctx, "channels", 0); - if (channels <= 0) { - fail_shape_inference("invalid channel count"); - } - auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); - if (output_shape->dim_size() < 2) { - fail_shape_inference("tensor rank too small"); - } - auto* channels_dim = output_shape->mutable_dim(1); - channels_dim->clear_dim_param(); - channels_dim->set_dim_value(channels); - }); - - ONNX_CONTRIB_OPERATOR_SCHEMA(Conv) - .SetDomain(kMSNchwcDomain) - .SinceVersion(1) - .SetDoc(R"DOC(For internal use.)DOC") - .Attr( - "auto_pad", - "", - AttributeProto::STRING, - std::string("NOTSET")) - .Attr( - "kernel_shape", - "", - AttributeProto::INTS, - OPTIONAL) - .Attr( - "dilations", - "", - AttributeProto::INTS, - OPTIONAL) - .Attr( - "strides", - "", - AttributeProto::INTS, - OPTIONAL) - .Attr( - "pads", - "", - AttributeProto::INTS, OPTIONAL) - .Attr( - "group", - "", - AttributeProto::INT, - static_cast(1)) - .Attr( - "activation", - "", - AttributeProto::STRING, - OPTIONAL) - .Attr( - "activation_params", - "", - AttributeProto::FLOATS, - OPTIONAL) - .Input(0, "X", "", "T") - .Input(1, "W", "", "T") - .Input(2, "B", "", "T", OpSchema::Optional) - .Input(3, "Sum", "", "T", OpSchema::Optional) - .Output(0, "Y", "", "T") - .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); - ONNX_NAMESPACE::convPoolShapeInference(ctx, true, false, 0, 1); - }); - - ONNX_CONTRIB_OPERATOR_SCHEMA(MaxPool) - .FillUsing(NchwcPoolOpSchemaGenerator) - .Attr( - "storage_order", - "", - AttributeProto::INT, - static_cast(0)); - - ONNX_CONTRIB_OPERATOR_SCHEMA(AveragePool) - .FillUsing(NchwcPoolOpSchemaGenerator) - .Attr( - "count_include_pad", - "", - AttributeProto::INT, - static_cast(0)); - - ONNX_CONTRIB_OPERATOR_SCHEMA(GlobalMaxPool) - .FillUsing(NchwcGlobalPoolOpSchemaGenerator); - - ONNX_CONTRIB_OPERATOR_SCHEMA(GlobalAveragePool) - .FillUsing(NchwcGlobalPoolOpSchemaGenerator); -} - void RegisterBertSchemas() { ONNX_CONTRIB_OPERATOR_SCHEMA(Attention) .SetDomain(kMSDomain) @@ -1383,8 +1226,8 @@ activation and leaky_relu_alpha.)DOC") ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema); static const char* QuantizeLinear_ver1_doc = R"DOC( -The linear quantization operator. It consumes a full precision data, a scale, a zero point and computes the quantized data. -The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale), it computes the nearest integer value to arg (in floating-point format), +The linear quantization operator. It consumes a full precision data, a scale, a zero point and computes the quantized data. +The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale), it computes the nearest integer value to arg (in floating-point format), rounding halfway cases away from zero. Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').)DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(QuantizeLinear) @@ -1440,8 +1283,8 @@ The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale), }); static const char* DequantizeLinear_ver1_doc = R"DOC( -The linear dequantization operator. It consumes a quantized data, a scale, a zero point and computes the full precision data. -The dequantization formula is y = (x - x_zero_point) * x_scale. +The linear dequantization operator. It consumes a quantized data, a scale, a zero point and computes the full precision data. +The dequantization formula is y = (x - x_zero_point) * x_scale. Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').)DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(DequantizeLinear) @@ -1682,7 +1525,7 @@ Computes the mean of the low-precision input tensor's element along the provided The resulting tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then the resulting tensor have the reduced dimension pruned. The above behavior is similar to numpy, with the exception that numpy default keepdims to False instead of True. -Input and Output scales and zero points are used to requantize the output in a new range. +Input and Output scales and zero points are used to requantize the output in a new range. This helps to improve accuracy as after ReduceMean operation the range of the output is expected to decrease. ``` @@ -1861,7 +1704,7 @@ C (int32) = (A - A_zero_point) * (B - B_zero_point) ``` pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + kernel_spatial_shape[i] - input_spatial_shape[i] ``` - + The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero). Input and output scales and zero points are used to convert the output to a new quantization range. @@ -2448,7 +2291,7 @@ Example 4: R"DOC(Gaussian Error Linear Unit. A high-performing neural network activation function.The GELU nonlinearity is the expected transformation of a stochastic regularizer which randomly applies -the identity or zero map to a neuron's input. The GELU nonlinearity weights +the identity or zero map to a neuron's input. The GELU nonlinearity weights inputs by their magnitude, rather than gates inputs by their sign as in ReLUs.)DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(Gelu) diff --git a/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc b/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc new file mode 100644 index 0000000000..b92e5f9095 --- /dev/null +++ b/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc @@ -0,0 +1,153 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/tensorprotoutils.h" +#include "core/graph/constants.h" +#include "core/graph/contrib_ops/contrib_defs.h" +#include "core/graph/contrib_ops/nchwc_schema_defs.h" + +namespace ONNX_NAMESPACE { +void convPoolShapeInference( + ONNX_NAMESPACE::InferenceContext& ctx, + bool use_dilation, bool require_kernel_shape, + int input1Idx, + int input2Idx); +void globalPoolTypeShapeInference(ONNX_NAMESPACE::InferenceContext& ctx); +} // namespace ONNX_NAMESPACE + +namespace onnxruntime { +namespace contrib { + +using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::InferenceContext; +using ONNX_NAMESPACE::OpSchema; +using ONNX_NAMESPACE::OPTIONAL; + +void NchwcPoolOpSchemaGenerator(OpSchema& schema) { + schema.SetDomain(kMSNchwcDomain); + schema.SinceVersion(1); + schema.SetDoc(R"DOC(For internal use.)DOC"); + schema.Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET")); + schema.Attr("kernel_shape", "", AttributeProto::INTS); + schema.Attr("dilations", "", AttributeProto::INTS, OPTIONAL); + schema.Attr("strides", "", AttributeProto::INTS, OPTIONAL); + schema.Attr("pads", "", AttributeProto::INTS, OPTIONAL); + schema.Attr("ceil_mode", "", AttributeProto::INT, static_cast(0)); + schema.Input(0, "X", "", "T"); + schema.Output(0, "Y", "", "T"); + schema.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors"); + schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + ONNX_NAMESPACE::convPoolShapeInference(ctx, true, true, 0, 1); + }); +} + +void NchwcGlobalPoolOpSchemaGenerator(OpSchema& schema) { + schema.SetDomain(kMSNchwcDomain); + schema.SinceVersion(1); + schema.SetDoc(R"DOC(For internal use.)DOC"); + schema.Input(0, "X", "", "T"); + schema.Output(0, "Y", "", "T"); + schema.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors"); + schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + ONNX_NAMESPACE::globalPoolTypeShapeInference(ctx); + }); +} + +void RegisterNchwcSchemas() { + ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderInput) + .SetDomain(kMSNchwcDomain) + .SinceVersion(1) + .SetDoc(R"DOC(For internal use.)DOC") + .Input(0, "X", "", "T") + .Output(0, "Y", "", "T") + .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + + ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderOutput) + .SetDomain(kMSNchwcDomain) + .SinceVersion(1) + .SetDoc(R"DOC(For internal use.)DOC") + .Attr("channels", "", AttributeProto::INT, static_cast(0)) + .Attr("channels_last", "", AttributeProto::INT, static_cast(0)) + .Input(0, "X", "", "T") + .Output(0, "Y", "", "T") + .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasNInputShapes(ctx, 1)) { + return; + } + + auto input_shape = ctx.getInputType(0)->tensor_type().shape(); + auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + + auto input_rank = input_shape.dim_size(); + if (input_rank < 2) { + fail_shape_inference("tensor rank too small"); + } + + // Update the output shape with the actual number of channels. + auto channels = getAttribute(ctx, "channels", 0); + if (channels <= 0) { + fail_shape_inference("invalid channel count"); + } + + // Copy batch dimension. + *output_shape->add_dim() = input_shape.dim(0); + + auto channels_last = getAttribute(ctx, "channels_last", 0); + if (channels_last == 0) { + output_shape->add_dim()->set_dim_value(channels); + } + + // Copy spatial dimensions. + for (int i = 0; i < input_rank - 2; i++) { + *output_shape->add_dim() = input_shape.dim(2 + i); + } + + if (channels_last != 0) { + output_shape->add_dim()->set_dim_value(channels); + } + }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(Conv) + .SetDomain(kMSNchwcDomain) + .SinceVersion(1) + .SetDoc(R"DOC(For internal use.)DOC") + .Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET")) + .Attr("kernel_shape", "", AttributeProto::INTS, OPTIONAL) + .Attr("dilations", "", AttributeProto::INTS, OPTIONAL) + .Attr("strides", "", AttributeProto::INTS, OPTIONAL) + .Attr("pads", "", AttributeProto::INTS, OPTIONAL) + .Attr("group", "", AttributeProto::INT, static_cast(1)) + .Attr("activation", "", AttributeProto::STRING, OPTIONAL) + .Attr("activation_params", "", AttributeProto::FLOATS, OPTIONAL) + .Input(0, "X", "", "T") + .Input(1, "W", "", "T") + .Input(2, "B", "", "T", OpSchema::Optional) + .Input(3, "Sum", "", "T", OpSchema::Optional) + .Output(0, "Y", "", "T") + .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + ONNX_NAMESPACE::convPoolShapeInference(ctx, true, false, 0, 1); + }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(MaxPool) + .FillUsing(NchwcPoolOpSchemaGenerator) + .Attr("storage_order", "", AttributeProto::INT, static_cast(0)); + + ONNX_CONTRIB_OPERATOR_SCHEMA(AveragePool) + .FillUsing(NchwcPoolOpSchemaGenerator) + .Attr("count_include_pad", "", AttributeProto::INT, static_cast(0)); + + ONNX_CONTRIB_OPERATOR_SCHEMA(GlobalMaxPool) + .FillUsing(NchwcGlobalPoolOpSchemaGenerator); + + ONNX_CONTRIB_OPERATOR_SCHEMA(GlobalAveragePool) + .FillUsing(NchwcGlobalPoolOpSchemaGenerator); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.h b/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.h new file mode 100644 index 0000000000..3036506d0a --- /dev/null +++ b/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace contrib { + +void RegisterNchwcSchemas(); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 899b11f91f..74263f3986 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -333,7 +333,15 @@ MlasReorderInput( void MLASCALL -MlasReorderOutput( +MlasReorderOutputNchw( + const int64_t* OutputShape, + const float* S, + float* D + ); + +void +MLASCALL +MlasReorderOutputNhwc( const int64_t* OutputShape, const float* S, float* D diff --git a/onnxruntime/core/mlas/lib/reorder.cpp b/onnxruntime/core/mlas/lib/reorder.cpp index 2e3da2fc00..7687a207c7 100644 --- a/onnxruntime/core/mlas/lib/reorder.cpp +++ b/onnxruntime/core/mlas/lib/reorder.cpp @@ -269,7 +269,7 @@ Return Value: void MLASCALL -MlasReorderOutput( +MlasReorderOutputNchw( const int64_t* OutputShape, const float* S, float* D @@ -365,6 +365,81 @@ Return Value: } } +void +MLASCALL +MlasReorderOutputNhwc( + const int64_t* OutputShape, + const float* S, + float* D + ) +/*++ + +Routine Description: + + This routine reorders an output buffer from NCHWc to NHWC format. + +Arguments: + + OutputShape - Supplies the shape of the output tensor. + + S - Supplies the address of the source tensor. + + D - Supplies the address of the destination tensor. + +Return Value: + + None. + +--*/ +{ + const size_t BlockSize = MlasNchwcGetBlockSize(); + + const size_t BatchCount = size_t(OutputShape[0]); + const size_t OutputChannels = size_t(OutputShape[3]); + const size_t OutputSize = size_t(OutputShape[1]) * size_t(OutputShape[2]); + + const size_t AlignedOutputChannels = (OutputChannels + BlockSize - 1) & ~(BlockSize - 1); + + // + // Copy NCHWc blocks from the source buffer to the destination buffer. + // + + for (size_t batch = 0; batch < BatchCount; batch++) { + + const float* s = S; + size_t OutputSizeRemaining = OutputSize; + + for (; OutputSizeRemaining > 0; OutputSizeRemaining--) { + + const float* ss = s; + + for (size_t o = OutputChannels; o > 0;) { + + const size_t OutputChannelsThisIteration = (std::min)(o, BlockSize); + const size_t AlignedOutputChannelsThisIteration = OutputChannelsThisIteration & (~3); + o -= OutputChannelsThisIteration; + + size_t bc = 0; + + for (; bc < AlignedOutputChannelsThisIteration; bc += 4) { + MlasStoreFloat32x4(&D[bc], MlasLoadFloat32x4(&ss[bc])); + } + + for (; bc < OutputChannelsThisIteration; bc += 1) { + D[bc] = ss[bc]; + } + + ss += BlockSize * OutputSize; + D += OutputChannelsThisIteration; + } + + s += BlockSize; + } + + S += AlignedOutputChannels * OutputSize; + } +} + void MLASCALL MlasReorderFilterOIHWBiBo( diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index c6c5e5316c..4ffb96dea0 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -4,7 +4,6 @@ #include "core/optimizer/initializer.h" #include "core/optimizer/gemm_activation_fusion.h" #include "core/graph/graph_utils.h" -#include using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; @@ -23,7 +22,6 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l GraphViewer graph_viewer(graph); const auto& order = graph_viewer.GetNodesInTopologicalOrder(); - std::deque removed_nodes; for (auto index : order) { auto* node_ptr = graph.GetNode(index); if (!node_ptr) diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 30835f264b..d9cad823c8 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -119,6 +119,7 @@ class NchwcTransformerImpl { void TransformConcat(Node& node); void TransformActivation(Node& node); void TransformBatchNormalization(Node& node); + void TransformTranspose(Node& node); Graph& graph_; @@ -199,7 +200,7 @@ void NchwcTransformerImpl::InsertReorderInput(Node& node) { {input_nchwc_arg}, nullptr, kMSNchwcDomain); - reorder_input_node.SetExecutionProviderType(node.GetExecutionProviderType()); + reorder_input_node.SetExecutionProviderType(kCpuExecutionProvider); input_defs[0] = input_nchwc_arg; } else { input_defs[0] = it->second; @@ -426,7 +427,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) { output_defs, &node.GetAttributes(), kMSNchwcDomain); - nchwc_node.SetExecutionProviderType(node.GetExecutionProviderType()); + nchwc_node.SetExecutionProviderType(kCpuExecutionProvider); nchwc_node.MutableInputDefs()[1] = nchwc_conv_W_arg; @@ -485,7 +486,7 @@ void NchwcTransformerImpl::TransformPool(Node& node) { output_defs, &node.GetAttributes(), kMSNchwcDomain); - nchwc_node.SetExecutionProviderType(node.GetExecutionProviderType()); + nchwc_node.SetExecutionProviderType(kCpuExecutionProvider); NchwcArgument::Shape output_shape(output_defs[0]); @@ -774,7 +775,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { output_defs, nullptr, kMSNchwcDomain); - nchwc_node.SetExecutionProviderType(node.GetExecutionProviderType()); + nchwc_node.SetExecutionProviderType(kCpuExecutionProvider); nchwc_node.AddAttribute("group", nchwc_channels); nchwc_input->remaining_original_uses_--; @@ -783,6 +784,47 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { removed_nodes_.push_front(node.Index()); } +void NchwcTransformerImpl::TransformTranspose(Node& node) { + auto& input_defs = node.MutableInputDefs(); + auto& output_defs = node.MutableOutputDefs(); + + const ONNX_NAMESPACE::AttributeProto* perm_attr = graph_utils::GetNodeAttribute(node, "perm"); + if (perm_attr == nullptr || perm_attr->ints_size() != 4) { + return; + } + + // Test if this transposes from NCHW to NHWC layout order. + const int64_t* perm_data = perm_attr->ints().data(); + if (perm_data[0] != 0 || perm_data[1] != 2 || perm_data[2] != 3 || perm_data[3] != 1) { + return; + } + + // Don't transform the node if the input is not already in NCHWc format. + auto it = nchwc_args_.find(input_defs[0]); + if (it == nchwc_args_.end()) { + return; + } + auto* nchwc_input = it->second.get(); + + // Create the replacement node. + Node& reorder_output_node = graph_.AddNode(graph_.GenerateNodeName("ReorderOutput"), + "ReorderOutput", + "ReorderOutput", + {nchwc_input->nchwc_arg_}, + output_defs, + nullptr, + kMSNchwcDomain); + reorder_output_node.SetExecutionProviderType(kCpuExecutionProvider); + reorder_output_node.AddAttribute("channels", nchwc_input->channels_); + reorder_output_node.AddAttribute("channels_last", static_cast(1)); + + nchwc_input->remaining_original_uses_--; + + graph_utils::RemoveNodeOutputEdges(graph_, node); + + removed_nodes_.push_front(node.Index()); +} + void NchwcTransformerImpl::Transform(Node& node) { if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "FusedConv", {1}, kMSDomain)) { @@ -807,6 +849,8 @@ void NchwcTransformerImpl::Transform(Node& node) { TransformActivation(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9})) { TransformBatchNormalization(node); + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1})) { + TransformTranspose(node); } } diff --git a/onnxruntime/test/mlas/unittest.cpp b/onnxruntime/test/mlas/unittest.cpp index d78167f8a5..051df88dfd 100644 --- a/onnxruntime/test/mlas/unittest.cpp +++ b/onnxruntime/test/mlas/unittest.cpp @@ -183,7 +183,9 @@ public: void ExecuteShort( void - ) = 0; + ) + { + } // // Contains tests that can run slowly to more exhaustively test that @@ -194,7 +196,9 @@ public: void ExecuteLong( void - ) = 0; + ) + { + } }; template @@ -1101,7 +1105,7 @@ protected: // Reorder the output buffer. // - MlasReorderOutput(OutputShape, NchwcOutput, Output); + MlasReorderOutputNchw(OutputShape, NchwcOutput, Output); } const size_t BlockSize = MlasNchwcGetBlockSize(); @@ -1510,7 +1514,7 @@ protected: NchwcOutput, nullptr); - MlasReorderOutput(OutputShape, NchwcOutput, Output); + MlasReorderOutputNchw(OutputShape, NchwcOutput, Output); } MatrixGuardBuffer BufferNchwcInput; @@ -1963,12 +1967,107 @@ public: } } } +}; + +class MlasReorderOutputTest : public MlasTestBase +{ +private: + const size_t BlockSize = MlasNchwcGetBlockSize(); + + MatrixGuardBuffer BufferInput; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutput2; + MatrixGuardBuffer BufferOutputReference; void - ExecuteLong( + Test( + size_t BatchCount, + size_t Channels, + size_t Height, + size_t Width + ) + { + size_t NchwcChannels = (Channels + BlockSize - 1) & ~(BlockSize - 1); + + size_t InputBufferElements = BatchCount * NchwcChannels * Height * Width; + size_t OutputBufferElements = BatchCount * Channels * Height * Width; + + const float* Input = BufferInput.GetBuffer(InputBufferElements); + float* Output = BufferOutput.GetBuffer(OutputBufferElements); + float* OutputReference = BufferOutputReference.GetBuffer(OutputBufferElements); + + int64_t NchwOutputShape[] = { int64_t(BatchCount), int64_t(Channels), int64_t(Height), int64_t(Width) }; + + std::fill_n(Output, OutputBufferElements, -0.5f); + std::fill_n(OutputReference, OutputBufferElements, -0.5f); + + MlasReorderOutputNchw(NchwOutputShape, Input, Output); + ReferenceReorderOutput(BatchCount, Channels, Height, Width, Input, OutputReference, false); + + if (memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)) != 0) { + printf("mismatch ReorderOutputNchw: batch=%zd channels=%zd height=%zd width=%zd\n", + BatchCount, Channels, Height, Width); + } + + int64_t NhwcOutputShape[] = { int64_t(BatchCount), int64_t(Height), int64_t(Width), int64_t(Channels) }; + + std::fill_n(Output, OutputBufferElements, -0.5f); + std::fill_n(OutputReference, OutputBufferElements, -0.5f); + + MlasReorderOutputNhwc(NhwcOutputShape, Input, Output); + ReferenceReorderOutput(BatchCount, Channels, Height, Width, Input, OutputReference, true); + + if (memcmp(Output, OutputReference, OutputBufferElements * sizeof(float)) != 0) { + printf("mismatch ReorderOutputNhwc: batch=%zd channels=%zd height=%zd width=%zd\n", + BatchCount, Channels, Height, Width); + } + } + + void + ReferenceReorderOutput( + size_t BatchCount, + size_t Channels, + size_t Height, + size_t Width, + const float* Input, + float* Output, + bool NhwcFormat + ) + { + size_t NchwcChannels = (Channels + (BlockSize - 1)) & ~(BlockSize - 1); + size_t SpatialSize = Height * Width; + + size_t ChannelStride = NhwcFormat ? 1 : SpatialSize; + size_t SpatialStride = NhwcFormat ? Channels : 1; + + for (size_t n = 0; n < BatchCount; n++) { + + for (size_t c = 0; c < Channels; c++) { + + const float* input = Input + ((c & ~(BlockSize - 1)) * SpatialSize) + (c & (BlockSize - 1)); + float* output = Output + (c * ChannelStride); + + for (size_t hw = 0; hw < SpatialSize; hw++) { + output[hw * SpatialStride] = input[hw * BlockSize]; + } + } + + Input += NchwcChannels * SpatialSize; + Output += Channels * SpatialSize; + } + } + +public: + void + ExecuteShort( void ) override { + for (size_t c = 1; c < 48; c++) { + Test(1, c, 112, 112); + Test(4, c, 15, 21); + Test(16, c, 11, 11); + } } }; @@ -2010,9 +2109,6 @@ main( printf("Pool3D tests.\n"); onnxruntime::make_unique()->ExecuteShort(); - printf("Activation tests.\n"); - onnxruntime::make_unique()->ExecuteShort(); - printf("Done.\n"); #if !defined(MLAS_NO_ONNXRUNTIME_THREADPOOL) if(threadpool != nullptr) threadpool = new onnxruntime::concurrency::ThreadPool("test", 2); @@ -2023,5 +2119,14 @@ main( #if !defined(MLAS_NO_ONNXRUNTIME_THREADPOOL) delete threadpool; #endif + + printf("Activation tests.\n"); + onnxruntime::make_unique()->ExecuteShort(); + + printf("ReorderOutput tests.\n"); + if (MlasNchwcGetBlockSize() > 1) { + onnxruntime::make_unique()->ExecuteShort(); + } + return 0; } diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index d3c468b70e..8407f1cdff 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -138,6 +138,24 @@ struct NchwcTestHelper { return node; } + Node& AddTransposeNode(NodeArg* input_arg, NodeArg* output_arg, const std::vector& perm) { + auto& node = AddNode("Transpose", {input_arg}, {output_arg}); + node.AddAttribute("perm", perm); + return node; + } + + Node& AddTransposeToNchwNode(NodeArg* input_arg, NodeArg* output_arg) { + return AddTransposeNode(input_arg, output_arg, {0, 3, 1, 2}); + } + + Node& AddTransposeToNhwcNode(NodeArg* input_arg, NodeArg* output_arg) { + return AddTransposeNode(input_arg, output_arg, {0, 2, 3, 1}); + } + + Node& AddTransposeToCnhwNode(NodeArg* input_arg, NodeArg* output_arg) { + return AddTransposeNode(input_arg, output_arg, {1, 0, 2, 3}); + } + std::vector FillRandomData(size_t count) { constexpr int min_fill_value = -23; constexpr int max_fill_value = 23; @@ -1041,6 +1059,75 @@ TEST(NchwcOptimizerTests, BatchNormalization) { test_case(true); } +TEST(NchwcOptimizerTests, ConvReorderOutputNhwc) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({1, 64, 28, 32}); + auto* conv_output_arg = helper.MakeIntermediate(); + auto* nhwc_output_arg = helper.MakeOutput(); + + helper.AddConvNode(input_arg, conv_output_arg, {130, 64, 1, 1}); + helper.AddTransposeToNhwcNode(conv_output_arg, nhwc_output_arg); + }; + + auto check_nchwc_graph = [&](NchwcInferenceSession& session) { + auto op_to_count = session.CountOpsInGraph(); + EXPECT_EQ(op_to_count["nchwc.Conv"], 1); + EXPECT_EQ(op_to_count["nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count["Transpose"], 0); + }; + + // Verify that a NHWC transpose is fused into ReorderOutput. + NchwcOptimizerTester(build_test_case, check_nchwc_graph); +} + +TEST(NchwcOptimizerTests, ConvReorderOutputBoth) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({5, 64, 33, 37}); + auto* conv_output_arg = helper.MakeIntermediate(); + auto* nchw_output_arg = helper.MakeOutput(); + auto* nhwc_output_arg = helper.MakeOutput(); + + helper.AddConvNode(input_arg, conv_output_arg, {7, 64, 1, 1}); + helper.AddTransposeToNhwcNode(conv_output_arg, nhwc_output_arg); + helper.AddNode("Neg", {conv_output_arg}, {nchw_output_arg}); + }; + + auto check_nchwc_graph = [&](NchwcInferenceSession& session) { + auto op_to_count = session.CountOpsInGraph(); + EXPECT_EQ(op_to_count["nchwc.Conv"], 1); + EXPECT_EQ(op_to_count["nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["nchwc.ReorderOutput"], 2); + EXPECT_EQ(op_to_count["Transpose"], 0); + }; + + // Verify that if an output argument is used as both NCHW and NHWC, then + // two ReorderOutput nodes are inserted. + NchwcOptimizerTester(build_test_case, check_nchwc_graph); +} + +TEST(NchwcOptimizerTests, ConvReorderOutputCnhw) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({1, 64, 28, 32}); + auto* conv_output_arg = helper.MakeIntermediate(); + auto* nhwc_output_arg = helper.MakeOutput(); + + helper.AddConvNode(input_arg, conv_output_arg, {130, 64, 1, 1}); + helper.AddTransposeToCnhwNode(conv_output_arg, nhwc_output_arg); + }; + + auto check_nchwc_graph = [&](NchwcInferenceSession& session) { + auto op_to_count = session.CountOpsInGraph(); + EXPECT_EQ(op_to_count["nchwc.Conv"], 1); + EXPECT_EQ(op_to_count["nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count["Transpose"], 1); + }; + + // Verify that a CNHW transpose is not fused into ReorderOutput. + NchwcOptimizerTester(build_test_case, check_nchwc_graph); +} + #endif } // namespace test