Implement NCHWc Upsample linear mode (#7623)

Extend the existing NCHWc Upsample operator to support linear modes too.
This commit is contained in:
Tracy Sharpe 2021-05-10 12:16:16 -07:00 committed by GitHub
parent ec885040ef
commit 16297a8e61
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 414 additions and 71 deletions

View file

@ -3063,6 +3063,10 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes
<dl>
<dt><tt>coordinate_transformation_mode</tt> : string</dt>
<dd></dd>
<dt><tt>mode</tt> : string</dt>
<dd></dd>
<dt><tt>scales</tt> : list of ints</dt>
<dd></dd>
</dl>

View file

@ -52,14 +52,8 @@ Status ReorderInput::Compute(OpKernelContext* context) const {
// elements, so that operations involving a smaller number of channels will
// process more rows per worker.
constexpr ptrdiff_t worker_goal = 48 * 1024;
ptrdiff_t work_per_worker = worker_goal / nchwc_channels;
if (work_per_worker == 0) {
work_per_worker = 1;
}
worker_count = total_work / work_per_worker;
if (worker_count == 0) {
worker_count = 1;
}
ptrdiff_t work_per_worker = std::max<ptrdiff_t>(worker_goal / nchwc_channels, 1);
worker_count = std::max<ptrdiff_t>(total_work / work_per_worker, 1);
} else {
// Each iteration produces one spatial_size chunk of NCHWc blocks.
total_work = static_cast<ptrdiff_t>(batch_count * (nchwc_channels / nchwc_block_size));
@ -205,20 +199,21 @@ Status NchwcConv::Compute(OpKernelContext* context) const {
}
}
MlasNchwcConv(X_shape.GetDims().data(),
kernel_shape.data(),
dilations.data(),
pads.data(),
strides.data(),
Y_dims.data(),
static_cast<size_t>(conv_attrs_.group),
X->template Data<float>(),
W->template Data<float>(),
B != nullptr ? B->template Data<float>() : nullptr,
y_data,
&activation_,
Sum == nullptr,
context->GetOperatorThreadPool());
MlasNchwcConv(
X_shape.GetDims().data(),
kernel_shape.data(),
dilations.data(),
pads.data(),
strides.data(),
Y_dims.data(),
static_cast<size_t>(conv_attrs_.group),
X->template Data<float>(),
W->template Data<float>(),
B != nullptr ? B->template Data<float>() : nullptr,
y_data,
&activation_,
Sum == nullptr,
context->GetOperatorThreadPool());
return Status::OK();
}
@ -233,16 +228,17 @@ Status NchwcPoolBase::NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind
std::vector<int64_t> output_dims = pool_attrs_.SetOutputSize(X_shape, X_shape[1], &pads);
auto* Y = context->Output(0, output_dims);
MlasNchwcPool(kind,
X_shape.GetDims().data(),
pool_attrs_.global_pooling ? nullptr : pool_attrs_.kernel_shape.data(),
pool_attrs_.global_pooling ? nullptr : pool_attrs_.dilations.data(),
pool_attrs_.global_pooling ? nullptr : pads.data(),
pool_attrs_.global_pooling ? nullptr : pool_attrs_.strides.data(),
output_dims.data(),
X->template Data<float>(),
Y->template MutableData<float>(),
context->GetOperatorThreadPool());
MlasNchwcPool(
kind,
X_shape.GetDims().data(),
pool_attrs_.global_pooling ? nullptr : pool_attrs_.kernel_shape.data(),
pool_attrs_.global_pooling ? nullptr : pool_attrs_.dilations.data(),
pool_attrs_.global_pooling ? nullptr : pads.data(),
pool_attrs_.global_pooling ? nullptr : pool_attrs_.strides.data(),
output_dims.data(),
X->template Data<float>(),
Y->template MutableData<float>(),
context->GetOperatorThreadPool());
return Status::OK();
}
@ -256,19 +252,124 @@ Status NchwcAveragePool::Compute(OpKernelContext* context) const {
: MlasAveragePoolingExcludePad);
}
std::vector<float> NchwcUpsample::ComputeInterpolation(int64_t input_length,
int64_t output_length,
int64_t scale) const {
std::vector<float> interpolation;
interpolation.resize(output_length);
if (scale == 1) {
// Identity map for unscaled.
for (int64_t o = 0; o < output_length; o++) {
interpolation[o] = static_cast<float>(o);
}
} else if (transformation_mode_ == TransformationMode::ALIGN_CORNERS) {
for (int64_t o = 0; o < output_length; o++) {
interpolation[o] =
static_cast<float>(o) * static_cast<float>(input_length - 1) / static_cast<float>(output_length - 1);
}
} else if (transformation_mode_ == TransformationMode::HALF_PIXEL) {
for (int64_t o = 0; o < output_length; o++) {
interpolation[o] =
std::max(0.0f, (static_cast<float>(o) + 0.5f) / static_cast<float>(scale) - 0.5f);
}
} else {
// Default to TransformationMode::ASYMMETRIC.
for (int64_t o = 0; o < output_length; o++) {
interpolation[o] = static_cast<float>(o) / static_cast<float>(scale);
}
}
return interpolation;
}
Status NchwcUpsample::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
const auto& X_shape = X->Shape();
ORT_ENFORCE(X_shape.NumDimensions() == 4);
const auto& X_shape = X->Shape().GetDims();
ORT_ENFORCE(X_shape.size() == 4);
ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
TensorShape Y_shape{X_shape[0], X_shape[1], X_shape[2] * scales_[2], X_shape[3] * scales_[3]};
auto* Y = context->Output(0, Y_shape);
const int64_t batch_count = X_shape[0];
const int64_t nchwc_channels = X_shape[1];
MlasNchwcUpsample(X_shape.GetDims().data(),
scales_.data() + 2,
X->template Data<float>(),
Y->template MutableData<float>());
const int64_t input_h = X_shape[2];
const int64_t input_w = X_shape[3];
const int64_t output_h = input_h * scales_[2];
const int64_t output_w = input_w * scales_[3];
auto* Y = context->Output(0, {batch_count, nchwc_channels, output_h, output_w});
// Bail out early if one of the dimensions is zero.
if (Y->Shape().Size() == 0) {
return Status::OK();
}
const auto* x_data = X->template Data<float>();
auto* y_data = Y->template MutableData<float>();
if (nearest_mode_) {
MlasNchwcUpsampleNearest(
X_shape.data(),
scales_.data() + 2,
x_data,
y_data);
} else {
// Compute the interpolation value per output height and width.
const auto interpolation_h = ComputeInterpolation(input_h, output_h, scales_[2]);
const auto interpolation_w = ComputeInterpolation(input_w, output_w, scales_[3]);
const int64_t nchwc_block_size = static_cast<int64_t>(MlasNchwcGetBlockSize());
const ptrdiff_t total_work = ((batch_count * nchwc_channels) / nchwc_block_size) * output_h;
// Partition the work with the goal of generating the following number of
// elements, so that operations involving a smaller number of columns will
// process more rows per worker.
constexpr ptrdiff_t worker_goal = 16 * 1024;
ptrdiff_t work_per_worker = std::max<ptrdiff_t>(worker_goal / (output_w * nchwc_block_size), 1);
ptrdiff_t worker_count = std::max<ptrdiff_t>(total_work / work_per_worker, 1);
auto upsample_worker = [&](ptrdiff_t batch) {
auto work = concurrency::ThreadPool::PartitionWork(batch, worker_count, total_work);
int64_t work_index = static_cast<int64_t>(work.start);
int64_t work_remaining = static_cast<int64_t>(work.end - work.start);
while (work_remaining > 0) {
// Limit the current loop iteration to the same source image.
const int64_t channel_index = work_index / output_h;
int64_t row_index = work_index % output_h;
int64_t rows_this_iteration = std::min(work_remaining, output_h - row_index);
work_index += rows_this_iteration;
work_remaining -= rows_this_iteration;
const auto* x_channel_base = x_data + (channel_index * input_h * input_w * nchwc_block_size);
auto* y_row = y_data + (((channel_index * output_h) + row_index) * output_w * nchwc_block_size);
// Loop upsampling each row of the output.
do {
MlasNchwcUpsampleLinear(
static_cast<size_t>(input_h),
static_cast<size_t>(input_w),
static_cast<size_t>(output_w),
interpolation_h[row_index],
interpolation_w.data(),
x_channel_base,
y_row);
y_row += output_w * nchwc_block_size;
row_index++;
} while (--rows_this_iteration);
}
};
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
// Handle the work in a single batch if only a single thread is available.
if (concurrency::ThreadPool::DegreeOfParallelism(thread_pool) == 1) {
worker_count = 1;
}
concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, worker_count, upsample_worker);
}
return Status::OK();
}

View file

@ -12,7 +12,7 @@
namespace onnxruntime {
namespace contrib {
class ReorderInput : public OpKernel {
class ReorderInput final : public OpKernel {
public:
ReorderInput(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(info.GetAttr<int64_t>("channels_last", &channels_last_).IsOK());
@ -24,7 +24,7 @@ class ReorderInput : public OpKernel {
int64_t channels_last_;
};
class ReorderOutput : public OpKernel {
class ReorderOutput final : public OpKernel {
public:
ReorderOutput(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(info.GetAttr<int64_t>("channels", &channels_).IsOK());
@ -39,7 +39,7 @@ class ReorderOutput : public OpKernel {
int64_t channels_last_;
};
class NchwcConv : public OpKernel {
class NchwcConv final : public OpKernel {
public:
NchwcConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
@ -64,7 +64,7 @@ class NchwcPoolBase : public PoolBase {
Status NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind) const;
};
class NchwcMaxPool : public OpKernel, public NchwcPoolBase {
class NchwcMaxPool final : public OpKernel, public NchwcPoolBase {
public:
NchwcMaxPool(const OpKernelInfo& info) : OpKernel(info), NchwcPoolBase(info) {
}
@ -72,7 +72,7 @@ class NchwcMaxPool : public OpKernel, public NchwcPoolBase {
Status Compute(OpKernelContext* context) const override;
};
class NchwcAveragePool : public OpKernel, public NchwcPoolBase {
class NchwcAveragePool final : public OpKernel, public NchwcPoolBase {
public:
NchwcAveragePool(const OpKernelInfo& info) : OpKernel(info), NchwcPoolBase(info) {
}
@ -80,19 +80,55 @@ class NchwcAveragePool : public OpKernel, public NchwcPoolBase {
Status Compute(OpKernelContext* context) const override;
};
class NchwcUpsample : public OpKernel {
class NchwcUpsample final : public OpKernel {
private:
enum class TransformationMode {
ASYMMETRIC,
ALIGN_CORNERS,
HALF_PIXEL,
};
public:
NchwcUpsample(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(info.GetAttrs<int64_t>("scales", scales_).IsOK());
ORT_ENFORCE(scales_.size() == 4);
// Batch and channel dimensions cannot scale and spatial scaling must be positive.
ORT_ENFORCE(scales_[0] == 1 && scales_[1] == 1 && scales_[2] >= 1 && scales_[3] >= 1);
std::string transformation_mode;
ORT_ENFORCE(info.GetAttr<std::string>("coordinate_transformation_mode", &transformation_mode).IsOK());
if (transformation_mode == "asymmetric") {
transformation_mode_ = TransformationMode::ASYMMETRIC;
} else if (transformation_mode == "align_corners") {
transformation_mode_ = TransformationMode::ALIGN_CORNERS;
} else if (transformation_mode == "half_pixel") {
transformation_mode_ = TransformationMode::HALF_PIXEL;
} else {
ORT_THROW("Unsupported transformation mode '" + transformation_mode + "' for NCHWc Upsample");
}
std::string mode;
ORT_ENFORCE(info.GetAttr<std::string>("mode", &mode).IsOK());
if (mode == "nearest") {
nearest_mode_ = true;
ORT_ENFORCE(transformation_mode_ == TransformationMode::ASYMMETRIC);
} else if (mode == "linear") {
nearest_mode_ = false;
} else {
ORT_THROW("Unsupported mode '" + mode + "' for NCHWc Upsample");
}
}
Status Compute(OpKernelContext* context) const override;
private:
std::vector<float> ComputeInterpolation(int64_t input_length,
int64_t output_length,
int64_t scale) const;
std::vector<int64_t> scales_;
TransformationMode transformation_mode_;
bool nearest_mode_;
};
} // namespace contrib

View file

@ -188,6 +188,8 @@ void RegisterNchwcSchemas() {
.SinceVersion(1)
.SetDoc(R"DOC(For internal use.)DOC")
.Attr("scales", "", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("mode", "", AttributeProto::STRING, std::string("nearest"))
.Attr("coordinate_transformation_mode", "", AttributeProto::STRING, std::string("asymmetric"))
.Input(0, "X", "", "T")
.Output(0, "Y", "", "T")
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors")

View file

@ -168,7 +168,7 @@ struct MLAS_SGEMM_DATA_PARAMS {
/**
* @brief Batched single precision matrix/matrix multiply operation (SGEMM)
*
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
@ -195,7 +195,7 @@ MlasGemmBatch(
/**
* @brief Single precision matrix/matrix multiply operation (SGEMM)
*
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
@ -223,7 +223,7 @@ MlasGemm(
/**
* @brief Single precision matrix/matrix multiply operation (SGEMM)
*
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
@ -231,7 +231,7 @@ MlasGemm(
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
* @param A Supplies the address of matrix A
* @param A Supplies the address of matrix A
* @param lda Supplies the first dimension of matrix A.
* @param B Supplies the address of matrix B
* @param ldb Supplies the first dimension of matrix B.
@ -341,7 +341,7 @@ struct MLAS_DGEMM_DATA_PARAMS {
/**
* @brief Batched double precision matrix/matrix multiply operation (DGEMM)
*
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
@ -369,7 +369,7 @@ MlasGemmBatch(
/**
* @brief Double precision matrix/matrix multiply operation (DGEMM)
*
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
@ -397,7 +397,7 @@ MlasGemm(
/**
* @brief Double precision matrix/matrix multiply operation (DGEMM)
*
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
@ -405,7 +405,7 @@ MlasGemm(
* @param K Supplies the number of columns of matrix A and the number
of rows of matrix B.
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
* @param A Supplies the address of matrix A
* @param A Supplies the address of matrix A
* @param lda Supplies the first dimension of matrix A.
* @param B Supplies the address of matrix B
* @param ldb Supplies the first dimension of matrix B.
@ -929,13 +929,25 @@ MlasNchwcPool(
void
MLASCALL
MlasNchwcUpsample(
MlasNchwcUpsampleNearest(
const int64_t* InputShape,
const int64_t* Scales,
const float* Input,
float* Output
);
void
MLASCALL
MlasNchwcUpsampleLinear(
size_t InputHeight,
size_t InputWidth,
size_t OutputWidth,
float InterpolationHeight,
const float* InterpolationWidth,
const float* Input,
float* Output
);
//
// Linear quantization routines.
//

View file

@ -1410,7 +1410,7 @@ Return Value:
void
MLASCALL
MlasNchwcUpsample(
MlasNchwcUpsampleNearest(
const int64_t* InputShape,
const int64_t* Scales,
const float* Input,
@ -1511,6 +1511,116 @@ Return Value:
}
}
MLAS_FORCEINLINE
void
MlasNchwcExtractInterpolation(
float InterpolationValue,
size_t InputLimit,
ptrdiff_t InputIndex[2],
MLAS_FLOAT32X4 Multipliers[2]
)
{
InputIndex[0] = ptrdiff_t(InterpolationValue);
InputIndex[1] = std::min<ptrdiff_t>(InputIndex[0] + 1, ptrdiff_t(InputLimit - 1));
float ScalarMultiplier0 = InterpolationValue - float(InputIndex[0]);
float ScalarMultiplier1 = 1.0f - ScalarMultiplier0;
Multipliers[0] = MlasBroadcastFloat32x4(ScalarMultiplier0);
Multipliers[1] = MlasBroadcastFloat32x4(ScalarMultiplier1);
}
void
MLASCALL
MlasNchwcUpsampleLinear(
size_t InputHeight,
size_t InputWidth,
size_t OutputWidth,
float InterpolationHeight,
const float* InterpolationWidth,
const float* Input,
float* Output
)
/*++
Routine Description:
This routine implements the NCHWc upsample linear operation for a single row.
The integer portion of each interpolation float supplies the mapping from
output element to input element. The fractional portion supplies the relative
weights for the four points of the interpolation.
Arguments:
InputHeight - Supplies the input height.
InputWidth - Supplies the input width.
OutputWidth - Supplies the output width.
InterpolationHeight - Supplies the height interpolation values for the target
row.
InterpolationWidth - Supplies an array of computed interpolation values of
length OutputWidth.
Input - Supplies the input spatial buffer.
Output - Supplies the output row buffer.
Return Value:
None.
--*/
{
const size_t BlockSize = MlasNchwcGetBlockSize();
ptrdiff_t InputIndexY[2];
MLAS_FLOAT32X4 MultipliersY[2];
MlasNchwcExtractInterpolation(InterpolationHeight, InputHeight, InputIndexY, MultipliersY);
const float* InputRowY0 = Input + InputIndexY[0] * InputWidth * BlockSize;
const float* InputRowY1 = Input + InputIndexY[1] * InputWidth * BlockSize;
for (size_t ow = 0; ow < OutputWidth; ow++) {
ptrdiff_t InputIndexX[2];
MLAS_FLOAT32X4 MultipliersX[2];
MlasNchwcExtractInterpolation(InterpolationWidth[ow], InputWidth, InputIndexX, MultipliersX);
MLAS_FLOAT32X4 MultiplierY0X0 = MlasMultiplyFloat32x4(MultipliersY[0], MultipliersX[0]);
MLAS_FLOAT32X4 MultiplierY0X1 = MlasMultiplyFloat32x4(MultipliersY[0], MultipliersX[1]);
MLAS_FLOAT32X4 MultiplierY1X0 = MlasMultiplyFloat32x4(MultipliersY[1], MultipliersX[0]);
MLAS_FLOAT32X4 MultiplierY1X1 = MlasMultiplyFloat32x4(MultipliersY[1], MultipliersX[1]);
for (size_t bc = 0; bc < BlockSize; bc += 4) {
MLAS_FLOAT32X4 v00 = MlasLoadFloat32x4(InputRowY0 + InputIndexX[0] * BlockSize + bc);
MLAS_FLOAT32X4 v01 = MlasLoadFloat32x4(InputRowY0 + InputIndexX[1] * BlockSize + bc);
MLAS_FLOAT32X4 v10 = MlasLoadFloat32x4(InputRowY1 + InputIndexX[0] * BlockSize + bc);
MLAS_FLOAT32X4 v11 = MlasLoadFloat32x4(InputRowY1 + InputIndexX[1] * BlockSize + bc);
v00 = MlasMultiplyFloat32x4(MultiplierY1X1, v00);
v01 = MlasMultiplyFloat32x4(MultiplierY1X0, v01);
v10 = MlasMultiplyFloat32x4(MultiplierY0X1, v10);
v11 = MlasMultiplyFloat32x4(MultiplierY0X0, v11);
MLAS_FLOAT32X4 Reduction0 = MlasAddFloat32x4(v00, v01);
MLAS_FLOAT32X4 Reduction1 = MlasAddFloat32x4(v10, v11);
MLAS_FLOAT32X4 Reduction = MlasAddFloat32x4(Reduction0, Reduction1);
MlasStoreFloat32x4(&Output[bc], Reduction);
}
Output += BlockSize;
}
}
#if !defined(MLAS_TARGET_AMD64)
//

View file

@ -982,14 +982,21 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
}
auto* nchwc_input = it->second.get();
// Only support the nearest interpolation mode (the default value).
// Support nearest (default) and linear modes.
const auto* mode_attr = graph_utils::GetNodeAttribute(node, "mode");
bool nearest_mode = true;
if (mode_attr != nullptr && utils::HasString(*mode_attr)) {
if (mode_attr->s() != "nearest") {
return;
if (mode_attr->s() == "linear") {
nearest_mode = false;
} else {
return;
}
}
}
const ONNX_NAMESPACE::AttributeProto* transformation_mode_attr = nullptr;
NodeArg* sizes_arg = nullptr;
NodeArg* scales_arg = nullptr;
@ -1001,20 +1008,30 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
scales_arg = input_defs[2];
}
// Only support the asymmetric coordinate transformation mode.
const auto* transform_mode_attr = graph_utils::GetNodeAttribute(node, "coordinate_transformation_mode");
if ((transform_mode_attr == nullptr) ||
!utils::HasString(*transform_mode_attr) ||
(transform_mode_attr->s() != "asymmetric")) {
transformation_mode_attr = graph_utils::GetNodeAttribute(node, "coordinate_transformation_mode");
if ((transformation_mode_attr == nullptr) ||
!utils::HasString(*transformation_mode_attr)) {
return;
}
if (transformation_mode_attr->s() != "asymmetric") {
// Nearest mode kernel support asymmetric transformation mode only.
if (nearest_mode) {
return;
}
if ((transformation_mode_attr->s() != "align_corners") &&
(transformation_mode_attr->s() != "half_pixel")) {
return;
}
}
// Only support the floor rounding mode.
const auto* nearest_mode_attr = graph_utils::GetNodeAttribute(node, "nearest_mode");
if ((nearest_mode_attr == nullptr) ||
!utils::HasString(*nearest_mode_attr) ||
(nearest_mode_attr->s() != "floor")) {
return;
if (nearest_mode) {
// Only support the floor rounding mode.
const auto* nearest_mode_attr = graph_utils::GetNodeAttribute(node, "nearest_mode");
if ((nearest_mode_attr == nullptr) ||
!utils::HasString(*nearest_mode_attr) ||
(nearest_mode_attr->s() != "floor")) {
return;
}
}
} else {
scales_arg = input_defs[1];
@ -1099,6 +1116,12 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
kMSNchwcDomain);
nchwc_node.SetExecutionProviderType(kCpuExecutionProvider);
nchwc_node.AddAttribute("scales", scales_attr);
if (!nearest_mode) {
nchwc_node.AddAttribute("mode", mode_attr->s());
if (transformation_mode_attr != nullptr) {
nchwc_node.AddAttribute("coordinate_transformation_mode", transformation_mode_attr->s());
}
}
nchwc_input->remaining_original_uses_--;

View file

@ -1085,7 +1085,7 @@ TEST(NchwcOptimizerTests, BatchNormalization) {
// Override the sample tolerance for this test. By default, the NCHWc
// tests generate bit identical results when run with and without
// optimizations, but the BatchNormalizationtransform does introduce
// optimizations, but the BatchNormalization transform does introduce
// small bit differences.
helper.per_sample_tolerance_ = .00025;
};
@ -1213,7 +1213,7 @@ TEST(NchwcOptimizerTests, ConvReorderOutputCnhw) {
NchwcOptimizerTester(build_test_case, check_nchwc_graph);
}
TEST(NchwcOptimizerTests, Upsample) {
TEST(NchwcOptimizerTests, UpsampleNearest) {
auto test_case = [&](int opset_version, float scale_h, float scale_w, bool use_sizes_arg) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput<float>({3, 16, 27, 15});
@ -1283,6 +1283,61 @@ TEST(NchwcOptimizerTests, Upsample) {
test_case(13, 2.2f, 2.8f, true);
}
TEST(NchwcOptimizerTests, UpsampleLinear) {
auto test_case = [&](int opset_version, float scale_h, float scale_w, const std::string& transformation_mode) {
auto build_test_case = [&](NchwcTestHelper& helper) {
auto* input_arg = helper.MakeInput<float>({3, 16, 21, 25});
auto* conv_output_arg = helper.MakeIntermediate();
auto* output_arg = helper.MakeOutput();
helper.AddConvNode(input_arg, conv_output_arg, {28, 16, 1, 1});
std::string op_name = opset_version >= 10 ? "Resize" : "Upsample";
std::vector<NodeArg*> input_args;
input_args.push_back(conv_output_arg);
if (opset_version >= 11) {
input_args.push_back(helper.Make1DInitializer<float>({0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f}));
}
input_args.push_back(helper.Make1DInitializer<float>({1.f, 1.f, scale_h, scale_w}));
Node& resize_node = helper.AddNode(op_name, input_args, {output_arg});
resize_node.AddAttribute("mode", "linear");
if (opset_version >= 11) {
resize_node.AddAttribute("coordinate_transformation_mode", transformation_mode);
}
helper.per_sample_tolerance_ = .001f;
};
auto check_nchwc_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 1);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1);
EXPECT_EQ(op_to_count["com.microsoft.nchwc.Upsample"], 1);
EXPECT_EQ(op_to_count["Resize"] + op_to_count["Upsample"], 0);
};
NchwcOptimizerTester(build_test_case, check_nchwc_graph, opset_version);
};
// Verify that upsample nodes can be converted to the NCHWc format for
// various versions of the operator.
std::vector<std::string> transformation_modes{"asymmetric", "align_corners", "half_pixel"};
for (auto& transformation_mode : transformation_modes) {
static const int opset_versions[] = {9, 10, 11, 13};
for (auto opset_version : opset_versions) {
// Older versions of the operator do not support transformation modes.
if (opset_version < 11 && transformation_mode == "asymmetric") {
continue;
}
test_case(opset_version, 1.f, 1.f, transformation_mode);
test_case(opset_version, 2.f, 2.f, transformation_mode);
test_case(opset_version, 3.f, 5.f, transformation_mode);
test_case(opset_version, 9.f, 7.f, transformation_mode);
}
}
}
TEST(NchwcOptimizerTests, Activation) {
auto test_case = [&](const std::string& activation_op_type) {
auto build_test_case = [&](NchwcTestHelper& helper) {