Implement ConstantOfShape (#443)

Implement ConstantOfShape
This commit is contained in:
Dmitri Smirnov 2019-02-06 11:38:22 -08:00 committed by GitHub
parent 4038db14e2
commit c932ab8e99
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 428 additions and 5 deletions

View file

@ -229,6 +229,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Loo
// Opset 9
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Compress);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantOfShape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MeanVarianceNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less);
@ -474,6 +475,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
// Opset 9
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Compress)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantOfShape)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MeanVarianceNormalization)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less)>());

View file

@ -0,0 +1,188 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/tensorutils.h"
#include "core/providers/cpu/generator/constant_of_shape.h"
#include "gsl/span"
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
ONNX_CPU_OPERATOR_KERNEL(
ConstantOfShape,
9,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T2", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int8_t>(),
DataTypeImpl::GetTensorType<int16_t>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<uint16_t>(),
DataTypeImpl::GetTensorType<uint32_t>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<bool>()}),
ConstantOfShape);
#define FETCH_VALUE_DATA(field, c_type) \
{ \
c_type t; \
ORT_ENFORCE(TensorUtils::UnpackTensor(t_proto, &t, 1).IsOK(), "Value attribute unpacking failed"); \
field = t; \
}
void onnxruntime::ConstantOfShape::SetValue(const ONNX_NAMESPACE::TensorProto& t_proto) {
using namespace utils;
ORT_ENFORCE(t_proto.has_data_type());
ORT_ENFORCE(TensorProto::DataType_IsValid(t_proto.data_type()));
tensor_type_ = static_cast<TensorProto_DataType>(t_proto.data_type());
switch (tensor_type_) {
case TensorProto::BOOL:
FETCH_VALUE_DATA(value_.ui64_, bool);
break;
case TensorProto::FLOAT:
FETCH_VALUE_DATA(value_.fl_, float);
break;
case TensorProto::FLOAT16:
FETCH_VALUE_DATA(value_.fl16_, MLFloat16);
break;
case TensorProto::DOUBLE:
FETCH_VALUE_DATA(value_.dbl_, double);
break;
case TensorProto::INT8:
FETCH_VALUE_DATA(value_.i64_, int8_t);
break;
case TensorProto::INT16:
FETCH_VALUE_DATA(value_.i64_, int16_t);
break;
case TensorProto::INT32:
FETCH_VALUE_DATA(value_.i64_, int32_t);
break;
case TensorProto::INT64:
FETCH_VALUE_DATA(value_.i64_, int64_t);
break;
case TensorProto::UINT8:
FETCH_VALUE_DATA(value_.ui64_, uint8_t);
break;
case TensorProto::UINT16:
FETCH_VALUE_DATA(value_.ui64_, uint16_t);
break;
case TensorProto::UINT32:
FETCH_VALUE_DATA(value_.ui64_, uint32_t);
break;
case TensorProto::UINT64:
FETCH_VALUE_DATA(value_.ui64_, uint64_t);
break;
default:
ORT_THROW("Unsupported value attribute datatype: ", TensorProto::DataType_Name(tensor_type_));
break;
}
}
#undef FETCH_VALUE_DATA
template <class T>
inline T onnxruntime::ConstantOfShape::Value::GetFromSigned() const {
return static_cast<T>(i64_);
}
template <class T>
inline T onnxruntime::ConstantOfShape::Value::GetFromUnsigned() const {
return static_cast<T>(ui64_);
}
template <class T>
inline void FilloutOutput(T value, Tensor* output_tensor) {
auto out = gsl::make_span(output_tensor->template MutableData<T>(), output_tensor->Shape().Size());
std::fill(out.begin(), out.end(), value);
}
void onnxruntime::ConstantOfShape::DispatchTypeAndFillOutput(Tensor* output_tensor) const {
switch (tensor_type_) {
case TensorProto::BOOL:
FilloutOutput(value_.GetFromUnsigned<bool>(), output_tensor);
break;
case TensorProto::FLOAT:
FilloutOutput(value_.GetFloat(), output_tensor);
break;
case TensorProto::FLOAT16:
FilloutOutput(value_.GetFloat16(), output_tensor);
break;
case TensorProto::DOUBLE:
FilloutOutput(value_.GetDouble(), output_tensor);
break;
case TensorProto::INT8:
FilloutOutput(value_.GetFromSigned<int8_t>(), output_tensor);
break;
case TensorProto::INT16:
FilloutOutput(value_.GetFromSigned<int16_t>(), output_tensor);
break;
case TensorProto::INT32:
FilloutOutput(value_.GetFromSigned<int32_t>(), output_tensor);
break;
case TensorProto::INT64:
FilloutOutput(value_.GetFromSigned<int64_t>(), output_tensor);
break;
case TensorProto::UINT8:
FilloutOutput(value_.GetFromUnsigned<uint8_t>(), output_tensor);
break;
case TensorProto::UINT16:
FilloutOutput(value_.GetFromUnsigned<uint16_t>(), output_tensor);
break;
case TensorProto::UINT32:
FilloutOutput(value_.GetFromUnsigned<uint32_t>(), output_tensor);
break;
case TensorProto::UINT64:
FilloutOutput(value_.GetFromUnsigned<uint64_t>(), output_tensor);
break;
default:
ORT_THROW("Unsupported value attribute datatype: ", TensorProto::DataType_Name(tensor_type_));
break;
}
}
ConstantOfShape::ConstantOfShape(const OpKernelInfo& info) : OpKernel(info) {
TensorProto t_proto;
if (info.GetAttr<TensorProto>("value", &t_proto).IsOK()) {
ORT_ENFORCE(t_proto.dims_size() == 1, "Must have a single dimension");
ORT_ENFORCE(t_proto.dims()[0] == 1, "Must have a single dimension of 1");
SetValue(t_proto);
} else {
tensor_type_ = TensorProto::FLOAT;
value_.fl_ = 0.f;
}
}
Status ConstantOfShape::Compute(OpKernelContext* ctx) const {
auto shape_tensor = ctx->Input<Tensor>(0);
if (shape_tensor->DataType() != DataTypeImpl::GetType<int64_t>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input tensor expected to contain int64 data");
}
auto& input_shape = shape_tensor->Shape();
// If empty the output is a scalar, which means a single value
std::vector<int64_t> output_dims;
const auto input_size = input_shape.Size();
if (input_size > 0) {
auto span = gsl::make_span(shape_tensor->Data<int64_t>(), input_size);
output_dims.insert(output_dims.end(), span.cbegin(), span.cend());
} else if (input_size == 0) {
output_dims.push_back(1); // scalar
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input tensor dimensions is expected to be either 1-D or empty");
}
TensorShape output_shape(output_dims);
auto output_tensor = ctx->Output(0, output_shape);
DispatchTypeAndFillOutput(output_tensor);
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/data_types.h"
#include "core/framework/op_kernel.h"
namespace onnxruntime {
class ConstantOfShape final : public OpKernel {
public:
explicit ConstantOfShape(const OpKernelInfo& info);
Status Compute(OpKernelContext* ctx) const override;
private:
ONNX_NAMESPACE::TensorProto_DataType tensor_type_;
union Value {
float fl_;
MLFloat16 fl16_;
double dbl_;
int64_t i64_;
uint64_t ui64_;
Value() : ui64_(0) {}
float GetFloat() const {
return fl_;
}
MLFloat16 GetFloat16() const {
return fl16_;
}
double GetDouble() const {
return dbl_;
}
template <class T>
T GetFromSigned() const;
template <class T>
T GetFromUnsigned() const;
} value_;
void SetValue(const ONNX_NAMESPACE::TensorProto&);
void DispatchTypeAndFillOutput(Tensor* output_tensor) const;
};
} // namespace onnxruntime

View file

@ -271,7 +271,7 @@ int real_main(int argc, char* argv[]) {
{"PoissonNLLLLoss_no_reduce", "disable reason"},
{"Softsign", "disable reason"},
{"convtranspose_1d", "disable reason"},
{"convtranspose_3d", "disable reason"},
{"convtranspose_3d", "disable reason"},
{"flatten_axis0", "disable reason"},
{"flatten_axis1", "disable reason"},
{"flatten_axis2", "disable reason"},
@ -314,11 +314,9 @@ int real_main(int argc, char* argv[]) {
{"scatter_without_axis", "opset 9 not supported yet"},
{"scan_sum", "opset 9 not supported yet"},
{"shrink", "opset 9 not supported yet"},
{"constantofshape_int_zeros", "opset 9 not supported yet"},
{"shrink_hard", "opset 9 not supported yet"},
{"shrink_soft", "opset 9 not supported yet"},
{"where_example", "opset 9 not supported yet"},
{"constantofshape_float_ones", "opset 9 not supported yet"},
{"cast_DOUBLE_to_FLOAT16", "Cast opset 9 not supported yet"},
{"cast_DOUBLE_to_FLOAT", "Cast opset 9 not supported yet"},
{"cast_FLOAT_to_DOUBLE", "Cast opset 9 not supported yet"},

View file

@ -0,0 +1,184 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include <type_traits>
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace test {
TEST(ConstantOfShape, Float_Ones) {
OpTester test("ConstantOfShape", 9);
TensorProto t_proto;
t_proto.set_data_type(TensorProto::FLOAT);
t_proto.mutable_dims()->Add(1);
t_proto.mutable_float_data()->Add(1.f);
test.AddAttribute("value", t_proto);
// We will input 1-D Tensor that will store 3 dimensions
// and will provide shape for the output
std::vector<int64_t> input_dims{3};
std::vector<int64_t> input{4, 3, 2};
test.AddInput<int64_t>("input", input_dims, input);
std::vector<int64_t> output_dims(input);
std::vector<float> output;
output.resize(4 * 3 * 2);
std::fill_n(output.begin(), 4 * 3 * 2, 1.f);
test.AddOutput<float>("output", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
TEST(ConstantOfShape, Int32_Zeros) {
OpTester test("ConstantOfShape", 9);
TensorProto t_proto;
t_proto.set_data_type(TensorProto::INT32);
t_proto.mutable_dims()->Add(1);
t_proto.mutable_int32_data()->Add(0);
test.AddAttribute("value", t_proto);
std::vector<int64_t> input_dims{2};
std::vector<int64_t> input{10, 6};
test.AddInput<int64_t>("input", input_dims, input);
std::vector<int64_t> output_dims(input);
std::vector<int32_t> output;
output.resize(10 * 6);
std::fill_n(output.begin(), output.size(), 0);
test.AddOutput<int32_t>("output", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
TEST(ConstantOfShape, DefaultValue) {
OpTester test("ConstantOfShape", 9);
// By default the output will be FLOAT zeros
std::vector<int64_t> input_dims{2};
std::vector<int64_t> input{2, 6};
test.AddInput<int64_t>("input", input_dims, input);
std::vector<int64_t> output_dims(input);
std::vector<float> output;
output.resize(2 * 6);
std::fill_n(output.begin(), output.size(), 0.f);
test.AddOutput<float>("output", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
// Our infrastructure does not allow for empty input.
// But the spec makes a provision for it
//TEST(ConstantOfShape, EmptyInput) {
// // Output must contain a scalar
// OpTester test("ConstantOfShape", 9);
//
// TensorProto t_proto;
// t_proto.set_data_type(TensorProto::INT32);
// t_proto.mutable_dims()->Add(1);
// t_proto.mutable_int32_data()->Add(0);
// test.AddAttribute("value", t_proto);
//
// std::vector<int64_t> input_dims;
// std::vector<int64_t> input;
// test.AddInput<int64_t>("input", input_dims, input);
//
// std::vector<int64_t> output_dims{1};
// std::vector<float> output{0};
// test.AddOutput<float>("output", output_dims, output);
//
// test.Run(OpTester::ExpectResult::kExpectSuccess);
//}
inline void SetValue(TensorProto& t_proto, float value) {
t_proto.mutable_float_data()->Add(value);
}
inline void SetValue(TensorProto& t_proto, double value) {
t_proto.mutable_double_data()->Add(value);
}
inline void SetValue(TensorProto& t_proto, MLFloat16 value) {
t_proto.mutable_int32_data()->Add(value.val);
}
// This works for int64_t
template <class T>
inline void SetValue(TensorProto& t_proto, T value,
typename std::enable_if<std::is_same<T, int64_t>::value>::type* = nullptr) {
t_proto.mutable_int64_data()->Add(value);
}
// For uint32 and uint64
template <class T>
inline void SetValue(TensorProto& t_proto, T value,
typename std::enable_if<std::is_same<T, uint64_t>::value ||
std::is_same<T, uint32_t>::value>::type* = nullptr) {
t_proto.mutable_uint64_data()->Add(value);
}
// For everything else except float, double and MLFloat16
template <class T>
inline void SetValue(TensorProto& t_proto, T value,
typename std::enable_if<!std::is_same<T, int64_t>::value &&
!std::is_same<T, uint32_t>::value &&
!std::is_same<T, uint64_t>::value>::type* = nullptr) {
t_proto.mutable_int32_data()->Add(value);
}
template <class T>
void RunTypedTest(TensorProto::DataType dt, T value) {
OpTester test("ConstantOfShape", 9);
TensorProto t_proto;
t_proto.set_data_type(dt);
t_proto.mutable_dims()->Add(1);
SetValue(t_proto, value);
test.AddAttribute("value", t_proto);
// By default the output will be FLOAT zeros
std::vector<int64_t> input_dims{2};
std::vector<int64_t> input{2, 6};
test.AddInput<int64_t>("input", input_dims, input);
std::vector<int64_t> output_dims(input);
std::vector<T> output;
output.resize(2 * 6);
std::fill_n(output.begin(), output.size(), value);
test.AddOutput<T>("output", output_dims, output);
test.Run(OpTester::ExpectResult::kExpectSuccess);
}
TEST(ConstantOfShape, TypeTests) {
// bool can not be tested due to a shortcoming of
// our test infrastructure which makes use of
// std::vector<T> which has a specialization for bool
// and does not have a continuous buffer implementation
//RunTypedTest(TensorProto::BOOL, true);
// The following two types even though supported by the
// operator cause a failure at
// onnx\onnx\checker.cc tensor_checker() where these
// two types are not listed among those that a tensor may
// contain
//RunTypedTest(TensorProto::INT8, int8_t(8));
//RunTypedTest(TensorProto::INT16, int16_t(16));
RunTypedTest(TensorProto::FLOAT, 1.f);
RunTypedTest(TensorProto::FLOAT16, MLFloat16(5));
RunTypedTest(TensorProto::DOUBLE, 1.0);
RunTypedTest(TensorProto::INT32, int32_t(32));
RunTypedTest(TensorProto::INT64, int64_t(64));
RunTypedTest(TensorProto::UINT8, uint8_t(8U));
RunTypedTest(TensorProto::UINT16, uint16_t(6U));
RunTypedTest(TensorProto::UINT32, uint32_t(32U));
RunTypedTest(TensorProto::UINT64, uint64_t(64U));
}
} // namespace test
} // namespace onnxruntime

View file

@ -22,8 +22,6 @@ backend_test.exclude(r'('
'|^test_cast_FLOAT_to_DOUBLE_cpu.*'
'|^test_cast_FLOAT_to_STRING_cpu.*'
'|^test_cast_STRING_to_FLOAT_cpu.*'
'|^test_constantofshape_float_ones_cpu.*'
'|^test_constantofshape_int_zeros_cpu.*'
'|^test_convtranspose_1d_cpu.*'
'|^test_convtranspose_3d_cpu.*'
'|^test_eyelike_populate_off_main_diagonal_cpu.*'