From 4f309f05ca580e7ee641cb6dcdd95d5de850b21b Mon Sep 17 00:00:00 2001 From: JiCheng Date: Sat, 14 Jan 2023 06:57:23 +0800 Subject: [PATCH] [CPU] Resize of Opset 18 (#13890) ### Description To Implement Resize 18. This PR depends on https://github.com/microsoft/onnxruntime/pull/13765. ### Motivation and Context --- ThirdPartyNotices.txt | 33 + cgmanifests/cgmanifest.json | 12 +- .../InferenceTest.netcore.cs | 15 +- .../core/providers/cpu/tensor/upsample.cc | 210 +++-- .../providers/cpu/tensor/upsample_antialias.h | 770 ++++++++++++++++++ .../core/providers/cpu/tensor/upsamplebase.h | 225 ++++- .../core/providers/cuda/tensor/upsample.cc | 25 +- .../core/providers/xnnpack/nn/resize.cc | 14 +- onnxruntime/test/onnx/main.cc | 13 - .../providers/cpu/tensor/resize_op_test.cc | 440 +++++++++- .../onnx_backend_test_series_filters.jsonc | 1 - 11 files changed, 1577 insertions(+), 181 deletions(-) create mode 100644 onnxruntime/core/providers/cpu/tensor/upsample_antialias.h diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index 137727883f..e215bc8d2a 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -5166,3 +5166,36 @@ THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +_____ + +The Python Imaging Library (PIL) is + + Copyright © 1997-2011 by Secret Labs AB + Copyright © 1995-2011 by Fredrik Lundh + +Pillow is the friendly PIL fork. It is + + Copyright © 2010-2023 by Alex Clark and contributors + +Like PIL, Pillow is licensed under the open source HPND License: + +By obtaining, using, and/or copying this software and/or its associated +documentation, you agree that you have read, understood, and will comply +with the following terms and conditions: + +Permission to use, copy, modify, and distribute this software and its +associated documentation for any purpose and without fee is hereby granted, +provided that the above copyright notice appears in all copies, and that +both that copyright notice and this permission notice appear in supporting +documentation, and that the name of Secret Labs AB or the author not be +used in advertising or publicity pertaining to distribution of the software +without specific, written prior permission. + +SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS +SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. +IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL, +INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE +OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +PERFORMANCE OF THIS SOFTWARE. diff --git a/cgmanifests/cgmanifest.json b/cgmanifests/cgmanifest.json index 62b4cffd27..72327806e4 100644 --- a/cgmanifests/cgmanifest.json +++ b/cgmanifests/cgmanifest.json @@ -553,7 +553,17 @@ }, "comments": "dlfcn-win32" } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "6812205f18ca4ef54372e87e1a13ce4a859434df", + "repositoryUrl": "https://github.com/python-pillow/Pillow.git" + }, + "comments": "python-pillow. Implementation logic for anti-aliasing copied by Resize CPU kernel." + } } ], "Version": 1 -} +} \ No newline at end of file diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index c520e4bef6..c34259fb96 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -380,7 +380,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests { "test_sequence_map_add_2_sequences", "sequence type is not supported in test infra." }, { "test_sequence_map_identity_1_sequence", "sequence type is not supported in test infra." }, { "BERT-Squad-int8", "training domain"}, - { "YOLOv3-12-int8", "training_domain"}, + { "YOLOv3-12-int8", "training_domain"}, // opset 18 models. these should be supported by ORT 1.14 when released { "test_bitwise_and_i16_3d", "pending opset 18 support"}, { "test_bitwise_and_i32_2d", "pending opset 18 support"}, @@ -416,19 +416,6 @@ namespace Microsoft.ML.OnnxRuntime.Tests { "test_constant_pad_axes", "pending opset 18 support"}, { "test_edge_pad", "pending opset 18 support"}, { "test_reflect_pad", "pending opset 18 support"}, - { "test_resize_downsample_scales_cubic_antialias", "pending opset 18 support"}, - { "test_resize_downsample_scales_linear_antialias", "pending opset 18 support"}, - { "test_resize_downsample_sizes_cubic_antialias", "pending opset 18 support"}, - { "test_resize_downsample_sizes_linear_antialias", "pending opset 18 support"}, - { "test_resize_downsample_sizes_nearest_not_larger", "pending opset 18 support"}, - { "test_resize_downsample_sizes_nearest_not_smaller", "pending opset 18 support"}, - { "test_resize_tf_crop_and_resize_axes_2_3", "pending opset 18 support"}, - { "test_resize_tf_crop_and_resize_axes_3_2", "pending opset 18 support"}, - { "test_resize_upsample_scales_nearest_axes_2_3", "pending opset 18 support"}, - { "test_resize_upsample_scales_nearest_axes_3_2", "pending opset 18 support"}, - { "test_resize_upsample_sizes_nearest_axes_2_3", "pending opset 18 support"}, - { "test_resize_upsample_sizes_nearest_axes_3_2", "pending opset 18 support"}, - { "test_resize_upsample_sizes_nearest_not_larger", "pending opset 18 support"}, { "test_scatter_elements_with_axis", "pending opset 18 support"}, { "test_scatter_elements_without_axis", "pending opset 18 support"}, { "test_scatter_elements_with_duplicate_indices", "pending opset 18 support"}, diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index f72e24b0db..29ff7cc7a7 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -4,7 +4,7 @@ #include "core/common/safeint.h" #include "core/platform/threadpool.h" #include "core/providers/cpu/tensor/upsample.h" - +#include "core/providers/cpu/tensor/upsample_antialias.h" using namespace onnxruntime::common; using namespace std; using onnxruntime::narrow; @@ -1071,9 +1071,8 @@ Status Upsample::BaseCompute(OpKernelContext* context, const std::vector& scales, const gsl::span& output_dims) const { const auto* X = context->Input(0); - ORT_ENFORCE(X != nullptr); auto dims = X->Shape().GetDims(); - ORT_ENFORCE(output_dims.size() == dims.size(), "Rank of input and output tensor should be same."); + ORT_RETURN_IF_NOT(output_dims.size() == dims.size(), "Rank of input and output tensor should be same."); Tensor* Y = context->Output(0, output_dims); @@ -1087,7 +1086,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, is_resize_ ? "Resize: input tensor's dimension does not match the scales." : "Upsample: input tensor's dimension does not match the scales."); - if (roi.size() != 2 * X->Shape().GetDims().size()) + if (roi.size() != 2 * dims.size()) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Resize: size of roi array should be 2 * N where N is the rank of input tensor X."); @@ -1100,7 +1099,8 @@ Status Upsample::BaseCompute(OpKernelContext* context, memcpy(Y->MutableDataRaw(), X->DataRaw(), Y->SizeInBytes()); return Status::OK(); } - + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); switch (mode_) { case UpsampleMode::NN: return UpsampleNearest(X->Data(), Y->MutableData(), X->Shape(), Y->Shape(), @@ -1150,7 +1150,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, height_scale = scales[2]; width_scale = scales[3]; } else { - ORT_ENFORCE(scales[3] == 1.0f, "4-D input with innermost scale (usually channel of NHWC) as 1."); + ORT_RETURN_IF_NOT(scales[3] == 1.0f, "4-D input with innermost scale (usually channel of NHWC) as 1."); is_nchw = false; batch_size = static_cast(dims[0]); @@ -1166,46 +1166,65 @@ Status Upsample::BaseCompute(OpKernelContext* context, } } - AllocatorPtr alloc; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); if (is_nchw) { - UpsampleBilinear(batch_size, num_channels, input_height, input_width, output_height, output_width, - height_scale, width_scale, roi, - use_extrapolation_, extrapolation_value_, X->Data(), - Y->MutableData(), alloc, get_original_coordinate_, - output_height * output_width > 64 ? context->GetOperatorThreadPool() : nullptr); + if (antialias_) { + UpsampleBilinearAntiAlias(batch_size, num_channels, input_height, input_width, output_height, output_width, + height_scale, width_scale, roi, use_extrapolation_, extrapolation_value_, exclude_outside_, + X, Y->MutableData(), alloc, get_original_coordinate_, + output_height * output_width > 64 ? context->GetOperatorThreadPool() : nullptr); + } else { + UpsampleBilinear(batch_size, num_channels, input_height, input_width, output_height, output_width, + height_scale, width_scale, roi, + use_extrapolation_, extrapolation_value_, X->Data(), + Y->MutableData(), alloc, get_original_coordinate_, + output_height * output_width > 64 ? context->GetOperatorThreadPool() : nullptr); + } } else { if (use_extrapolation_) { - if (!is_2D && - (Y->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_UINT8 || - Y->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT8)) { - NhwcUpsampleBilinearInteger( - batch_size, num_channels, input_height, input_width, output_height, output_width, - height_scale, width_scale, roi, extrapolation_value_, X->Data(), Y->MutableData(), - alloc, get_original_coordinate_, - output_height * output_width * num_channels > 64 ? context->GetOperatorThreadPool() : nullptr); + if (antialias_) { + NhwcUpsampleBilinearAntiAlias(batch_size, num_channels, input_height, input_width, output_height, output_width, + height_scale, width_scale, roi, use_extrapolation_, extrapolation_value_, exclude_outside_, + X, Y->MutableData(), alloc, get_original_coordinate_, + output_height * output_width > 64 ? context->GetOperatorThreadPool() : nullptr); } else { - NhwcUpsampleBilinear( - batch_size, num_channels, input_height, input_width, output_height, output_width, - height_scale, width_scale, roi, extrapolation_value_, X->Data(), Y->MutableData(), - alloc, get_original_coordinate_, - output_height * output_width * num_channels > 64 ? context->GetOperatorThreadPool() : nullptr); + if (!is_2D && + (Y->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_UINT8 || + Y->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT8)) { + NhwcUpsampleBilinearInteger( + batch_size, num_channels, input_height, input_width, output_height, output_width, + height_scale, width_scale, roi, extrapolation_value_, X->Data(), Y->MutableData(), + alloc, get_original_coordinate_, + output_height * output_width * num_channels > 64 ? context->GetOperatorThreadPool() : nullptr); + } else { + NhwcUpsampleBilinear( + batch_size, num_channels, input_height, input_width, output_height, output_width, + height_scale, width_scale, roi, extrapolation_value_, X->Data(), Y->MutableData(), + alloc, get_original_coordinate_, + output_height * output_width * num_channels > 64 ? context->GetOperatorThreadPool() : nullptr); + } } } else { - if (!is_2D && - (Y->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_UINT8 || - Y->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT8)) { - NhwcUpsampleBilinearInteger( - batch_size, num_channels, input_height, input_width, output_height, output_width, - height_scale, width_scale, roi, extrapolation_value_, X->Data(), Y->MutableData(), - alloc, get_original_coordinate_, - output_height * output_width * num_channels > 64 ? context->GetOperatorThreadPool() : nullptr); + if (antialias_) { + NhwcUpsampleBilinearAntiAlias(batch_size, num_channels, input_height, input_width, output_height, output_width, + height_scale, width_scale, roi, use_extrapolation_, extrapolation_value_, exclude_outside_, + X, Y->MutableData(), alloc, get_original_coordinate_, + output_height * output_width > 64 ? context->GetOperatorThreadPool() : nullptr); } else { - NhwcUpsampleBilinear( - batch_size, num_channels, input_height, input_width, output_height, output_width, - height_scale, width_scale, roi, extrapolation_value_, X->Data(), Y->MutableData(), - alloc, get_original_coordinate_, - output_height * output_width * num_channels > 64 ? context->GetOperatorThreadPool() : nullptr); + if (!is_2D && + (Y->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_UINT8 || + Y->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT8)) { + NhwcUpsampleBilinearInteger( + batch_size, num_channels, input_height, input_width, output_height, output_width, + height_scale, width_scale, roi, extrapolation_value_, X->Data(), Y->MutableData(), + alloc, get_original_coordinate_, + output_height * output_width * num_channels > 64 ? context->GetOperatorThreadPool() : nullptr); + } else { + NhwcUpsampleBilinear( + batch_size, num_channels, input_height, input_width, output_height, output_width, + height_scale, width_scale, roi, extrapolation_value_, X->Data(), Y->MutableData(), + alloc, get_original_coordinate_, + output_height * output_width * num_channels > 64 ? context->GetOperatorThreadPool() : nullptr); + } } } } @@ -1224,14 +1243,21 @@ Status Upsample::BaseCompute(OpKernelContext* context, const int64_t output_height = is_3D ? output_dims[1] : output_dims[3]; const int64_t output_width = is_3D ? output_dims[2] : output_dims[4]; - AllocatorPtr alloc; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); - UpsampleTrilinear(batch_size, num_channels, input_depth, input_height, input_width, - output_depth, output_height, output_width, - is_3D ? scales[0] : scales[2], is_3D ? scales[1] : scales[3], - is_3D ? scales[2] : scales[4], roi, use_extrapolation_, extrapolation_value_, - X->Data(), Y->MutableData(), alloc, get_original_coordinate_, - output_height * output_width > 64 ? context->GetOperatorThreadPool() : nullptr); + if (antialias_) { + UpsampleTrilinearAntiAlias(batch_size, num_channels, input_depth, input_height, input_width, + output_depth, output_height, output_width, + is_3D ? scales[0] : scales[2], is_3D ? scales[1] : scales[3], + is_3D ? scales[2] : scales[4], roi, use_extrapolation_, extrapolation_value_, + exclude_outside_, X, Y->MutableData(), alloc, get_original_coordinate_, + output_height * output_width > 64 ? context->GetOperatorThreadPool() : nullptr); + } else { + UpsampleTrilinear(batch_size, num_channels, input_depth, input_height, input_width, + output_depth, output_height, output_width, + is_3D ? scales[0] : scales[2], is_3D ? scales[1] : scales[3], + is_3D ? scales[2] : scales[4], roi, use_extrapolation_, extrapolation_value_, + X->Data(), Y->MutableData(), alloc, get_original_coordinate_, + output_height * output_width > 64 ? context->GetOperatorThreadPool() : nullptr); + } return Status::OK(); } else { // User shouldn't hit this as the check has been performed in ScalesValidation() @@ -1252,17 +1278,37 @@ Status Upsample::BaseCompute(OpKernelContext* context, } bool is_2D = dims.size() == 2; - const int64_t batch_size = is_2D ? 1 : dims[0]; - const int64_t num_channels = is_2D ? 1 : dims[1]; - const int64_t input_height = is_2D ? dims[0] : dims[2]; - const int64_t input_width = is_2D ? dims[1] : dims[3]; - const int64_t output_height = is_2D ? output_dims[0] : output_dims[2]; - const int64_t output_width = is_2D ? output_dims[1] : output_dims[3]; + bool is_nchw = is_2D ? true : (scales[1] == 1.0f); - ResizeBiCubic(batch_size, num_channels, input_height, input_width, output_height, output_width, - is_2D ? scales[0] : scales[2], is_2D ? scales[1] : scales[3], cubic_coeff_a_, use_extrapolation_, - extrapolation_value_, exclude_outside_, roi, X->Data(), - Y->MutableData(), get_original_coordinate_); + const int64_t batch_size = is_2D ? 1 : dims[0]; + const int64_t num_channels = is_2D ? 1 : (is_nchw ? dims[1] : dims[3]); + const int64_t input_height = is_2D ? dims[0] : (is_nchw ? dims[2] : dims[1]); + const int64_t input_width = is_2D ? dims[1] : (is_nchw ? dims[3] : dims[2]); + const int64_t output_height = is_2D ? output_dims[0] : (is_nchw ? output_dims[2] : output_dims[1]); + const int64_t output_width = is_2D ? output_dims[1] : (is_nchw ? output_dims[3] : output_dims[2]); + const float height_scale = is_2D ? scales[0] : (is_nchw ? scales[2] : scales[1]); + const float width_scale = is_2D ? scales[1] : (is_nchw ? scales[3] : scales[2]); + + if (antialias_) { + if (!is_nchw) { + NhwcResizeBiCubicAntiAlias(batch_size, num_channels, input_height, input_width, output_height, output_width, + height_scale, width_scale, cubic_coeff_a_, use_extrapolation_, + extrapolation_value_, exclude_outside_, roi, X, + Y->MutableData(), alloc, get_original_coordinate_, + output_height * output_width * num_channels > 64 ? context->GetOperatorThreadPool() : nullptr); + } else { + ResizeBiCubicAntiAlias(batch_size, num_channels, input_height, input_width, output_height, output_width, + height_scale, width_scale, cubic_coeff_a_, use_extrapolation_, + extrapolation_value_, exclude_outside_, roi, X, + Y->MutableData(), alloc, get_original_coordinate_, + output_height * output_width * num_channels > 64 ? context->GetOperatorThreadPool() : nullptr); + } + } else { + ResizeBiCubic(batch_size, num_channels, input_height, input_width, output_height, output_width, + height_scale, width_scale, cubic_coeff_a_, use_extrapolation_, + extrapolation_value_, exclude_outside_, roi, X->Data(), + Y->MutableData(), get_original_coordinate_); + } return Status::OK(); } default: @@ -1273,21 +1319,20 @@ Status Upsample::BaseCompute(OpKernelContext* context, template Status Upsample::Compute(OpKernelContext* context) const { const auto* X = context->Input(0); - ORT_ENFORCE(X != nullptr); - TensorShapeVector output_dims(X->Shape().GetDims().size()); + auto input_dims = X->Shape().GetDims(); + TensorShapeVector output_dims(input_dims.size()); // Get roi data // Initialize the roi array to all zeros as this will be the most common case // Roi data is needed only when coordinate transformation mode is set to tf_crop_and_resize // for all other cases we need a 0 initialized roi array - std::vector roi_array; - const std::vector* roi_ptr = roi_cached_ ? &roi_ : &roi_array; + std::vector roi_array(roi_); if (!roi_cached_) { bool use_default_roi = true; if (need_roi_input_) { - ORT_ENFORCE(roi_input_idx_ > 0, "Invalid roi input index."); + ORT_RETURN_IF_NOT(roi_input_idx_ > 0, "Invalid roi input index."); const auto* roi = context->Input(roi_input_idx_); if (roi != nullptr) { ParseRoiData(roi, roi_array); @@ -1297,7 +1342,6 @@ Status Upsample::Compute(OpKernelContext* context) const { if (use_default_roi) { // default roi includes ensures all the values in that axis are included in the roi // normalized roi is thus : [start, end] = [0, 1] - const auto& input_dims = X->Shape().GetDims(); size_t input_rank = input_dims.size(); roi_array.resize(input_rank * 2); for (size_t i = 0; i < input_rank; ++i) { @@ -1307,44 +1351,44 @@ Status Upsample::Compute(OpKernelContext* context) const { } } + ComputeROIWithAxes(roi_array, input_dims.size()); + // Get scales data + std::vector scales_array(input_dims.size()); + if (OpKernel::Node().InputDefs().size() == 1) { // Compute output shape from scales and input dims - ComputeOutputShape(scales_, X->Shape().GetDims(), output_dims); - return BaseCompute(context, *roi_ptr, scales_, output_dims); + scales_array = scales_; + + ComputeOutputShape(scales_array, input_dims, output_dims); + return BaseCompute(context, roi_array, scales_array, output_dims); } const auto* scales = context->Input(scales_input_idx_); const auto* sizes = context->Input(sizes_input_idx_); + // Get scales data if (scales_cached_) { - ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input."); - + ORT_RETURN_IF_NOT(sizes == nullptr, "Only one of scales or sizes must be provided as input."); + scales_array = scales_; // Compute output shape from scales and input dims - ComputeOutputShape(scales_, X->Shape().GetDims(), output_dims); - return BaseCompute(context, *roi_ptr, scales_, output_dims); + ComputeOutputShape(scales_array, input_dims, output_dims); + return BaseCompute(context, roi_array, scales_array, output_dims); } - // Get scales data - std::vector scales_array(X->Shape().GetDims().size()); - if (scales != nullptr && scales->Shape().Size() != 0) { - ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input."); - ParseScalesData(scales, scales_array); + ORT_RETURN_IF_NOT(sizes == nullptr, "Only one of scales or sizes must be provided as input."); + ORT_RETURN_IF_ERROR(ParseScalesData(scales, scales_array, input_dims.size())); // Compute output shape from scales and input dims - ComputeOutputShape(scales_array, X->Shape().GetDims(), output_dims); + ComputeOutputShape(scales_array, input_dims, output_dims); } else { - ORT_ENFORCE(sizes != nullptr && sizes->Shape().Size() != 0, "Either scales or sizes MUST be provided as input."); + ORT_RETURN_IF_NOT(sizes != nullptr && sizes->Shape().Size() != 0, "Either scales or sizes MUST be provided as input."); - // When sizes input is available directly populate it into the output_dims array. - memcpy(output_dims.data(), sizes->template Data(), SafeInt(sizes->Shape().Size())* sizeof(int64_t)); + ORT_RETURN_IF_ERROR(ParseSizesData(sizes, output_dims, input_dims)); - ORT_ENFORCE(X->Shape().GetDims().size() == output_dims.size(), - "Resize: input tensor's rank does not match the output tensor's rank."); - - ParseScalesDataFromOutputSize(output_dims, X->Shape().GetDims(), scales_array); + ORT_RETURN_IF_ERROR(ParseScalesDataAndAdjustOutputSize(output_dims, input_dims, scales_array)); } - return BaseCompute(context, *roi_ptr, scales_array, output_dims); + return BaseCompute(context, roi_array, scales_array, output_dims); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h new file mode 100644 index 0000000000..e5e641e96f --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h @@ -0,0 +1,770 @@ +// Copyright c Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/* + * Pillow 's Resize is corresponding to ONNX op with exclude_outside equaling 1. + * And, for cubic mode, PIllow has a default value of 0.5 for "cubic_coeff_a", + * while ONNX op has a default value of 0.75. + */ + +#pragma once + +#include +#include // for round +#include +#include "core/framework/tensor.h" +#include "gsl/span" +#ifndef SHARED_PROVIDER +#include "core/framework/op_kernel.h" +#endif +#include "core/providers/cpu/tensor/upsamplebase.h" + +namespace onnxruntime { + +namespace ConstValue { +constexpr int32_t mag_factor = 1 << (22 - 1); +} + +namespace { +const uint8_t* GetLookupTableShared() { + // initialized once + static const auto* lookup_table = []() { + // if we have already initialized the lookup table, just return + // ideally we could have a global lookup table, but that account for too much space. + /* Handles values form -640 to 639. */ + static uint8_t table[1280] = {0}; + + // taken from https://github.com/python-pillow/Pillow/blob/66add095a50d76c35c7f58643461f2edf78a3f05/src/libImaging/Resample.c#L94 + // we need to handle negative values + // it's equivalent to :x = np.clip(x, 0, 255) where x \in [-640, 639] + // we will accept a negative x for (&table[640])[x] means table +640 -x + for (int i = 0; i < 1280; ++i) { + table[i] = static_cast(std::min(std::max(i - 640, 0), 255)); + } + return table; + }(); + return lookup_table; +} +} // namespace + +template +struct FilterParamsBaseAntiAlias { + std::vector bound; + std::vector out_of_bound_idx; + int64_t window_size = 2; + IAllocatorUniquePtr weight_coefficients; +}; + +template +struct FilterParamsAntiAlias { + float support_size = 2.0f; + float cubic_coeff_a = -0.75f; + + FilterParamsBaseAntiAlias dim_x; + FilterParamsBaseAntiAlias dim_y; + FilterParamsBaseAntiAlias dim_z; + + const uint8_t* GetClip8LookupTable() const { + return GetLookupTableShared(); + } + virtual ~FilterParamsAntiAlias() = default; + virtual float Filter(float x) const = 0; +}; + +template +struct BilinearParamsAntiAlias : FilterParamsAntiAlias { + // taken from + // https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/src/libImaging/Resample.c#L20-L29 + float Filter(float x) const override { + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return 1.0f - x; + } + return 0.0f; + } +}; + +template +struct BiCubicParamsAntiAlias : FilterParamsAntiAlias { + BiCubicParamsAntiAlias() { + this->support_size = 4.0f; + } + + // taken from + // https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ + // src/libImaging/Resample.c + float Filter(float x) const override { + /* https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm + */ + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return ((this->cubic_coeff_a + 2.0f) * x - (this->cubic_coeff_a + 3.0f)) * x * x + 1; + } + if (x < 2.0f) { + return (((x - 5.0f) * x + 8.f) * x - 4.f) * this->cubic_coeff_a; + } + return 0.0f; + } +}; + +template +struct TriLinearParamsAntiAlias : FilterParamsAntiAlias { + float Filter(float x) const override { + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return 1.0f - x; + } + return 0.0f; + } +}; + +template +struct AccumulateType { + using type = int32_t; + using Dtype = T; +}; + +template <> +struct AccumulateType { + using type = float; +}; + +template <> +struct AccumulateType { + using type = float; +}; + +template <> +struct AccumulateType { + using type = double; +}; + +// The following method supports a 3/4/5-D input in 'Linear mode, cubic mode' +// that amounts to 'Bilinear,TriLinear, Bicubic/Tricubic' Upsampling/Resizing in the sense that it assumes +// A N-D tensor has +// 1. the scale values for the outermost 2 dimensions are 1 or +// 2. the scale values for the outermost and innermost dimensions are 1 +// This is the common use-case where the 4-D input (batched multi-channel images) +// is usually of shapes: +// - [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale] +// - [N, H, W, C] and the scales are [1.0, height_scale, width_scale, 1.0] +template +void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p, + const gsl::span input_h_w_c, + const gsl::span output_h_w_c, + const gsl::span scale_h_w_c, + const std::vector& roi, + AllocatorPtr& alloc, + const GetOriginalCoordinateFunc& get_original_coordinate, + bool exclude_outside, const bool is_nchw) { + auto compute_weight_coefficients = [&alloc, &roi, &get_original_coordinate, exclude_outside](const FilterParamsAntiAlias& p, + const int64_t input_size, + const int64_t output_size, + size_t rindex, + FilterParamsBaseAntiAlias& param_base, + const float rscale) -> int64_t { + param_base.bound.reserve(static_cast(output_size) * 2); + param_base.out_of_bound_idx.reserve(static_cast(output_size)); + + float scale = 1.0f / rscale; + float support = (scale >= 1.0f) ? (p.support_size * 0.5f) * scale : p.support_size * 0.5f; + + int32_t window_size = narrow(ceilf(support)) * 2 + 1; + const size_t scale_buffer_size = narrow(window_size * output_size); + + param_base.weight_coefficients = IAllocator::MakeUniquePtr(alloc, scale_buffer_size); + // Get pointers to appropriate memory locations in the scratch buffer + auto* scale_data = reinterpret_cast(param_base.weight_coefficients.get()); + int64_t xmin = 0, xmax = 0; + float inv_scale = (scale >= 1.0f) ? 1.0f / scale : 1.0f; + + const auto roi_start = roi.size() / 2 - (rindex + 1); + const auto roi_end = roi.size() - (rindex + 1); + + for (int32_t i = 0; i < output_size; i++) { + // double center = (i + 0.5) * scale; + float center = 0.5f + (scale == 1.0f ? static_cast(i) + : get_original_coordinate(static_cast(i), rscale, + static_cast(output_size), + static_cast(input_size), + roi[roi_start], roi[roi_end])); + if (center - 0.5f < 0 || center - 0.5f > narrow(input_size - 1)) { + param_base.out_of_bound_idx.emplace_back(i); + } + float total_weight = 0.0; + + auto fmin = std::floor(center - support + 0.5f); + auto fmax = std::floor(center + support + 0.5f); + int64_t xmin_real = static_cast(fmin); + int64_t xmax_real = static_cast(fmax); + int64_t xmin_cut = std::max(xmin_real, (0)); + int64_t xmax_cut = std::min(xmax_real, input_size); + + xmin = exclude_outside ? xmin_cut : xmin_real; + xmax = exclude_outside ? xmax_cut : xmax_real; + param_base.bound.push_back(xmin_cut); + param_base.bound.push_back(xmax_cut); + + auto* scale_buffer = &scale_data[i * window_size]; + int64_t x = 0; + xmax -= xmin; + for (; x < xmax; x++) { + float w = p.Filter((x + xmin - center + 0.5f) * inv_scale); + scale_buffer[x] = w; + total_weight += w; + } + + if (!exclude_outside) { + int64_t neg_xsize = xmin < 0 ? -xmin : 0; + for (x = 0; x < neg_xsize; x++) { + scale_buffer[neg_xsize] += scale_buffer[x]; + } + + int64_t bound_xsize = + xmax + xmin > input_size ? xmax + xmin - input_size : 0; + for (x = xmax - bound_xsize; x < xmax; x++) { + scale_buffer[xmax - bound_xsize - 1] += + scale_buffer[x]; + } + + for (x = 0; (neg_xsize | bound_xsize) > 0 && x < xmax_cut - xmin_cut; x++) { + scale_buffer[x] = scale_buffer[x + neg_xsize]; + } + } + + float total_weight_inv = total_weight == 0.0f ? 1.f : 1.0f / total_weight; + auto* scale_buffer_int = reinterpret_cast(scale_buffer); + for (x = 0; x < xmax_cut - xmin_cut; x++) { + scale_buffer[x] *= total_weight_inv; + + // normalize the scale to 1 << 22 for int8/uint8 + if constexpr (std::is_same::value) { + scale_buffer_int[x] = static_cast(std::round(scale_buffer[x] * ConstValue::mag_factor * 2.f)); + } + } + /*for (; x < window_size; x++) { + scale_buffer[x] = 0; + }*/ + } + return window_size; + }; + + const size_t width_rindex = is_nchw ? 0 : 1; + const size_t height_rindex = is_nchw ? 1 : 2; + const size_t channel_rindex = is_nchw ? 2 : 2; // only works for trilinear NC(chw) + + p.dim_x.window_size = compute_weight_coefficients(p, input_h_w_c[1], output_h_w_c[1], width_rindex, + p.dim_x, scale_h_w_c[1]); + p.dim_y.window_size = compute_weight_coefficients(p, input_h_w_c[0], output_h_w_c[0], height_rindex, + p.dim_y, scale_h_w_c[0]); + if (input_h_w_c.size() == 3) { + p.dim_z.window_size = compute_weight_coefficients(p, input_h_w_c[2], output_h_w_c[2], channel_rindex, + p.dim_z, scale_h_w_c[2]); + } +} + +template +inline constexpr bool is_8bit_v = std::is_same::value || std::is_same::value; + +/** + * @brief To compute interpolation along with the last axis. + * For brief,we assume the input tensor has 3 dimensions and we all it CHW for each character represent a dim. + * But it doesn't mean the input tensor has semantic meaning of CHW in traditional. + * we can treat a tensor with rank 4 NCHW as (NC)HW or CHW with a for loop in N dimension. + * @param num_channels The number of C in CHW. + * @param input_height The number of H in CHW. + * @param input_width The number of W in CHW. + * @param output_height The number of H in CHW. + * @param output_width The number of W in CHW. + * @param Xdata_span The input tensor data. + * @param Ydata_span The output tensor data. + * @param p The filter params. + * @param p_dim The filter params for each dim. + * @param tp The thread pool. + * + */ +template +void ComputeInterpolationAtLevel1(int64_t num_channels, int64_t input_height, int64_t input_width, + int64_t output_height, int64_t output_width, + gsl::span Xdata_span, gsl::span Ydata_span, + const FilterParamsAntiAlias& p, + const FilterParamsBaseAntiAlias& p_dim, + concurrency::ThreadPool* tp) { + const uint8_t* clip8_lookups = &p.GetClip8LookupTable()[640]; + + concurrency::ThreadPool::TrySimpleParallelFor( + tp, narrow(num_channels), + [&](std::ptrdiff_t c) { + auto x_start = c * (input_height * input_width); + auto y_start = c * (output_height * output_width); + + const InputType* Xdata = Xdata_span.data() + x_start; + InputType* Ydata = Ydata_span.data() + y_start; + // no need to do scale + if (output_width == input_width) { + std::copy_n(Xdata_span.begin() + narrow(x_start), narrow(output_height * output_width), + Ydata_span.begin() + narrow(y_start)); + return; + } + + for (size_t y = 0; y < narrow(output_height); ++y) { + auto* Ydata_offset = Ydata + output_width * y; + auto* bound = p_dim.bound.data(); + for (size_t x = 0; x < narrow(output_width); ++x) { + AccumulateType output = is_8bit_v ? ConstValue::mag_factor : 0; + + const auto* weight_coeff = p_dim.weight_coefficients.get() + p_dim.window_size * x; + int64_t xmin = *bound++; + int64_t xmax = *bound++; + const auto* Xdata_offset = Xdata + y * input_width + xmin; + for (; xmin < xmax; ++xmin) { + output += (*Xdata_offset++) * (*weight_coeff++); + } + + if constexpr (is_8bit_v) { + *Ydata_offset++ = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset++ = narrow(std::round(output)); + } else { + *Ydata_offset++ = output; + } + } + } + }); +} + +/** + * @brief To calculate interpolation along with penultimate axis. + * For brief, we assume the input tensor has 3 dimensions and we all it CHW for each character represent a dim. + * But it doesn't mean the input tensor has semantic meaning of CHW in traditional. + * we can transform a tensor in formats like NCHW,NHWC,NcHWD,CHW,HWC..etc to a rank-3 tensor, + * then this function can be applied. + * @param num_channels The number of C in CHW. + * @param input_height The number of H in CHW. + * @param input_width The number of W in CHW. + * @param output_height The number of H in CHW. + * @param output_width The number of W in CHW. + * @param Xdata_span The input tensor data. + * @param Ydata_span The output tensor data. + * @param p The filter params. + * @param p_dim The filter params for each dim. + * @param tp The thread pool. + */ +template +void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, int64_t input_width, + int64_t output_height, int64_t output_width, + gsl::span Xdata_span, gsl::span Ydata_span, + const FilterParamsAntiAlias& p, + const FilterParamsBaseAntiAlias& p_dim, + concurrency::ThreadPool* tp) { + const uint8_t* clip8_lookups = &p.GetClip8LookupTable()[640]; + // This condition is set for higher performance. + // Observed that TrySimpleParallelFor in dim num_channels is always have higher efficiency, so I would rather + // choose the first path as long as num_channels is 3 or bigger. + if (num_channels > 2 && num_channels >= tp->DegreeOfParallelism(tp)) { + concurrency::ThreadPool::TrySimpleParallelFor( + tp, narrow(num_channels), + [&](std::ptrdiff_t c) { + auto x_start = c * (input_height * input_width); + auto y_start = c * (output_height * output_width); + + const InputType* Xdata = Xdata_span.data() + x_start; + InputType* Ydata = Ydata_span.data() + y_start; + + if (output_height == input_height) { + std::copy_n(Xdata_span.begin() + narrow(x_start), narrow(output_height * output_width), + Ydata_span.begin() + narrow(y_start)); + return; + } + + const auto* y_bound = p_dim.bound.data(); + for (size_t y = 0; y < narrow(output_height); ++y) { + const auto* weight_coeff = p_dim.weight_coefficients.get() + p_dim.window_size * y; + int64_t ymin = *y_bound++; + int64_t ymax = *y_bound++; + auto* Ydata_offset = Ydata + output_width * y; + for (size_t x = 0; x < narrow(output_width); ++x) { + AccumulateType output = is_8bit_v ? ConstValue::mag_factor : 0; + auto* weight_coeff_start = weight_coeff; + + const auto* Xdata_offset = Xdata + ymin * output_width + x; + for (auto idx = ymin; idx < ymax; ++idx) { + output += *Xdata_offset * (*weight_coeff_start++); + Xdata_offset += output_width; + } + if constexpr (is_8bit_v) { + *Ydata_offset++ = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset++ = narrow(std::round(output)); + } else { // float double + *Ydata_offset++ = output; + } + } + } + }); + } else { + concurrency::ThreadPool::TryParallelFor( + tp, static_cast(output_height * num_channels), + static_cast(output_height * 2), + [&](std::ptrdiff_t first, std::ptrdiff_t last) { + if (output_height == input_height) { + std::copy_n(Xdata_span.begin() + narrow(first * input_width), narrow((last - first) * output_width), + Ydata_span.begin() + narrow(first * output_width)); + return; + } + + for (auto start = first; start != last; start++) { + auto c = start / output_height; + auto y = start % output_height; + + auto x_start = c * (input_height * input_width); + auto y_start = c * (output_height * output_width); + + const InputType* Xdata = Xdata_span.data() + x_start; + InputType* Ydata = Ydata_span.data() + y_start; + + const auto* y_bound = p_dim.bound.data(); + const auto* weight_coeff = p_dim.weight_coefficients.get() + p_dim.window_size * y; + int64_t ymin = y_bound[2 * narrow(y)]; + int64_t ymax = y_bound[2 * narrow(y) + 1]; + auto* Ydata_offset = Ydata + output_width * y; + for (size_t x = 0; x < narrow(output_width); ++x) { + AccumulateType output = is_8bit_v ? ConstValue::mag_factor : 0; + auto* weight_coeff_start = weight_coeff; + + const auto* Xdata_offset = Xdata + ymin * output_width + x; + for (auto idx = ymin; idx < ymax; ++idx) { + output += *Xdata_offset * (*weight_coeff_start++); + Xdata_offset += output_width; + } + if constexpr (is_8bit_v) { + *Ydata_offset++ = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset++ = narrow(std::round(output)); + } else { // float double + *Ydata_offset++ = output; + } + } + } + }); + } +} + +template +void HandleExtrapolation(int64_t num_channels, + int64_t output_height, int64_t output_width, int64_t output_depth, + const float extrapolation_value, gsl::span Ydata_span, + const FilterParamsAntiAlias& p, + concurrency::ThreadPool* tp) { + concurrency::ThreadPool::TrySimpleParallelFor( + tp, static_cast(num_channels), + [&](std::ptrdiff_t nc) { + InputType* Ydata_base_nc = Ydata_span.data() + (nc) * (output_depth * output_height * output_width); + + for (int64_t z = 0; z < output_depth && p.dim_x.out_of_bound_idx.size() > 0; ++z) { + for (int64_t y = 0; y < output_height; ++y) { + InputType* Ydata_offset = Ydata_base_nc + (z * output_height + y) * output_width; + for (int64_t idx_x : p.dim_x.out_of_bound_idx) { + Ydata_offset[narrow(idx_x)] = static_cast(extrapolation_value); + } + } + } + + for (int64_t z = 0; z < output_depth && p.dim_y.out_of_bound_idx.size() > 0; ++z) { + for (int64_t y : p.dim_y.out_of_bound_idx) { + InputType* Ydata_offset = Ydata_base_nc + (z * output_height + y) * output_width; + std::fill_n(Ydata_offset, narrow(output_width), static_cast(extrapolation_value)); + } + } + + for (int64_t z : p.dim_z.out_of_bound_idx) { + InputType* Ydata_offset = Ydata_base_nc + z * output_height * output_width; + std::fill_n(Ydata_offset, narrow(output_height * output_width), static_cast(extrapolation_value)); + } + }); +} + +template +void UpsampleBaseAntiAlias(FilterParamsAntiAlias& p, + const int64_t batch_size, + const int64_t num_channels, + const int64_t input_height, + const int64_t input_width, + const int64_t output_height, + const int64_t output_width, + const bool use_extrapolation, + const float extrapolation_value, + const T* Xdata_base, + T* Ydata_base, + AllocatorPtr& alloc, + concurrency::ThreadPool* tp) { + IAllocatorUniquePtr image_temp_buffer = IAllocator::MakeUniquePtr( + alloc, static_cast(input_height * output_width * num_channels)); + + for (int64_t n = 0; n < batch_size; ++n) { + { + // horizon interpolate + auto xdata_span = gsl::make_span(Xdata_base + n * (input_height * num_channels * input_width), + narrow(input_height * num_channels * input_width)); + auto ydata_span = gsl::make_span(image_temp_buffer.get(), narrow(input_height * num_channels * output_width)); + + ComputeInterpolationAtLevel1(num_channels, input_height, input_width, input_height, output_width, + xdata_span, ydata_span, p, p.dim_x, tp); + } + { + // vertical interpolate + auto xdata_span = gsl::make_span(image_temp_buffer.get(), + narrow(input_height * num_channels * output_width)); + auto ydata_span = gsl::make_span(Ydata_base + n * (output_height * num_channels * output_width), + narrow(output_height * num_channels * output_width)); + + ComputeInterpolationAtLevel2(num_channels, input_height, output_width, output_height, output_width, + xdata_span, ydata_span, p, p.dim_y, tp); + } + } + if (use_extrapolation) { + auto ydata_span = gsl::make_span(Ydata_base, + narrow(batch_size * output_height * num_channels * output_width)); + HandleExtrapolation(batch_size * num_channels, output_height, output_width, 1, + extrapolation_value, ydata_span, p, tp); + } +} + +template +void UpsampleBilinearAntiAlias(const int64_t batch_size, + const int64_t num_channels, + const int64_t input_height, + const int64_t input_width, + const int64_t output_height, + const int64_t output_width, + const float height_scale, + const float width_scale, + const std::vector& roi, + const bool use_extrapolation, + const float extrapolation_value, + bool exclude_outside, + const Tensor* X, + T* Ydata_base, + AllocatorPtr& alloc, + const GetOriginalCoordinateFunc& get_original_coordinate, + concurrency::ThreadPool* tp) { + int64_t input_paras[] = {input_height, input_width}; + int64_t output_paras[] = {output_height, output_width}; + float scale_paras[] = {height_scale, width_scale}; + BilinearParamsAntiAlias::type> p; + SetupUpsampleFilterAntiAlias(p, input_paras, output_paras, scale_paras, roi, + alloc, get_original_coordinate, exclude_outside, true); + return UpsampleBaseAntiAlias(p, batch_size, num_channels, input_height, input_width, output_height, output_width, + use_extrapolation, extrapolation_value, + X->Data(), Ydata_base, alloc, tp); +} + +template +void NhwcUpsampleBilinearAntiAlias(const int64_t batch_size, + const int64_t num_channels, + const int64_t input_height, + const int64_t input_width, + const int64_t output_height, + const int64_t output_width, + const float height_scale, + const float width_scale, + const std::vector& roi, + const bool use_extrapolation, + const float extrapolation_value, + bool exclude_outside, + const Tensor* X, + T* Ydata_base, + AllocatorPtr& alloc, + const GetOriginalCoordinateFunc& get_original_coordinate, + concurrency::ThreadPool* tp) { + int64_t input_paras[] = {input_height, input_width}; + int64_t output_paras[] = {output_height, output_width}; + float scale_paras[] = {height_scale, width_scale}; + BilinearParamsAntiAlias::type> p; + SetupUpsampleFilterAntiAlias(p, input_paras, output_paras, scale_paras, roi, + alloc, get_original_coordinate, exclude_outside, false); + return NhwcUpsampleBasicAntiAlias(p, batch_size, num_channels, input_height, input_width, output_height, output_width, + use_extrapolation, extrapolation_value, + X->Data(), Ydata_base, alloc, tp); +} + +template +void NhwcResizeBiCubicAntiAlias(const int64_t batch_size, + const int64_t num_channels, + const int64_t input_height, + const int64_t input_width, + const int64_t output_height, + const int64_t output_width, + const float height_scale, + const float width_scale, + float cubic_coeff_a, + bool use_extrapolation, + float extrapolation_value, + bool exclude_outside, + const std::vector& roi, + const Tensor* X, + T* Ydata_base, + AllocatorPtr& alloc, + const GetOriginalCoordinateFunc& get_original_coordinate, + concurrency::ThreadPool* tp) { + int64_t input_paras[] = {input_height, input_width}; + int64_t output_paras[] = {output_height, output_width}; + float scale_paras[] = {height_scale, width_scale}; + BiCubicParamsAntiAlias::type> p; + p.cubic_coeff_a = cubic_coeff_a; + SetupUpsampleFilterAntiAlias(p, input_paras, output_paras, scale_paras, roi, + alloc, get_original_coordinate, exclude_outside, false); + return NhwcUpsampleBasicAntiAlias(p, batch_size, num_channels, input_height, input_width, output_height, output_width, + use_extrapolation, extrapolation_value, + X->Data(), Ydata_base, alloc, tp); +} + +template +void NhwcUpsampleBasicAntiAlias(FilterParamsAntiAlias& p, + const int64_t batch_size, + const int64_t num_channels, + const int64_t input_height, + const int64_t input_width, + const int64_t output_height, + const int64_t output_width, + const bool use_extrapolation, + const float extrapolation_value, + const T* Xdata_base, + T* Ydata_base, + AllocatorPtr& alloc, + concurrency::ThreadPool* tp) { + IAllocatorUniquePtr image_temp_buffer = IAllocator::MakeUniquePtr( + alloc, static_cast(input_height * output_width * num_channels)); + + for (int64_t n = 0; n < batch_size; ++n) { + // horizon interpolate + { + auto xdata_span = gsl::make_span(Xdata_base + n * (input_height * num_channels * input_width), + narrow(input_height * num_channels * input_width)); + auto ydata_span = gsl::make_span(image_temp_buffer.get(), narrow(input_height * num_channels * output_width)); + + ComputeInterpolationAtLevel2(input_height, input_width, num_channels, output_width, num_channels, + xdata_span, ydata_span, p, p.dim_x, tp); + } + + // vertical interpolate + { + // vertical interpolate + auto xdata_span = gsl::make_span(image_temp_buffer.get(), + narrow(input_height * num_channels * output_width)); + auto ydata_span = gsl::make_span(Ydata_base + n * (output_height * num_channels * output_width), + narrow(output_height * num_channels * output_width)); + + ComputeInterpolationAtLevel2(1, input_height, output_width * num_channels, output_height, output_width * num_channels, + xdata_span, ydata_span, p, p.dim_y, tp); + } + } + + if (use_extrapolation) { + auto ydata_span = gsl::make_span(Ydata_base, + narrow(batch_size * output_height * num_channels * output_width)); + HandleExtrapolation(batch_size * num_channels, output_height, output_width, 1, + extrapolation_value, ydata_span, p, tp); + } +} + +template +void ResizeBiCubicAntiAlias(int64_t batch_size, + int64_t num_channels, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + float height_scale, + float width_scale, + float cubic_coeff_a, + bool use_extrapolation, + float extrapolation_value, + bool exclude_outside, + const std::vector& roi, + const Tensor* X, + T* Ydata_base, + AllocatorPtr& alloc, + const GetOriginalCoordinateFunc& get_original_coordinate, + concurrency::ThreadPool* tp) { + int64_t input_paras[] = {input_height, input_width}; + int64_t output_paras[] = {output_height, output_width}; + float scale_paras[] = {height_scale, width_scale}; + BiCubicParamsAntiAlias::type> p; + p.cubic_coeff_a = cubic_coeff_a; + SetupUpsampleFilterAntiAlias(p, input_paras, output_paras, scale_paras, roi, + alloc, get_original_coordinate, exclude_outside, false); + + return UpsampleBaseAntiAlias(p, batch_size, num_channels, input_height, input_width, output_height, output_width, + use_extrapolation, extrapolation_value, + X->Data(), Ydata_base, alloc, tp); +} + +template +void UpsampleTrilinearAntiAlias(int64_t batch_size, + int64_t num_channels, + int64_t input_depth, + int64_t input_height, + int64_t input_width, + int64_t output_depth, + int64_t output_height, + int64_t output_width, + float depth_scale, + float height_scale, + float width_scale, + const std::vector& roi, + bool use_extrapolation, + float extrapolation_value, + bool exclude_outside, + const Tensor* X, + T* Ydata_base, + AllocatorPtr& alloc, + const GetOriginalCoordinateFunc& get_original_coordinate, + concurrency::ThreadPool* tp) { + int64_t input_paras[] = {input_height, input_width, input_depth}; + int64_t output_paras[] = {output_height, output_width, output_depth}; + float scale_paras[] = {height_scale, width_scale, depth_scale}; + + TriLinearParamsAntiAlias::type> p; + SetupUpsampleFilterAntiAlias(p, input_paras, output_paras, scale_paras, roi, + alloc, get_original_coordinate, exclude_outside, true); + + IAllocatorUniquePtr image_temp_buffer = IAllocator::MakeUniquePtr( + alloc, static_cast(batch_size * output_height * output_width * + input_depth * num_channels)); + + UpsampleBaseAntiAlias(p, batch_size, num_channels * input_depth, input_height, input_width, output_height, output_width, + false, extrapolation_value, + X->Data(), image_temp_buffer.get(), alloc, tp); + + auto m_batch_size = batch_size * num_channels < tp->DegreeOfParallelism(tp) ? 1 : batch_size; + auto m_channel_size = batch_size * num_channels < tp->DegreeOfParallelism(tp) ? num_channels * batch_size : num_channels; + for (int64_t n = 0; n < m_batch_size; ++n) { + // depth interpolate + { + // depth interpolate + auto xdata_span = gsl::make_span(image_temp_buffer.get() + n * (output_height * num_channels * output_width * input_depth), + narrow(output_height * num_channels * output_width * input_depth)); + auto ydata_span = gsl::make_span(Ydata_base + n * (output_height * num_channels * output_width * output_depth), + narrow(output_height * num_channels * output_width * output_depth)); + + ComputeInterpolationAtLevel2(m_channel_size, input_depth, output_height * output_width, output_depth, output_height * output_width, + xdata_span, ydata_span, p, p.dim_z, tp); + } + } + + if (use_extrapolation) { + auto ydata_span = gsl::make_span(Ydata_base, + narrow(batch_size * output_height * num_channels * output_width * output_depth)); + HandleExtrapolation(batch_size * num_channels, output_height, output_width, output_depth, + extrapolation_value, ydata_span, p, tp); + } +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index e13bff6fac..72948ae263 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -4,7 +4,11 @@ #pragma once #include +#include +#include #include +#include +#include "core/common/status.h" #include #include #ifndef SHARED_PROVIDER @@ -49,6 +53,12 @@ enum ResizeNearestMode { NearestModeCount = 5, }; +enum class AspectRatioPolicy { + STRETCH, + NOT_LARGER, + NOT_SMALLER, +}; + class UpsampleBase { protected: explicit UpsampleBase(const OpKernelInfo& info) @@ -60,14 +70,24 @@ class UpsampleBase { std::string mode; ORT_ENFORCE(info.GetAttr("mode", &mode).IsOK()); mode_ = StringToUpsampleMode(mode); + antialias_ = info.GetAttrOrDefault("antialias", 0) == 0 ? false : true; + if (antialias_) { + ORT_ENFORCE((UpsampleMode::LINEAR == mode_ || UpsampleMode::CUBIC == mode_), + "when anti-aliasing is set, Resize only supports mode `LINEAR` and `CUBIC`."); + } auto input_count = info.GetInputCount(); if (input_count == 1) { // opset < 10 ORT_ENFORCE(info.GetAttrs("scales", scales_).IsOK()); - ScalesValidation(scales_, mode_); + ORT_THROW_IF_ERROR(ScalesValidation(scales_, mode_)); scales_cached_ = true; } + std::string keep_aspect_ratio_policy = info.GetAttrOrDefault("keep_aspect_ratio_policy", "stretch"); + keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy); + + axes_ = info.GetAttrsOrDefault("axes"); + extrapolation_value_ = info.GetAttrOrDefault("extrapolation_value", 0.0f); // Coordinate transformation mode attr was introduced in version 11. @@ -96,8 +116,12 @@ class UpsampleBase { cubic_coeff_a_ = info.GetAttrOrDefault("cubic_coeff_a", -0.75f); exclude_outside_ = info.GetAttrOrDefault("exclude_outside", 0) == 0 ? false : true; - if (exclude_outside_ == 1 && mode_ != CUBIC) { - ORT_THROW("exclude_outside can be set to 1 only when mode is CUBIC. Current mode is set to " + mode); + if ((exclude_outside_ == 1 && mode_ != CUBIC) && (antialias_ == false || mode_ != LINEAR)) { + ORT_THROW( + "exclude_outside can be set to 1 when (1 mode is CUBIC. " + "\n(2 mode is CUBIC or LINEAR when anti-aliasing is on" + ". Current mode is set to " + + mode + " and anti-aliasing is set to " + std::to_string(antialias_)); } // see if we can potentially use the nearest2x optimization. scales are checked at runtime to be {1,1,2,2} @@ -118,9 +142,10 @@ class UpsampleBase { if (scales_input_idx_ > 0) { const Tensor* scale; bool get_scale = info.TryGetConstantInput(scales_input_idx_, &scale); - - if (get_scale && scale->Shape().Size() > 0) { - ParseScalesData(scale, scales_); + auto x_shape = node.InputDefs()[0]->Shape(); + int64_t rank = x_shape ? x_shape->dim_size() : -1; + if (get_scale && scale->Shape().Size() > 0 && ((opset < 18) || (rank > 0 && opset >= 18))) { + ORT_THROW_IF_ERROR(ParseScalesData(scale, scales_, rank)); scales_cached_ = true; } } @@ -142,14 +167,18 @@ class UpsampleBase { ResizeCoordinateTransformationMode coordinate_transform_mode_; GetOriginalCoordinateFunc get_original_coordinate_; ResizeNearestMode nearest_mode_; + AspectRatioPolicy keep_aspect_ratio_policy_; GetNearestPixelFunc get_nearest_pixel_; float cubic_coeff_a_; bool exclude_outside_; + bool antialias_{false}; float extrapolation_value_; bool use_nearest2x_optimization_ = false; std::vector scales_; std::vector roi_; + std::vector axes_; + bool scales_cached_; bool roi_cached_; bool need_roi_input_; @@ -174,6 +203,20 @@ class UpsampleBase { UpsampleModeNN + "(default) or " + UpsampleModeLinear + " or " + UpsampleModeCubic + "."); } + AspectRatioPolicy StringToKeepAspectRatioPolicy(const std::string& policy) { + const static std::unordered_map policy_map{ + {"stretch", AspectRatioPolicy::STRETCH}, + {"not_larger", AspectRatioPolicy::NOT_LARGER}, + {"not_smaller", AspectRatioPolicy::NOT_SMALLER}, + }; + + if (auto it = policy_map.find(policy); it != policy_map.end()) { + return it->second; + } else { + ORT_THROW("keep_aspect_ratio of [" + policy + "] is not supported!"); + } + } + ResizeCoordinateTransformationMode StringToCoordinateTransformationMode( const std::string& coordinate_transform_mode_name) { if (coordinate_transform_mode_name == "asymmetric") { @@ -281,49 +324,68 @@ class UpsampleBase { } } - void ScalesValidation(const std::vector& scales, const UpsampleMode mode) const { + [[nodiscard]] Status ScalesValidation(const std::vector& scales, const UpsampleMode mode) const { if (!is_resize_) { for (auto& scale : scales) { - ORT_ENFORCE(scale >= 1, "Scale value should be greater than or equal to 1."); + ORT_RETURN_IF_NOT(scale >= 1, "Scale value should be greater than or equal to 1."); } } else { for (auto& scale : scales) { - ORT_ENFORCE(scale > 0, "Scale value should be greater than 0."); + ORT_RETURN_IF_NOT(scale > 0, "Scale value should be greater than 0."); } } if (UpsampleMode::LINEAR == mode) { - ORT_ENFORCE(scales.size() == 2 || - (scales.size() == 4 && scales[0] == 1 && scales[1] == 1) || - (scales.size() == 4 && scales[0] == 1 && scales[3] == 1) || - scales.size() == 3 || - (scales.size() == 5 && scales[0] == 1 && scales[1] == 1), - "'Linear' mode only support:\n" - " * 2-D inputs or\n" - " * 3-D inputs ('Bilinear', 'Trilinear') or\n" - " * 4-D inputs with the corresponding outermost 2 scale values being 1" - " or the corresponding outermost and innermost scale values being 1 or\n" - " * 5-D inputs with the corresponding outermost 2 scale values being 1" - "in the ", - is_resize_ ? "Resize operator" : "Upsample operator"); + ORT_RETURN_IF_NOT(scales.size() == 2 || + (scales.size() == 4 && scales[0] == 1 && scales[1] == 1) || + (scales.size() == 4 && scales[0] == 1 && scales[3] == 1) || + scales.size() == 3 || + (scales.size() == 5 && scales[0] == 1 && scales[1] == 1), + "'Linear' mode only support:\n" + " * 2-D inputs or\n" + " * 3-D inputs ('Bilinear', 'Trilinear') or\n" + " * 4-D inputs with the corresponding outermost 2 scale values being 1" + " or the corresponding outermost and innermost scale values being 1 or\n" + " * 5-D inputs with the corresponding outermost 2 scale values being 1" + "in the ", + is_resize_ ? "Resize operator" : "Upsample operator"); } else if (UpsampleMode::CUBIC == mode) { - ORT_ENFORCE(scales.size() == 2 || (scales.size() == 4 && scales[0] == 1 && scales[1] == 1), - "'Cubic' mode only support 2-D inputs ('Bicubic') or 4-D inputs " - "with the corresponding outermost 2 scale values being 1 in the ", - is_resize_ ? "Resize operator" : "Upsample operator"); + // we support cubic in NHWC format once anti-alias is enabled + ORT_RETURN_IF_NOT(scales.size() == 2 || (scales.size() == 4 && scales[0] == 1 && scales[1] == 1) || + (antialias_ && scales.size() == 4 && scales[0] == 1 && scales[3] == 1), + "'Cubic' mode only support 2-D inputs ('Bicubic') or 4-D inputs " + "with the corresponding outermost 2 scale values being 1 in the ", + is_resize_ ? "Resize operator" : "Upsample operator"); } + return Status::OK(); } - void - ParseScalesData(const Tensor* scale, std::vector& scales) const { + [[nodiscard]] Status + ParseScalesData(const Tensor* scale, std::vector& scales, int64_t rank) const { const auto* scale_data = scale->Data(); int64_t scales_size = scale->Shape().Size(); - ORT_ENFORCE(scales_size > 0, "scales size should be greater than 0."); + ORT_RETURN_IF_NOT(scales_size > 0, "scales size should be greater than 0."); if (scales.empty()) { scales.resize(onnxruntime::narrow(scales_size)); } + memcpy(scales.data(), scale_data, SafeInt(scales_size) * sizeof(float)); - ScalesValidation(scales, mode_); + + // since opset 18, + // we allow scales only specified on axes of interest, + // in which case the other axes is ignored and use default scale of 1 + // scales_size == axes_.size() should be guaranteed if axes is not empty + if (rank > 0 && (scales_size != rank || axes_.size())) { + std::vector new_scales(size_t(rank), 1.0f); + ORT_RETURN_IF_NOT(*std::max_element(axes_.begin(), axes_.end()) < rank && (int64_t(axes_.size()) == scales_size), + "all values in axes should be less than rank of the data"); + + for (size_t i = 0; i < axes_.size(); i++) { + new_scales[static_cast(axes_[i])] = scales[i]; + } + scales = new_scales; + } + return ScalesValidation(scales, mode_); } void ParseRoiData(const Tensor* roi, std::vector& roi_array) const { @@ -334,18 +396,84 @@ class UpsampleBase { } } - void ParseScalesDataFromOutputSize(gsl::span output_dims, - gsl::span input_dims, - std::vector& scales) const { + // output_shape is changeable in opset-18 or above. + // It should be re-computed if axes is not empty. + [[nodiscard]] Status ParseSizesData(const Tensor* sizes, TensorShapeVector& output_dims, + gsl::span input_dims) const { + auto size_span = sizes->DataAsSpan(); + ORT_RETURN_IF_NOT(input_dims.size() >= size_span.size(), + "Resize: input tensor's rank does not match the output tensor's rank."); + + if (axes_.size()) { + output_dims.assign(input_dims.begin(), input_dims.end()); + ORT_RETURN_IF_NOT(*std::max_element(axes_.begin(), axes_.end()) < int64_t(output_dims.size()), + "axes should be less than output_dims.size()"); + + for (size_t i = 0; i < axes_.size(); i++) { + output_dims[static_cast(axes_[i])] = size_span[i]; + } + } else { + std::copy(size_span.begin(), size_span.end(), output_dims.begin()); + } + return Status::OK(); + } + + // it works iff output_shape is specified + void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, + std::vector& scales) const { + std::unordered_set axes_set(axes_.begin(), axes_.end()); + + // AspectRatioPolicy::STRETCH is default policy when opset < 18 + if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::STRETCH) { + return; + } + + float scale_in_policy = 0.0f; + if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) { + scale_in_policy = std::numeric_limits::max(); + + for (size_t i = 0; i < scales.size(); i++) { + if (axes_set.empty() || axes_set.count(i) > 0) { + scale_in_policy = std::min(scale_in_policy, scales[i]); + } + } + } else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) { + scale_in_policy = std::numeric_limits::min(); + + for (size_t i = 0; i < scales.size(); i++) { + if (axes_set.empty() || axes_set.count(i) > 0) { + scale_in_policy = std::max(scale_in_policy, scales[i]); + } + } + } + + for (size_t i = 0; i < scales.size(); i++) { + // if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes + if (axes_set.empty() || axes_set.count(i) > 0) { + scales[i] = scale_in_policy; + output_dims[i] = static_cast(std::round(scales[i] * input_dims[i])); + } else { + scales[i] = 1.0f; + output_dims[i] = input_dims[i]; + } + } + } + + // It's different in Opset 18 and before. + // we will modify output_shape by sorts of policy even if it's specified + [[nodiscard]] Status ParseScalesDataAndAdjustOutputSize(TensorShapeVector& output_dims, + gsl::span input_dims, + std::vector& scales) const { for (size_t i = 0, end = input_dims.size(); i < end; ++i) { // Handle corner case to avoid dividing by zero in the next step if (input_dims[i] == 0) { // Enforce that output_dim is 0, given that we cannot scale 0 by any factor to // result in any non-zero value - ORT_ENFORCE(output_dims[i] == 0, - "Input dim is zero but required output dim is non-zero. ", - "Cannot scale 0 by any factor to generate a non-zero value. ", - "Dimension: ", i, " Input dim value: ", input_dims[i], " Output dim value: ", output_dims[i]); + ORT_RETURN_IF_NOT(output_dims[i] == 0, + "Input dim is zero but required output dim is non-zero. ", + "Cannot scale 0 by any factor to generate a non-zero value. ", + "Dimension: ", i, " Input dim value: ", input_dims[i], " Output dim value: ", output_dims[i]); + // Scale can be any arbitrary value as technically scaling 0 by any factor // results in 0. Keeping scale as 1 is more intuitive given that input_dim == output_dim. scales[i] = 1.f; @@ -353,16 +481,35 @@ class UpsampleBase { scales[i] = static_cast(output_dims[i]) / static_cast(input_dims[i]); } } - ScalesValidation(scales, mode_); + + AdjustOutputSizeAsPolicy(output_dims, input_dims, scales); + return ScalesValidation(scales, mode_); } - void ComputeOutputShape(const std::vector& scales, + void ComputeOutputShape(gsl::span scales, gsl::span input_dims, TensorShapeVector& output_dims) const { for (std::size_t i = 0; i < input_dims.size(); i++) { output_dims[i] = static_cast(scales[i] * input_dims[i]); } } + + // Roi is redefined in Opset-18, we have a concept of axes. + // So we need to update it accordingly. + void ComputeROIWithAxes(std::vector& roi_array, size_t rank) const { + if (axes_.size()) { + std::vector roi_tmp(rank * 2, 0); + for (size_t i = rank; i < rank * 2; ++i) { + roi_tmp[i] = 1; + } + for (size_t i = 0; i < axes_.size(); i++) { + auto v_in_axes = static_cast(axes_[i]); + roi_tmp[v_in_axes] = (roi_array[i]); + roi_tmp[rank + v_in_axes] = (roi_array[axes_.size() + i]); + } + roi_array = roi_tmp; + } + } }; // UpsampleBase } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index c5b1ed489f..ae12ca328b 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -53,7 +53,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, if (rank != static_cast(scales.size())) return Status(ONNXRUNTIME, INVALID_ARGUMENT, is_resize_ ? "Resize: input tensor's dimension does not match the scales." : "Upsample: input tensor's dimension does not match the scales."); - if (roi.size() != 2 * X->Shape().GetDims().size()) + if (roi.size() != 2 * X_dims.size()) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Resize: size of roi array should be 2 * N where N is the rank of input tensor X."); @@ -121,9 +121,10 @@ template Status Upsample::ComputeInternal(OpKernelContext* context) const { const Tensor* X = context->Input(0); ORT_ENFORCE(X != nullptr); + auto input_dims = X->Shape().GetDims(); - TensorShapeVector output_dims(X->Shape().GetDims().size()); - std::vector roi_array(X->Shape().GetDims().size() * 2, 0.0f); + TensorShapeVector output_dims(input_dims.size()); + std::vector roi_array(input_dims.size() * 2, 0.0f); if (!roi_cached_) { bool use_default_roi = true; if (need_roi_input_) { @@ -137,7 +138,6 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { if (use_default_roi) { // default roi includes ensures all the values in that axis are included in the roi // normalized roi is thus : [start, end] = [0, 1] - const auto input_dims = X->Shape().GetDims(); size_t input_rank = input_dims.size(); roi_array.resize(input_rank * 2); for (size_t i = 0; i < input_rank; ++i) { @@ -148,10 +148,11 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { } const std::vector& roi = roi_cached_ ? roi_ : roi_array; + std::vector scales_array = scales_; if (OpKernel::Node().InputDefs().size() == 1) { // Compute output shape from scales and input dims - ComputeOutputShape(scales_, X->Shape().GetDims(), output_dims); + ComputeOutputShape(scales_array, input_dims, output_dims); return BaseCompute(context, roi, scales_, output_dims); } @@ -160,24 +161,22 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { if (scales_cached_) { ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input."); - ComputeOutputShape(scales_, X->Shape().GetDims(), output_dims); + ComputeOutputShape(scales_array, input_dims, output_dims); return BaseCompute(context, roi, scales_, output_dims); } - std::vector scales_array(X->Shape().GetDims().size()); + scales_array.resize((input_dims.size())); if (scales != nullptr && scales->Shape().Size() != 0) { // use scales input data ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input."); - ParseScalesData(scales, scales_array); - ComputeOutputShape(scales_array, X->Shape().GetDims(), output_dims); + ORT_RETURN_IF_ERROR(ParseScalesData(scales, scales_array, input_dims.size())); + ComputeOutputShape(scales_array, input_dims, output_dims); } else { // When sizes input is available directly populate it into the output_dims array. ORT_ENFORCE(sizes != nullptr && sizes->Shape().Size() != 0, "Either scales or sizes MUST be provided as input."); - ORT_ENFORCE(sizes->Shape().Size() == static_cast(output_dims.size()), - "Resize: input tensor's rank does not match the output tensor's rank."); - memcpy(output_dims.data(), sizes->Data(), sizes->Shape().Size() * sizeof(int64_t)); - ParseScalesDataFromOutputSize(output_dims, X->Shape().GetDims(), scales_array); + ORT_RETURN_IF_ERROR(ParseSizesData(sizes, output_dims, input_dims)); + ORT_RETURN_IF_ERROR(ParseScalesDataAndAdjustOutputSize(output_dims, input_dims, scales_array)); } return BaseCompute(context, roi, scales_array, output_dims); diff --git a/onnxruntime/core/providers/xnnpack/nn/resize.cc b/onnxruntime/core/providers/xnnpack/nn/resize.cc index 930967da8d..672b259727 100644 --- a/onnxruntime/core/providers/xnnpack/nn/resize.cc +++ b/onnxruntime/core/providers/xnnpack/nn/resize.cc @@ -81,8 +81,7 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit, break; } - std::string keep_aspect_ratio_policy = "stretch"; - info.GetAttrOrDefault("keep_aspect_ratio_policy", &mode, "stretch"); + std::string keep_aspect_ratio_policy = info.GetAttrOrDefault("keep_aspect_ratio_policy", "stretch"); if (keep_aspect_ratio_policy != "stretch") { break; } @@ -203,9 +202,8 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf output_dims_.resize(input_dims); if (sizes && sizes->Shape().Size() == 4) { scales_.resize(input_shape.NumDimensions()); - auto size_span = sizes->DataAsSpan(); - ParseScalesDataFromOutputSize(size_span, input_shape.GetDims(), scales_); - std::copy(size_span.begin(), size_span.end(), output_dims_.begin()); + ORT_THROW_IF_ERROR(ParseSizesData(sizes, output_dims_, input_shape.GetDims())); + ORT_THROW_IF_ERROR(ParseScalesDataAndAdjustOutputSize(output_dims_, input_shape.GetDims(), scales_)); scales_cached_ = true; } else { ComputeOutputShape(scales_, input_shape.GetDims(), output_dims_); @@ -303,14 +301,14 @@ Status Resize::Compute(OpKernelContext* ctx) const { std::vector scales_array(X->Shape().GetDims().size()); if (scales != nullptr && scales->Shape().Size() != 0) { - ParseScalesData(scales, scales_array); + ORT_RETURN_IF_ERROR(ParseScalesData(scales, scales_array, output_shape.size())); // Compute output shape from scales and input dims ComputeOutputShape(scales_array, X->Shape().GetDims(), output_shape); } else { const Tensor* sizes = ctx->Input(sizes_input_idx_); // When sizes input is available directly populate it into the output_dims array. - memcpy(output_shape.data(), sizes->Data(), sizes->SizeInBytes()); - ParseScalesDataFromOutputSize(output_shape, X->Shape().GetDims(), scales_array); + ORT_RETURN_IF_ERROR(ParseSizesData(sizes, output_shape, X->Shape().GetDims())); + ORT_RETURN_IF_ERROR(ParseScalesDataAndAdjustOutputSize(output_shape, X->Shape().GetDims(), scales_array)); } } output_shape[0] = X->Shape()[0]; diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index f368f437cf..3abe682c1e 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -680,19 +680,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); {"test_scatternd_add", "Opset 16 not supported yet."}, {"test_scatternd_multiply", "Opset 16 not supported yet."}, {"test_scatter_elements_with_duplicate_indices", "Opset 16 not supported yet."}, - {"resize_downsample_scales_cubic_antialias", "resize kernel needs update for opset18."}, - {"resize_downsample_scales_linear_antialias", "resize kernel needs update for opset18."}, - {"resize_downsample_sizes_cubic_antialias", "resize kernel needs update for opset18."}, - {"resize_downsample_sizes_linear_antialias", "resize kernel needs update for opset18."}, - {"resize_downsample_sizes_nearest_not_larger", "resize kernel needs update for opset18."}, - {"resize_downsample_sizes_nearest_not_smaller", "resize kernel needs update for opset18."}, - {"resize_tf_crop_and_resize_axes_2_3", "resize kernel needs update for opset18."}, - {"resize_tf_crop_and_resize_axes_3_2", "resize kernel needs update for opset18."}, - {"resize_upsample_scales_nearest_axes_2_3", "resize kernel needs update for opset18."}, - {"resize_upsample_scales_nearest_axes_3_2", "resize kernel needs update for opset18."}, - {"resize_upsample_sizes_nearest_axes_2_3", "resize kernel needs update for opset18."}, - {"resize_upsample_sizes_nearest_axes_3_2", "resize kernel needs update for opset18."}, - {"resize_upsample_sizes_nearest_not_larger", "resize kernel needs update for opset18."}, #if defined(DISABLE_OPTIONAL_TYPE) {"test_optional_get_element", "Optional type not supported in this build flavor."}, diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 05215c618d..94ee8bd567 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" @@ -1725,7 +1726,7 @@ TEST(ResizeOpTest, ResizeOpTypeCheck_Ver_10) { } template -void ResizeOpTypeCheck_Ver_11_13(int opset_version) { +void ResizeOpTypeCheck_Ver_11_13_18(int opset_version) { OpTester test("Resize", opset_version); std::vector roi{}; std::vector scales{1.0f, 1.0f, 2.0f, 3.0f}; @@ -1749,17 +1750,438 @@ void ResizeOpTypeCheck_Ver_11_13(int opset_version) { } TEST(ResizeOpTest, ResizeOpTypeCheck_Ver11) { - ResizeOpTypeCheck_Ver_11_13(11); - ResizeOpTypeCheck_Ver_11_13(11); - ResizeOpTypeCheck_Ver_11_13(11); - ResizeOpTypeCheck_Ver_11_13(11); + ResizeOpTypeCheck_Ver_11_13_18(11); + ResizeOpTypeCheck_Ver_11_13_18(11); + ResizeOpTypeCheck_Ver_11_13_18(11); + ResizeOpTypeCheck_Ver_11_13_18(11); } TEST(ResizeOpTest, ResizeOpTypeCheck_Ver13) { - ResizeOpTypeCheck_Ver_11_13(13); - ResizeOpTypeCheck_Ver_11_13(13); - ResizeOpTypeCheck_Ver_11_13(13); - ResizeOpTypeCheck_Ver_11_13(13); + ResizeOpTypeCheck_Ver_11_13_18(13); + ResizeOpTypeCheck_Ver_11_13_18(13); + ResizeOpTypeCheck_Ver_11_13_18(13); + ResizeOpTypeCheck_Ver_11_13_18(13); +} + +TEST(ResizeOpTest, ResizeOpTypeCheck_Ver18) { + ResizeOpTypeCheck_Ver_11_13_18(18); + ResizeOpTypeCheck_Ver_11_13_18(18); + ResizeOpTypeCheck_Ver_11_13_18(18); + ResizeOpTypeCheck_Ver_11_13_18(18); +} + +/* + * Most of TestCase against Anti-aliasing will have the attribute of "exclude_outside" as 1. + * It's as Pillow 's Resize is corresponding to ONNX op with exclude_outside equaling 1. + * Besides, for cubic mode, PIllow's one has a default value of 0.5 for "cubic_coeff_a", + * while ONNX op has a default value of 0.75. + */ +template +void TestAntialiasing(std::map attributes, + std::vector input_shape, + std::vector input_data, + std::vector output_shape_or_scale, std::vector output_data) { + auto parse_attr = [](const std::string& str, auto typed_v) { + using Tdata = decltype(typed_v); + std::vector vect; + + std::stringstream ss(str.substr(1, str.size() - 2)); + + for (Tdata i; ss >> i;) { + vect.push_back(i); + if (ss.peek() == ',') + ss.ignore(); + } + return vect; + }; + + OpTester test("Resize", 18); + test.AddAttribute("antialias", 1LL); + + std::vector roi{}; + std::vector scales{}; + std::vector output_shape; + + for (auto& [k, v] : attributes) { + if (k == "mode") { + test.AddAttribute("mode", v); + } else if (k == "exclude_outside") { + test.AddAttribute("exclude_outside", std::stoll(v)); + } else if (k == "cubic_coeff_a") { + test.AddAttribute("cubic_coeff_a", std::stof(v)); + } else if (k == "axes") { + int64_t type = 0; + test.AddAttribute>("axes", parse_attr(v, type)); + } else if (k == "output_shape") { + int64_t type = 0; + output_shape = parse_attr(v, type); + } else if (k == "coordinate_transformation_mode") { + test.AddAttribute("coordinate_transformation_mode", v); + } else if (k == "policy") { + test.AddAttribute("keep_aspect_ratio_policy", v); + } else if (k == "extrapolation_value") { + test.AddAttribute("extrapolation_value", std::stof(v)); + } else if (k == "roi") { + roi = parse_attr(v, 0.0f); + } else { + throw std::invalid_argument("Unknown attribute"); + } + } + + test.AddInput("X", input_shape, input_data); + test.AddInput("roi", {int64_t(roi.size())}, roi); + + if constexpr (std::is_same_v) { + test.AddInput("scales", {int64_t(output_shape_or_scale.size())}, output_shape_or_scale, true); + } else { + test.AddInput("", {0}, scales); + test.AddInput("sizes", {int64_t(output_shape_or_scale.size())}, output_shape_or_scale, true); + if (output_shape.empty()) { + output_shape = output_shape_or_scale; + } + } + + test.AddOutput("Y", output_shape, output_data); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ResizeOpTest, Antialias_Bilinear_No_ExcludeOutside) { + std::vector X(16); + std::iota(X.begin(), X.end(), 1.f); + + std::vector Y = {2.3636363f, 3.590909f, 4.818182f, + 7.2727275f, 8.5f, 9.727273f, + 12.181818f, 13.409091f, 14.636364f}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y); +} + +// match pillow +TEST(ResizeOpTest, Antialias_Bilinear_ExcludeOutside) { + std::vector X(16); + std::iota(X.begin(), X.end(), 1.f); + std::vector Y = {2.5f, 3.7f, 4.9f, + 7.3f, 8.5f, 9.7f, + 12.1f, 13.3f, 14.5f}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y); +} + +TEST(ResizeOpTest, Antialias_Bilinear_Scale_Is_All_1) { + std::vector X(3 * 4 * 5 * 6); + std::iota(X.begin(), X.end(), 1.f); + std::vector Y = X; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {3, 4, 5, 6}, X, {3, 4, 5, 6}, Y); +} + +TEST(ResizeOpTest, Antialias_Bilinear_dtype) { + { + std::vector X(16); + std::iota(X.begin(), X.end(), uint8_t(0)); + std::vector Y = {1, 3, 4, + 6, 8, 9, + 11, 13, 14}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y); + } + { + std::vector X(16); + std::iota(X.begin(), X.end(), int8_t(0)); + std::vector Y = {1, 3, 4, + 6, 8, 9, + 11, 13, 14}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y); + } + { + std::vector X(16); + std::iota(X.begin(), X.end(), 0); + std::vector Y = {1, 3, 4, + 6, 8, 9, + 11, 13, 14}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y); + } +} + +TEST(ResizeOpTest, Antialias_NhwcBilinear) { + std::vector X(3 * 5 * 8); + std::vector X1(3 * 5 * 8); + std::iota(X1.begin(), X1.end(), 0.f); + for (size_t x = 0; x < 5; x++) { + for (size_t y = 0; y < 8; y++) { + for (size_t c = 0; c < 3; c++) { + X[x * 8 * 3 + y * 3 + c] = X1[c * 5 * 8 + x * 8 + y]; + } + } + } + std::vector Y = {2.409091f, 42.409092f, 82.40909f, + 3.925926f, 43.925926f, 83.92593f, + 5.5f, 45.5f, 85.5f, + 7.0740743f, 47.074074f, 87.07407f, + 8.590909f, 48.590908f, 88.59091f, + 11.742424f, 51.742424f, 91.742424f, + 13.259259f, 53.25926f, 93.25926f, + 14.833333f, 54.833332f, 94.833336f, + 16.407408f, 56.407406f, 96.40741f, + 17.924242f, 57.924244f, 97.92424f, + 21.075758f, 61.075756f, 101.07576f, + 22.592592f, 62.592594f, 102.59259f, + 24.166666f, 64.166664f, 104.166664f, + 25.74074f, 65.74074f, 105.74074f, + 27.257576f, 67.257576f, 107.257576f, + 30.40909f, 70.40909f, 110.40909f, + 31.925926f, 71.92593f, 111.92593f, + 33.5f, 73.5f, 113.5f, + 35.074074f, 75.07407f, 115.07407f, + 36.590908f, 76.59091f, 116.59091f}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 5, 8, 3}, X, {1, 4, 5, 3}, Y); +} + +TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) { + { + std::vector X(16); + std::iota(X.begin(), X.end(), uint8_t(0)); + std::vector Y = {1, 3, 4, + 6, 8, 9, + 11, 13, 14}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y); + } + { + std::vector X(16); + std::iota(X.begin(), X.end(), int8_t(0)); + std::vector Y = {1, 3, 4, + 6, 8, 9, + 11, 13, 14}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y); + } + { + std::vector X(16); + std::iota(X.begin(), X.end(), 0); + std::vector Y = {1, 3, 4, + 6, 8, 9, + 11, 13, 14}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y); + } +} + +TEST(ResizeOpTest, Antialias_Trilinear_No_ExcludeOutside) { + std::vector X(16 * 4); + std::iota(X.begin(), X.end(), 0.f); + std::vector Y = {5.7272725f, 6.9545455f, 8.181818f, 10.636364f, 11.863636f, + 13.090909f, 15.545455f, 16.772728f, 18.f, 25.363636f, + 26.59091f, 27.818182f, 30.272728f, 31.5f, 32.727272f, + 35.18182f, 36.409092f, 37.636364f, 45.f, 46.227272f, + 47.454544f, 49.909092f, 51.136364f, 52.363636f, 54.81818f, + 56.045456f, 57.272728f}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 4, 4}, X, {3, 3, 3}, Y); +} + +TEST(ResizeOpTest, Antialias_Trilinear_ExcludeOutside) { + std::vector X(16 * 4); + std::iota(X.begin(), X.end(), 0.f); + std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, + 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, + 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {4, 4, 4}, X, {3, 3, 3}, Y); +} + +TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) { + std::vector X(16 * 4 * 4); + std::iota(X.begin(), X.end(), 0.f); + { + std::vector Y = X; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 4}, Y); + } + { + std::vector Y = {0.625f, 2.375f, 4.625f, 6.375f, 8.625f, 10.375f, 12.625f, + 14.375f, 16.625f, 18.375f, 20.625f, 22.375f, 24.625f, 26.375f, + 28.625f, 30.375f, 32.625f, 34.375f, 36.625f, 38.375f, 40.625f, + 42.375f, 44.625f, 46.375f, 48.625f, 50.375f, 52.625f, 54.375f, + 56.625f, 58.375f, 60.625f, 62.375f, 64.625f, 66.375f, 68.625f, + 70.375f, 72.625f, 74.375f, 76.625f, 78.375f, 80.625f, 82.375f, + 84.625f, 86.375f, 88.625f, 90.375f, 92.625f, 94.375f, 96.625f, + 98.375f, 100.625f, 102.375f, 104.625f, 106.375f, 108.625f, 110.375f, + 112.625f, 114.375f, 116.625f, 118.375f, 120.625f, 122.375f, 124.625f, + 126.375f, 128.625f, 130.375f, 132.625f, 134.375f, 136.625f, 138.375f, + 140.625f, 142.375f, 144.625f, 146.375f, 148.625f, 150.375f, 152.625f, + 154.375f, 156.625f, 158.375f, 160.625f, 162.375f, 164.625f, 166.375f, + 168.625f, 170.375f, 172.625f, 174.375f, 176.625f, 178.375f, 180.625f, + 182.375f, 184.625f, 186.375f, 188.625f, 190.375f, 192.625f, 194.375f, + 196.625f, 198.375f, 200.625f, 202.375f, 204.625f, 206.375f, 208.625f, + 210.375f, 212.625f, 214.375f, 216.625f, 218.375f, 220.625f, 222.375f, + 224.625f, 226.375f, 228.625f, 230.375f, 232.625f, 234.375f, 236.625f, + 238.375f, 240.625f, 242.375f, 244.625f, 246.375f, 248.625f, 250.375f, + 252.625f, 254.375f}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 2}, Y); + } + { + std::vector Y = {2.5f, 3.5f, 4.5f, 5.5f, 9.5f, 10.5f, 11.5f, 12.5f, 18.5f, + 19.5f, 20.5f, 21.5f, 25.5f, 26.5f, 27.5f, 28.5f, 34.5f, 35.5f, + 36.5f, 37.5f, 41.5f, 42.5f, 43.5f, 44.5f, 50.5f, 51.5f, 52.5f, + 53.5f, 57.5f, 58.5f, 59.5f, 60.5f, 66.5f, 67.5f, 68.5f, 69.5f, + 73.5f, 74.5f, 75.5f, 76.5f, 82.5f, 83.5f, 84.5f, 85.5f, 89.5f, + 90.5f, 91.5f, 92.5f, 98.5f, 99.5f, 100.5f, 101.5f, 105.5f, 106.5f, + 107.5f, 108.5f, 114.5f, 115.5f, 116.5f, 117.5f, 121.5f, 122.5f, 123.5f, + 124.5f, 130.5f, 131.5f, 132.5f, 133.5f, 137.5f, 138.5f, 139.5f, 140.5f, + 146.5f, 147.5f, 148.5f, 149.5f, 153.5f, 154.5f, 155.5f, 156.5f, 162.5f, + 163.5f, 164.5f, 165.5f, 169.5f, 170.5f, 171.5f, 172.5f, 178.5f, 179.5f, + 180.5f, 181.5f, 185.5f, 186.5f, 187.5f, 188.5f, 194.5f, 195.5f, 196.5f, + 197.5f, 201.5f, 202.5f, 203.5f, 204.5f, 210.5f, 211.5f, 212.5f, 213.5f, + 217.5f, 218.5f, 219.5f, 220.5f, 226.5f, 227.5f, 228.5f, 229.5f, 233.5f, + 234.5f, 235.5f, 236.5f, 242.5f, 243.5f, 244.5f, 245.5f, 249.5f, 250.5f, + 251.5f, 252.5f}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 2, 4}, Y); + } +} + +TEST(ResizeOpTest, Antialias_Bicubic_No_ExcludeOutside) { + std::vector X(48); + std::iota(X.begin(), X.end(), 1.0f); + std::vector Y = {2.175381f, 3.655320f, 5.204702f, 6.684640f, 10.245370f, + 11.725308f, 13.274692f, 14.754630f, 18.315359f, 19.795298f, + 21.344681f, 22.824619f, 26.175381f, 27.655319f, 29.204702f, + 30.684641f, 34.245369f, 35.725307f, 37.274693f, 38.754631f, + 42.315361f, 43.795300f, 45.344681f, 46.824619f}; + TestAntialiasing({{"mode", "cubic"}, {"exclude_outside", "0"}}, {1, 2, 4, 6}, X, {1, 2, 3, 4}, Y); +} + +TEST(ResizeOpTest, Antialias_NHWCBicubic_ExcludeOutside) { + std::vector X(48); + std::vector X1(48); + std::iota(X1.begin(), X1.end(), 1.0f); + for (size_t x = 0; x < 4; x++) { + for (size_t y = 0; y < 6; y++) { + for (size_t c = 0; c < 2; c++) { + X[x * 6 * 2 + y * 2 + c] = X1[c * 4 * 6 + x * 6 + y]; + } + } + } + std::vector Y = { + 0.6125579f, 24.612558f, 2.0924962f, 26.092497f, 3.6418788f, + 27.641878f, 5.121817f, 29.121817f, 2.393808f, 26.393808f, + 3.8737462f, 27.873747f, 5.423129f, 29.423128f, 6.903067f, + 30.903067f, 5.253183f, 29.253183f, 6.733121f, 30.733122f, + 8.282504f, 32.282505f, 9.762443f, 33.762444f, 9.02662f, + 33.02662f, 10.506558f, 34.506557f, 12.055942f, 36.055943f, + 13.53588f, 37.53588f, 11.46412f, 35.46412f, 12.944058f, + 36.944057f, 14.493442f, 38.493443f, 15.97338f, 39.97338f, + 15.237557f, 39.237556f, 16.717497f, 40.717495f, 18.266878f, + 42.26688f, 19.746817f, 43.74682f, 18.096933f, 42.09693f, + 19.576872f, 43.57687f, 21.126253f, 45.126255f, 22.606192f, + 46.606194f, 19.878183f, 43.87818f, 21.358122f, 45.35812f, + 22.907503f, 46.907505f, 24.387442f, 48.387444f}; + TestAntialiasing({{"mode", "cubic"}, {"exclude_outside", "0"}}, {1, 4, 6, 2}, X, {1, 8, 4, 2}, Y); +} + +TEST(ResizeOpTest, Antialias_Linear_AlignCorners) { + std::vector X(256); + std::iota(X.begin(), X.end(), 0.0f); + + std::vector Y = { + 3.9166667f, 6.4166665f, 13.916667f, 16.416666f, 25.25f, + 27.75f, 35.25f, 37.75f, 46.583332f, 49.083332f, + 56.583332f, 59.083332f, 67.916664f, 70.416664f, 77.916664f, + 80.416664f, 89.25f, 91.75f, 99.25f, 101.75f, + 110.583336f, 113.083336f, 120.583336f, 123.083336f, 131.91667f, + 134.41667f, 141.91667f, 144.41667f, 153.25f, 155.75f, + 163.25f, 165.75f, 174.58333f, 177.08333f, 184.58333f, + 187.08333f, 195.91667f, 198.41667f, 205.91667f, 208.41667f, + 217.25f, 219.75f, 227.25f, 229.75f, 238.58333f, + 241.08333f, 248.58333f, 251.08333f}; + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "0"}, {"coordinate_transformation_mode", "align_corners"}}, + {4, 1, 4, 4, 4}, X, {4, 1, 3, 2, 2}, Y); +} + +TEST(ResizeOpTest, Antialias_Bicubic_ExcludeOutside) { + std::vector X(48); + std::iota(X.begin(), X.end(), 1.0f); + std::vector Y = {2.222252f, 3.670954f, 5.259818f, 6.708520f, 10.256866f, 11.705568f, + 13.294432f, 14.743134f, 18.291479f, 19.740183f, 21.329046f, + 22.777748f, 26.222252f, 27.670954f, 29.259817f, + 30.708521f, 34.256866f, 35.705566f, 37.294434f, 38.743134f, 42.291481f, + 43.740181f, 45.329044f, 46.777748f}; + TestAntialiasing({{"mode", "cubic"}, {"exclude_outside", "1"}}, {1, 2, 4, 6}, X, {1, 2, 3, 4}, Y); +} + +TEST(ResizeOpTest, Antialias_Bicubic_Dtype) { + { + std::vector X(36); + std::iota(X.begin(), X.end(), uint8_t(0)); + std::vector Y = {4, 6, 7, 16, 18, 19, 28, 30, 31}; + TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y); + } + { + std::vector X(36); + std::iota(X.begin(), X.end(), int8_t(0)); + std::vector Y = {4, 6, 7, 16, 18, 19, 28, 30, 31}; + TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y); + } + { + std::vector X(36); + std::iota(X.begin(), X.end(), 0); + std::vector Y = {4, 6, 7, 16, 18, 19, 28, 30, 31}; + TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y); + } +} + +// test new attributes +TEST(ResizeOpTest, Antialias_Axes_and_Scale) { + std::vector X(16 * 4); + std::iota(X.begin(), X.end(), 0.f); + std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, + 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, + 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, {1, 1, 4, 4, 4}, X, + std::vector{3 / 4.0f, 3 / 4.0f, 3 / 4.0f}, Y); +} + +TEST(ResizeOpTest, Antialias_Axes_and_Size) { + std::vector X(16 * 4); + std::iota(X.begin(), X.end(), 0.f); + std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, + 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, + 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, {1, 1, 4, 4, 4}, X, + {3, 3, 3}, Y); +} + +TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoLarger) { + std::vector X(16 * 4); + std::iota(X.begin(), X.end(), 0.f); + std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, + 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, + 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, {"policy", "not_larger"}}, + {1, 1, 4, 4, 4}, X, + {3, 4, 5}, Y); +} + +TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoSmaller) { + std::vector X(16 * 4); + std::iota(X.begin(), X.end(), 0.f); + std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, + 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, + 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, {"policy", "not_smaller"}}, + {1, 1, 4, 4, 4}, X, + {1, 2, 3}, Y); +} + +TEST(ResizeOpTest, Antialias_Use_Extrapolation) { + std::vector X(16 * 4); + std::iota(X.begin(), X.end(), 0.f); + std::vector Y = {4.5555553f, 5.4385967f, 6.1666665f, 9.888889f, 10.77193f, + 11.5f, 15.222222f, 16.105263f, 16.833334f, 16.20468f, + 17.08772f, 17.81579f, 21.538012f, 22.421053f, 23.149124f, + 26.871346f, 27.754387f, 28.482456f, 30.333334f, 31.216375f, + 31.944447f, 35.666668f, 36.54971f, 37.27778f, 41., + 41.88304f, 42.61111f}; + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "0"}, {"extrapolation_value", "1.1f"}, + + {"coordinate_transformation_mode", "tf_crop_and_resize"}, + {"roi", "{0, 0, 0.4, 0.6, 1, 1}"}, + {"axes", "{0,1,2}"} + + }, + {4, 4, 4}, X, {3, 3, 3}, Y); } } // namespace test diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 674c2d5cfa..d91511e7d9 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -127,7 +127,6 @@ "^test_constant_pad_cpu", "^test_edge_pad_cpu", "^test_reflect_pad_cpu", - "^test_resize_*", "^test_scatter_elements_*", "^test_softplus_example_expanded_cpu", "^test_softplus_expanded_cpu",