From c932ab8e99ff7db7a16801bd12d97bc1b5d0191b Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 6 Feb 2019 11:38:22 -0800 Subject: [PATCH] Implement ConstantOfShape (#443) Implement ConstantOfShape --- .../providers/cpu/cpu_execution_provider.cc | 2 + .../cpu/generator/constant_of_shape.cc | 188 ++++++++++++++++++ .../cpu/generator/constant_of_shape.h | 53 +++++ onnxruntime/test/onnx/main.cc | 4 +- .../cpu/generator/constant_of_shape_test.cc | 184 +++++++++++++++++ .../test/python/onnx_backend_test_series.py | 2 - 6 files changed, 428 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/core/providers/cpu/generator/constant_of_shape.cc create mode 100644 onnxruntime/core/providers/cpu/generator/constant_of_shape.h create mode 100644 onnxruntime/test/providers/cpu/generator/constant_of_shape_test.cc diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index b8fc7ce270..de896fa7b1 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -229,6 +229,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Loo // Opset 9 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Compress); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantOfShape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MeanVarianceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, Less); @@ -474,6 +475,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { // Opset 9 kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc b/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc new file mode 100644 index 0000000000..84bdefa5b0 --- /dev/null +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/tensorutils.h" +#include "core/providers/cpu/generator/constant_of_shape.h" +#include "gsl/span" + +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +namespace onnxruntime { + +ONNX_CPU_OPERATOR_KERNEL( + ConstantOfShape, + 9, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + ConstantOfShape); + +#define FETCH_VALUE_DATA(field, c_type) \ + { \ + c_type t; \ + ORT_ENFORCE(TensorUtils::UnpackTensor(t_proto, &t, 1).IsOK(), "Value attribute unpacking failed"); \ + field = t; \ + } + +void onnxruntime::ConstantOfShape::SetValue(const ONNX_NAMESPACE::TensorProto& t_proto) { + using namespace utils; + ORT_ENFORCE(t_proto.has_data_type()); + ORT_ENFORCE(TensorProto::DataType_IsValid(t_proto.data_type())); + tensor_type_ = static_cast(t_proto.data_type()); + + switch (tensor_type_) { + case TensorProto::BOOL: + FETCH_VALUE_DATA(value_.ui64_, bool); + break; + case TensorProto::FLOAT: + FETCH_VALUE_DATA(value_.fl_, float); + break; + case TensorProto::FLOAT16: + FETCH_VALUE_DATA(value_.fl16_, MLFloat16); + break; + case TensorProto::DOUBLE: + FETCH_VALUE_DATA(value_.dbl_, double); + break; + case TensorProto::INT8: + FETCH_VALUE_DATA(value_.i64_, int8_t); + break; + case TensorProto::INT16: + FETCH_VALUE_DATA(value_.i64_, int16_t); + break; + case TensorProto::INT32: + FETCH_VALUE_DATA(value_.i64_, int32_t); + break; + case TensorProto::INT64: + FETCH_VALUE_DATA(value_.i64_, int64_t); + break; + case TensorProto::UINT8: + FETCH_VALUE_DATA(value_.ui64_, uint8_t); + break; + case TensorProto::UINT16: + FETCH_VALUE_DATA(value_.ui64_, uint16_t); + break; + case TensorProto::UINT32: + FETCH_VALUE_DATA(value_.ui64_, uint32_t); + break; + case TensorProto::UINT64: + FETCH_VALUE_DATA(value_.ui64_, uint64_t); + break; + default: + ORT_THROW("Unsupported value attribute datatype: ", TensorProto::DataType_Name(tensor_type_)); + break; + } +} + +#undef FETCH_VALUE_DATA + +template +inline T onnxruntime::ConstantOfShape::Value::GetFromSigned() const { + return static_cast(i64_); +} + +template +inline T onnxruntime::ConstantOfShape::Value::GetFromUnsigned() const { + return static_cast(ui64_); +} + +template +inline void FilloutOutput(T value, Tensor* output_tensor) { + auto out = gsl::make_span(output_tensor->template MutableData(), output_tensor->Shape().Size()); + std::fill(out.begin(), out.end(), value); +} + +void onnxruntime::ConstantOfShape::DispatchTypeAndFillOutput(Tensor* output_tensor) const { + switch (tensor_type_) { + case TensorProto::BOOL: + FilloutOutput(value_.GetFromUnsigned(), output_tensor); + break; + case TensorProto::FLOAT: + FilloutOutput(value_.GetFloat(), output_tensor); + break; + case TensorProto::FLOAT16: + FilloutOutput(value_.GetFloat16(), output_tensor); + break; + case TensorProto::DOUBLE: + FilloutOutput(value_.GetDouble(), output_tensor); + break; + case TensorProto::INT8: + FilloutOutput(value_.GetFromSigned(), output_tensor); + break; + case TensorProto::INT16: + FilloutOutput(value_.GetFromSigned(), output_tensor); + break; + case TensorProto::INT32: + FilloutOutput(value_.GetFromSigned(), output_tensor); + break; + case TensorProto::INT64: + FilloutOutput(value_.GetFromSigned(), output_tensor); + break; + case TensorProto::UINT8: + FilloutOutput(value_.GetFromUnsigned(), output_tensor); + break; + case TensorProto::UINT16: + FilloutOutput(value_.GetFromUnsigned(), output_tensor); + break; + case TensorProto::UINT32: + FilloutOutput(value_.GetFromUnsigned(), output_tensor); + break; + case TensorProto::UINT64: + FilloutOutput(value_.GetFromUnsigned(), output_tensor); + break; + default: + ORT_THROW("Unsupported value attribute datatype: ", TensorProto::DataType_Name(tensor_type_)); + break; + } +} + +ConstantOfShape::ConstantOfShape(const OpKernelInfo& info) : OpKernel(info) { + TensorProto t_proto; + if (info.GetAttr("value", &t_proto).IsOK()) { + ORT_ENFORCE(t_proto.dims_size() == 1, "Must have a single dimension"); + ORT_ENFORCE(t_proto.dims()[0] == 1, "Must have a single dimension of 1"); + SetValue(t_proto); + } else { + tensor_type_ = TensorProto::FLOAT; + value_.fl_ = 0.f; + } +} + +Status ConstantOfShape::Compute(OpKernelContext* ctx) const { + auto shape_tensor = ctx->Input(0); + + if (shape_tensor->DataType() != DataTypeImpl::GetType()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input tensor expected to contain int64 data"); + } + + auto& input_shape = shape_tensor->Shape(); + + // If empty the output is a scalar, which means a single value + std::vector output_dims; + const auto input_size = input_shape.Size(); + if (input_size > 0) { + auto span = gsl::make_span(shape_tensor->Data(), input_size); + output_dims.insert(output_dims.end(), span.cbegin(), span.cend()); + } else if (input_size == 0) { + output_dims.push_back(1); // scalar + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input tensor dimensions is expected to be either 1-D or empty"); + } + + TensorShape output_shape(output_dims); + auto output_tensor = ctx->Output(0, output_shape); + DispatchTypeAndFillOutput(output_tensor); + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape.h b/onnxruntime/core/providers/cpu/generator/constant_of_shape.h new file mode 100644 index 0000000000..217e94c8d7 --- /dev/null +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/data_types.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +class ConstantOfShape final : public OpKernel { + public: + explicit ConstantOfShape(const OpKernelInfo& info); + + Status Compute(OpKernelContext* ctx) const override; + + private: + ONNX_NAMESPACE::TensorProto_DataType tensor_type_; + union Value { + float fl_; + MLFloat16 fl16_; + double dbl_; + int64_t i64_; + uint64_t ui64_; + Value() : ui64_(0) {} + + float GetFloat() const { + return fl_; + } + + MLFloat16 GetFloat16() const { + return fl16_; + } + + double GetDouble() const { + return dbl_; + } + + template + T GetFromSigned() const; + + template + T GetFromUnsigned() const; + + } value_; + + void SetValue(const ONNX_NAMESPACE::TensorProto&); + + void DispatchTypeAndFillOutput(Tensor* output_tensor) const; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 79a8cd03e0..58000ee52f 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -271,7 +271,7 @@ int real_main(int argc, char* argv[]) { {"PoissonNLLLLoss_no_reduce", "disable reason"}, {"Softsign", "disable reason"}, {"convtranspose_1d", "disable reason"}, - {"convtranspose_3d", "disable reason"}, + {"convtranspose_3d", "disable reason"}, {"flatten_axis0", "disable reason"}, {"flatten_axis1", "disable reason"}, {"flatten_axis2", "disable reason"}, @@ -314,11 +314,9 @@ int real_main(int argc, char* argv[]) { {"scatter_without_axis", "opset 9 not supported yet"}, {"scan_sum", "opset 9 not supported yet"}, {"shrink", "opset 9 not supported yet"}, - {"constantofshape_int_zeros", "opset 9 not supported yet"}, {"shrink_hard", "opset 9 not supported yet"}, {"shrink_soft", "opset 9 not supported yet"}, {"where_example", "opset 9 not supported yet"}, - {"constantofshape_float_ones", "opset 9 not supported yet"}, {"cast_DOUBLE_to_FLOAT16", "Cast opset 9 not supported yet"}, {"cast_DOUBLE_to_FLOAT", "Cast opset 9 not supported yet"}, {"cast_FLOAT_to_DOUBLE", "Cast opset 9 not supported yet"}, diff --git a/onnxruntime/test/providers/cpu/generator/constant_of_shape_test.cc b/onnxruntime/test/providers/cpu/generator/constant_of_shape_test.cc new file mode 100644 index 0000000000..fee316f0e6 --- /dev/null +++ b/onnxruntime/test/providers/cpu/generator/constant_of_shape_test.cc @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include + +using namespace ONNX_NAMESPACE; +namespace onnxruntime { +namespace test { +TEST(ConstantOfShape, Float_Ones) { + OpTester test("ConstantOfShape", 9); + + TensorProto t_proto; + t_proto.set_data_type(TensorProto::FLOAT); + t_proto.mutable_dims()->Add(1); + t_proto.mutable_float_data()->Add(1.f); + test.AddAttribute("value", t_proto); + + // We will input 1-D Tensor that will store 3 dimensions + // and will provide shape for the output + std::vector input_dims{3}; + std::vector input{4, 3, 2}; + test.AddInput("input", input_dims, input); + + std::vector output_dims(input); + std::vector output; + output.resize(4 * 3 * 2); + std::fill_n(output.begin(), 4 * 3 * 2, 1.f); + + test.AddOutput("output", output_dims, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ConstantOfShape, Int32_Zeros) { + OpTester test("ConstantOfShape", 9); + + TensorProto t_proto; + t_proto.set_data_type(TensorProto::INT32); + t_proto.mutable_dims()->Add(1); + t_proto.mutable_int32_data()->Add(0); + test.AddAttribute("value", t_proto); + + std::vector input_dims{2}; + std::vector input{10, 6}; + test.AddInput("input", input_dims, input); + + std::vector output_dims(input); + std::vector output; + output.resize(10 * 6); + std::fill_n(output.begin(), output.size(), 0); + test.AddOutput("output", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ConstantOfShape, DefaultValue) { + OpTester test("ConstantOfShape", 9); + + // By default the output will be FLOAT zeros + std::vector input_dims{2}; + std::vector input{2, 6}; + test.AddInput("input", input_dims, input); + + std::vector output_dims(input); + std::vector output; + output.resize(2 * 6); + std::fill_n(output.begin(), output.size(), 0.f); + test.AddOutput("output", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +// Our infrastructure does not allow for empty input. +// But the spec makes a provision for it +//TEST(ConstantOfShape, EmptyInput) { +// // Output must contain a scalar +// OpTester test("ConstantOfShape", 9); +// +// TensorProto t_proto; +// t_proto.set_data_type(TensorProto::INT32); +// t_proto.mutable_dims()->Add(1); +// t_proto.mutable_int32_data()->Add(0); +// test.AddAttribute("value", t_proto); +// +// std::vector input_dims; +// std::vector input; +// test.AddInput("input", input_dims, input); +// +// std::vector output_dims{1}; +// std::vector output{0}; +// test.AddOutput("output", output_dims, output); +// +// test.Run(OpTester::ExpectResult::kExpectSuccess); +//} + +inline void SetValue(TensorProto& t_proto, float value) { + t_proto.mutable_float_data()->Add(value); +} + +inline void SetValue(TensorProto& t_proto, double value) { + t_proto.mutable_double_data()->Add(value); +} + +inline void SetValue(TensorProto& t_proto, MLFloat16 value) { + t_proto.mutable_int32_data()->Add(value.val); +} + +// This works for int64_t +template +inline void SetValue(TensorProto& t_proto, T value, + typename std::enable_if::value>::type* = nullptr) { + t_proto.mutable_int64_data()->Add(value); +} + +// For uint32 and uint64 +template +inline void SetValue(TensorProto& t_proto, T value, + typename std::enable_if::value || + std::is_same::value>::type* = nullptr) { + t_proto.mutable_uint64_data()->Add(value); +} + +// For everything else except float, double and MLFloat16 +template +inline void SetValue(TensorProto& t_proto, T value, + typename std::enable_if::value && + !std::is_same::value && + !std::is_same::value>::type* = nullptr) { + t_proto.mutable_int32_data()->Add(value); +} + +template +void RunTypedTest(TensorProto::DataType dt, T value) { + OpTester test("ConstantOfShape", 9); + + TensorProto t_proto; + t_proto.set_data_type(dt); + t_proto.mutable_dims()->Add(1); + SetValue(t_proto, value); + test.AddAttribute("value", t_proto); + + // By default the output will be FLOAT zeros + std::vector input_dims{2}; + std::vector input{2, 6}; + test.AddInput("input", input_dims, input); + + std::vector output_dims(input); + std::vector output; + output.resize(2 * 6); + std::fill_n(output.begin(), output.size(), value); + test.AddOutput("output", output_dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ConstantOfShape, TypeTests) { + // bool can not be tested due to a shortcoming of + // our test infrastructure which makes use of + // std::vector which has a specialization for bool + // and does not have a continuous buffer implementation + //RunTypedTest(TensorProto::BOOL, true); + + // The following two types even though supported by the + // operator cause a failure at + // onnx\onnx\checker.cc tensor_checker() where these + // two types are not listed among those that a tensor may + // contain + //RunTypedTest(TensorProto::INT8, int8_t(8)); + //RunTypedTest(TensorProto::INT16, int16_t(16)); + + RunTypedTest(TensorProto::FLOAT, 1.f); + RunTypedTest(TensorProto::FLOAT16, MLFloat16(5)); + RunTypedTest(TensorProto::DOUBLE, 1.0); + RunTypedTest(TensorProto::INT32, int32_t(32)); + RunTypedTest(TensorProto::INT64, int64_t(64)); + RunTypedTest(TensorProto::UINT8, uint8_t(8U)); + RunTypedTest(TensorProto::UINT16, uint16_t(6U)); + RunTypedTest(TensorProto::UINT32, uint32_t(32U)); + RunTypedTest(TensorProto::UINT64, uint64_t(64U)); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 739854a696..0c0e4523a4 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -22,8 +22,6 @@ backend_test.exclude(r'(' '|^test_cast_FLOAT_to_DOUBLE_cpu.*' '|^test_cast_FLOAT_to_STRING_cpu.*' '|^test_cast_STRING_to_FLOAT_cpu.*' -'|^test_constantofshape_float_ones_cpu.*' -'|^test_constantofshape_int_zeros_cpu.*' '|^test_convtranspose_1d_cpu.*' '|^test_convtranspose_3d_cpu.*' '|^test_eyelike_populate_off_main_diagonal_cpu.*'