Support MultiD input data for OneHotEncoder op (#1343)

* Support MultiD input data for OneHotEncoder op

* Fix some nits
This commit is contained in:
Hariharan Seshadri 2019-07-19 15:32:54 -07:00 committed by Ke Zhang
parent 751ee7bb23
commit e5107fd0cb
2 changed files with 38 additions and 12 deletions

View file

@ -52,7 +52,8 @@ template <typename T>
OneHotEncoderOp<T>::OneHotEncoderOp(const OpKernelInfo& info) : OpKernel(info), zeros_(info.GetAttrOrDefault<int64_t>("zeros", 1)), num_categories_(0) {
std::vector<int64_t> tmp_cats_int64s = info.GetAttrsOrDefault<int64_t>("cats_int64s");
std::vector<std::string> tmp_cats_strings = info.GetAttrsOrDefault<string>("cats_strings");
ORT_ENFORCE(tmp_cats_int64s.empty() || tmp_cats_strings.empty());
ORT_ENFORCE(tmp_cats_int64s.empty() || tmp_cats_strings.empty(),
"One and only one of the 'cats_*' attributes must be defined");
if (!tmp_cats_int64s.empty()) {
num_categories_ = tmp_cats_int64s.size();
for (size_t idx = 0, end = tmp_cats_int64s.size(); idx < end; ++idx) {
@ -71,18 +72,18 @@ template <typename T>
common::Status OneHotEncoderOp<T>::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
const TensorShape& input_shape = X->Shape();
ORT_ENFORCE(input_shape.NumDimensions() <= 2);
std::vector<int64_t> output_shape(input_shape.GetDims());
output_shape.push_back(num_categories_);
Tensor* Y = context->Output(0, TensorShape(output_shape));
auto y_data = Y->template MutableData<float>();
auto* y_data = Y->template MutableData<float>();
std::fill_n(y_data, Y->Shape().Size(), 0.0f);
auto x_data = X->template Data<T>();
const auto* x_data = X->template Data<T>();
const auto x_size = input_shape.Size();
std::unordered_map<int64_t, size_t>::const_iterator idx;
for (int64_t i = 0; i < input_shape.Size(); ++i) {
for (int64_t i = 0; i < x_size; ++i) {
auto int_idx = cats_int64s_.find(static_cast<int64_t>(x_data[i]));
if (int_idx != cats_int64s_.cend())
y_data[i * num_categories_ + int_idx->second] = 1.0f;
@ -96,17 +97,17 @@ template <>
common::Status OneHotEncoderOp<std::string>::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
const TensorShape& input_shape = X->Shape();
ORT_ENFORCE(input_shape.NumDimensions() <= 2);
std::vector<int64_t> output_shape(input_shape.GetDims());
output_shape.push_back(num_categories_);
Tensor* Y = context->Output(0, TensorShape(output_shape));
auto y_data = Y->template MutableData<float>();
auto* y_data = Y->template MutableData<float>();
std::fill_n(y_data, Y->Shape().Size(), 0.0f);
auto x_data = X->template Data<std::string>();
for (int64_t i = 0; i < input_shape.Size(); ++i) {
const auto* x_data = X->template Data<std::string>();
const auto x_size = input_shape.Size();
for (int64_t i = 0; i < x_size; ++i) {
auto str_idx = cats_strings_.find(x_data[i]);
if (str_idx != cats_strings_.cend())
y_data[i * num_categories_ + str_idx->second] = 1.0f;

View file

@ -41,6 +41,18 @@ void TestIntCategory(std::vector<T>& input) {
test_vector.AddAttribute("zeros", int64_t{0});
test_vector.Run(OpTester::ExpectResult::kExpectFailure);
// Test MultiDimensional [:, :, Labels]
OpTester test_multiD("OneHotEncoder", 1, onnxruntime::kMLDomain);
test_multiD.AddAttribute("cats_int64s", categories);
test_multiD.AddInput<T>("X", {1, 1, 7}, input);
test_multiD.AddOutput<float>("Y", {1, 1, 7, 8}, expected_output);
test_multiD.AddAttribute("zeros", int64_t{1});
test_multiD.Run();
test_multiD.AddAttribute("zeros", int64_t{0});
test_multiD.Run(OpTester::ExpectResult::kExpectFailure);
}
TEST(OneHotEncoderOpTest, IntegerWithInt64) {
@ -49,17 +61,18 @@ TEST(OneHotEncoderOpTest, IntegerWithInt64) {
}
/*
// TODO: Support int32_t type kernel for the op and uncomment the test
TEST(OneHotEncoderOpTest, IntegerWithInt32) {
vector<int> input{ 8, 1, 0, 0, 3, 7, 4 };
TestIntCategory<int>(input);
vector<int32_t> input{ 8, 1, 0, 0, 3, 7, 4 };
TestIntCategory<int32_t>(input);
}
*/
TEST(OneHotEncoderOpTest, IntegerWithDouble) {
vector<double> input{ 8.1f, 1.2f, 0.0f, 0.7f, 3.4f, 7.9f, 4.4f };
TestIntCategory<double>(input);
}
*/
TEST(OneHotEncoderOpTest, String) {
std::vector<std::string> categories{"Apple", "Orange", "Watermelon", "Blueberry", "Coconut", "Mango", "Tangerine"};
vector<std::string> input{"Watermelon", "Orange", "Tangerine", "Apple", "Kit"};
@ -95,6 +108,18 @@ TEST(OneHotEncoderOpTest, String) {
test_vector.AddAttribute("zeros", int64_t{0});
test_vector.Run(OpTester::ExpectResult::kExpectFailure);
// Test MultiDimensional [:, Labels, :]
OpTester test_multiD("OneHotEncoder", 1, onnxruntime::kMLDomain);
test_multiD.AddAttribute("cats_strings", categories);
test_multiD.AddInput<string>("X", {1, 5, 1}, input);
test_multiD.AddOutput<float>("Y", {1, 5, 1, 7}, expected_output);
test_multiD.AddAttribute("zeros", int64_t{1});
test_multiD.Run();
test_multiD.AddAttribute("zeros", int64_t{0});
test_multiD.Run(OpTester::ExpectResult::kExpectFailure);
}
} // namespace test