Improve code readability and performance. (#2257)

Improve code readability and performance. (#2257)  
  Remove one time checks from loops.
  Move out GetType<>() calls from loop as they
  go through local function statics.
  Get rid of index calculations from input and output
  so we can simlpy advance ptrs and potentially do better pre-fetch.
  Improve code readability.
This commit is contained in:
Dmitri Smirnov 2019-10-25 16:19:59 -07:00 committed by GitHub
parent ce14b07b1c
commit 88c58c19d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 26 deletions

View file

@ -171,34 +171,40 @@ Status MurmurHash3::Compute(OpKernelContext* ctx) const {
Tensor* output_tensor = ctx->Output(0, input_shape);
const MLDataType keys_type = keys->DataType();
const int input_element_bytes = static_cast<int>(keys->DataType()->Size());
const int output_element_bytes = static_cast<int>(output_tensor->DataType()->Size());
const bool is_string = keys_type == DataTypeImpl::GetType<std::string>();
const auto input_element_bytes = keys->DataType()->Size();
const auto output_element_bytes = output_tensor->DataType()->Size();
const int64_t input_count = input_shape.Size();
for (int i = 0; i < input_count; ++i) {
if (DataTypeImpl::GetType<std::string>() == keys_type) {
auto input = keys->DataRaw();
auto output = output_tensor->MutableDataRaw();
auto input_string = reinterpret_cast<const std::string*>(input)[i];
MurmurHash3_x86_32(input_string.c_str(),
static_cast<int>(input_string.length()),
// Output type is inferred by the inference function and it can be of two types int32_t and uint32_t
// however, all is needed is a ptr that can step 4 bytes at a time and for that reason we choose
// raw data casted to a type of choice.
ORT_ENFORCE(sizeof(uint32_t) == output_element_bytes, "Invalid assumption of output element size");
auto output = reinterpret_cast<uint32_t*>(output_tensor->MutableDataRaw());
if (is_string) {
auto input = keys->Data<std::string>();
const auto input_end = input + input_count;
while (input != input_end) {
MurmurHash3_x86_32(input->c_str(),
static_cast<int>(input->length()),
seed_,
reinterpret_cast<uint8_t*>(output) + static_cast<int64_t>(i) * output_element_bytes);
} else {
auto output_type = output_tensor->DataType();
if ((DataTypeImpl::GetType<int32_t>() == keys_type || DataTypeImpl::GetType<uint32_t>() == keys_type) &&
(DataTypeImpl::GetType<int32_t>() == output_type || DataTypeImpl::GetType<uint32_t>() == output_type)) {
auto input = keys->DataRaw();
auto output = output_tensor->MutableDataRaw();
MurmurHash3_x86_32(reinterpret_cast<const uint8_t*>(input) + static_cast<int64_t>(i) * input_element_bytes,
input_element_bytes,
seed_,
reinterpret_cast<uint8_t*>(output) + static_cast<int64_t>(i) * output_element_bytes);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type not supported.");
}
output);
++input;
++output;
}
} else {
auto input = reinterpret_cast<const uint32_t*>(keys->DataRaw());
const auto input_end = input + input_count;
while (input != input_end) {
MurmurHash3_x86_32(input,
static_cast<int>(input_element_bytes),
seed_,
output);
++input;
++output;
}
}
return Status::OK();
}

View file

@ -13,7 +13,7 @@ class MurmurHash3 final : public OpKernel {
public:
MurmurHash3(const OpKernelInfo& info) : OpKernel(info) {
seed_ = static_cast<uint32_t>(info.GetAttrOrDefault<int64_t>("seed", 0));
is_positive_ = static_cast<int64_t>(info.GetAttrOrDefault<int64_t>("positive", 1));
is_positive_ = info.GetAttrOrDefault<int64_t>("positive", 1) == 1;
}
Status Compute(OpKernelContext* context) const override;
@ -23,7 +23,7 @@ private:
private :
uint32_t seed_;
int64_t is_positive_{1};
bool is_positive_{true};
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -7,6 +7,15 @@
namespace onnxruntime {
namespace test {
TEST(MurmurHash3OpTest, UnsupportedInputType) {
OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain);
test.AddInput<double>("X", {1}, {3.});
test.AddAttribute<int64_t>("positive", 0);
test.AddOutput<int32_t>("Y", {1}, {847579505L});
// Unsupported input type
test.Run(OpTester::ExpectResult::kExpectFailure);
}
TEST(MurmurHash3OpTest, DefaultSeed) {
OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain);
test.AddInput<int32_t>("X", {1}, {3L});