mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Implement Shrink operator (#485)
* Initial commit * Adding shrink tests * Fix formatting in shrink_test.cc * Fix broken build * More changes * PR feedback and formatting * Place files in the right location corresponding to def file location in onnx * Exclude shrink model test in test_series.py * Remove shrink from exclusion list in main.cc * Adding test to exclusion list * More tests * Formatting * PR feedback * PR feedback * More changes * PR feedback * More changes * Fix broken build * Fix nit * Fix nit
This commit is contained in:
parent
9fb80ea927
commit
a697e0b710
9 changed files with 302 additions and 4 deletions
|
|
@ -186,6 +186,7 @@ class DataTypeImpl {
|
|||
|
||||
static const std::vector<MLDataType>& AllTensorTypes();
|
||||
static const std::vector<MLDataType>& AllFixedSizeTensorTypes();
|
||||
static const std::vector<MLDataType>& AllNumericTensorTypes();
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, MLDataType data_type);
|
||||
|
|
|
|||
|
|
@ -736,6 +736,24 @@ const std::vector<MLDataType>& DataTypeImpl::AllTensorTypes() {
|
|||
return all_tensor_types;
|
||||
}
|
||||
|
||||
const std::vector<MLDataType>& DataTypeImpl::AllNumericTensorTypes() {
|
||||
static std::vector<MLDataType> all_numeric_size_tensor_types =
|
||||
{DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<uint64_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<uint32_t>(),
|
||||
DataTypeImpl::GetTensorType<int16_t>(),
|
||||
DataTypeImpl::GetTensorType<uint16_t>(),
|
||||
DataTypeImpl::GetTensorType<int8_t>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<MLFloat16>(),
|
||||
DataTypeImpl::GetTensorType<BFloat16>()};
|
||||
|
||||
return all_numeric_size_tensor_types;
|
||||
}
|
||||
|
||||
// helper to stream. expected to only be used for error output, so any typeid lookup
|
||||
// cost should be fine. alternative would be to add a static string field to DataTypeImpl
|
||||
// that we set in the register macro to the type name, and output that instead.
|
||||
|
|
|
|||
|
|
@ -243,6 +243,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Eye
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sign);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Shrink);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int64_t_int64_t, OneHot);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_int64_t_int64_t, OneHot);
|
||||
|
|
@ -502,6 +503,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sign)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Shrink)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int64_t_int64_t, OneHot)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_int64_t_int64_t, OneHot)>());
|
||||
|
|
|
|||
80
onnxruntime/core/providers/cpu/nn/shrink.cc
Normal file
80
onnxruntime/core/providers/cpu/nn/shrink.cc
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cpu/nn/shrink.h"
|
||||
|
||||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "core/framework/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Shrink,
|
||||
9,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllNumericTensorTypes()),
|
||||
Shrink);
|
||||
|
||||
namespace shrink_internal {
|
||||
template <class T>
|
||||
inline T ShrinkCore(const T& val, float bias, float lambd) {
|
||||
// The ONNX spec doesn't take numeric overflow and underflow into account
|
||||
// Implementing the spec as is for now
|
||||
if (val < -lambd) {
|
||||
return T(val + bias);
|
||||
} else if (val > lambd) {
|
||||
return T(val - bias);
|
||||
} else {
|
||||
return T(0);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
Status ShrinkImpl(const Tensor* input, Tensor* output, float bias, float lambd) {
|
||||
EigenMap<T>(*output) = EigenMap<T>(*input).unaryExpr([bias, lambd](const T& val) { return ShrinkCore<T>(val, bias, lambd); });
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status ShrinkImpl<MLFloat16>(const Tensor* input, Tensor* output, float bias, float lambd) {
|
||||
const auto& span = gsl::make_span(input->Data<MLFloat16>(), input->Shape().Size());
|
||||
auto* output_data = output->template MutableData<MLFloat16>();
|
||||
std::transform(span.cbegin(), span.cend(), output_data, [bias, lambd](const MLFloat16& val) {
|
||||
float fl = math::halfToFloat(val.val);
|
||||
return MLFloat16(math::floatToHalf(ShrinkCore<float>(fl, bias, lambd)));
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status ShrinkImpl<BFloat16>(const Tensor* input, Tensor* output, float bias, float lambd) {
|
||||
const auto& span = gsl::make_span(input->Data<BFloat16>(), input->Shape().Size());
|
||||
auto* output_data = output->template MutableData<BFloat16>();
|
||||
std::transform(span.cbegin(), span.cend(), output_data, [bias, lambd](const BFloat16& val) {
|
||||
float fl = val.ToFloat();
|
||||
return BFloat16(ShrinkCore<float>(fl, bias, lambd));
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status ShrinkImpl<bool>(const Tensor* /*input*/, Tensor* /*output*/, float /*bias*/, float /*lambd*/) {
|
||||
return ORT_MAKE_STATUS(
|
||||
ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input types for the Shrink operator are constrained "
|
||||
"to all numeric types only. Got bool type here."
|
||||
);
|
||||
}
|
||||
|
||||
} // namespace shrink_internal
|
||||
|
||||
Status Shrink::Compute(OpKernelContext* p_op_kernel_context) const {
|
||||
using namespace shrink_internal;
|
||||
|
||||
const auto* input = p_op_kernel_context->Input<Tensor>(0);
|
||||
auto* output = p_op_kernel_context->Output(0, input->Shape());
|
||||
const auto& dtype = input->DataType();
|
||||
Status status;
|
||||
DispatchOnTensorTypeWithReturn(dtype, status, ShrinkImpl, input, output, bias_, lambd_);
|
||||
return status;
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
27
onnxruntime/core/providers/cpu/nn/shrink.h
Normal file
27
onnxruntime/core/providers/cpu/nn/shrink.h
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class Shrink final : public OpKernel {
|
||||
public:
|
||||
Shrink(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) {
|
||||
float bias_temp;
|
||||
ORT_ENFORCE(op_kernel_info.GetAttr<float>("bias", &bias_temp).IsOK());
|
||||
bias_ = gsl::narrow_cast<float>(bias_temp);
|
||||
|
||||
float lambd_temp;
|
||||
ORT_ENFORCE(op_kernel_info.GetAttr<float>("lambd", &lambd_temp).IsOK());
|
||||
lambd_ = gsl::narrow_cast<float>(lambd_temp);
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* p_op_kernel_context) const override;
|
||||
|
||||
private:
|
||||
float bias_;
|
||||
float lambd_;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -312,8 +312,6 @@ int real_main(int argc, char* argv[]) {
|
|||
{"scatter_without_axis", "opset 9 not supported yet"},
|
||||
{"scan_sum", "opset 9 not supported yet"},
|
||||
{"shrink", "opset 9 not supported yet"},
|
||||
{"shrink_hard", "opset 9 not supported yet"},
|
||||
{"shrink_soft", "opset 9 not supported yet"},
|
||||
{"cast_DOUBLE_to_FLOAT16", "Cast opset 9 not supported yet"},
|
||||
{"cast_DOUBLE_to_FLOAT", "Cast opset 9 not supported yet"},
|
||||
{"cast_FLOAT_to_DOUBLE", "Cast opset 9 not supported yet"},
|
||||
|
|
|
|||
174
onnxruntime/test/providers/cpu/nn/shrink_test.cc
Normal file
174
onnxruntime/test/providers/cpu/nn/shrink_test.cc
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "core/util/math.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
template <typename T>
|
||||
struct ShrinkTestData {
|
||||
std::string name;
|
||||
float bias;
|
||||
float lambd;
|
||||
std::vector<T> input_vals;
|
||||
std::vector<int64_t> input_dimensions;
|
||||
std::vector<T> expected_vals;
|
||||
std::vector<int64_t> expected_dimensions;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::vector<ShrinkTestData<T>> GenerateSignedTestCases() {
|
||||
std::vector<ShrinkTestData<T>> test_cases;
|
||||
test_cases.push_back(
|
||||
{"default attributes",
|
||||
0.0f,
|
||||
0.5f,
|
||||
{-1, 0, 0, 1},
|
||||
{2, 2},
|
||||
{-1, 0, 0, 1},
|
||||
{2, 2}});
|
||||
|
||||
test_cases.push_back(
|
||||
{"non-default attributes",
|
||||
10.0f,
|
||||
2.0f,
|
||||
{-3, -1, 1, 4},
|
||||
{2, 2},
|
||||
{7, 0, 0, -6},
|
||||
{2, 2}});
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<ShrinkTestData<T>> GenerateUnsignedTestCases() {
|
||||
std::vector<ShrinkTestData<T>> test_cases;
|
||||
|
||||
test_cases.push_back(
|
||||
{"default attributes",
|
||||
0.0f,
|
||||
0.5f,
|
||||
{0, 0, 0, 1},
|
||||
{2, 2},
|
||||
{0, 0, 0, 1},
|
||||
{2, 2}});
|
||||
|
||||
test_cases.push_back(
|
||||
{"non-default attributes",
|
||||
10.0f,
|
||||
2.0f,
|
||||
{37, 1, 1, 11},
|
||||
{2, 2},
|
||||
{27, 0, 0, 1},
|
||||
{2, 2}});
|
||||
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RunShrinkTest(const std::vector<ShrinkTestData<T>>& test_cases) {
|
||||
for (const auto& test_data : test_cases) {
|
||||
OpTester test("Shrink", 9);
|
||||
|
||||
if (test_data.bias != 0.0f) {
|
||||
test.AddAttribute("bias", test_data.bias);
|
||||
}
|
||||
|
||||
if (test_data.lambd != 0.5f) {
|
||||
test.AddAttribute("lambd", test_data.lambd);
|
||||
}
|
||||
|
||||
test.AddInput<T>("X", test_data.input_dimensions, test_data.input_vals);
|
||||
test.AddOutput<T>("Y", test_data.expected_dimensions, test_data.expected_vals);
|
||||
test.Run();
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<MLFloat16> ConvertFloatToMLFloat16(const std::vector<float>& float_data) {
|
||||
std::vector<MLFloat16> new_data;
|
||||
for (const auto& f : float_data) {
|
||||
new_data.push_back(MLFloat16(math::floatToHalf(f)));
|
||||
}
|
||||
return new_data;
|
||||
}
|
||||
|
||||
TEST(MathOpTest, ShrinkInt8Type) {
|
||||
const auto& test_cases = GenerateSignedTestCases<int8_t>();
|
||||
RunShrinkTest<int8_t>(test_cases);
|
||||
}
|
||||
|
||||
TEST(MathOpTest, ShrinkUint8Type) {
|
||||
const auto& test_cases = GenerateUnsignedTestCases<uint8_t>();
|
||||
RunShrinkTest<uint8_t>(test_cases);
|
||||
}
|
||||
|
||||
TEST(MathOpTest, ShrinkInt16Type) {
|
||||
const auto& test_cases = GenerateSignedTestCases<int16_t>();
|
||||
RunShrinkTest<int16_t>(test_cases);
|
||||
}
|
||||
|
||||
TEST(MathOpTest, ShrinkUint16Type) {
|
||||
const auto& test_cases = GenerateUnsignedTestCases<uint16_t>();
|
||||
RunShrinkTest<uint16_t>(test_cases);
|
||||
}
|
||||
|
||||
TEST(MathOpTest, ShrinkInt32Type) {
|
||||
const auto& test_cases = GenerateSignedTestCases<int32_t>();
|
||||
RunShrinkTest<int32_t>(test_cases);
|
||||
}
|
||||
|
||||
TEST(MathOpTest, ShrinkUint32Type) {
|
||||
const auto& test_cases = GenerateUnsignedTestCases<uint32_t>();
|
||||
RunShrinkTest<uint32_t>(test_cases);
|
||||
}
|
||||
|
||||
TEST(MathOpTest, ShrinkInt64Type) {
|
||||
const auto& test_cases = GenerateSignedTestCases<int64_t>();
|
||||
RunShrinkTest<int64_t>(test_cases);
|
||||
}
|
||||
|
||||
TEST(MathOpTest, ShrinkUint64Type) {
|
||||
const auto& test_cases = GenerateUnsignedTestCases<uint64_t>();
|
||||
RunShrinkTest<uint64_t>(test_cases);
|
||||
}
|
||||
|
||||
TEST(MathOpTest, ShrinkFloatType) {
|
||||
const auto& test_cases = GenerateSignedTestCases<float>();
|
||||
RunShrinkTest<float>(test_cases);
|
||||
}
|
||||
|
||||
TEST(MathOpTest, ShrinkDoubleType) {
|
||||
const auto& test_cases = GenerateSignedTestCases<double>();
|
||||
RunShrinkTest<double>(test_cases);
|
||||
}
|
||||
|
||||
TEST(MathOpTest, ShrinkMLFloat16Type) {
|
||||
const std::vector<MLFloat16> input_test_data_default = ConvertFloatToMLFloat16({-1, 0, 0, 1});
|
||||
const std::vector<MLFloat16> output_test_data_default = ConvertFloatToMLFloat16({-1, 0, 0, 1});
|
||||
|
||||
const std::vector<MLFloat16> input_test_data_nondefault = ConvertFloatToMLFloat16({-3, -1, 1, 4});
|
||||
const std::vector<MLFloat16> output_test_data_nondefault = ConvertFloatToMLFloat16({7, 0, 0, -6});
|
||||
std::vector<ShrinkTestData<MLFloat16>> test_cases;
|
||||
test_cases.push_back(
|
||||
{"default attributes",
|
||||
0.0f,
|
||||
0.5f,
|
||||
input_test_data_default,
|
||||
{2, 2},
|
||||
output_test_data_default,
|
||||
{2, 2}});
|
||||
test_cases.push_back(
|
||||
{"non-default attributes",
|
||||
10.0f,
|
||||
2.0f,
|
||||
input_test_data_nondefault,
|
||||
{2, 2},
|
||||
output_test_data_nondefault,
|
||||
{2, 2}});
|
||||
RunShrinkTest<MLFloat16>(test_cases);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -26,8 +26,6 @@ backend_test.exclude(r'('
|
|||
'|^test_convtranspose_3d_cpu.*'
|
||||
'|^test_scatter_with_axis_cpu.*'
|
||||
'|^test_scatter_without_axis_cpu.*'
|
||||
'|^test_shrink_hard_cpu.*'
|
||||
'|^test_shrink_soft_cpu.*'
|
||||
'|^test_AvgPool1d_cpu.*'
|
||||
'|^test_AvgPool1d_stride_cpu.*'
|
||||
'|^test_AvgPool2d_cpu.*'
|
||||
|
|
|
|||
Loading…
Reference in a new issue