mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Implement NCHWc Upsample linear mode (#7623)
Extend the existing NCHWc Upsample operator to support linear modes too.
This commit is contained in:
parent
ec885040ef
commit
16297a8e61
8 changed files with 414 additions and 71 deletions
|
|
@ -3063,6 +3063,10 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
#### Attributes
|
||||
|
||||
<dl>
|
||||
<dt><tt>coordinate_transformation_mode</tt> : string</dt>
|
||||
<dd></dd>
|
||||
<dt><tt>mode</tt> : string</dt>
|
||||
<dd></dd>
|
||||
<dt><tt>scales</tt> : list of ints</dt>
|
||||
<dd></dd>
|
||||
</dl>
|
||||
|
|
|
|||
|
|
@ -52,14 +52,8 @@ Status ReorderInput::Compute(OpKernelContext* context) const {
|
|||
// elements, so that operations involving a smaller number of channels will
|
||||
// process more rows per worker.
|
||||
constexpr ptrdiff_t worker_goal = 48 * 1024;
|
||||
ptrdiff_t work_per_worker = worker_goal / nchwc_channels;
|
||||
if (work_per_worker == 0) {
|
||||
work_per_worker = 1;
|
||||
}
|
||||
worker_count = total_work / work_per_worker;
|
||||
if (worker_count == 0) {
|
||||
worker_count = 1;
|
||||
}
|
||||
ptrdiff_t work_per_worker = std::max<ptrdiff_t>(worker_goal / nchwc_channels, 1);
|
||||
worker_count = std::max<ptrdiff_t>(total_work / work_per_worker, 1);
|
||||
} else {
|
||||
// Each iteration produces one spatial_size chunk of NCHWc blocks.
|
||||
total_work = static_cast<ptrdiff_t>(batch_count * (nchwc_channels / nchwc_block_size));
|
||||
|
|
@ -205,20 +199,21 @@ Status NchwcConv::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
}
|
||||
|
||||
MlasNchwcConv(X_shape.GetDims().data(),
|
||||
kernel_shape.data(),
|
||||
dilations.data(),
|
||||
pads.data(),
|
||||
strides.data(),
|
||||
Y_dims.data(),
|
||||
static_cast<size_t>(conv_attrs_.group),
|
||||
X->template Data<float>(),
|
||||
W->template Data<float>(),
|
||||
B != nullptr ? B->template Data<float>() : nullptr,
|
||||
y_data,
|
||||
&activation_,
|
||||
Sum == nullptr,
|
||||
context->GetOperatorThreadPool());
|
||||
MlasNchwcConv(
|
||||
X_shape.GetDims().data(),
|
||||
kernel_shape.data(),
|
||||
dilations.data(),
|
||||
pads.data(),
|
||||
strides.data(),
|
||||
Y_dims.data(),
|
||||
static_cast<size_t>(conv_attrs_.group),
|
||||
X->template Data<float>(),
|
||||
W->template Data<float>(),
|
||||
B != nullptr ? B->template Data<float>() : nullptr,
|
||||
y_data,
|
||||
&activation_,
|
||||
Sum == nullptr,
|
||||
context->GetOperatorThreadPool());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -233,16 +228,17 @@ Status NchwcPoolBase::NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind
|
|||
std::vector<int64_t> output_dims = pool_attrs_.SetOutputSize(X_shape, X_shape[1], &pads);
|
||||
auto* Y = context->Output(0, output_dims);
|
||||
|
||||
MlasNchwcPool(kind,
|
||||
X_shape.GetDims().data(),
|
||||
pool_attrs_.global_pooling ? nullptr : pool_attrs_.kernel_shape.data(),
|
||||
pool_attrs_.global_pooling ? nullptr : pool_attrs_.dilations.data(),
|
||||
pool_attrs_.global_pooling ? nullptr : pads.data(),
|
||||
pool_attrs_.global_pooling ? nullptr : pool_attrs_.strides.data(),
|
||||
output_dims.data(),
|
||||
X->template Data<float>(),
|
||||
Y->template MutableData<float>(),
|
||||
context->GetOperatorThreadPool());
|
||||
MlasNchwcPool(
|
||||
kind,
|
||||
X_shape.GetDims().data(),
|
||||
pool_attrs_.global_pooling ? nullptr : pool_attrs_.kernel_shape.data(),
|
||||
pool_attrs_.global_pooling ? nullptr : pool_attrs_.dilations.data(),
|
||||
pool_attrs_.global_pooling ? nullptr : pads.data(),
|
||||
pool_attrs_.global_pooling ? nullptr : pool_attrs_.strides.data(),
|
||||
output_dims.data(),
|
||||
X->template Data<float>(),
|
||||
Y->template MutableData<float>(),
|
||||
context->GetOperatorThreadPool());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -256,19 +252,124 @@ Status NchwcAveragePool::Compute(OpKernelContext* context) const {
|
|||
: MlasAveragePoolingExcludePad);
|
||||
}
|
||||
|
||||
std::vector<float> NchwcUpsample::ComputeInterpolation(int64_t input_length,
|
||||
int64_t output_length,
|
||||
int64_t scale) const {
|
||||
std::vector<float> interpolation;
|
||||
interpolation.resize(output_length);
|
||||
|
||||
if (scale == 1) {
|
||||
// Identity map for unscaled.
|
||||
for (int64_t o = 0; o < output_length; o++) {
|
||||
interpolation[o] = static_cast<float>(o);
|
||||
}
|
||||
} else if (transformation_mode_ == TransformationMode::ALIGN_CORNERS) {
|
||||
for (int64_t o = 0; o < output_length; o++) {
|
||||
interpolation[o] =
|
||||
static_cast<float>(o) * static_cast<float>(input_length - 1) / static_cast<float>(output_length - 1);
|
||||
}
|
||||
} else if (transformation_mode_ == TransformationMode::HALF_PIXEL) {
|
||||
for (int64_t o = 0; o < output_length; o++) {
|
||||
interpolation[o] =
|
||||
std::max(0.0f, (static_cast<float>(o) + 0.5f) / static_cast<float>(scale) - 0.5f);
|
||||
}
|
||||
} else {
|
||||
// Default to TransformationMode::ASYMMETRIC.
|
||||
for (int64_t o = 0; o < output_length; o++) {
|
||||
interpolation[o] = static_cast<float>(o) / static_cast<float>(scale);
|
||||
}
|
||||
}
|
||||
|
||||
return interpolation;
|
||||
}
|
||||
|
||||
Status NchwcUpsample::Compute(OpKernelContext* context) const {
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
const auto& X_shape = X->Shape();
|
||||
ORT_ENFORCE(X_shape.NumDimensions() == 4);
|
||||
const auto& X_shape = X->Shape().GetDims();
|
||||
ORT_ENFORCE(X_shape.size() == 4);
|
||||
ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
|
||||
|
||||
TensorShape Y_shape{X_shape[0], X_shape[1], X_shape[2] * scales_[2], X_shape[3] * scales_[3]};
|
||||
auto* Y = context->Output(0, Y_shape);
|
||||
const int64_t batch_count = X_shape[0];
|
||||
const int64_t nchwc_channels = X_shape[1];
|
||||
|
||||
MlasNchwcUpsample(X_shape.GetDims().data(),
|
||||
scales_.data() + 2,
|
||||
X->template Data<float>(),
|
||||
Y->template MutableData<float>());
|
||||
const int64_t input_h = X_shape[2];
|
||||
const int64_t input_w = X_shape[3];
|
||||
|
||||
const int64_t output_h = input_h * scales_[2];
|
||||
const int64_t output_w = input_w * scales_[3];
|
||||
|
||||
auto* Y = context->Output(0, {batch_count, nchwc_channels, output_h, output_w});
|
||||
|
||||
// Bail out early if one of the dimensions is zero.
|
||||
if (Y->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto* x_data = X->template Data<float>();
|
||||
auto* y_data = Y->template MutableData<float>();
|
||||
|
||||
if (nearest_mode_) {
|
||||
MlasNchwcUpsampleNearest(
|
||||
X_shape.data(),
|
||||
scales_.data() + 2,
|
||||
x_data,
|
||||
y_data);
|
||||
} else {
|
||||
// Compute the interpolation value per output height and width.
|
||||
const auto interpolation_h = ComputeInterpolation(input_h, output_h, scales_[2]);
|
||||
const auto interpolation_w = ComputeInterpolation(input_w, output_w, scales_[3]);
|
||||
|
||||
const int64_t nchwc_block_size = static_cast<int64_t>(MlasNchwcGetBlockSize());
|
||||
const ptrdiff_t total_work = ((batch_count * nchwc_channels) / nchwc_block_size) * output_h;
|
||||
// Partition the work with the goal of generating the following number of
|
||||
// elements, so that operations involving a smaller number of columns will
|
||||
// process more rows per worker.
|
||||
constexpr ptrdiff_t worker_goal = 16 * 1024;
|
||||
ptrdiff_t work_per_worker = std::max<ptrdiff_t>(worker_goal / (output_w * nchwc_block_size), 1);
|
||||
ptrdiff_t worker_count = std::max<ptrdiff_t>(total_work / work_per_worker, 1);
|
||||
|
||||
auto upsample_worker = [&](ptrdiff_t batch) {
|
||||
auto work = concurrency::ThreadPool::PartitionWork(batch, worker_count, total_work);
|
||||
int64_t work_index = static_cast<int64_t>(work.start);
|
||||
int64_t work_remaining = static_cast<int64_t>(work.end - work.start);
|
||||
|
||||
while (work_remaining > 0) {
|
||||
// Limit the current loop iteration to the same source image.
|
||||
const int64_t channel_index = work_index / output_h;
|
||||
int64_t row_index = work_index % output_h;
|
||||
int64_t rows_this_iteration = std::min(work_remaining, output_h - row_index);
|
||||
|
||||
work_index += rows_this_iteration;
|
||||
work_remaining -= rows_this_iteration;
|
||||
|
||||
const auto* x_channel_base = x_data + (channel_index * input_h * input_w * nchwc_block_size);
|
||||
auto* y_row = y_data + (((channel_index * output_h) + row_index) * output_w * nchwc_block_size);
|
||||
|
||||
// Loop upsampling each row of the output.
|
||||
do {
|
||||
MlasNchwcUpsampleLinear(
|
||||
static_cast<size_t>(input_h),
|
||||
static_cast<size_t>(input_w),
|
||||
static_cast<size_t>(output_w),
|
||||
interpolation_h[row_index],
|
||||
interpolation_w.data(),
|
||||
x_channel_base,
|
||||
y_row);
|
||||
y_row += output_w * nchwc_block_size;
|
||||
row_index++;
|
||||
} while (--rows_this_iteration);
|
||||
}
|
||||
};
|
||||
|
||||
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
|
||||
|
||||
// Handle the work in a single batch if only a single thread is available.
|
||||
if (concurrency::ThreadPool::DegreeOfParallelism(thread_pool) == 1) {
|
||||
worker_count = 1;
|
||||
}
|
||||
|
||||
concurrency::ThreadPool::TrySimpleParallelFor(thread_pool, worker_count, upsample_worker);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
class ReorderInput : public OpKernel {
|
||||
class ReorderInput final : public OpKernel {
|
||||
public:
|
||||
ReorderInput(const OpKernelInfo& info) : OpKernel(info) {
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("channels_last", &channels_last_).IsOK());
|
||||
|
|
@ -24,7 +24,7 @@ class ReorderInput : public OpKernel {
|
|||
int64_t channels_last_;
|
||||
};
|
||||
|
||||
class ReorderOutput : public OpKernel {
|
||||
class ReorderOutput final : public OpKernel {
|
||||
public:
|
||||
ReorderOutput(const OpKernelInfo& info) : OpKernel(info) {
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("channels", &channels_).IsOK());
|
||||
|
|
@ -39,7 +39,7 @@ class ReorderOutput : public OpKernel {
|
|||
int64_t channels_last_;
|
||||
};
|
||||
|
||||
class NchwcConv : public OpKernel {
|
||||
class NchwcConv final : public OpKernel {
|
||||
public:
|
||||
NchwcConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) {
|
||||
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
|
||||
|
|
@ -64,7 +64,7 @@ class NchwcPoolBase : public PoolBase {
|
|||
Status NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind) const;
|
||||
};
|
||||
|
||||
class NchwcMaxPool : public OpKernel, public NchwcPoolBase {
|
||||
class NchwcMaxPool final : public OpKernel, public NchwcPoolBase {
|
||||
public:
|
||||
NchwcMaxPool(const OpKernelInfo& info) : OpKernel(info), NchwcPoolBase(info) {
|
||||
}
|
||||
|
|
@ -72,7 +72,7 @@ class NchwcMaxPool : public OpKernel, public NchwcPoolBase {
|
|||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
class NchwcAveragePool : public OpKernel, public NchwcPoolBase {
|
||||
class NchwcAveragePool final : public OpKernel, public NchwcPoolBase {
|
||||
public:
|
||||
NchwcAveragePool(const OpKernelInfo& info) : OpKernel(info), NchwcPoolBase(info) {
|
||||
}
|
||||
|
|
@ -80,19 +80,55 @@ class NchwcAveragePool : public OpKernel, public NchwcPoolBase {
|
|||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
class NchwcUpsample : public OpKernel {
|
||||
class NchwcUpsample final : public OpKernel {
|
||||
private:
|
||||
enum class TransformationMode {
|
||||
ASYMMETRIC,
|
||||
ALIGN_CORNERS,
|
||||
HALF_PIXEL,
|
||||
};
|
||||
|
||||
public:
|
||||
NchwcUpsample(const OpKernelInfo& info) : OpKernel(info) {
|
||||
ORT_ENFORCE(info.GetAttrs<int64_t>("scales", scales_).IsOK());
|
||||
ORT_ENFORCE(scales_.size() == 4);
|
||||
// Batch and channel dimensions cannot scale and spatial scaling must be positive.
|
||||
ORT_ENFORCE(scales_[0] == 1 && scales_[1] == 1 && scales_[2] >= 1 && scales_[3] >= 1);
|
||||
|
||||
std::string transformation_mode;
|
||||
ORT_ENFORCE(info.GetAttr<std::string>("coordinate_transformation_mode", &transformation_mode).IsOK());
|
||||
if (transformation_mode == "asymmetric") {
|
||||
transformation_mode_ = TransformationMode::ASYMMETRIC;
|
||||
} else if (transformation_mode == "align_corners") {
|
||||
transformation_mode_ = TransformationMode::ALIGN_CORNERS;
|
||||
} else if (transformation_mode == "half_pixel") {
|
||||
transformation_mode_ = TransformationMode::HALF_PIXEL;
|
||||
} else {
|
||||
ORT_THROW("Unsupported transformation mode '" + transformation_mode + "' for NCHWc Upsample");
|
||||
}
|
||||
|
||||
std::string mode;
|
||||
ORT_ENFORCE(info.GetAttr<std::string>("mode", &mode).IsOK());
|
||||
if (mode == "nearest") {
|
||||
nearest_mode_ = true;
|
||||
ORT_ENFORCE(transformation_mode_ == TransformationMode::ASYMMETRIC);
|
||||
} else if (mode == "linear") {
|
||||
nearest_mode_ = false;
|
||||
} else {
|
||||
ORT_THROW("Unsupported mode '" + mode + "' for NCHWc Upsample");
|
||||
}
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
std::vector<float> ComputeInterpolation(int64_t input_length,
|
||||
int64_t output_length,
|
||||
int64_t scale) const;
|
||||
|
||||
std::vector<int64_t> scales_;
|
||||
TransformationMode transformation_mode_;
|
||||
bool nearest_mode_;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -188,6 +188,8 @@ void RegisterNchwcSchemas() {
|
|||
.SinceVersion(1)
|
||||
.SetDoc(R"DOC(For internal use.)DOC")
|
||||
.Attr("scales", "", AttributeProto::INTS, OPTIONAL_VALUE)
|
||||
.Attr("mode", "", AttributeProto::STRING, std::string("nearest"))
|
||||
.Attr("coordinate_transformation_mode", "", AttributeProto::STRING, std::string("asymmetric"))
|
||||
.Input(0, "X", "", "T")
|
||||
.Output(0, "Y", "", "T")
|
||||
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors")
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ struct MLAS_SGEMM_DATA_PARAMS {
|
|||
|
||||
/**
|
||||
* @brief Batched single precision matrix/matrix multiply operation (SGEMM)
|
||||
*
|
||||
*
|
||||
* @param TransA Supplies the transpose operation for matrix A.
|
||||
* @param TransB Supplies the transpose operation for matrix B.
|
||||
* @param M Supplies the number of rows of matrix A and matrix C.
|
||||
|
|
@ -195,7 +195,7 @@ MlasGemmBatch(
|
|||
|
||||
/**
|
||||
* @brief Single precision matrix/matrix multiply operation (SGEMM)
|
||||
*
|
||||
*
|
||||
* @param TransA Supplies the transpose operation for matrix A.
|
||||
* @param TransB Supplies the transpose operation for matrix B.
|
||||
* @param M Supplies the number of rows of matrix A and matrix C.
|
||||
|
|
@ -223,7 +223,7 @@ MlasGemm(
|
|||
|
||||
/**
|
||||
* @brief Single precision matrix/matrix multiply operation (SGEMM)
|
||||
*
|
||||
*
|
||||
* @param TransA Supplies the transpose operation for matrix A.
|
||||
* @param TransB Supplies the transpose operation for matrix B.
|
||||
* @param M Supplies the number of rows of matrix A and matrix C.
|
||||
|
|
@ -231,7 +231,7 @@ MlasGemm(
|
|||
* @param K Supplies the number of columns of matrix A and the number
|
||||
of rows of matrix B.
|
||||
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
|
||||
* @param A Supplies the address of matrix A
|
||||
* @param A Supplies the address of matrix A
|
||||
* @param lda Supplies the first dimension of matrix A.
|
||||
* @param B Supplies the address of matrix B
|
||||
* @param ldb Supplies the first dimension of matrix B.
|
||||
|
|
@ -341,7 +341,7 @@ struct MLAS_DGEMM_DATA_PARAMS {
|
|||
|
||||
/**
|
||||
* @brief Batched double precision matrix/matrix multiply operation (DGEMM)
|
||||
*
|
||||
*
|
||||
* @param TransA Supplies the transpose operation for matrix A.
|
||||
* @param TransB Supplies the transpose operation for matrix B.
|
||||
* @param M Supplies the number of rows of matrix A and matrix C.
|
||||
|
|
@ -369,7 +369,7 @@ MlasGemmBatch(
|
|||
|
||||
/**
|
||||
* @brief Double precision matrix/matrix multiply operation (DGEMM)
|
||||
*
|
||||
*
|
||||
* @param TransA Supplies the transpose operation for matrix A.
|
||||
* @param TransB Supplies the transpose operation for matrix B.
|
||||
* @param M Supplies the number of rows of matrix A and matrix C.
|
||||
|
|
@ -397,7 +397,7 @@ MlasGemm(
|
|||
|
||||
/**
|
||||
* @brief Double precision matrix/matrix multiply operation (DGEMM)
|
||||
*
|
||||
*
|
||||
* @param TransA Supplies the transpose operation for matrix A.
|
||||
* @param TransB Supplies the transpose operation for matrix B.
|
||||
* @param M Supplies the number of rows of matrix A and matrix C.
|
||||
|
|
@ -405,7 +405,7 @@ MlasGemm(
|
|||
* @param K Supplies the number of columns of matrix A and the number
|
||||
of rows of matrix B.
|
||||
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
|
||||
* @param A Supplies the address of matrix A
|
||||
* @param A Supplies the address of matrix A
|
||||
* @param lda Supplies the first dimension of matrix A.
|
||||
* @param B Supplies the address of matrix B
|
||||
* @param ldb Supplies the first dimension of matrix B.
|
||||
|
|
@ -929,13 +929,25 @@ MlasNchwcPool(
|
|||
|
||||
void
|
||||
MLASCALL
|
||||
MlasNchwcUpsample(
|
||||
MlasNchwcUpsampleNearest(
|
||||
const int64_t* InputShape,
|
||||
const int64_t* Scales,
|
||||
const float* Input,
|
||||
float* Output
|
||||
);
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasNchwcUpsampleLinear(
|
||||
size_t InputHeight,
|
||||
size_t InputWidth,
|
||||
size_t OutputWidth,
|
||||
float InterpolationHeight,
|
||||
const float* InterpolationWidth,
|
||||
const float* Input,
|
||||
float* Output
|
||||
);
|
||||
|
||||
//
|
||||
// Linear quantization routines.
|
||||
//
|
||||
|
|
|
|||
|
|
@ -1410,7 +1410,7 @@ Return Value:
|
|||
|
||||
void
|
||||
MLASCALL
|
||||
MlasNchwcUpsample(
|
||||
MlasNchwcUpsampleNearest(
|
||||
const int64_t* InputShape,
|
||||
const int64_t* Scales,
|
||||
const float* Input,
|
||||
|
|
@ -1511,6 +1511,116 @@ Return Value:
|
|||
}
|
||||
}
|
||||
|
||||
MLAS_FORCEINLINE
|
||||
void
|
||||
MlasNchwcExtractInterpolation(
|
||||
float InterpolationValue,
|
||||
size_t InputLimit,
|
||||
ptrdiff_t InputIndex[2],
|
||||
MLAS_FLOAT32X4 Multipliers[2]
|
||||
)
|
||||
{
|
||||
InputIndex[0] = ptrdiff_t(InterpolationValue);
|
||||
InputIndex[1] = std::min<ptrdiff_t>(InputIndex[0] + 1, ptrdiff_t(InputLimit - 1));
|
||||
|
||||
float ScalarMultiplier0 = InterpolationValue - float(InputIndex[0]);
|
||||
float ScalarMultiplier1 = 1.0f - ScalarMultiplier0;
|
||||
|
||||
Multipliers[0] = MlasBroadcastFloat32x4(ScalarMultiplier0);
|
||||
Multipliers[1] = MlasBroadcastFloat32x4(ScalarMultiplier1);
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasNchwcUpsampleLinear(
|
||||
size_t InputHeight,
|
||||
size_t InputWidth,
|
||||
size_t OutputWidth,
|
||||
float InterpolationHeight,
|
||||
const float* InterpolationWidth,
|
||||
const float* Input,
|
||||
float* Output
|
||||
)
|
||||
/*++
|
||||
|
||||
Routine Description:
|
||||
|
||||
This routine implements the NCHWc upsample linear operation for a single row.
|
||||
|
||||
The integer portion of each interpolation float supplies the mapping from
|
||||
output element to input element. The fractional portion supplies the relative
|
||||
weights for the four points of the interpolation.
|
||||
|
||||
Arguments:
|
||||
|
||||
InputHeight - Supplies the input height.
|
||||
|
||||
InputWidth - Supplies the input width.
|
||||
|
||||
OutputWidth - Supplies the output width.
|
||||
|
||||
InterpolationHeight - Supplies the height interpolation values for the target
|
||||
row.
|
||||
|
||||
InterpolationWidth - Supplies an array of computed interpolation values of
|
||||
length OutputWidth.
|
||||
|
||||
Input - Supplies the input spatial buffer.
|
||||
|
||||
Output - Supplies the output row buffer.
|
||||
|
||||
Return Value:
|
||||
|
||||
None.
|
||||
|
||||
--*/
|
||||
{
|
||||
const size_t BlockSize = MlasNchwcGetBlockSize();
|
||||
|
||||
ptrdiff_t InputIndexY[2];
|
||||
MLAS_FLOAT32X4 MultipliersY[2];
|
||||
|
||||
MlasNchwcExtractInterpolation(InterpolationHeight, InputHeight, InputIndexY, MultipliersY);
|
||||
|
||||
const float* InputRowY0 = Input + InputIndexY[0] * InputWidth * BlockSize;
|
||||
const float* InputRowY1 = Input + InputIndexY[1] * InputWidth * BlockSize;
|
||||
|
||||
for (size_t ow = 0; ow < OutputWidth; ow++) {
|
||||
|
||||
ptrdiff_t InputIndexX[2];
|
||||
MLAS_FLOAT32X4 MultipliersX[2];
|
||||
|
||||
MlasNchwcExtractInterpolation(InterpolationWidth[ow], InputWidth, InputIndexX, MultipliersX);
|
||||
|
||||
MLAS_FLOAT32X4 MultiplierY0X0 = MlasMultiplyFloat32x4(MultipliersY[0], MultipliersX[0]);
|
||||
MLAS_FLOAT32X4 MultiplierY0X1 = MlasMultiplyFloat32x4(MultipliersY[0], MultipliersX[1]);
|
||||
MLAS_FLOAT32X4 MultiplierY1X0 = MlasMultiplyFloat32x4(MultipliersY[1], MultipliersX[0]);
|
||||
MLAS_FLOAT32X4 MultiplierY1X1 = MlasMultiplyFloat32x4(MultipliersY[1], MultipliersX[1]);
|
||||
|
||||
for (size_t bc = 0; bc < BlockSize; bc += 4) {
|
||||
|
||||
MLAS_FLOAT32X4 v00 = MlasLoadFloat32x4(InputRowY0 + InputIndexX[0] * BlockSize + bc);
|
||||
MLAS_FLOAT32X4 v01 = MlasLoadFloat32x4(InputRowY0 + InputIndexX[1] * BlockSize + bc);
|
||||
MLAS_FLOAT32X4 v10 = MlasLoadFloat32x4(InputRowY1 + InputIndexX[0] * BlockSize + bc);
|
||||
MLAS_FLOAT32X4 v11 = MlasLoadFloat32x4(InputRowY1 + InputIndexX[1] * BlockSize + bc);
|
||||
|
||||
v00 = MlasMultiplyFloat32x4(MultiplierY1X1, v00);
|
||||
v01 = MlasMultiplyFloat32x4(MultiplierY1X0, v01);
|
||||
v10 = MlasMultiplyFloat32x4(MultiplierY0X1, v10);
|
||||
v11 = MlasMultiplyFloat32x4(MultiplierY0X0, v11);
|
||||
|
||||
MLAS_FLOAT32X4 Reduction0 = MlasAddFloat32x4(v00, v01);
|
||||
MLAS_FLOAT32X4 Reduction1 = MlasAddFloat32x4(v10, v11);
|
||||
|
||||
MLAS_FLOAT32X4 Reduction = MlasAddFloat32x4(Reduction0, Reduction1);
|
||||
|
||||
MlasStoreFloat32x4(&Output[bc], Reduction);
|
||||
}
|
||||
|
||||
Output += BlockSize;
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(MLAS_TARGET_AMD64)
|
||||
|
||||
//
|
||||
|
|
|
|||
|
|
@ -982,14 +982,21 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
|
|||
}
|
||||
auto* nchwc_input = it->second.get();
|
||||
|
||||
// Only support the nearest interpolation mode (the default value).
|
||||
// Support nearest (default) and linear modes.
|
||||
const auto* mode_attr = graph_utils::GetNodeAttribute(node, "mode");
|
||||
bool nearest_mode = true;
|
||||
if (mode_attr != nullptr && utils::HasString(*mode_attr)) {
|
||||
if (mode_attr->s() != "nearest") {
|
||||
return;
|
||||
if (mode_attr->s() == "linear") {
|
||||
nearest_mode = false;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::AttributeProto* transformation_mode_attr = nullptr;
|
||||
|
||||
NodeArg* sizes_arg = nullptr;
|
||||
NodeArg* scales_arg = nullptr;
|
||||
|
||||
|
|
@ -1001,20 +1008,30 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
|
|||
scales_arg = input_defs[2];
|
||||
}
|
||||
|
||||
// Only support the asymmetric coordinate transformation mode.
|
||||
const auto* transform_mode_attr = graph_utils::GetNodeAttribute(node, "coordinate_transformation_mode");
|
||||
if ((transform_mode_attr == nullptr) ||
|
||||
!utils::HasString(*transform_mode_attr) ||
|
||||
(transform_mode_attr->s() != "asymmetric")) {
|
||||
transformation_mode_attr = graph_utils::GetNodeAttribute(node, "coordinate_transformation_mode");
|
||||
if ((transformation_mode_attr == nullptr) ||
|
||||
!utils::HasString(*transformation_mode_attr)) {
|
||||
return;
|
||||
}
|
||||
if (transformation_mode_attr->s() != "asymmetric") {
|
||||
// Nearest mode kernel support asymmetric transformation mode only.
|
||||
if (nearest_mode) {
|
||||
return;
|
||||
}
|
||||
if ((transformation_mode_attr->s() != "align_corners") &&
|
||||
(transformation_mode_attr->s() != "half_pixel")) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Only support the floor rounding mode.
|
||||
const auto* nearest_mode_attr = graph_utils::GetNodeAttribute(node, "nearest_mode");
|
||||
if ((nearest_mode_attr == nullptr) ||
|
||||
!utils::HasString(*nearest_mode_attr) ||
|
||||
(nearest_mode_attr->s() != "floor")) {
|
||||
return;
|
||||
if (nearest_mode) {
|
||||
// Only support the floor rounding mode.
|
||||
const auto* nearest_mode_attr = graph_utils::GetNodeAttribute(node, "nearest_mode");
|
||||
if ((nearest_mode_attr == nullptr) ||
|
||||
!utils::HasString(*nearest_mode_attr) ||
|
||||
(nearest_mode_attr->s() != "floor")) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scales_arg = input_defs[1];
|
||||
|
|
@ -1099,6 +1116,12 @@ void NchwcTransformerImpl::TransformResize(Node& node) {
|
|||
kMSNchwcDomain);
|
||||
nchwc_node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
nchwc_node.AddAttribute("scales", scales_attr);
|
||||
if (!nearest_mode) {
|
||||
nchwc_node.AddAttribute("mode", mode_attr->s());
|
||||
if (transformation_mode_attr != nullptr) {
|
||||
nchwc_node.AddAttribute("coordinate_transformation_mode", transformation_mode_attr->s());
|
||||
}
|
||||
}
|
||||
|
||||
nchwc_input->remaining_original_uses_--;
|
||||
|
||||
|
|
|
|||
|
|
@ -1085,7 +1085,7 @@ TEST(NchwcOptimizerTests, BatchNormalization) {
|
|||
|
||||
// Override the sample tolerance for this test. By default, the NCHWc
|
||||
// tests generate bit identical results when run with and without
|
||||
// optimizations, but the BatchNormalizationtransform does introduce
|
||||
// optimizations, but the BatchNormalization transform does introduce
|
||||
// small bit differences.
|
||||
helper.per_sample_tolerance_ = .00025;
|
||||
};
|
||||
|
|
@ -1213,7 +1213,7 @@ TEST(NchwcOptimizerTests, ConvReorderOutputCnhw) {
|
|||
NchwcOptimizerTester(build_test_case, check_nchwc_graph);
|
||||
}
|
||||
|
||||
TEST(NchwcOptimizerTests, Upsample) {
|
||||
TEST(NchwcOptimizerTests, UpsampleNearest) {
|
||||
auto test_case = [&](int opset_version, float scale_h, float scale_w, bool use_sizes_arg) {
|
||||
auto build_test_case = [&](NchwcTestHelper& helper) {
|
||||
auto* input_arg = helper.MakeInput<float>({3, 16, 27, 15});
|
||||
|
|
@ -1283,6 +1283,61 @@ TEST(NchwcOptimizerTests, Upsample) {
|
|||
test_case(13, 2.2f, 2.8f, true);
|
||||
}
|
||||
|
||||
TEST(NchwcOptimizerTests, UpsampleLinear) {
|
||||
auto test_case = [&](int opset_version, float scale_h, float scale_w, const std::string& transformation_mode) {
|
||||
auto build_test_case = [&](NchwcTestHelper& helper) {
|
||||
auto* input_arg = helper.MakeInput<float>({3, 16, 21, 25});
|
||||
auto* conv_output_arg = helper.MakeIntermediate();
|
||||
auto* output_arg = helper.MakeOutput();
|
||||
|
||||
helper.AddConvNode(input_arg, conv_output_arg, {28, 16, 1, 1});
|
||||
|
||||
std::string op_name = opset_version >= 10 ? "Resize" : "Upsample";
|
||||
std::vector<NodeArg*> input_args;
|
||||
input_args.push_back(conv_output_arg);
|
||||
if (opset_version >= 11) {
|
||||
input_args.push_back(helper.Make1DInitializer<float>({0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f}));
|
||||
}
|
||||
input_args.push_back(helper.Make1DInitializer<float>({1.f, 1.f, scale_h, scale_w}));
|
||||
Node& resize_node = helper.AddNode(op_name, input_args, {output_arg});
|
||||
resize_node.AddAttribute("mode", "linear");
|
||||
if (opset_version >= 11) {
|
||||
resize_node.AddAttribute("coordinate_transformation_mode", transformation_mode);
|
||||
}
|
||||
|
||||
helper.per_sample_tolerance_ = .001f;
|
||||
};
|
||||
|
||||
auto check_nchwc_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 1);
|
||||
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1);
|
||||
EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1);
|
||||
EXPECT_EQ(op_to_count["com.microsoft.nchwc.Upsample"], 1);
|
||||
EXPECT_EQ(op_to_count["Resize"] + op_to_count["Upsample"], 0);
|
||||
};
|
||||
|
||||
NchwcOptimizerTester(build_test_case, check_nchwc_graph, opset_version);
|
||||
};
|
||||
|
||||
// Verify that upsample nodes can be converted to the NCHWc format for
|
||||
// various versions of the operator.
|
||||
std::vector<std::string> transformation_modes{"asymmetric", "align_corners", "half_pixel"};
|
||||
for (auto& transformation_mode : transformation_modes) {
|
||||
static const int opset_versions[] = {9, 10, 11, 13};
|
||||
for (auto opset_version : opset_versions) {
|
||||
// Older versions of the operator do not support transformation modes.
|
||||
if (opset_version < 11 && transformation_mode == "asymmetric") {
|
||||
continue;
|
||||
}
|
||||
test_case(opset_version, 1.f, 1.f, transformation_mode);
|
||||
test_case(opset_version, 2.f, 2.f, transformation_mode);
|
||||
test_case(opset_version, 3.f, 5.f, transformation_mode);
|
||||
test_case(opset_version, 9.f, 7.f, transformation_mode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NchwcOptimizerTests, Activation) {
|
||||
auto test_case = [&](const std::string& activation_op_type) {
|
||||
auto build_test_case = [&](NchwcTestHelper& helper) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue