From 88c58c19d476dc3d8534e2518ce88977e41450bd Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 25 Oct 2019 16:19:59 -0700 Subject: [PATCH] Improve code readability and performance. (#2257) Improve code readability and performance. (#2257) Remove one time checks from loops. Move out GetType<>() calls from loop as they go through local function statics. Get rid of index calculations from input and output so we can simlpy advance ptrs and potentially do better pre-fetch. Improve code readability. --- onnxruntime/contrib_ops/cpu/murmur_hash3.cc | 54 ++++++++++--------- onnxruntime/contrib_ops/cpu/murmur_hash3.h | 4 +- .../test/contrib_ops/murmur_hash3_test.cc | 9 ++++ 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/murmur_hash3.cc b/onnxruntime/contrib_ops/cpu/murmur_hash3.cc index ee8f2fbb0b..62f90d22be 100644 --- a/onnxruntime/contrib_ops/cpu/murmur_hash3.cc +++ b/onnxruntime/contrib_ops/cpu/murmur_hash3.cc @@ -171,34 +171,40 @@ Status MurmurHash3::Compute(OpKernelContext* ctx) const { Tensor* output_tensor = ctx->Output(0, input_shape); const MLDataType keys_type = keys->DataType(); - const int input_element_bytes = static_cast(keys->DataType()->Size()); - const int output_element_bytes = static_cast(output_tensor->DataType()->Size()); + const bool is_string = keys_type == DataTypeImpl::GetType(); + + const auto input_element_bytes = keys->DataType()->Size(); + const auto output_element_bytes = output_tensor->DataType()->Size(); const int64_t input_count = input_shape.Size(); - for (int i = 0; i < input_count; ++i) { - if (DataTypeImpl::GetType() == keys_type) { - auto input = keys->DataRaw(); - auto output = output_tensor->MutableDataRaw(); - auto input_string = reinterpret_cast(input)[i]; - MurmurHash3_x86_32(input_string.c_str(), - static_cast(input_string.length()), + // Output type is inferred by the inference function and it can be of two types int32_t and uint32_t + // however, all is needed is a ptr that can step 4 bytes at a time and for that reason we choose + // raw data casted to a type of choice. + ORT_ENFORCE(sizeof(uint32_t) == output_element_bytes, "Invalid assumption of output element size"); + auto output = reinterpret_cast(output_tensor->MutableDataRaw()); + + if (is_string) { + auto input = keys->Data(); + const auto input_end = input + input_count; + while (input != input_end) { + MurmurHash3_x86_32(input->c_str(), + static_cast(input->length()), seed_, - reinterpret_cast(output) + static_cast(i) * output_element_bytes); - } else { - auto output_type = output_tensor->DataType(); - if ((DataTypeImpl::GetType() == keys_type || DataTypeImpl::GetType() == keys_type) && - (DataTypeImpl::GetType() == output_type || DataTypeImpl::GetType() == output_type)) { - auto input = keys->DataRaw(); - auto output = output_tensor->MutableDataRaw(); - MurmurHash3_x86_32(reinterpret_cast(input) + static_cast(i) * input_element_bytes, - input_element_bytes, - seed_, - reinterpret_cast(output) + static_cast(i) * output_element_bytes); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type not supported."); - } + output); + ++input; + ++output; + } + } else { + auto input = reinterpret_cast(keys->DataRaw()); + const auto input_end = input + input_count; + while (input != input_end) { + MurmurHash3_x86_32(input, + static_cast(input_element_bytes), + seed_, + output); + ++input; + ++output; } } - return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/murmur_hash3.h b/onnxruntime/contrib_ops/cpu/murmur_hash3.h index da7f1462c6..ac943aa3a4 100644 --- a/onnxruntime/contrib_ops/cpu/murmur_hash3.h +++ b/onnxruntime/contrib_ops/cpu/murmur_hash3.h @@ -13,7 +13,7 @@ class MurmurHash3 final : public OpKernel { public: MurmurHash3(const OpKernelInfo& info) : OpKernel(info) { seed_ = static_cast(info.GetAttrOrDefault("seed", 0)); - is_positive_ = static_cast(info.GetAttrOrDefault("positive", 1)); + is_positive_ = info.GetAttrOrDefault("positive", 1) == 1; } Status Compute(OpKernelContext* context) const override; @@ -23,7 +23,7 @@ private: private : uint32_t seed_; - int64_t is_positive_{1}; + bool is_positive_{true}; }; } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/murmur_hash3_test.cc b/onnxruntime/test/contrib_ops/murmur_hash3_test.cc index 5676e9963b..e6f1774a7e 100644 --- a/onnxruntime/test/contrib_ops/murmur_hash3_test.cc +++ b/onnxruntime/test/contrib_ops/murmur_hash3_test.cc @@ -7,6 +7,15 @@ namespace onnxruntime { namespace test { +TEST(MurmurHash3OpTest, UnsupportedInputType) { + OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain); + test.AddInput("X", {1}, {3.}); + test.AddAttribute("positive", 0); + test.AddOutput("Y", {1}, {847579505L}); + // Unsupported input type + test.Run(OpTester::ExpectResult::kExpectFailure); +} + TEST(MurmurHash3OpTest, DefaultSeed) { OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain); test.AddInput("X", {1}, {3L});