From 16297a8e612e8a31fa6df048c2d0fe33d92ca856 Mon Sep 17 00:00:00 2001 From: Tracy Sharpe <42477615+tracysh@users.noreply.github.com> Date: Mon, 10 May 2021 12:16:16 -0700 Subject: [PATCH] Implement NCHWc Upsample linear mode (#7623) Extend the existing NCHWc Upsample operator to support linear modes too. --- docs/ContribOperators.md | 4 + onnxruntime/contrib_ops/cpu/nchwc_ops.cc | 181 ++++++++++++++---- onnxruntime/contrib_ops/cpu/nchwc_ops.h | 48 ++++- .../graph/contrib_ops/nchwc_schema_defs.cc | 2 + onnxruntime/core/mlas/inc/mlas.h | 30 ++- onnxruntime/core/mlas/lib/snchwc.cpp | 112 ++++++++++- .../core/optimizer/nchwc_transformer.cc | 49 +++-- .../test/optimizer/nchwc_optimizer_test.cc | 59 +++++- 8 files changed, 414 insertions(+), 71 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index fc9d26441c..4e18737a31 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3063,6 +3063,10 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
coordinate_transformation_mode : string
+
+
mode : string
+
scales : list of ints
diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc index fd9c1f3a40..6439b2c8f4 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc @@ -52,14 +52,8 @@ Status ReorderInput::Compute(OpKernelContext* context) const { // elements, so that operations involving a smaller number of channels will // process more rows per worker. constexpr ptrdiff_t worker_goal = 48 * 1024; - ptrdiff_t work_per_worker = worker_goal / nchwc_channels; - if (work_per_worker == 0) { - work_per_worker = 1; - } - worker_count = total_work / work_per_worker; - if (worker_count == 0) { - worker_count = 1; - } + ptrdiff_t work_per_worker = std::max(worker_goal / nchwc_channels, 1); + worker_count = std::max(total_work / work_per_worker, 1); } else { // Each iteration produces one spatial_size chunk of NCHWc blocks. total_work = static_cast(batch_count * (nchwc_channels / nchwc_block_size)); @@ -205,20 +199,21 @@ Status NchwcConv::Compute(OpKernelContext* context) const { } } - MlasNchwcConv(X_shape.GetDims().data(), - kernel_shape.data(), - dilations.data(), - pads.data(), - strides.data(), - Y_dims.data(), - static_cast(conv_attrs_.group), - X->template Data(), - W->template Data(), - B != nullptr ? B->template Data() : nullptr, - y_data, - &activation_, - Sum == nullptr, - context->GetOperatorThreadPool()); + MlasNchwcConv( + X_shape.GetDims().data(), + kernel_shape.data(), + dilations.data(), + pads.data(), + strides.data(), + Y_dims.data(), + static_cast(conv_attrs_.group), + X->template Data(), + W->template Data(), + B != nullptr ? B->template Data() : nullptr, + y_data, + &activation_, + Sum == nullptr, + context->GetOperatorThreadPool()); return Status::OK(); } @@ -233,16 +228,17 @@ Status NchwcPoolBase::NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind std::vector output_dims = pool_attrs_.SetOutputSize(X_shape, X_shape[1], &pads); auto* Y = context->Output(0, output_dims); - MlasNchwcPool(kind, - X_shape.GetDims().data(), - pool_attrs_.global_pooling ? nullptr : pool_attrs_.kernel_shape.data(), - pool_attrs_.global_pooling ? nullptr : pool_attrs_.dilations.data(), - pool_attrs_.global_pooling ? nullptr : pads.data(), - pool_attrs_.global_pooling ? nullptr : pool_attrs_.strides.data(), - output_dims.data(), - X->template Data(), - Y->template MutableData(), - context->GetOperatorThreadPool()); + MlasNchwcPool( + kind, + X_shape.GetDims().data(), + pool_attrs_.global_pooling ? nullptr : pool_attrs_.kernel_shape.data(), + pool_attrs_.global_pooling ? nullptr : pool_attrs_.dilations.data(), + pool_attrs_.global_pooling ? nullptr : pads.data(), + pool_attrs_.global_pooling ? nullptr : pool_attrs_.strides.data(), + output_dims.data(), + X->template Data(), + Y->template MutableData(), + context->GetOperatorThreadPool()); return Status::OK(); } @@ -256,19 +252,124 @@ Status NchwcAveragePool::Compute(OpKernelContext* context) const { : MlasAveragePoolingExcludePad); } +std::vector NchwcUpsample::ComputeInterpolation(int64_t input_length, + int64_t output_length, + int64_t scale) const { + std::vector interpolation; + interpolation.resize(output_length); + + if (scale == 1) { + // Identity map for unscaled. + for (int64_t o = 0; o < output_length; o++) { + interpolation[o] = static_cast(o); + } + } else if (transformation_mode_ == TransformationMode::ALIGN_CORNERS) { + for (int64_t o = 0; o < output_length; o++) { + interpolation[o] = + static_cast(o) * static_cast(input_length - 1) / static_cast(output_length - 1); + } + } else if (transformation_mode_ == TransformationMode::HALF_PIXEL) { + for (int64_t o = 0; o < output_length; o++) { + interpolation[o] = + std::max(0.0f, (static_cast(o) + 0.5f) / static_cast(scale) - 0.5f); + } + } else { + // Default to TransformationMode::ASYMMETRIC. + for (int64_t o = 0; o < output_length; o++) { + interpolation[o] = static_cast(o) / static_cast(scale); + } + } + + return interpolation; +} + Status NchwcUpsample::Compute(OpKernelContext* context) const { const auto* X = context->Input(0); - const auto& X_shape = X->Shape(); - ORT_ENFORCE(X_shape.NumDimensions() == 4); + const auto& X_shape = X->Shape().GetDims(); + ORT_ENFORCE(X_shape.size() == 4); ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0); - TensorShape Y_shape{X_shape[0], X_shape[1], X_shape[2] * scales_[2], X_shape[3] * scales_[3]}; - auto* Y = context->Output(0, Y_shape); + const int64_t batch_count = X_shape[0]; + const int64_t nchwc_channels = X_shape[1]; - MlasNchwcUpsample(X_shape.GetDims().data(), - scales_.data() + 2, - X->template Data(), - Y->template MutableData()); + const int64_t input_h = X_shape[2]; + const int64_t input_w = X_shape[3]; + + const int64_t output_h = input_h * scales_[2]; + const int64_t output_w = input_w * scales_[3]; + + auto* Y = context->Output(0, {batch_count, nchwc_channels, output_h, output_w}); + + // Bail out early if one of the dimensions is zero. + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + const auto* x_data = X->template Data(); + auto* y_data = Y->template MutableData(); + + if (nearest_mode_) { + MlasNchwcUpsampleNearest( + X_shape.data(), + scales_.data() + 2, + x_data, + y_data); + } else { + // Compute the interpolation value per output height and width. + const auto interpolation_h = ComputeInterpolation(input_h, output_h, scales_[2]); + const auto interpolation_w = ComputeInterpolation(input_w, output_w, scales_[3]); + + const int64_t nchwc_block_size = static_cast(MlasNchwcGetBlockSize()); + const ptrdiff_t total_work = ((batch_count * nchwc_channels) / nchwc_block_size) * output_h; + // Partition the work with the goal of generating the following number of + // elements, so that operations involving a smaller number of columns will + // process more rows per worker. + constexpr ptrdiff_t worker_goal = 16 * 1024; + ptrdiff_t work_per_worker = std::max(worker_goal / (output_w * nchwc_block_size), 1); + ptrdiff_t worker_count = std::max(total_work / work_per_worker, 1); + + auto upsample_worker = [&](ptrdiff_t batch) { + auto work = concurrency::ThreadPool::PartitionWork(batch, worker_count, total_work); + int64_t work_index = static_cast(work.start); + int64_t work_remaining = static_cast(work.end - work.start); + + while (work_remaining > 0) { + // Limit the current loop iteration to the same source image. + const int64_t channel_index = work_index / output_h; + int64_t row_index = work_index % output_h; + int64_t rows_this_iteration = std::min(work_remaining, output_h - row_index); + + work_index += rows_this_iteration; + work_remaining -= rows_this_iteration; + + const auto* x_channel_base = x_data + (channel_index * input_h * input_w * nchwc_block_size); + auto* y_row = y_data + (((channel_index * output_h) + row_index) * output_w * nchwc_block_size); + + // Loop upsampling each row of the output. + do { + MlasNchwcUpsampleLinear( + static_cast(input_h), + static_cast(input_w), + static_cast(output_w), + interpolation_h[row_index], + interpolation_w.data(), + x_channel_base, + y_row); + y_row += output_w * nchwc_block_size; + row_index++; + } while (--rows_this_iteration); + } + }; + + concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); + + // Handle the work in a single batch if only a single thread is available. + if (concurrency::ThreadPool::DegreeOfParallelism(thread_pool) == 1) { + worker_count = 1; + } + + concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, worker_count, upsample_worker); + } return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.h b/onnxruntime/contrib_ops/cpu/nchwc_ops.h index f6639d03f4..7e8e735694 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.h +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.h @@ -12,7 +12,7 @@ namespace onnxruntime { namespace contrib { -class ReorderInput : public OpKernel { +class ReorderInput final : public OpKernel { public: ReorderInput(const OpKernelInfo& info) : OpKernel(info) { ORT_ENFORCE(info.GetAttr("channels_last", &channels_last_).IsOK()); @@ -24,7 +24,7 @@ class ReorderInput : public OpKernel { int64_t channels_last_; }; -class ReorderOutput : public OpKernel { +class ReorderOutput final : public OpKernel { public: ReorderOutput(const OpKernelInfo& info) : OpKernel(info) { ORT_ENFORCE(info.GetAttr("channels", &channels_).IsOK()); @@ -39,7 +39,7 @@ class ReorderOutput : public OpKernel { int64_t channels_last_; }; -class NchwcConv : public OpKernel { +class NchwcConv final : public OpKernel { public: NchwcConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) { ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); @@ -64,7 +64,7 @@ class NchwcPoolBase : public PoolBase { Status NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind) const; }; -class NchwcMaxPool : public OpKernel, public NchwcPoolBase { +class NchwcMaxPool final : public OpKernel, public NchwcPoolBase { public: NchwcMaxPool(const OpKernelInfo& info) : OpKernel(info), NchwcPoolBase(info) { } @@ -72,7 +72,7 @@ class NchwcMaxPool : public OpKernel, public NchwcPoolBase { Status Compute(OpKernelContext* context) const override; }; -class NchwcAveragePool : public OpKernel, public NchwcPoolBase { +class NchwcAveragePool final : public OpKernel, public NchwcPoolBase { public: NchwcAveragePool(const OpKernelInfo& info) : OpKernel(info), NchwcPoolBase(info) { } @@ -80,19 +80,55 @@ class NchwcAveragePool : public OpKernel, public NchwcPoolBase { Status Compute(OpKernelContext* context) const override; }; -class NchwcUpsample : public OpKernel { +class NchwcUpsample final : public OpKernel { + private: + enum class TransformationMode { + ASYMMETRIC, + ALIGN_CORNERS, + HALF_PIXEL, + }; + public: NchwcUpsample(const OpKernelInfo& info) : OpKernel(info) { ORT_ENFORCE(info.GetAttrs("scales", scales_).IsOK()); ORT_ENFORCE(scales_.size() == 4); // Batch and channel dimensions cannot scale and spatial scaling must be positive. ORT_ENFORCE(scales_[0] == 1 && scales_[1] == 1 && scales_[2] >= 1 && scales_[3] >= 1); + + std::string transformation_mode; + ORT_ENFORCE(info.GetAttr("coordinate_transformation_mode", &transformation_mode).IsOK()); + if (transformation_mode == "asymmetric") { + transformation_mode_ = TransformationMode::ASYMMETRIC; + } else if (transformation_mode == "align_corners") { + transformation_mode_ = TransformationMode::ALIGN_CORNERS; + } else if (transformation_mode == "half_pixel") { + transformation_mode_ = TransformationMode::HALF_PIXEL; + } else { + ORT_THROW("Unsupported transformation mode '" + transformation_mode + "' for NCHWc Upsample"); + } + + std::string mode; + ORT_ENFORCE(info.GetAttr("mode", &mode).IsOK()); + if (mode == "nearest") { + nearest_mode_ = true; + ORT_ENFORCE(transformation_mode_ == TransformationMode::ASYMMETRIC); + } else if (mode == "linear") { + nearest_mode_ = false; + } else { + ORT_THROW("Unsupported mode '" + mode + "' for NCHWc Upsample"); + } } Status Compute(OpKernelContext* context) const override; private: + std::vector ComputeInterpolation(int64_t input_length, + int64_t output_length, + int64_t scale) const; + std::vector scales_; + TransformationMode transformation_mode_; + bool nearest_mode_; }; } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc b/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc index 39d7621eec..065e2912b5 100644 --- a/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc @@ -188,6 +188,8 @@ void RegisterNchwcSchemas() { .SinceVersion(1) .SetDoc(R"DOC(For internal use.)DOC") .Attr("scales", "", AttributeProto::INTS, OPTIONAL_VALUE) + .Attr("mode", "", AttributeProto::STRING, std::string("nearest")) + .Attr("coordinate_transformation_mode", "", AttributeProto::STRING, std::string("asymmetric")) .Input(0, "X", "", "T") .Output(0, "Y", "", "T") .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors") diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index b2ce940482..ecceb64f18 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -168,7 +168,7 @@ struct MLAS_SGEMM_DATA_PARAMS { /** * @brief Batched single precision matrix/matrix multiply operation (SGEMM) - * + * * @param TransA Supplies the transpose operation for matrix A. * @param TransB Supplies the transpose operation for matrix B. * @param M Supplies the number of rows of matrix A and matrix C. @@ -195,7 +195,7 @@ MlasGemmBatch( /** * @brief Single precision matrix/matrix multiply operation (SGEMM) - * + * * @param TransA Supplies the transpose operation for matrix A. * @param TransB Supplies the transpose operation for matrix B. * @param M Supplies the number of rows of matrix A and matrix C. @@ -223,7 +223,7 @@ MlasGemm( /** * @brief Single precision matrix/matrix multiply operation (SGEMM) - * + * * @param TransA Supplies the transpose operation for matrix A. * @param TransB Supplies the transpose operation for matrix B. * @param M Supplies the number of rows of matrix A and matrix C. @@ -231,7 +231,7 @@ MlasGemm( * @param K Supplies the number of columns of matrix A and the number of rows of matrix B. * @param alpha Supplies the scalar alpha multiplier (see SGEMM definition) - * @param A Supplies the address of matrix A + * @param A Supplies the address of matrix A * @param lda Supplies the first dimension of matrix A. * @param B Supplies the address of matrix B * @param ldb Supplies the first dimension of matrix B. @@ -341,7 +341,7 @@ struct MLAS_DGEMM_DATA_PARAMS { /** * @brief Batched double precision matrix/matrix multiply operation (DGEMM) - * + * * @param TransA Supplies the transpose operation for matrix A. * @param TransB Supplies the transpose operation for matrix B. * @param M Supplies the number of rows of matrix A and matrix C. @@ -369,7 +369,7 @@ MlasGemmBatch( /** * @brief Double precision matrix/matrix multiply operation (DGEMM) - * + * * @param TransA Supplies the transpose operation for matrix A. * @param TransB Supplies the transpose operation for matrix B. * @param M Supplies the number of rows of matrix A and matrix C. @@ -397,7 +397,7 @@ MlasGemm( /** * @brief Double precision matrix/matrix multiply operation (DGEMM) - * + * * @param TransA Supplies the transpose operation for matrix A. * @param TransB Supplies the transpose operation for matrix B. * @param M Supplies the number of rows of matrix A and matrix C. @@ -405,7 +405,7 @@ MlasGemm( * @param K Supplies the number of columns of matrix A and the number of rows of matrix B. * @param alpha Supplies the scalar alpha multiplier (see SGEMM definition) - * @param A Supplies the address of matrix A + * @param A Supplies the address of matrix A * @param lda Supplies the first dimension of matrix A. * @param B Supplies the address of matrix B * @param ldb Supplies the first dimension of matrix B. @@ -929,13 +929,25 @@ MlasNchwcPool( void MLASCALL -MlasNchwcUpsample( +MlasNchwcUpsampleNearest( const int64_t* InputShape, const int64_t* Scales, const float* Input, float* Output ); +void +MLASCALL +MlasNchwcUpsampleLinear( + size_t InputHeight, + size_t InputWidth, + size_t OutputWidth, + float InterpolationHeight, + const float* InterpolationWidth, + const float* Input, + float* Output + ); + // // Linear quantization routines. // diff --git a/onnxruntime/core/mlas/lib/snchwc.cpp b/onnxruntime/core/mlas/lib/snchwc.cpp index a30ee219af..2bc0b5d661 100644 --- a/onnxruntime/core/mlas/lib/snchwc.cpp +++ b/onnxruntime/core/mlas/lib/snchwc.cpp @@ -1410,7 +1410,7 @@ Return Value: void MLASCALL -MlasNchwcUpsample( +MlasNchwcUpsampleNearest( const int64_t* InputShape, const int64_t* Scales, const float* Input, @@ -1511,6 +1511,116 @@ Return Value: } } +MLAS_FORCEINLINE +void +MlasNchwcExtractInterpolation( + float InterpolationValue, + size_t InputLimit, + ptrdiff_t InputIndex[2], + MLAS_FLOAT32X4 Multipliers[2] + ) +{ + InputIndex[0] = ptrdiff_t(InterpolationValue); + InputIndex[1] = std::min(InputIndex[0] + 1, ptrdiff_t(InputLimit - 1)); + + float ScalarMultiplier0 = InterpolationValue - float(InputIndex[0]); + float ScalarMultiplier1 = 1.0f - ScalarMultiplier0; + + Multipliers[0] = MlasBroadcastFloat32x4(ScalarMultiplier0); + Multipliers[1] = MlasBroadcastFloat32x4(ScalarMultiplier1); +} + +void +MLASCALL +MlasNchwcUpsampleLinear( + size_t InputHeight, + size_t InputWidth, + size_t OutputWidth, + float InterpolationHeight, + const float* InterpolationWidth, + const float* Input, + float* Output + ) +/*++ + +Routine Description: + + This routine implements the NCHWc upsample linear operation for a single row. + + The integer portion of each interpolation float supplies the mapping from + output element to input element. The fractional portion supplies the relative + weights for the four points of the interpolation. + +Arguments: + + InputHeight - Supplies the input height. + + InputWidth - Supplies the input width. + + OutputWidth - Supplies the output width. + + InterpolationHeight - Supplies the height interpolation values for the target + row. + + InterpolationWidth - Supplies an array of computed interpolation values of + length OutputWidth. + + Input - Supplies the input spatial buffer. + + Output - Supplies the output row buffer. + +Return Value: + + None. + +--*/ +{ + const size_t BlockSize = MlasNchwcGetBlockSize(); + + ptrdiff_t InputIndexY[2]; + MLAS_FLOAT32X4 MultipliersY[2]; + + MlasNchwcExtractInterpolation(InterpolationHeight, InputHeight, InputIndexY, MultipliersY); + + const float* InputRowY0 = Input + InputIndexY[0] * InputWidth * BlockSize; + const float* InputRowY1 = Input + InputIndexY[1] * InputWidth * BlockSize; + + for (size_t ow = 0; ow < OutputWidth; ow++) { + + ptrdiff_t InputIndexX[2]; + MLAS_FLOAT32X4 MultipliersX[2]; + + MlasNchwcExtractInterpolation(InterpolationWidth[ow], InputWidth, InputIndexX, MultipliersX); + + MLAS_FLOAT32X4 MultiplierY0X0 = MlasMultiplyFloat32x4(MultipliersY[0], MultipliersX[0]); + MLAS_FLOAT32X4 MultiplierY0X1 = MlasMultiplyFloat32x4(MultipliersY[0], MultipliersX[1]); + MLAS_FLOAT32X4 MultiplierY1X0 = MlasMultiplyFloat32x4(MultipliersY[1], MultipliersX[0]); + MLAS_FLOAT32X4 MultiplierY1X1 = MlasMultiplyFloat32x4(MultipliersY[1], MultipliersX[1]); + + for (size_t bc = 0; bc < BlockSize; bc += 4) { + + MLAS_FLOAT32X4 v00 = MlasLoadFloat32x4(InputRowY0 + InputIndexX[0] * BlockSize + bc); + MLAS_FLOAT32X4 v01 = MlasLoadFloat32x4(InputRowY0 + InputIndexX[1] * BlockSize + bc); + MLAS_FLOAT32X4 v10 = MlasLoadFloat32x4(InputRowY1 + InputIndexX[0] * BlockSize + bc); + MLAS_FLOAT32X4 v11 = MlasLoadFloat32x4(InputRowY1 + InputIndexX[1] * BlockSize + bc); + + v00 = MlasMultiplyFloat32x4(MultiplierY1X1, v00); + v01 = MlasMultiplyFloat32x4(MultiplierY1X0, v01); + v10 = MlasMultiplyFloat32x4(MultiplierY0X1, v10); + v11 = MlasMultiplyFloat32x4(MultiplierY0X0, v11); + + MLAS_FLOAT32X4 Reduction0 = MlasAddFloat32x4(v00, v01); + MLAS_FLOAT32X4 Reduction1 = MlasAddFloat32x4(v10, v11); + + MLAS_FLOAT32X4 Reduction = MlasAddFloat32x4(Reduction0, Reduction1); + + MlasStoreFloat32x4(&Output[bc], Reduction); + } + + Output += BlockSize; + } +} + #if !defined(MLAS_TARGET_AMD64) // diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 7065876510..253db85e64 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -982,14 +982,21 @@ void NchwcTransformerImpl::TransformResize(Node& node) { } auto* nchwc_input = it->second.get(); - // Only support the nearest interpolation mode (the default value). + // Support nearest (default) and linear modes. const auto* mode_attr = graph_utils::GetNodeAttribute(node, "mode"); + bool nearest_mode = true; if (mode_attr != nullptr && utils::HasString(*mode_attr)) { if (mode_attr->s() != "nearest") { - return; + if (mode_attr->s() == "linear") { + nearest_mode = false; + } else { + return; + } } } + const ONNX_NAMESPACE::AttributeProto* transformation_mode_attr = nullptr; + NodeArg* sizes_arg = nullptr; NodeArg* scales_arg = nullptr; @@ -1001,20 +1008,30 @@ void NchwcTransformerImpl::TransformResize(Node& node) { scales_arg = input_defs[2]; } - // Only support the asymmetric coordinate transformation mode. - const auto* transform_mode_attr = graph_utils::GetNodeAttribute(node, "coordinate_transformation_mode"); - if ((transform_mode_attr == nullptr) || - !utils::HasString(*transform_mode_attr) || - (transform_mode_attr->s() != "asymmetric")) { + transformation_mode_attr = graph_utils::GetNodeAttribute(node, "coordinate_transformation_mode"); + if ((transformation_mode_attr == nullptr) || + !utils::HasString(*transformation_mode_attr)) { return; } + if (transformation_mode_attr->s() != "asymmetric") { + // Nearest mode kernel support asymmetric transformation mode only. + if (nearest_mode) { + return; + } + if ((transformation_mode_attr->s() != "align_corners") && + (transformation_mode_attr->s() != "half_pixel")) { + return; + } + } - // Only support the floor rounding mode. - const auto* nearest_mode_attr = graph_utils::GetNodeAttribute(node, "nearest_mode"); - if ((nearest_mode_attr == nullptr) || - !utils::HasString(*nearest_mode_attr) || - (nearest_mode_attr->s() != "floor")) { - return; + if (nearest_mode) { + // Only support the floor rounding mode. + const auto* nearest_mode_attr = graph_utils::GetNodeAttribute(node, "nearest_mode"); + if ((nearest_mode_attr == nullptr) || + !utils::HasString(*nearest_mode_attr) || + (nearest_mode_attr->s() != "floor")) { + return; + } } } else { scales_arg = input_defs[1]; @@ -1099,6 +1116,12 @@ void NchwcTransformerImpl::TransformResize(Node& node) { kMSNchwcDomain); nchwc_node.SetExecutionProviderType(kCpuExecutionProvider); nchwc_node.AddAttribute("scales", scales_attr); + if (!nearest_mode) { + nchwc_node.AddAttribute("mode", mode_attr->s()); + if (transformation_mode_attr != nullptr) { + nchwc_node.AddAttribute("coordinate_transformation_mode", transformation_mode_attr->s()); + } + } nchwc_input->remaining_original_uses_--; diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 9f357d57bd..69ca27e241 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -1085,7 +1085,7 @@ TEST(NchwcOptimizerTests, BatchNormalization) { // Override the sample tolerance for this test. By default, the NCHWc // tests generate bit identical results when run with and without - // optimizations, but the BatchNormalizationtransform does introduce + // optimizations, but the BatchNormalization transform does introduce // small bit differences. helper.per_sample_tolerance_ = .00025; }; @@ -1213,7 +1213,7 @@ TEST(NchwcOptimizerTests, ConvReorderOutputCnhw) { NchwcOptimizerTester(build_test_case, check_nchwc_graph); } -TEST(NchwcOptimizerTests, Upsample) { +TEST(NchwcOptimizerTests, UpsampleNearest) { auto test_case = [&](int opset_version, float scale_h, float scale_w, bool use_sizes_arg) { auto build_test_case = [&](NchwcTestHelper& helper) { auto* input_arg = helper.MakeInput({3, 16, 27, 15}); @@ -1283,6 +1283,61 @@ TEST(NchwcOptimizerTests, Upsample) { test_case(13, 2.2f, 2.8f, true); } +TEST(NchwcOptimizerTests, UpsampleLinear) { + auto test_case = [&](int opset_version, float scale_h, float scale_w, const std::string& transformation_mode) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({3, 16, 21, 25}); + auto* conv_output_arg = helper.MakeIntermediate(); + auto* output_arg = helper.MakeOutput(); + + helper.AddConvNode(input_arg, conv_output_arg, {28, 16, 1, 1}); + + std::string op_name = opset_version >= 10 ? "Resize" : "Upsample"; + std::vector input_args; + input_args.push_back(conv_output_arg); + if (opset_version >= 11) { + input_args.push_back(helper.Make1DInitializer({0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f})); + } + input_args.push_back(helper.Make1DInitializer({1.f, 1.f, scale_h, scale_w})); + Node& resize_node = helper.AddNode(op_name, input_args, {output_arg}); + resize_node.AddAttribute("mode", "linear"); + if (opset_version >= 11) { + resize_node.AddAttribute("coordinate_transformation_mode", transformation_mode); + } + + helper.per_sample_tolerance_ = .001f; + }; + + auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Upsample"], 1); + EXPECT_EQ(op_to_count["Resize"] + op_to_count["Upsample"], 0); + }; + + NchwcOptimizerTester(build_test_case, check_nchwc_graph, opset_version); + }; + + // Verify that upsample nodes can be converted to the NCHWc format for + // various versions of the operator. + std::vector transformation_modes{"asymmetric", "align_corners", "half_pixel"}; + for (auto& transformation_mode : transformation_modes) { + static const int opset_versions[] = {9, 10, 11, 13}; + for (auto opset_version : opset_versions) { + // Older versions of the operator do not support transformation modes. + if (opset_version < 11 && transformation_mode == "asymmetric") { + continue; + } + test_case(opset_version, 1.f, 1.f, transformation_mode); + test_case(opset_version, 2.f, 2.f, transformation_mode); + test_case(opset_version, 3.f, 5.f, transformation_mode); + test_case(opset_version, 9.f, 7.f, transformation_mode); + } + } +} + TEST(NchwcOptimizerTests, Activation) { auto test_case = [&](const std::string& activation_op_type) { auto build_test_case = [&](NchwcTestHelper& helper) {