diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index 69c44a63ed..3f191a82b6 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -186,6 +186,7 @@ class DataTypeImpl { static const std::vector& AllTensorTypes(); static const std::vector& AllFixedSizeTensorTypes(); + static const std::vector& AllNumericTensorTypes(); }; std::ostream& operator<<(std::ostream& out, MLDataType data_type); diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc index 73d1db82d6..ad8dbe8154 100644 --- a/onnxruntime/core/framework/data_types.cc +++ b/onnxruntime/core/framework/data_types.cc @@ -736,6 +736,24 @@ const std::vector& DataTypeImpl::AllTensorTypes() { return all_tensor_types; } +const std::vector& DataTypeImpl::AllNumericTensorTypes() { + static std::vector all_numeric_size_tensor_types = + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + + return all_numeric_size_tensor_types; +} + // helper to stream. expected to only be used for error output, so any typeid lookup // cost should be fine. alternative would be to add a static string field to DataTypeImpl // that we set in the register macro to the type name, and output that instead. diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 83a5a727c7..498ee9d4c9 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -243,6 +243,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Eye class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sign); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Shrink); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int64_t_int64_t, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_int64_t_int64_t, OneHot); @@ -502,6 +503,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); diff --git a/onnxruntime/core/providers/cpu/nn/shrink.cc b/onnxruntime/core/providers/cpu/nn/shrink.cc new file mode 100644 index 0000000000..e893f12647 --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/shrink.cc @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/nn/shrink.h" + +#include "core/util/math.h" +#include "core/util/math_cpuonly.h" +#include "core/framework/utils.h" + +namespace onnxruntime { +ONNX_CPU_OPERATOR_KERNEL( + Shrink, + 9, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllNumericTensorTypes()), + Shrink); + +namespace shrink_internal { +template +inline T ShrinkCore(const T& val, float bias, float lambd) { + // The ONNX spec doesn't take numeric overflow and underflow into account + // Implementing the spec as is for now + if (val < -lambd) { + return T(val + bias); + } else if (val > lambd) { + return T(val - bias); + } else { + return T(0); + } +} + +template +Status ShrinkImpl(const Tensor* input, Tensor* output, float bias, float lambd) { + EigenMap(*output) = EigenMap(*input).unaryExpr([bias, lambd](const T& val) { return ShrinkCore(val, bias, lambd); }); + return Status::OK(); +} + +template <> +Status ShrinkImpl(const Tensor* input, Tensor* output, float bias, float lambd) { + const auto& span = gsl::make_span(input->Data(), input->Shape().Size()); + auto* output_data = output->template MutableData(); + std::transform(span.cbegin(), span.cend(), output_data, [bias, lambd](const MLFloat16& val) { + float fl = math::halfToFloat(val.val); + return MLFloat16(math::floatToHalf(ShrinkCore(fl, bias, lambd))); + }); + return Status::OK(); +} + +template <> +Status ShrinkImpl(const Tensor* input, Tensor* output, float bias, float lambd) { + const auto& span = gsl::make_span(input->Data(), input->Shape().Size()); + auto* output_data = output->template MutableData(); + std::transform(span.cbegin(), span.cend(), output_data, [bias, lambd](const BFloat16& val) { + float fl = val.ToFloat(); + return BFloat16(ShrinkCore(fl, bias, lambd)); + }); + return Status::OK(); +} + +template <> +Status ShrinkImpl(const Tensor* /*input*/, Tensor* /*output*/, float /*bias*/, float /*lambd*/) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Input types for the Shrink operator are constrained " + "to all numeric types only. Got bool type here." + ); +} + +} // namespace shrink_internal + +Status Shrink::Compute(OpKernelContext* p_op_kernel_context) const { + using namespace shrink_internal; + + const auto* input = p_op_kernel_context->Input(0); + auto* output = p_op_kernel_context->Output(0, input->Shape()); + const auto& dtype = input->DataType(); + Status status; + DispatchOnTensorTypeWithReturn(dtype, status, ShrinkImpl, input, output, bias_, lambd_); + return status; +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/cpu/nn/shrink.h b/onnxruntime/core/providers/cpu/nn/shrink.h new file mode 100644 index 0000000000..44a9d0c20f --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/shrink.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +class Shrink final : public OpKernel { + public: + Shrink(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) { + float bias_temp; + ORT_ENFORCE(op_kernel_info.GetAttr("bias", &bias_temp).IsOK()); + bias_ = gsl::narrow_cast(bias_temp); + + float lambd_temp; + ORT_ENFORCE(op_kernel_info.GetAttr("lambd", &lambd_temp).IsOK()); + lambd_ = gsl::narrow_cast(lambd_temp); + } + + Status Compute(OpKernelContext* p_op_kernel_context) const override; + + private: + float bias_; + float lambd_; +}; +} // namespace onnxruntime diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 6b89f869db..e18205c3ea 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -312,8 +312,6 @@ int real_main(int argc, char* argv[]) { {"scatter_without_axis", "opset 9 not supported yet"}, {"scan_sum", "opset 9 not supported yet"}, {"shrink", "opset 9 not supported yet"}, - {"shrink_hard", "opset 9 not supported yet"}, - {"shrink_soft", "opset 9 not supported yet"}, {"cast_DOUBLE_to_FLOAT16", "Cast opset 9 not supported yet"}, {"cast_DOUBLE_to_FLOAT", "Cast opset 9 not supported yet"}, {"cast_FLOAT_to_DOUBLE", "Cast opset 9 not supported yet"}, diff --git a/onnxruntime/test/providers/cpu/math/sing_test.cc b/onnxruntime/test/providers/cpu/math/sign_test.cc similarity index 100% rename from onnxruntime/test/providers/cpu/math/sing_test.cc rename to onnxruntime/test/providers/cpu/math/sign_test.cc diff --git a/onnxruntime/test/providers/cpu/nn/shrink_test.cc b/onnxruntime/test/providers/cpu/nn/shrink_test.cc new file mode 100644 index 0000000000..03bf0eeb15 --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/shrink_test.cc @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "core/util/math.h" + +namespace onnxruntime { +namespace test { + +template +struct ShrinkTestData { + std::string name; + float bias; + float lambd; + std::vector input_vals; + std::vector input_dimensions; + std::vector expected_vals; + std::vector expected_dimensions; +}; + +template +std::vector> GenerateSignedTestCases() { + std::vector> test_cases; + test_cases.push_back( + {"default attributes", + 0.0f, + 0.5f, + {-1, 0, 0, 1}, + {2, 2}, + {-1, 0, 0, 1}, + {2, 2}}); + + test_cases.push_back( + {"non-default attributes", + 10.0f, + 2.0f, + {-3, -1, 1, 4}, + {2, 2}, + {7, 0, 0, -6}, + {2, 2}}); + return test_cases; +} + +template +std::vector> GenerateUnsignedTestCases() { + std::vector> test_cases; + + test_cases.push_back( + {"default attributes", + 0.0f, + 0.5f, + {0, 0, 0, 1}, + {2, 2}, + {0, 0, 0, 1}, + {2, 2}}); + + test_cases.push_back( + {"non-default attributes", + 10.0f, + 2.0f, + {37, 1, 1, 11}, + {2, 2}, + {27, 0, 0, 1}, + {2, 2}}); + + return test_cases; +} + +template +void RunShrinkTest(const std::vector>& test_cases) { + for (const auto& test_data : test_cases) { + OpTester test("Shrink", 9); + + if (test_data.bias != 0.0f) { + test.AddAttribute("bias", test_data.bias); + } + + if (test_data.lambd != 0.5f) { + test.AddAttribute("lambd", test_data.lambd); + } + + test.AddInput("X", test_data.input_dimensions, test_data.input_vals); + test.AddOutput("Y", test_data.expected_dimensions, test_data.expected_vals); + test.Run(); + } +} + +const std::vector ConvertFloatToMLFloat16(const std::vector& float_data) { + std::vector new_data; + for (const auto& f : float_data) { + new_data.push_back(MLFloat16(math::floatToHalf(f))); + } + return new_data; +} + +TEST(MathOpTest, ShrinkInt8Type) { + const auto& test_cases = GenerateSignedTestCases(); + RunShrinkTest(test_cases); +} + +TEST(MathOpTest, ShrinkUint8Type) { + const auto& test_cases = GenerateUnsignedTestCases(); + RunShrinkTest(test_cases); +} + +TEST(MathOpTest, ShrinkInt16Type) { + const auto& test_cases = GenerateSignedTestCases(); + RunShrinkTest(test_cases); +} + +TEST(MathOpTest, ShrinkUint16Type) { + const auto& test_cases = GenerateUnsignedTestCases(); + RunShrinkTest(test_cases); +} + +TEST(MathOpTest, ShrinkInt32Type) { + const auto& test_cases = GenerateSignedTestCases(); + RunShrinkTest(test_cases); +} + +TEST(MathOpTest, ShrinkUint32Type) { + const auto& test_cases = GenerateUnsignedTestCases(); + RunShrinkTest(test_cases); +} + +TEST(MathOpTest, ShrinkInt64Type) { + const auto& test_cases = GenerateSignedTestCases(); + RunShrinkTest(test_cases); +} + +TEST(MathOpTest, ShrinkUint64Type) { + const auto& test_cases = GenerateUnsignedTestCases(); + RunShrinkTest(test_cases); +} + +TEST(MathOpTest, ShrinkFloatType) { + const auto& test_cases = GenerateSignedTestCases(); + RunShrinkTest(test_cases); +} + +TEST(MathOpTest, ShrinkDoubleType) { + const auto& test_cases = GenerateSignedTestCases(); + RunShrinkTest(test_cases); +} + +TEST(MathOpTest, ShrinkMLFloat16Type) { + const std::vector input_test_data_default = ConvertFloatToMLFloat16({-1, 0, 0, 1}); + const std::vector output_test_data_default = ConvertFloatToMLFloat16({-1, 0, 0, 1}); + + const std::vector input_test_data_nondefault = ConvertFloatToMLFloat16({-3, -1, 1, 4}); + const std::vector output_test_data_nondefault = ConvertFloatToMLFloat16({7, 0, 0, -6}); + std::vector> test_cases; + test_cases.push_back( + {"default attributes", + 0.0f, + 0.5f, + input_test_data_default, + {2, 2}, + output_test_data_default, + {2, 2}}); + test_cases.push_back( + {"non-default attributes", + 10.0f, + 2.0f, + input_test_data_nondefault, + {2, 2}, + output_test_data_nondefault, + {2, 2}}); + RunShrinkTest(test_cases); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 2ce3f986ec..e57873e43f 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -26,8 +26,6 @@ backend_test.exclude(r'(' '|^test_convtranspose_3d_cpu.*' '|^test_scatter_with_axis_cpu.*' '|^test_scatter_without_axis_cpu.*' -'|^test_shrink_hard_cpu.*' -'|^test_shrink_soft_cpu.*' '|^test_AvgPool1d_cpu.*' '|^test_AvgPool1d_stride_cpu.*' '|^test_AvgPool2d_cpu.*'