Implement Shrink operator (#485)

* Initial commit

* Adding shrink tests

* Fix formatting in shrink_test.cc

* Fix broken build

* More changes

* PR feedback and formatting

* Place files in the right location corresponding to def file location in onnx

* Exclude shrink model test in test_series.py

* Remove shrink from exclusion list in main.cc

* Adding test to exclusion list

* More tests

* Formatting

* PR feedback

* PR feedback

* More changes

* PR feedback

* More changes

* Fix broken build

* Fix nit

* Fix nit
This commit is contained in:
Hariharan Seshadri 2019-03-01 12:51:22 -08:00 committed by GitHub
parent 9fb80ea927
commit a697e0b710
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 302 additions and 4 deletions

View file

@ -186,6 +186,7 @@ class DataTypeImpl {
static const std::vector<MLDataType>& AllTensorTypes();
static const std::vector<MLDataType>& AllFixedSizeTensorTypes();
static const std::vector<MLDataType>& AllNumericTensorTypes();
};
std::ostream& operator<<(std::ostream& out, MLDataType data_type);

View file

@ -736,6 +736,24 @@ const std::vector<MLDataType>& DataTypeImpl::AllTensorTypes() {
return all_tensor_types;
}
const std::vector<MLDataType>& DataTypeImpl::AllNumericTensorTypes() {
static std::vector<MLDataType> all_numeric_size_tensor_types =
{DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>(),
DataTypeImpl::GetTensorType<int16_t>(),
DataTypeImpl::GetTensorType<uint16_t>(),
DataTypeImpl::GetTensorType<int8_t>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<BFloat16>()};
return all_numeric_size_tensor_types;
}
// helper to stream. expected to only be used for error output, so any typeid lookup
// cost should be fine. alternative would be to add a static string field to DataTypeImpl
// that we set in the register macro to the type name, and output that instead.

View file

@ -243,6 +243,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Eye
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sign);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Shrink);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int64_t_int64_t, OneHot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_int64_t_int64_t, OneHot);
@ -502,6 +503,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sign)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Shrink)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int64_t_int64_t, OneHot)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_int64_t_int64_t, OneHot)>());

View file

@ -0,0 +1,80 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cpu/nn/shrink.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
#include "core/framework/utils.h"
namespace onnxruntime {
ONNX_CPU_OPERATOR_KERNEL(
Shrink,
9,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllNumericTensorTypes()),
Shrink);
namespace shrink_internal {
template <class T>
inline T ShrinkCore(const T& val, float bias, float lambd) {
// The ONNX spec doesn't take numeric overflow and underflow into account
// Implementing the spec as is for now
if (val < -lambd) {
return T(val + bias);
} else if (val > lambd) {
return T(val - bias);
} else {
return T(0);
}
}
template <class T>
Status ShrinkImpl(const Tensor* input, Tensor* output, float bias, float lambd) {
EigenMap<T>(*output) = EigenMap<T>(*input).unaryExpr([bias, lambd](const T& val) { return ShrinkCore<T>(val, bias, lambd); });
return Status::OK();
}
template <>
Status ShrinkImpl<MLFloat16>(const Tensor* input, Tensor* output, float bias, float lambd) {
const auto& span = gsl::make_span(input->Data<MLFloat16>(), input->Shape().Size());
auto* output_data = output->template MutableData<MLFloat16>();
std::transform(span.cbegin(), span.cend(), output_data, [bias, lambd](const MLFloat16& val) {
float fl = math::halfToFloat(val.val);
return MLFloat16(math::floatToHalf(ShrinkCore<float>(fl, bias, lambd)));
});
return Status::OK();
}
template <>
Status ShrinkImpl<BFloat16>(const Tensor* input, Tensor* output, float bias, float lambd) {
const auto& span = gsl::make_span(input->Data<BFloat16>(), input->Shape().Size());
auto* output_data = output->template MutableData<BFloat16>();
std::transform(span.cbegin(), span.cend(), output_data, [bias, lambd](const BFloat16& val) {
float fl = val.ToFloat();
return BFloat16(ShrinkCore<float>(fl, bias, lambd));
});
return Status::OK();
}
template <>
Status ShrinkImpl<bool>(const Tensor* /*input*/, Tensor* /*output*/, float /*bias*/, float /*lambd*/) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Input types for the Shrink operator are constrained "
"to all numeric types only. Got bool type here."
);
}
} // namespace shrink_internal
Status Shrink::Compute(OpKernelContext* p_op_kernel_context) const {
using namespace shrink_internal;
const auto* input = p_op_kernel_context->Input<Tensor>(0);
auto* output = p_op_kernel_context->Output(0, input->Shape());
const auto& dtype = input->DataType();
Status status;
DispatchOnTensorTypeWithReturn(dtype, status, ShrinkImpl, input, output, bias_, lambd_);
return status;
}
} // namespace onnxruntime

View file

@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/op_kernel.h"
namespace onnxruntime {
class Shrink final : public OpKernel {
public:
Shrink(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) {
float bias_temp;
ORT_ENFORCE(op_kernel_info.GetAttr<float>("bias", &bias_temp).IsOK());
bias_ = gsl::narrow_cast<float>(bias_temp);
float lambd_temp;
ORT_ENFORCE(op_kernel_info.GetAttr<float>("lambd", &lambd_temp).IsOK());
lambd_ = gsl::narrow_cast<float>(lambd_temp);
}
Status Compute(OpKernelContext* p_op_kernel_context) const override;
private:
float bias_;
float lambd_;
};
} // namespace onnxruntime

View file

@ -312,8 +312,6 @@ int real_main(int argc, char* argv[]) {
{"scatter_without_axis", "opset 9 not supported yet"},
{"scan_sum", "opset 9 not supported yet"},
{"shrink", "opset 9 not supported yet"},
{"shrink_hard", "opset 9 not supported yet"},
{"shrink_soft", "opset 9 not supported yet"},
{"cast_DOUBLE_to_FLOAT16", "Cast opset 9 not supported yet"},
{"cast_DOUBLE_to_FLOAT", "Cast opset 9 not supported yet"},
{"cast_FLOAT_to_DOUBLE", "Cast opset 9 not supported yet"},

View file

@ -0,0 +1,174 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "core/util/math.h"
namespace onnxruntime {
namespace test {
template <typename T>
struct ShrinkTestData {
std::string name;
float bias;
float lambd;
std::vector<T> input_vals;
std::vector<int64_t> input_dimensions;
std::vector<T> expected_vals;
std::vector<int64_t> expected_dimensions;
};
template <typename T>
std::vector<ShrinkTestData<T>> GenerateSignedTestCases() {
std::vector<ShrinkTestData<T>> test_cases;
test_cases.push_back(
{"default attributes",
0.0f,
0.5f,
{-1, 0, 0, 1},
{2, 2},
{-1, 0, 0, 1},
{2, 2}});
test_cases.push_back(
{"non-default attributes",
10.0f,
2.0f,
{-3, -1, 1, 4},
{2, 2},
{7, 0, 0, -6},
{2, 2}});
return test_cases;
}
template <typename T>
std::vector<ShrinkTestData<T>> GenerateUnsignedTestCases() {
std::vector<ShrinkTestData<T>> test_cases;
test_cases.push_back(
{"default attributes",
0.0f,
0.5f,
{0, 0, 0, 1},
{2, 2},
{0, 0, 0, 1},
{2, 2}});
test_cases.push_back(
{"non-default attributes",
10.0f,
2.0f,
{37, 1, 1, 11},
{2, 2},
{27, 0, 0, 1},
{2, 2}});
return test_cases;
}
template <typename T>
void RunShrinkTest(const std::vector<ShrinkTestData<T>>& test_cases) {
for (const auto& test_data : test_cases) {
OpTester test("Shrink", 9);
if (test_data.bias != 0.0f) {
test.AddAttribute("bias", test_data.bias);
}
if (test_data.lambd != 0.5f) {
test.AddAttribute("lambd", test_data.lambd);
}
test.AddInput<T>("X", test_data.input_dimensions, test_data.input_vals);
test.AddOutput<T>("Y", test_data.expected_dimensions, test_data.expected_vals);
test.Run();
}
}
const std::vector<MLFloat16> ConvertFloatToMLFloat16(const std::vector<float>& float_data) {
std::vector<MLFloat16> new_data;
for (const auto& f : float_data) {
new_data.push_back(MLFloat16(math::floatToHalf(f)));
}
return new_data;
}
TEST(MathOpTest, ShrinkInt8Type) {
const auto& test_cases = GenerateSignedTestCases<int8_t>();
RunShrinkTest<int8_t>(test_cases);
}
TEST(MathOpTest, ShrinkUint8Type) {
const auto& test_cases = GenerateUnsignedTestCases<uint8_t>();
RunShrinkTest<uint8_t>(test_cases);
}
TEST(MathOpTest, ShrinkInt16Type) {
const auto& test_cases = GenerateSignedTestCases<int16_t>();
RunShrinkTest<int16_t>(test_cases);
}
TEST(MathOpTest, ShrinkUint16Type) {
const auto& test_cases = GenerateUnsignedTestCases<uint16_t>();
RunShrinkTest<uint16_t>(test_cases);
}
TEST(MathOpTest, ShrinkInt32Type) {
const auto& test_cases = GenerateSignedTestCases<int32_t>();
RunShrinkTest<int32_t>(test_cases);
}
TEST(MathOpTest, ShrinkUint32Type) {
const auto& test_cases = GenerateUnsignedTestCases<uint32_t>();
RunShrinkTest<uint32_t>(test_cases);
}
TEST(MathOpTest, ShrinkInt64Type) {
const auto& test_cases = GenerateSignedTestCases<int64_t>();
RunShrinkTest<int64_t>(test_cases);
}
TEST(MathOpTest, ShrinkUint64Type) {
const auto& test_cases = GenerateUnsignedTestCases<uint64_t>();
RunShrinkTest<uint64_t>(test_cases);
}
TEST(MathOpTest, ShrinkFloatType) {
const auto& test_cases = GenerateSignedTestCases<float>();
RunShrinkTest<float>(test_cases);
}
TEST(MathOpTest, ShrinkDoubleType) {
const auto& test_cases = GenerateSignedTestCases<double>();
RunShrinkTest<double>(test_cases);
}
TEST(MathOpTest, ShrinkMLFloat16Type) {
const std::vector<MLFloat16> input_test_data_default = ConvertFloatToMLFloat16({-1, 0, 0, 1});
const std::vector<MLFloat16> output_test_data_default = ConvertFloatToMLFloat16({-1, 0, 0, 1});
const std::vector<MLFloat16> input_test_data_nondefault = ConvertFloatToMLFloat16({-3, -1, 1, 4});
const std::vector<MLFloat16> output_test_data_nondefault = ConvertFloatToMLFloat16({7, 0, 0, -6});
std::vector<ShrinkTestData<MLFloat16>> test_cases;
test_cases.push_back(
{"default attributes",
0.0f,
0.5f,
input_test_data_default,
{2, 2},
output_test_data_default,
{2, 2}});
test_cases.push_back(
{"non-default attributes",
10.0f,
2.0f,
input_test_data_nondefault,
{2, 2},
output_test_data_nondefault,
{2, 2}});
RunShrinkTest<MLFloat16>(test_cases);
}
} // namespace test
} // namespace onnxruntime

View file

@ -26,8 +26,6 @@ backend_test.exclude(r'('
'|^test_convtranspose_3d_cpu.*'
'|^test_scatter_with_axis_cpu.*'
'|^test_scatter_without_axis_cpu.*'
'|^test_shrink_hard_cpu.*'
'|^test_shrink_soft_cpu.*'
'|^test_AvgPool1d_cpu.*'
'|^test_AvgPool1d_stride_cpu.*'
'|^test_AvgPool2d_cpu.*'