From e5107fd0cbf2abaf1856d1ae38cc13b3b4ae4eda Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Fri, 19 Jul 2019 15:32:54 -0700 Subject: [PATCH] Support MultiD input data for OneHotEncoder op (#1343) * Support MultiD input data for OneHotEncoder op * Fix some nits --- .../core/providers/cpu/ml/onehotencoder.cc | 19 ++++++------ .../providers/cpu/ml/onehotencoder_test.cc | 31 +++++++++++++++++-- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/cpu/ml/onehotencoder.cc b/onnxruntime/core/providers/cpu/ml/onehotencoder.cc index 1392e66072..4030bb77a8 100644 --- a/onnxruntime/core/providers/cpu/ml/onehotencoder.cc +++ b/onnxruntime/core/providers/cpu/ml/onehotencoder.cc @@ -52,7 +52,8 @@ template OneHotEncoderOp::OneHotEncoderOp(const OpKernelInfo& info) : OpKernel(info), zeros_(info.GetAttrOrDefault("zeros", 1)), num_categories_(0) { std::vector tmp_cats_int64s = info.GetAttrsOrDefault("cats_int64s"); std::vector tmp_cats_strings = info.GetAttrsOrDefault("cats_strings"); - ORT_ENFORCE(tmp_cats_int64s.empty() || tmp_cats_strings.empty()); + ORT_ENFORCE(tmp_cats_int64s.empty() || tmp_cats_strings.empty(), + "One and only one of the 'cats_*' attributes must be defined"); if (!tmp_cats_int64s.empty()) { num_categories_ = tmp_cats_int64s.size(); for (size_t idx = 0, end = tmp_cats_int64s.size(); idx < end; ++idx) { @@ -71,18 +72,18 @@ template common::Status OneHotEncoderOp::Compute(OpKernelContext* context) const { const auto* X = context->Input(0); const TensorShape& input_shape = X->Shape(); - ORT_ENFORCE(input_shape.NumDimensions() <= 2); std::vector output_shape(input_shape.GetDims()); output_shape.push_back(num_categories_); Tensor* Y = context->Output(0, TensorShape(output_shape)); - auto y_data = Y->template MutableData(); + auto* y_data = Y->template MutableData(); std::fill_n(y_data, Y->Shape().Size(), 0.0f); - auto x_data = X->template Data(); + const auto* x_data = X->template Data(); + const auto x_size = input_shape.Size(); std::unordered_map::const_iterator idx; - for (int64_t i = 0; i < input_shape.Size(); ++i) { + for (int64_t i = 0; i < x_size; ++i) { auto int_idx = cats_int64s_.find(static_cast(x_data[i])); if (int_idx != cats_int64s_.cend()) y_data[i * num_categories_ + int_idx->second] = 1.0f; @@ -96,17 +97,17 @@ template <> common::Status OneHotEncoderOp::Compute(OpKernelContext* context) const { const auto* X = context->Input(0); const TensorShape& input_shape = X->Shape(); - ORT_ENFORCE(input_shape.NumDimensions() <= 2); std::vector output_shape(input_shape.GetDims()); output_shape.push_back(num_categories_); Tensor* Y = context->Output(0, TensorShape(output_shape)); - auto y_data = Y->template MutableData(); + auto* y_data = Y->template MutableData(); std::fill_n(y_data, Y->Shape().Size(), 0.0f); - auto x_data = X->template Data(); - for (int64_t i = 0; i < input_shape.Size(); ++i) { + const auto* x_data = X->template Data(); + const auto x_size = input_shape.Size(); + for (int64_t i = 0; i < x_size; ++i) { auto str_idx = cats_strings_.find(x_data[i]); if (str_idx != cats_strings_.cend()) y_data[i * num_categories_ + str_idx->second] = 1.0f; diff --git a/onnxruntime/test/providers/cpu/ml/onehotencoder_test.cc b/onnxruntime/test/providers/cpu/ml/onehotencoder_test.cc index 05af2af22c..a4ce3f3f1b 100644 --- a/onnxruntime/test/providers/cpu/ml/onehotencoder_test.cc +++ b/onnxruntime/test/providers/cpu/ml/onehotencoder_test.cc @@ -41,6 +41,18 @@ void TestIntCategory(std::vector& input) { test_vector.AddAttribute("zeros", int64_t{0}); test_vector.Run(OpTester::ExpectResult::kExpectFailure); + + // Test MultiDimensional [:, :, Labels] + OpTester test_multiD("OneHotEncoder", 1, onnxruntime::kMLDomain); + test_multiD.AddAttribute("cats_int64s", categories); + test_multiD.AddInput("X", {1, 1, 7}, input); + test_multiD.AddOutput("Y", {1, 1, 7, 8}, expected_output); + + test_multiD.AddAttribute("zeros", int64_t{1}); + test_multiD.Run(); + + test_multiD.AddAttribute("zeros", int64_t{0}); + test_multiD.Run(OpTester::ExpectResult::kExpectFailure); } TEST(OneHotEncoderOpTest, IntegerWithInt64) { @@ -49,17 +61,18 @@ TEST(OneHotEncoderOpTest, IntegerWithInt64) { } /* +// TODO: Support int32_t type kernel for the op and uncomment the test TEST(OneHotEncoderOpTest, IntegerWithInt32) { - vector input{ 8, 1, 0, 0, 3, 7, 4 }; - TestIntCategory(input); + vector input{ 8, 1, 0, 0, 3, 7, 4 }; + TestIntCategory(input); } +*/ TEST(OneHotEncoderOpTest, IntegerWithDouble) { vector input{ 8.1f, 1.2f, 0.0f, 0.7f, 3.4f, 7.9f, 4.4f }; TestIntCategory(input); } -*/ TEST(OneHotEncoderOpTest, String) { std::vector categories{"Apple", "Orange", "Watermelon", "Blueberry", "Coconut", "Mango", "Tangerine"}; vector input{"Watermelon", "Orange", "Tangerine", "Apple", "Kit"}; @@ -95,6 +108,18 @@ TEST(OneHotEncoderOpTest, String) { test_vector.AddAttribute("zeros", int64_t{0}); test_vector.Run(OpTester::ExpectResult::kExpectFailure); + + // Test MultiDimensional [:, Labels, :] + OpTester test_multiD("OneHotEncoder", 1, onnxruntime::kMLDomain); + test_multiD.AddAttribute("cats_strings", categories); + test_multiD.AddInput("X", {1, 5, 1}, input); + test_multiD.AddOutput("Y", {1, 5, 1, 7}, expected_output); + + test_multiD.AddAttribute("zeros", int64_t{1}); + test_multiD.Run(); + + test_multiD.AddAttribute("zeros", int64_t{0}); + test_multiD.Run(OpTester::ExpectResult::kExpectFailure); } } // namespace test