diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 4727498238..905bd37ed2 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2984,6 +2984,13 @@ This version of the operator has been available since version 1 of the 'com.micr This version of the operator has been available since version 1 of the 'com.microsoft.nchwc' operator set. +#### Attributes + +
+
channels_last : int
+
+
+ #### Inputs
diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc index e81a04d587..fd9c1f3a40 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc @@ -7,90 +7,128 @@ namespace onnxruntime { namespace contrib { -#define ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(name, ver, type, builder, ...) \ - ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSNchwcDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__) - -ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( - ReorderInput, - 1, - float, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()), - ReorderInput); - -ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( - ReorderOutput, - 1, - float, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()), - ReorderOutput); - -ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( - Conv, - 1, - float, - KernelDefBuilder() - .MayInplace(3, 0) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), - NchwcConv); - -ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( - MaxPool, - 1, - float, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()), - NchwcMaxPool); - -ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( - GlobalMaxPool, - 1, - float, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()), - NchwcMaxPool); - -ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( - AveragePool, - 1, - float, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()), - NchwcAveragePool); - -ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( - GlobalAveragePool, - 1, - float, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()), - NchwcAveragePool); - -ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( - Upsample, - 1, - float, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()), - NchwcUpsample); - Status ReorderInput::Compute(OpKernelContext* context) const { const auto* X = context->Input(0); - const auto& X_shape = X->Shape(); - ORT_ENFORCE(X_shape.NumDimensions() == 4); - ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0); + const auto& X_shape = X->Shape().GetDims(); + const auto X_rank = X_shape.size(); + ORT_ENFORCE(X_rank == 4); - auto* Y = context->Output(0, X_shape); - MlasReorderInput(X_shape.GetDims().data(), X->template Data(), Y->template MutableData()); + const int64_t batch_count = X_shape[0]; + const int64_t channels = X_shape[channels_last_ ? X_rank - 1 : 1]; + const auto* X_spatial_dims = X_shape.data() + (channels_last_ ? 1 : 2); + + // The current implementation of MlasReorderInputNchw does not work for channels that + // are not a multiple of 4. + ORT_ENFORCE((channels % 4) == 0); + + const int64_t nchwc_block_size = static_cast(MlasNchwcGetBlockSize()); + const int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1); + + std::vector Y_shape(X_rank); + Y_shape[0] = batch_count; + Y_shape[1] = nchwc_channels; + int64_t spatial_size = 1; + for (size_t i = 0; i < X_rank - 2; i++) { + const int64_t spatial_dim = X_spatial_dims[i]; + spatial_size *= spatial_dim; + Y_shape[2 + i] = spatial_dim; + } + + auto* Y = context->Output(0, Y_shape); + + // Bail out early if one of the dimensions is zero. + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + // Compute the total amount of work depending on NCHW or NHWC format and estimate + // a number of workers to use. + ptrdiff_t total_work; + ptrdiff_t worker_count; + + if (channels_last_) { + total_work = static_cast(batch_count * spatial_size); + // Partition the work with the goal of reordering the following number of + // elements, so that operations involving a smaller number of channels will + // process more rows per worker. + constexpr ptrdiff_t worker_goal = 48 * 1024; + ptrdiff_t work_per_worker = worker_goal / nchwc_channels; + if (work_per_worker == 0) { + work_per_worker = 1; + } + worker_count = total_work / work_per_worker; + if (worker_count == 0) { + worker_count = 1; + } + } else { + // Each iteration produces one spatial_size chunk of NCHWc blocks. + total_work = static_cast(batch_count * (nchwc_channels / nchwc_block_size)); + worker_count = total_work; + } + + const auto* x_data = X->template Data(); + auto* y_data = Y->template MutableData(); + + auto reorder_worker = [&](ptrdiff_t batch) { + auto work = concurrency::ThreadPool::PartitionWork(batch, worker_count, total_work); + + if (channels_last_) { + int64_t work_index = static_cast(work.start); + int64_t work_remaining = static_cast(work.end - work.start); + + while (work_remaining > 0) { + const int64_t batch_index = work_index / spatial_size; + const int64_t spatial_index = work_index % spatial_size; + const int64_t rows_this_iteration = std::min(work_remaining, spatial_size - spatial_index); + + MlasReorderInputNhwc( + x_data + ((batch_index * spatial_size) + spatial_index) * channels, + y_data + (batch_index * spatial_size * nchwc_channels) + (spatial_index * nchwc_block_size), + static_cast(channels), + static_cast(rows_this_iteration), + static_cast(spatial_size)); + + work_index += rows_this_iteration; + work_remaining -= rows_this_iteration; + } + } else { + int64_t work_index = static_cast(work.start) * nchwc_block_size; + int64_t work_remaining = static_cast(work.end - work.start) * nchwc_block_size; + + while (work_remaining > 0) { + const int64_t batch_index = work_index / nchwc_channels; + const int64_t channel_index = work_index % nchwc_channels; + const int64_t channels_this_iteration = std::min(work_remaining, channels - channel_index); + + MlasReorderInputNchw( + x_data + ((batch_index * channels) + channel_index) * spatial_size, + y_data + ((batch_index * nchwc_channels) + channel_index) * spatial_size, + static_cast(channels_this_iteration), + static_cast(spatial_size)); + + const int64_t nchwc_channels_this_iteration = std::min(work_remaining, nchwc_channels - channel_index); + work_index += nchwc_channels_this_iteration; + work_remaining -= nchwc_channels_this_iteration; + } + } + }; + + concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); + + // Handle the work in a single batch if only a single thread is available. + if (concurrency::ThreadPool::DegreeOfParallelism(thread_pool) == 1) { + worker_count = 1; + } + + concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, worker_count, reorder_worker); return Status::OK(); } Status ReorderOutput::Compute(OpKernelContext* context) const { const auto* X = context->Input(0); - const auto& X_shape = X->Shape(); - const auto X_rank = X_shape.NumDimensions(); + const auto& X_shape = X->Shape().GetDims(); + const auto X_rank = X_shape.size(); ORT_ENFORCE(X_rank == 4); ORT_ENFORCE(channels_ <= X_shape[1]); @@ -235,5 +273,73 @@ Status NchwcUpsample::Compute(OpKernelContext* context) const { return Status::OK(); } +#define ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(name, ver, type, builder, ...) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSNchwcDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__) + +ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( + ReorderInput, + 1, + float, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + ReorderInput); + +ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( + ReorderOutput, + 1, + float, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + ReorderOutput); + +ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( + Conv, + 1, + float, + KernelDefBuilder() + .MayInplace(3, 0) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + NchwcConv); + +ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( + MaxPool, + 1, + float, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + NchwcMaxPool); + +ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( + GlobalMaxPool, + 1, + float, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + NchwcMaxPool); + +ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( + AveragePool, + 1, + float, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + NchwcAveragePool); + +ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( + GlobalAveragePool, + 1, + float, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + NchwcAveragePool); + +ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL( + Upsample, + 1, + float, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + NchwcUpsample); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.h b/onnxruntime/contrib_ops/cpu/nchwc_ops.h index 402f029640..f6639d03f4 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.h +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.h @@ -15,9 +15,13 @@ namespace contrib { class ReorderInput : public OpKernel { public: ReorderInput(const OpKernelInfo& info) : OpKernel(info) { + ORT_ENFORCE(info.GetAttr("channels_last", &channels_last_).IsOK()); } Status Compute(OpKernelContext* context) const override; + + private: + int64_t channels_last_; }; class ReorderOutput : public OpKernel { diff --git a/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc b/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc index dcb4d34541..39d7621eec 100644 --- a/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc @@ -4,6 +4,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/graph/constants.h" #include "core/graph/contrib_ops/contrib_defs.h" +#include "core/mlas/inc/mlas.h" namespace ONNX_NAMESPACE { void convPoolShapeInference( @@ -58,10 +59,45 @@ void RegisterNchwcSchemas() { .SetDomain(kMSNchwcDomain) .SinceVersion(1) .SetDoc(R"DOC(For internal use.)DOC") + .Attr("channels_last", "", AttributeProto::INT, static_cast(0)) .Input(0, "X", "", "T") .Output(0, "Y", "", "T") .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasNInputShapes(ctx, 1)) { + return; + } + + const auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); + auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + + auto input_rank = input_shape.dim_size(); + if (input_rank < 2) { + fail_shape_inference("tensor rank too small"); + } + + auto channels_last = getAttribute(ctx, "channels_last", 0); + + // Copy the batch dimension. + *output_shape->add_dim() = input_shape.dim(0); + + // Block align the channel dimension. + const auto& input_channel_dim = input_shape.dim((channels_last == 0) ? 1 : input_rank - 1); + auto* output_channel_dim = output_shape->add_dim(); + if (input_channel_dim.has_dim_value()) { + const int64_t channels = input_channel_dim.dim_value(); + const int64_t nchwc_block_size = static_cast(MlasNchwcGetBlockSize()); + int64_t nchwc_channels = (channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1); + output_channel_dim->set_dim_value(nchwc_channels); + } + + // Copy the spatial dimensions. + int first_spatial_dim = (channels_last == 0) ? 2 : 1; + for (int i = 0; i < input_rank - 2; i++) { + *output_shape->add_dim() = input_shape.dim(first_spatial_dim + i); + } + }); ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderOutput) .SetDomain(kMSNchwcDomain) @@ -78,8 +114,8 @@ void RegisterNchwcSchemas() { return; } - auto input_shape = ctx.getInputType(0)->tensor_type().shape(); - auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + const auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); + auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); auto input_rank = input_shape.dim_size(); if (input_rank < 2) { @@ -92,7 +128,7 @@ void RegisterNchwcSchemas() { fail_shape_inference("invalid channel count"); } - // Copy batch dimension. + // Copy the batch dimension. *output_shape->add_dim() = input_shape.dim(0); auto channels_last = getAttribute(ctx, "channels_last", 0); @@ -100,7 +136,7 @@ void RegisterNchwcSchemas() { output_shape->add_dim()->set_dim_value(channels); } - // Copy spatial dimensions. + // Copy the spatial dimensions. for (int i = 0; i < input_rank - 2; i++) { *output_shape->add_dim() = input_shape.dim(2 + i); } @@ -161,8 +197,8 @@ void RegisterNchwcSchemas() { return; } - auto input_shape = ctx.getInputType(0)->tensor_type().shape(); - auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + const auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); + auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); auto input_rank = input_shape.dim_size(); if (input_rank < 2) { diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index c417bdb17d..b2ce940482 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -552,12 +552,12 @@ MlasGemm( MLAS_THREADPOOL* ThreadPool ); -/** - * @brief Batched GEMM, for multiplying multiple pairs of matrices. +/** + * @brief Batched GEMM, for multiplying multiple pairs of matrices. * Note: We only support uniform batching, so shapes and types of the * input must be same: M, N, K, BIsSigned must be the - * same across all parameter blocks. - * + * same across all parameter blocks. + * * @param [IN] Shape A single shape descriptor for all the multiplications * @param [IN] DataParams Array of data descriptors for the matrices. * @param [IN] BatchN Size of the parameters array, also number of multiplications to perform @@ -834,10 +834,21 @@ MlasTranspose( void MLASCALL -MlasReorderInput( - const int64_t* InputShape, +MlasReorderInputNchw( const float* S, - float* D + float* D, + size_t InputChannels, + size_t InputSize + ); + +void +MLASCALL +MlasReorderInputNhwc( + const float* S, + float* D, + size_t InputChannels, + size_t RowCount, + size_t FullRowCount ); void diff --git a/onnxruntime/core/mlas/lib/reorder.cpp b/onnxruntime/core/mlas/lib/reorder.cpp index 821bb0c9a2..0d7fbd97a4 100644 --- a/onnxruntime/core/mlas/lib/reorder.cpp +++ b/onnxruntime/core/mlas/lib/reorder.cpp @@ -179,10 +179,11 @@ Return Value: void MLASCALL -MlasReorderInput( - const int64_t* InputShape, +MlasReorderInputNchw( const float* S, - float* D + float* D, + size_t InputChannels, + size_t InputSize ) /*++ @@ -192,12 +193,14 @@ Routine Description: Arguments: - InputShape - Supplies the shape of the input tensor. - S - Supplies the address of the source tensor. D - Supplies the address of the destination tensor. + InputChannels - Supplies the number of NCHW channels. + + InputSize - Supplies the spatial input size of the tensors. + Return Value: None. @@ -206,11 +209,12 @@ Return Value: { const size_t BlockSize = MlasNchwcGetBlockSize(); - const size_t InputChannels = size_t(InputShape[0] * InputShape[1]); - const size_t InputSize = size_t(InputShape[2]) * size_t(InputShape[3]); - const MLAS_FLOAT32X4 ZeroFloat32x4 = MlasZeroFloat32x4(); + // + // Iterate over BlockSize batches of the input channels. + // + for (size_t i = InputChannels; i > 0;) { const size_t InputChannelsThisIteration = std::min(i, BlockSize); @@ -233,7 +237,10 @@ Return Value: } for (; bc < BlockSize; bc += 4) { - MlasStoreFloat32x4(dd, ZeroFloat32x4); + MlasStoreFloat32x4(&dd[BlockSize * 0], ZeroFloat32x4); + MlasStoreFloat32x4(&dd[BlockSize * 1], ZeroFloat32x4); + MlasStoreFloat32x4(&dd[BlockSize * 2], ZeroFloat32x4); + MlasStoreFloat32x4(&dd[BlockSize * 3], ZeroFloat32x4); dd += 4; } @@ -267,6 +274,127 @@ Return Value: } } +void +MLASCALL +MlasReorderInputNhwc( + const float* S, + float* D, + size_t InputChannels, + size_t RowCount, + size_t FullRowCount + ) +/*++ + +Routine Description: + + This routine reorders an input buffer from NHWC to NCHWc format. + +Arguments: + + S - Supplies the address of the source tensor. + + D - Supplies the address of the destination tensor. + + InputChannels - Supplies the number of NHWC channels. + + RowCount - Supplies the number of NHWC rows to process. This number may be + less than FullRowCount to support threaded operation. + + FullRowCount - Supplies the total number of NHWC rows per image. + +Return Value: + + None. + +--*/ +{ + const size_t BlockSize = MlasNchwcGetBlockSize(); + + // + // Iterate over batches of the input size to improve locality. + // + + for (size_t OuterRowCountRemaining = RowCount; OuterRowCountRemaining > 0; ) { + + constexpr size_t OuterRowCountBatch = 32; + + const size_t OuterRowCountThisIteration = std::min(OuterRowCountRemaining, OuterRowCountBatch); + OuterRowCountRemaining -= OuterRowCountThisIteration; + + // + // Iterate over BlockSize batches of the input channels. + // + + const float* s = S; + float* d = D; + + for (size_t i = InputChannels; i > 0;) { + + const size_t InputChannelsThisIteration = std::min(i, BlockSize); + i -= InputChannelsThisIteration; + + const float* ss = s; + float* dd = d; + size_t InnerRowCountRemaining = OuterRowCountThisIteration; + + if (InputChannelsThisIteration == BlockSize) { + + if (BlockSize == 8) { + + while (InnerRowCountRemaining-- > 0) { + + MLAS_FLOAT32X4 v0 = MlasLoadFloat32x4(&ss[0]); + MLAS_FLOAT32X4 v1 = MlasLoadFloat32x4(&ss[4]); + + MlasStoreFloat32x4(&dd[0], v0); + MlasStoreFloat32x4(&dd[4], v1); + + ss += InputChannels; + dd += 8; + } + + } else { + + while (InnerRowCountRemaining-- > 0) { + + MLAS_FLOAT32X4 v0 = MlasLoadFloat32x4(&ss[0]); + MLAS_FLOAT32X4 v1 = MlasLoadFloat32x4(&ss[4]); + MLAS_FLOAT32X4 v2 = MlasLoadFloat32x4(&ss[8]); + MLAS_FLOAT32X4 v3 = MlasLoadFloat32x4(&ss[12]); + + MlasStoreFloat32x4(&dd[0], v0); + MlasStoreFloat32x4(&dd[4], v1); + MlasStoreFloat32x4(&dd[8], v2); + MlasStoreFloat32x4(&dd[12], v3); + + ss += InputChannels; + dd += 16; + } + } + + } else { + + size_t BlockPadding = BlockSize - InputChannelsThisIteration; + + while (InnerRowCountRemaining-- > 0) { + + std::copy_n(ss, InputChannelsThisIteration, dd); + std::fill_n(dd + InputChannelsThisIteration, BlockPadding, 0.0f); + + ss += InputChannels; + dd += BlockSize; + } + } + + s += InputChannelsThisIteration; + d += BlockSize * FullRowCount; + } + + S += InputChannels * OuterRowCountThisIteration; + D += BlockSize * OuterRowCountThisIteration; + } +} + void MLASCALL MlasReorderOutputNchw( diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 87d931576d..623202b3cb 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -120,8 +120,9 @@ class NchwcTransformerImpl { void TransformConcat(Node& node); void TransformActivation(Node& node); void TransformBatchNormalization(Node& node); - void TransformTranspose(Node& node); + void TransformTransposeToNhwc(Node& node); void TransformResize(Node& node); + void TrackTransposeFromNhwc(Node& node); Graph& graph_; @@ -149,6 +150,11 @@ class NchwcTransformerImpl { // or unsplitting the channels dimension of a tensor. std::unordered_map reshape_split_; std::unordered_map reshape_unsplit_; + + // Tracks the last Transpose node and output NodeArg that transposed from + // NHWC to NCHW format. + Node* transpose_from_nhwc_node_{nullptr}; + NodeArg* transpose_from_nhwc_output_arg_{nullptr}; }; size_t NchwcTransformerImpl::RemoveOutputEdges(Node& node) { @@ -209,6 +215,20 @@ void NchwcTransformerImpl::InsertReorderInput(Node& node) { kMSNchwcDomain); reorder_input_node.SetExecutionProviderType(kCpuExecutionProvider); input_defs[0] = input_nchwc_arg; + + // Attempt to fuse the ReorderInput with a previous Transpose of NHWC->NCHW. + // If the last known node to transpose from NHWC is the same as this input + // argument, then the Transpose node can be removed and the ReorderInput node + // is modified to consume a tensor with NHWC layout order. The transpose was + // already determined to have a single use and not be a graph output. + if (transpose_from_nhwc_output_arg_ == input_original_arg) { + reorder_input_node.MutableInputDefs()[0] = transpose_from_nhwc_node_->MutableInputDefs()[0]; + reorder_input_node.AddAttribute("channels_last", static_cast(1)); + graph_utils::RemoveNodeOutputEdges(graph_, *transpose_from_nhwc_node_); + removed_nodes_.push_front(transpose_from_nhwc_node_->Index()); + transpose_from_nhwc_node_ = nullptr; + } + } else { input_defs[0] = it->second; } @@ -324,14 +344,21 @@ void NchwcTransformerImpl::TransformConv(Node& node) { bool do_reorder_input = true; bool reorder_filter_OIHWBo = false; + int64_t filter_input_channels = input_channels; + int64_t nchwc_group_count = group_count; + + // The current implementation of ReorderInput requires the channel count to be + // aligned to this value. + constexpr int64_t channel_alignment = 4; if (group_count > 1) { - if ((output_channels % nchwc_block_size) != 0) { + if ((output_channels % channel_alignment) != 0) { return; } if (input_channels == 1 && output_channels == group_count) { // Depthwise convolution. reorder_filter_OIHWBo = true; + nchwc_group_count = nchwc_output_channels; } else if (((input_channels % nchwc_block_size) != 0) || ((output_channels % group_count) != 0) || (((output_channels / group_count) % nchwc_block_size) != 0)) { @@ -342,8 +369,11 @@ void NchwcTransformerImpl::TransformConv(Node& node) { // Use NCHW input buffer directly. reorder_filter_OIHWBo = true; do_reorder_input = false; - } else if ((input_channels % nchwc_block_size) != 0) { - return; + } else { + if ((input_channels % channel_alignment) != 0) { + return; + } + filter_input_channels = (input_channels + nchwc_block_size - 1) & ~(nchwc_block_size - 1); } } @@ -374,15 +404,19 @@ void NchwcTransformerImpl::TransformConv(Node& node) { nchwc_conv_W_arg = filters_it->second; } else { Initializer conv_W{*conv_W_tensor_proto, graph_.ModelPath()}; + const auto& conv_W_dims = conv_W.dims(); - int64_t reordered_filter_vec_size = conv_W.size() / output_channels * nchwc_output_channels; - std::vector reordered_filter(gsl::narrow(reordered_filter_vec_size)); + int64_t reordered_filter_size = nchwc_output_channels * filter_input_channels; + for (size_t i = 2; i < 4; i++) { + reordered_filter_size *= conv_W_dims[i]; + } + std::vector reordered_filter(gsl::narrow(reordered_filter_size)); // Reorder the weights tensor statically. if (reorder_filter_OIHWBo) { - MlasReorderFilterOIHWBo(conv_W.dims().data(), conv_W.data(), reordered_filter.data()); + MlasReorderFilterOIHWBo(conv_W_dims.data(), conv_W.data(), reordered_filter.data()); } else { - MlasReorderFilterOIHWBiBo(conv_W.dims().data(), conv_W.data(), reordered_filter.data()); + MlasReorderFilterOIHWBiBo(conv_W_dims.data(), conv_W.data(), reordered_filter.data()); } ONNX_NAMESPACE::TensorProto nchwc_conv_W_tensor_proto; @@ -392,8 +426,9 @@ void NchwcTransformerImpl::TransformConv(Node& node) { nchwc_conv_W_tensor_proto.set_raw_data(reordered_filter.data(), reordered_filter.size() * sizeof(float)); nchwc_conv_W_tensor_proto.add_dims(nchwc_output_channels); - for (size_t i = 1; i < 4; i++) { - nchwc_conv_W_tensor_proto.add_dims(conv_W.dims()[i]); + nchwc_conv_W_tensor_proto.add_dims(filter_input_channels); + for (size_t i = 2; i < 4; i++) { + nchwc_conv_W_tensor_proto.add_dims(conv_W_dims[i]); } nchwc_conv_W_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_W_tensor_proto); @@ -436,6 +471,9 @@ void NchwcTransformerImpl::TransformConv(Node& node) { &node.GetAttributes(), kMSNchwcDomain); nchwc_node.SetExecutionProviderType(kCpuExecutionProvider); + if (nchwc_group_count != group_count) { + nchwc_node.AddAttribute("group", nchwc_group_count); + } nchwc_node.MutableInputDefs()[1] = nchwc_conv_W_arg; @@ -892,7 +930,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { removed_nodes_.push_front(node.Index()); } -void NchwcTransformerImpl::TransformTranspose(Node& node) { +void NchwcTransformerImpl::TransformTransposeToNhwc(Node& node) { auto& input_defs = node.MutableInputDefs(); auto& output_defs = node.MutableOutputDefs(); @@ -1070,7 +1108,33 @@ void NchwcTransformerImpl::TransformResize(Node& node) { removed_nodes_.push_front(node.Index()); } +void NchwcTransformerImpl::TrackTransposeFromNhwc(Node& node) { + const auto* perm_attr = graph_utils::GetNodeAttribute(node, "perm"); + if (perm_attr == nullptr || perm_attr->ints_size() != 4) { + return; + } + + // Test if this transposes from NHWC to NCHW layout order. + const int64_t* perm_data = perm_attr->ints().data(); + if (perm_data[0] != 0 || perm_data[1] != 3 || perm_data[2] != 1 || perm_data[3] != 2) { + return; + } + + // Verify that the node does not produce a graph output and produces output + // for a single node. + if (!graph_.GetNodeOutputsInGraphOutputs(node).empty() || node.GetOutputEdgesCount() != 1) { + return; + } + + transpose_from_nhwc_node_ = &node; + transpose_from_nhwc_output_arg_ = node.MutableOutputDefs()[0]; +} + void NchwcTransformerImpl::Transform(Node& node) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13})) { + TrackTransposeFromNhwc(node); + } + if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "FusedConv", {1}, kMSDomain)) { TransformConv(node); @@ -1097,7 +1161,7 @@ void NchwcTransformerImpl::Transform(Node& node) { } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9})) { TransformBatchNormalization(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Transpose", {1, 13})) { - TransformTranspose(node); + TransformTransposeToNhwc(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Upsample", {9, 13}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, "Resize", {10, 11, 13})) { TransformResize(node); diff --git a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h index 5741be2d2a..439ebb027e 100644 --- a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h +++ b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h @@ -105,7 +105,7 @@ class MlasNchwcConv2DTest : public MlasConv2DTest { if (DoReorderInput) { size_t NchwcInputElements = BatchCount * NchwcInputChannels * InputHeight * InputWidth; float* NchwcInput = BufferNchwcInput.GetBuffer(NchwcInputElements); - MlasReorderInput(InputShape, Input, NchwcInput); + ReorderInputNchw(InputShape, Input, NchwcInput); Input = NchwcInput; InputShape[1] = NchwcInputChannels; } diff --git a/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.h b/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.h index 38efcd7deb..10e3f7f927 100644 --- a/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.h +++ b/onnxruntime/test/mlas/unittest/test_pool2d_nchwc.h @@ -36,7 +36,7 @@ class MlasNchwcPool2DTest : public MlasPool2DTest { size_t NchwcOutputElements = size_t(NchwcOutputShape[0]) * size_t(NchwcOutputShape[1]) * size_t(NchwcOutputShape[2]) * size_t(NchwcOutputShape[3]); float* NchwcOutput = BufferNchwcOutput.GetBuffer(NchwcOutputElements); - MlasReorderInput(InputShape, Input, NchwcInput); + ReorderInputNchw(InputShape, Input, NchwcInput); MlasNchwcPool(PoolingKind, NchwcInputShape, diff --git a/onnxruntime/test/mlas/unittest/test_util.h b/onnxruntime/test/mlas/unittest/test_util.h index 9c72409e36..45f8c33f66 100644 --- a/onnxruntime/test/mlas/unittest/test_util.h +++ b/onnxruntime/test/mlas/unittest/test_util.h @@ -211,7 +211,7 @@ class MlasLongExecuteTests : public MlasTestFixture { } }; -// Some short Execute may not need to distinguish each parameters, +// Some short Execute may not need to distinguish each parameters, // because they finish quickly, and may disturb others by inject too many small tests. // Register it as whole using following helper. template @@ -237,3 +237,16 @@ class MlasDirectShortExecuteTests : public MlasTestFixture { } }; +inline +void ReorderInputNchw(const int64_t* input_shape, const float* S, float* D) { + const int64_t nchwc_block_size = static_cast(MlasNchwcGetBlockSize()); + int64_t batch_count = input_shape[0]; + int64_t channel_count = input_shape[1]; + int64_t nchwc_channel_count = (channel_count + nchwc_block_size - 1) & ~(nchwc_block_size - 1); + int64_t spatial_count = input_shape[2] * input_shape[3]; + for (int64_t n = 0; n < batch_count; n++) { + MlasReorderInputNchw(S, D, static_cast(channel_count), static_cast(spatial_count)); + S += spatial_count * channel_count; + D += spatial_count * nchwc_channel_count; + } +} diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index d1b495e536..9f357d57bd 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -970,10 +970,10 @@ TEST(NchwcOptimizerTests, MixedOutputUsage) { TEST(NchwcOptimizerTests, TensorAlignment) { auto build_test_case = [&](NchwcTestHelper& helper) { - // Input channel count must currently be a multiple of the NCHWc block size. - auto* input1_arg = helper.MakeInput({1, 60, 28, 42}); + // Input channel count must currently be a multiple of 4. + auto* input1_arg = helper.MakeInput({1, 62, 28, 42}); auto* output1_arg = helper.MakeOutput(); - helper.AddConvNode(input1_arg, output1_arg, {128, 60, 1, 1}); + helper.AddConvNode(input1_arg, output1_arg, {128, 62, 1, 1}); // Grouped input channel count must be a multiple of the NCHWc block size. auto* input2_arg = helper.MakeInput({1, 48, 28, 42}); @@ -1116,6 +1116,34 @@ TEST(NchwcOptimizerTests, BatchNormalization) { test_case(true); } +TEST(NchwcOptimizerTests, ConvReorderInputNhwc) { + auto test_case = [&](int64_t channels) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({5, 27, 29, channels}); + auto* transpose_output_arg = helper.MakeIntermediate(); + auto* output_arg = helper.MakeOutput(); + + helper.AddTransposeToNchwNode(input_arg, transpose_output_arg); + helper.AddConvNode(transpose_output_arg, output_arg, {34, channels, 1, 1}); + }; + + auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count["Transpose"], 0); + }; + + // Verify that a NHWC->NCHW transpose is fused into ReorderInput. + NchwcOptimizerTester(build_test_case, check_nchwc_graph); + }; + + for (int64_t channels = 16; channels <= 32; channels += 4) { + test_case(channels); + } +} + TEST(NchwcOptimizerTests, ConvReorderOutputNhwc) { auto build_test_case = [&](NchwcTestHelper& helper) { auto* input_arg = helper.MakeInput({1, 64, 28, 32});