diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index fc9d26441c..4e18737a31 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -3063,6 +3063,10 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes
+- coordinate_transformation_mode : string
+
+- mode : string
+
- scales : list of ints
diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc
index fd9c1f3a40..6439b2c8f4 100644
--- a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc
+++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc
@@ -52,14 +52,8 @@ Status ReorderInput::Compute(OpKernelContext* context) const {
// elements, so that operations involving a smaller number of channels will
// process more rows per worker.
constexpr ptrdiff_t worker_goal = 48 * 1024;
- ptrdiff_t work_per_worker = worker_goal / nchwc_channels;
- if (work_per_worker == 0) {
- work_per_worker = 1;
- }
- worker_count = total_work / work_per_worker;
- if (worker_count == 0) {
- worker_count = 1;
- }
+ ptrdiff_t work_per_worker = std::max(worker_goal / nchwc_channels, 1);
+ worker_count = std::max(total_work / work_per_worker, 1);
} else {
// Each iteration produces one spatial_size chunk of NCHWc blocks.
total_work = static_cast(batch_count * (nchwc_channels / nchwc_block_size));
@@ -205,20 +199,21 @@ Status NchwcConv::Compute(OpKernelContext* context) const {
}
}
- MlasNchwcConv(X_shape.GetDims().data(),
- kernel_shape.data(),
- dilations.data(),
- pads.data(),
- strides.data(),
- Y_dims.data(),
- static_cast(conv_attrs_.group),
- X->template Data(),
- W->template Data(),
- B != nullptr ? B->template Data() : nullptr,
- y_data,
- &activation_,
- Sum == nullptr,
- context->GetOperatorThreadPool());
+ MlasNchwcConv(
+ X_shape.GetDims().data(),
+ kernel_shape.data(),
+ dilations.data(),
+ pads.data(),
+ strides.data(),
+ Y_dims.data(),
+ static_cast(conv_attrs_.group),
+ X->template Data(),
+ W->template Data(),
+ B != nullptr ? B->template Data() : nullptr,
+ y_data,
+ &activation_,
+ Sum == nullptr,
+ context->GetOperatorThreadPool());
return Status::OK();
}
@@ -233,16 +228,17 @@ Status NchwcPoolBase::NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind
std::vector output_dims = pool_attrs_.SetOutputSize(X_shape, X_shape[1], &pads);
auto* Y = context->Output(0, output_dims);
- MlasNchwcPool(kind,
- X_shape.GetDims().data(),
- pool_attrs_.global_pooling ? nullptr : pool_attrs_.kernel_shape.data(),
- pool_attrs_.global_pooling ? nullptr : pool_attrs_.dilations.data(),
- pool_attrs_.global_pooling ? nullptr : pads.data(),
- pool_attrs_.global_pooling ? nullptr : pool_attrs_.strides.data(),
- output_dims.data(),
- X->template Data(),
- Y->template MutableData(),
- context->GetOperatorThreadPool());
+ MlasNchwcPool(
+ kind,
+ X_shape.GetDims().data(),
+ pool_attrs_.global_pooling ? nullptr : pool_attrs_.kernel_shape.data(),
+ pool_attrs_.global_pooling ? nullptr : pool_attrs_.dilations.data(),
+ pool_attrs_.global_pooling ? nullptr : pads.data(),
+ pool_attrs_.global_pooling ? nullptr : pool_attrs_.strides.data(),
+ output_dims.data(),
+ X->template Data(),
+ Y->template MutableData(),
+ context->GetOperatorThreadPool());
return Status::OK();
}
@@ -256,19 +252,124 @@ Status NchwcAveragePool::Compute(OpKernelContext* context) const {
: MlasAveragePoolingExcludePad);
}
+std::vector NchwcUpsample::ComputeInterpolation(int64_t input_length,
+ int64_t output_length,
+ int64_t scale) const {
+ std::vector interpolation;
+ interpolation.resize(output_length);
+
+ if (scale == 1) {
+ // Identity map for unscaled.
+ for (int64_t o = 0; o < output_length; o++) {
+ interpolation[o] = static_cast(o);
+ }
+ } else if (transformation_mode_ == TransformationMode::ALIGN_CORNERS) {
+ for (int64_t o = 0; o < output_length; o++) {
+ interpolation[o] =
+ static_cast(o) * static_cast(input_length - 1) / static_cast(output_length - 1);
+ }
+ } else if (transformation_mode_ == TransformationMode::HALF_PIXEL) {
+ for (int64_t o = 0; o < output_length; o++) {
+ interpolation[o] =
+ std::max(0.0f, (static_cast(o) + 0.5f) / static_cast(scale) - 0.5f);
+ }
+ } else {
+ // Default to TransformationMode::ASYMMETRIC.
+ for (int64_t o = 0; o < output_length; o++) {
+ interpolation[o] = static_cast(o) / static_cast(scale);
+ }
+ }
+
+ return interpolation;
+}
+
Status NchwcUpsample::Compute(OpKernelContext* context) const {
const auto* X = context->Input(0);
- const auto& X_shape = X->Shape();
- ORT_ENFORCE(X_shape.NumDimensions() == 4);
+ const auto& X_shape = X->Shape().GetDims();
+ ORT_ENFORCE(X_shape.size() == 4);
ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
- TensorShape Y_shape{X_shape[0], X_shape[1], X_shape[2] * scales_[2], X_shape[3] * scales_[3]};
- auto* Y = context->Output(0, Y_shape);
+ const int64_t batch_count = X_shape[0];
+ const int64_t nchwc_channels = X_shape[1];
- MlasNchwcUpsample(X_shape.GetDims().data(),
- scales_.data() + 2,
- X->template Data(),
- Y->template MutableData());
+ const int64_t input_h = X_shape[2];
+ const int64_t input_w = X_shape[3];
+
+ const int64_t output_h = input_h * scales_[2];
+ const int64_t output_w = input_w * scales_[3];
+
+ auto* Y = context->Output(0, {batch_count, nchwc_channels, output_h, output_w});
+
+ // Bail out early if one of the dimensions is zero.
+ if (Y->Shape().Size() == 0) {
+ return Status::OK();
+ }
+
+ const auto* x_data = X->template Data();
+ auto* y_data = Y->template MutableData();
+
+ if (nearest_mode_) {
+ MlasNchwcUpsampleNearest(
+ X_shape.data(),
+ scales_.data() + 2,
+ x_data,
+ y_data);
+ } else {
+ // Compute the interpolation value per output height and width.
+ const auto interpolation_h = ComputeInterpolation(input_h, output_h, scales_[2]);
+ const auto interpolation_w = ComputeInterpolation(input_w, output_w, scales_[3]);
+
+ const int64_t nchwc_block_size = static_cast(MlasNchwcGetBlockSize());
+ const ptrdiff_t total_work = ((batch_count * nchwc_channels) / nchwc_block_size) * output_h;
+ // Partition the work with the goal of generating the following number of
+ // elements, so that operations involving a smaller number of columns will
+ // process more rows per worker.
+ constexpr ptrdiff_t worker_goal = 16 * 1024;
+ ptrdiff_t work_per_worker = std::max(worker_goal / (output_w * nchwc_block_size), 1);
+ ptrdiff_t worker_count = std::max(total_work / work_per_worker, 1);
+
+ auto upsample_worker = [&](ptrdiff_t batch) {
+ auto work = concurrency::ThreadPool::PartitionWork(batch, worker_count, total_work);
+ int64_t work_index = static_cast(work.start);
+ int64_t work_remaining = static_cast(work.end - work.start);
+
+ while (work_remaining > 0) {
+ // Limit the current loop iteration to the same source image.
+ const int64_t channel_index = work_index / output_h;
+ int64_t row_index = work_index % output_h;
+ int64_t rows_this_iteration = std::min(work_remaining, output_h - row_index);
+
+ work_index += rows_this_iteration;
+ work_remaining -= rows_this_iteration;
+
+ const auto* x_channel_base = x_data + (channel_index * input_h * input_w * nchwc_block_size);
+ auto* y_row = y_data + (((channel_index * output_h) + row_index) * output_w * nchwc_block_size);
+
+ // Loop upsampling each row of the output.
+ do {
+ MlasNchwcUpsampleLinear(
+ static_cast(input_h),
+ static_cast(input_w),
+ static_cast(output_w),
+ interpolation_h[row_index],
+ interpolation_w.data(),
+ x_channel_base,
+ y_row);
+ y_row += output_w * nchwc_block_size;
+ row_index++;
+ } while (--rows_this_iteration);
+ }
+ };
+
+ concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
+
+ // Handle the work in a single batch if only a single thread is available.
+ if (concurrency::ThreadPool::DegreeOfParallelism(thread_pool) == 1) {
+ worker_count = 1;
+ }
+
+ concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, worker_count, upsample_worker);
+ }
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.h b/onnxruntime/contrib_ops/cpu/nchwc_ops.h
index f6639d03f4..7e8e735694 100644
--- a/onnxruntime/contrib_ops/cpu/nchwc_ops.h
+++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.h
@@ -12,7 +12,7 @@
namespace onnxruntime {
namespace contrib {
-class ReorderInput : public OpKernel {
+class ReorderInput final : public OpKernel {
public:
ReorderInput(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(info.GetAttr("channels_last", &channels_last_).IsOK());
@@ -24,7 +24,7 @@ class ReorderInput : public OpKernel {
int64_t channels_last_;
};
-class ReorderOutput : public OpKernel {
+class ReorderOutput final : public OpKernel {
public:
ReorderOutput(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(info.GetAttr("channels", &channels_).IsOK());
@@ -39,7 +39,7 @@ class ReorderOutput : public OpKernel {
int64_t channels_last_;
};
-class NchwcConv : public OpKernel {
+class NchwcConv final : public OpKernel {
public:
NchwcConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
@@ -64,7 +64,7 @@ class NchwcPoolBase : public PoolBase {
Status NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind) const;
};
-class NchwcMaxPool : public OpKernel, public NchwcPoolBase {
+class NchwcMaxPool final : public OpKernel, public NchwcPoolBase {
public:
NchwcMaxPool(const OpKernelInfo& info) : OpKernel(info), NchwcPoolBase(info) {
}
@@ -72,7 +72,7 @@ class NchwcMaxPool : public OpKernel, public NchwcPoolBase {
Status Compute(OpKernelContext* context) const override;
};
-class NchwcAveragePool : public OpKernel, public NchwcPoolBase {
+class NchwcAveragePool final : public OpKernel, public NchwcPoolBase {
public:
NchwcAveragePool(const OpKernelInfo& info) : OpKernel(info), NchwcPoolBase(info) {
}
@@ -80,19 +80,55 @@ class NchwcAveragePool : public OpKernel, public NchwcPoolBase {
Status Compute(OpKernelContext* context) const override;
};
-class NchwcUpsample : public OpKernel {
+class NchwcUpsample final : public OpKernel {
+ private:
+ enum class TransformationMode {
+ ASYMMETRIC,
+ ALIGN_CORNERS,
+ HALF_PIXEL,
+ };
+
public:
NchwcUpsample(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(info.GetAttrs("scales", scales_).IsOK());
ORT_ENFORCE(scales_.size() == 4);
// Batch and channel dimensions cannot scale and spatial scaling must be positive.
ORT_ENFORCE(scales_[0] == 1 && scales_[1] == 1 && scales_[2] >= 1 && scales_[3] >= 1);
+
+ std::string transformation_mode;
+ ORT_ENFORCE(info.GetAttr("coordinate_transformation_mode", &transformation_mode).IsOK());
+ if (transformation_mode == "asymmetric") {
+ transformation_mode_ = TransformationMode::ASYMMETRIC;
+ } else if (transformation_mode == "align_corners") {
+ transformation_mode_ = TransformationMode::ALIGN_CORNERS;
+ } else if (transformation_mode == "half_pixel") {
+ transformation_mode_ = TransformationMode::HALF_PIXEL;
+ } else {
+ ORT_THROW("Unsupported transformation mode '" + transformation_mode + "' for NCHWc Upsample");
+ }
+
+ std::string mode;
+ ORT_ENFORCE(info.GetAttr("mode", &mode).IsOK());
+ if (mode == "nearest") {
+ nearest_mode_ = true;
+ ORT_ENFORCE(transformation_mode_ == TransformationMode::ASYMMETRIC);
+ } else if (mode == "linear") {
+ nearest_mode_ = false;
+ } else {
+ ORT_THROW("Unsupported mode '" + mode + "' for NCHWc Upsample");
+ }
}
Status Compute(OpKernelContext* context) const override;
private:
+ std::vector ComputeInterpolation(int64_t input_length,
+ int64_t output_length,
+ int64_t scale) const;
+
std::vector scales_;
+ TransformationMode transformation_mode_;
+ bool nearest_mode_;
};
} // namespace contrib
diff --git a/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc b/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc
index 39d7621eec..065e2912b5 100644
--- a/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc
@@ -188,6 +188,8 @@ void RegisterNchwcSchemas() {
.SinceVersion(1)
.SetDoc(R"DOC(For internal use.)DOC")
.Attr("scales", "", AttributeProto::INTS, OPTIONAL_VALUE)
+ .Attr("mode", "", AttributeProto::STRING, std::string("nearest"))
+ .Attr("coordinate_transformation_mode", "", AttributeProto::STRING, std::string("asymmetric"))
.Input(0, "X", "", "T")
.Output(0, "Y", "", "T")
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors")
diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h
index b2ce940482..ecceb64f18 100644
--- a/onnxruntime/core/mlas/inc/mlas.h
+++ b/onnxruntime/core/mlas/inc/mlas.h
@@ -168,7 +168,7 @@ struct MLAS_SGEMM_DATA_PARAMS {
/**
* @brief Batched single precision matrix/matrix multiply operation (SGEMM)
- *
+ *
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
@@ -195,7 +195,7 @@ MlasGemmBatch(
/**
* @brief Single precision matrix/matrix multiply operation (SGEMM)
- *
+ *
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
@@ -223,7 +223,7 @@ MlasGemm(
/**
* @brief Single precision matrix/matrix multiply operation (SGEMM)
- *
+ *
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
@@ -231,7 +231,7 @@ MlasGemm(
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
- * @param A Supplies the address of matrix A
+ * @param A Supplies the address of matrix A
* @param lda Supplies the first dimension of matrix A.
* @param B Supplies the address of matrix B
* @param ldb Supplies the first dimension of matrix B.
@@ -341,7 +341,7 @@ struct MLAS_DGEMM_DATA_PARAMS {
/**
* @brief Batched double precision matrix/matrix multiply operation (DGEMM)
- *
+ *
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
@@ -369,7 +369,7 @@ MlasGemmBatch(
/**
* @brief Double precision matrix/matrix multiply operation (DGEMM)
- *
+ *
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
@@ -397,7 +397,7 @@ MlasGemm(
/**
* @brief Double precision matrix/matrix multiply operation (DGEMM)
- *
+ *
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
@@ -405,7 +405,7 @@ MlasGemm(
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
- * @param A Supplies the address of matrix A
+ * @param A Supplies the address of matrix A
* @param lda Supplies the first dimension of matrix A.
* @param B Supplies the address of matrix B
* @param ldb Supplies the first dimension of matrix B.
@@ -929,13 +929,25 @@ MlasNchwcPool(
void
MLASCALL
-MlasNchwcUpsample(
+MlasNchwcUpsampleNearest(
const int64_t* InputShape,
const int64_t* Scales,
const float* Input,
float* Output
);
+void
+MLASCALL
+MlasNchwcUpsampleLinear(
+ size_t InputHeight,
+ size_t InputWidth,
+ size_t OutputWidth,
+ float InterpolationHeight,
+ const float* InterpolationWidth,
+ const float* Input,
+ float* Output
+ );
+
//
// Linear quantization routines.
//
diff --git a/onnxruntime/core/mlas/lib/snchwc.cpp b/onnxruntime/core/mlas/lib/snchwc.cpp
index a30ee219af..2bc0b5d661 100644
--- a/onnxruntime/core/mlas/lib/snchwc.cpp
+++ b/onnxruntime/core/mlas/lib/snchwc.cpp
@@ -1410,7 +1410,7 @@ Return Value:
void
MLASCALL
-MlasNchwcUpsample(
+MlasNchwcUpsampleNearest(
const int64_t* InputShape,
const int64_t* Scales,
const float* Input,
@@ -1511,6 +1511,116 @@ Return Value:
}
}
+MLAS_FORCEINLINE
+void
+MlasNchwcExtractInterpolation(
+ float InterpolationValue,
+ size_t InputLimit,
+ ptrdiff_t InputIndex[2],
+ MLAS_FLOAT32X4 Multipliers[2]
+ )
+{
+ InputIndex[0] = ptrdiff_t(InterpolationValue);
+ InputIndex[1] = std::min(InputIndex[0] + 1, ptrdiff_t(InputLimit - 1));
+
+ float ScalarMultiplier0 = InterpolationValue - float(InputIndex[0]);
+ float ScalarMultiplier1 = 1.0f - ScalarMultiplier0;
+
+ Multipliers[0] = MlasBroadcastFloat32x4(ScalarMultiplier0);
+ Multipliers[1] = MlasBroadcastFloat32x4(ScalarMultiplier1);
+}
+
+void
+MLASCALL
+MlasNchwcUpsampleLinear(
+ size_t InputHeight,
+ size_t InputWidth,
+ size_t OutputWidth,
+ float InterpolationHeight,
+ const float* InterpolationWidth,
+ const float* Input,
+ float* Output
+ )
+/*++
+
+Routine Description:
+
+ This routine implements the NCHWc upsample linear operation for a single row.
+
+ The integer portion of each interpolation float supplies the mapping from
+ output element to input element. The fractional portion supplies the relative
+ weights for the four points of the interpolation.
+
+Arguments:
+
+ InputHeight - Supplies the input height.
+
+ InputWidth - Supplies the input width.
+
+ OutputWidth - Supplies the output width.
+
+ InterpolationHeight - Supplies the height interpolation values for the target
+ row.
+
+ InterpolationWidth - Supplies an array of computed interpolation values of
+ length OutputWidth.
+
+ Input - Supplies the input spatial buffer.
+
+ Output - Supplies the output row buffer.
+
+Return Value:
+
+ None.
+
+--*/
+{
+ const size_t BlockSize = MlasNchwcGetBlockSize();
+
+ ptrdiff_t InputIndexY[2];
+ MLAS_FLOAT32X4 MultipliersY[2];
+
+ MlasNchwcExtractInterpolation(InterpolationHeight, InputHeight, InputIndexY, MultipliersY);
+
+ const float* InputRowY0 = Input + InputIndexY[0] * InputWidth * BlockSize;
+ const float* InputRowY1 = Input + InputIndexY[1] * InputWidth * BlockSize;
+
+ for (size_t ow = 0; ow < OutputWidth; ow++) {
+
+ ptrdiff_t InputIndexX[2];
+ MLAS_FLOAT32X4 MultipliersX[2];
+
+ MlasNchwcExtractInterpolation(InterpolationWidth[ow], InputWidth, InputIndexX, MultipliersX);
+
+ MLAS_FLOAT32X4 MultiplierY0X0 = MlasMultiplyFloat32x4(MultipliersY[0], MultipliersX[0]);
+ MLAS_FLOAT32X4 MultiplierY0X1 = MlasMultiplyFloat32x4(MultipliersY[0], MultipliersX[1]);
+ MLAS_FLOAT32X4 MultiplierY1X0 = MlasMultiplyFloat32x4(MultipliersY[1], MultipliersX[0]);
+ MLAS_FLOAT32X4 MultiplierY1X1 = MlasMultiplyFloat32x4(MultipliersY[1], MultipliersX[1]);
+
+ for (size_t bc = 0; bc < BlockSize; bc += 4) {
+
+ MLAS_FLOAT32X4 v00 = MlasLoadFloat32x4(InputRowY0 + InputIndexX[0] * BlockSize + bc);
+ MLAS_FLOAT32X4 v01 = MlasLoadFloat32x4(InputRowY0 + InputIndexX[1] * BlockSize + bc);
+ MLAS_FLOAT32X4 v10 = MlasLoadFloat32x4(InputRowY1 + InputIndexX[0] * BlockSize + bc);
+ MLAS_FLOAT32X4 v11 = MlasLoadFloat32x4(InputRowY1 + InputIndexX[1] * BlockSize + bc);
+
+ v00 = MlasMultiplyFloat32x4(MultiplierY1X1, v00);
+ v01 = MlasMultiplyFloat32x4(MultiplierY1X0, v01);
+ v10 = MlasMultiplyFloat32x4(MultiplierY0X1, v10);
+ v11 = MlasMultiplyFloat32x4(MultiplierY0X0, v11);
+
+ MLAS_FLOAT32X4 Reduction0 = MlasAddFloat32x4(v00, v01);
+ MLAS_FLOAT32X4 Reduction1 = MlasAddFloat32x4(v10, v11);
+
+ MLAS_FLOAT32X4 Reduction = MlasAddFloat32x4(Reduction0, Reduction1);
+
+ MlasStoreFloat32x4(&Output[bc], Reduction);
+ }
+
+ Output += BlockSize;
+ }
+}
+
#if !defined(MLAS_TARGET_AMD64)
//
diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc
index 7065876510..253db85e64 100644
--- a/onnxruntime/core/optimizer/nchwc_transformer.cc
+++ b/onnxruntime/core/optimizer/nchwc_transformer.cc
@@ -982,14 +982,21 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
}
auto* nchwc_input = it->second.get();
- // Only support the nearest interpolation mode (the default value).
+ // Support nearest (default) and linear modes.
const auto* mode_attr = graph_utils::GetNodeAttribute(node, "mode");
+ bool nearest_mode = true;
if (mode_attr != nullptr && utils::HasString(*mode_attr)) {
if (mode_attr->s() != "nearest") {
- return;
+ if (mode_attr->s() == "linear") {
+ nearest_mode = false;
+ } else {
+ return;
+ }
}
}
+ const ONNX_NAMESPACE::AttributeProto* transformation_mode_attr = nullptr;
+
NodeArg* sizes_arg = nullptr;
NodeArg* scales_arg = nullptr;
@@ -1001,20 +1008,30 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
scales_arg = input_defs[2];
}
- // Only support the asymmetric coordinate transformation mode.
- const auto* transform_mode_attr = graph_utils::GetNodeAttribute(node, "coordinate_transformation_mode");
- if ((transform_mode_attr == nullptr) ||
- !utils::HasString(*transform_mode_attr) ||
- (transform_mode_attr->s() != "asymmetric")) {
+ transformation_mode_attr = graph_utils::GetNodeAttribute(node, "coordinate_transformation_mode");
+ if ((transformation_mode_attr == nullptr) ||
+ !utils::HasString(*transformation_mode_attr)) {
return;
}
+ if (transformation_mode_attr->s() != "asymmetric") {
+ // Nearest mode kernel support asymmetric transformation mode only.
+ if (nearest_mode) {
+ return;
+ }
+ if ((transformation_mode_attr->s() != "align_corners") &&
+ (transformation_mode_attr->s() != "half_pixel")) {
+ return;
+ }
+ }
- // Only support the floor rounding mode.
- const auto* nearest_mode_attr = graph_utils::GetNodeAttribute(node, "nearest_mode");
- if ((nearest_mode_attr == nullptr) ||
- !utils::HasString(*nearest_mode_attr) ||
- (nearest_mode_attr->s() != "floor")) {
- return;
+ if (nearest_mode) {
+ // Only support the floor rounding mode.
+ const auto* nearest_mode_attr = graph_utils::GetNodeAttribute(node, "nearest_mode");
+ if ((nearest_mode_attr == nullptr) ||
+ !utils::HasString(*nearest_mode_attr) ||
+ (nearest_mode_attr->s() != "floor")) {
+ return;
+ }
}
} else {
scales_arg = input_defs[1];
@@ -1099,6 +1116,12 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
kMSNchwcDomain);
nchwc_node.SetExecutionProviderType(kCpuExecutionProvider);
nchwc_node.AddAttribute("scales", scales_attr);
+ if (!nearest_mode) {
+ nchwc_node.AddAttribute("mode", mode_attr->s());
+ if (transformation_mode_attr != nullptr) {
+ nchwc_node.AddAttribute("coordinate_transformation_mode", transformation_mode_attr->s());
+ }
+ }
nchwc_input->remaining_original_uses_--;
diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc
index 9f357d57bd..69ca27e241 100644
--- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc
+++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc
@@ -1085,7 +1085,7 @@ TEST(NchwcOptimizerTests, BatchNormalization) {
// Override the sample tolerance for this test. By default, the NCHWc
// tests generate bit identical results when run with and without
- // optimizations, but the BatchNormalizationtransform does introduce
+ // optimizations, but the BatchNormalization transform does introduce
// small bit differences.
helper.per_sample_tolerance_ = .00025;
};
@@ -1213,7 +1213,7 @@ TEST(NchwcOptimizerTests, ConvReorderOutputCnhw) {
NchwcOptimizerTester(build_test_case, check_nchwc_graph);
}
-TEST(NchwcOptimizerTests, Upsample) {
+TEST(NchwcOptimizerTests, UpsampleNearest) {
auto test_case = [&](int opset_version, float scale_h, float scale_w, bool use_sizes_arg) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput({3, 16, 27, 15});
@@ -1283,6 +1283,61 @@ TEST(NchwcOptimizerTests, Upsample) {
test_case(13, 2.2f, 2.8f, true);
}
+TEST(NchwcOptimizerTests, UpsampleLinear) {
+ auto test_case = [&](int opset_version, float scale_h, float scale_w, const std::string& transformation_mode) {
+ auto build_test_case = [&](NchwcTestHelper& helper) {
+ auto* input_arg = helper.MakeInput({3, 16, 21, 25});
+ auto* conv_output_arg = helper.MakeIntermediate();
+ auto* output_arg = helper.MakeOutput();
+
+ helper.AddConvNode(input_arg, conv_output_arg, {28, 16, 1, 1});
+
+ std::string op_name = opset_version >= 10 ? "Resize" : "Upsample";
+ std::vector input_args;
+ input_args.push_back(conv_output_arg);
+ if (opset_version >= 11) {
+ input_args.push_back(helper.Make1DInitializer({0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f}));
+ }
+ input_args.push_back(helper.Make1DInitializer({1.f, 1.f, scale_h, scale_w}));
+ Node& resize_node = helper.AddNode(op_name, input_args, {output_arg});
+ resize_node.AddAttribute("mode", "linear");
+ if (opset_version >= 11) {
+ resize_node.AddAttribute("coordinate_transformation_mode", transformation_mode);
+ }
+
+ helper.per_sample_tolerance_ = .001f;
+ };
+
+ auto check_nchwc_graph = [&](InferenceSessionWrapper& session) {
+ auto op_to_count = CountOpsInGraph(session.GetGraph());
+ EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 1);
+ EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1);
+ EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1);
+ EXPECT_EQ(op_to_count["com.microsoft.nchwc.Upsample"], 1);
+ EXPECT_EQ(op_to_count["Resize"] + op_to_count["Upsample"], 0);
+ };
+
+ NchwcOptimizerTester(build_test_case, check_nchwc_graph, opset_version);
+ };
+
+ // Verify that upsample nodes can be converted to the NCHWc format for
+ // various versions of the operator.
+ std::vector transformation_modes{"asymmetric", "align_corners", "half_pixel"};
+ for (auto& transformation_mode : transformation_modes) {
+ static const int opset_versions[] = {9, 10, 11, 13};
+ for (auto opset_version : opset_versions) {
+ // Older versions of the operator do not support transformation modes.
+ if (opset_version < 11 && transformation_mode == "asymmetric") {
+ continue;
+ }
+ test_case(opset_version, 1.f, 1.f, transformation_mode);
+ test_case(opset_version, 2.f, 2.f, transformation_mode);
+ test_case(opset_version, 3.f, 5.f, transformation_mode);
+ test_case(opset_version, 9.f, 7.f, transformation_mode);
+ }
+ }
+}
+
TEST(NchwcOptimizerTests, Activation) {
auto test_case = [&](const std::string& activation_op_type) {
auto build_test_case = [&](NchwcTestHelper& helper) {