mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
Improve code readability and performance. (#2257)
Improve code readability and performance. (#2257) Remove one time checks from loops. Move out GetType<>() calls from loop as they go through local function statics. Get rid of index calculations from input and output so we can simlpy advance ptrs and potentially do better pre-fetch. Improve code readability.
This commit is contained in:
parent
ce14b07b1c
commit
88c58c19d4
3 changed files with 41 additions and 26 deletions
|
|
@ -171,34 +171,40 @@ Status MurmurHash3::Compute(OpKernelContext* ctx) const {
|
|||
Tensor* output_tensor = ctx->Output(0, input_shape);
|
||||
|
||||
const MLDataType keys_type = keys->DataType();
|
||||
const int input_element_bytes = static_cast<int>(keys->DataType()->Size());
|
||||
const int output_element_bytes = static_cast<int>(output_tensor->DataType()->Size());
|
||||
const bool is_string = keys_type == DataTypeImpl::GetType<std::string>();
|
||||
|
||||
const auto input_element_bytes = keys->DataType()->Size();
|
||||
const auto output_element_bytes = output_tensor->DataType()->Size();
|
||||
const int64_t input_count = input_shape.Size();
|
||||
for (int i = 0; i < input_count; ++i) {
|
||||
if (DataTypeImpl::GetType<std::string>() == keys_type) {
|
||||
auto input = keys->DataRaw();
|
||||
auto output = output_tensor->MutableDataRaw();
|
||||
auto input_string = reinterpret_cast<const std::string*>(input)[i];
|
||||
MurmurHash3_x86_32(input_string.c_str(),
|
||||
static_cast<int>(input_string.length()),
|
||||
// Output type is inferred by the inference function and it can be of two types int32_t and uint32_t
|
||||
// however, all is needed is a ptr that can step 4 bytes at a time and for that reason we choose
|
||||
// raw data casted to a type of choice.
|
||||
ORT_ENFORCE(sizeof(uint32_t) == output_element_bytes, "Invalid assumption of output element size");
|
||||
auto output = reinterpret_cast<uint32_t*>(output_tensor->MutableDataRaw());
|
||||
|
||||
if (is_string) {
|
||||
auto input = keys->Data<std::string>();
|
||||
const auto input_end = input + input_count;
|
||||
while (input != input_end) {
|
||||
MurmurHash3_x86_32(input->c_str(),
|
||||
static_cast<int>(input->length()),
|
||||
seed_,
|
||||
reinterpret_cast<uint8_t*>(output) + static_cast<int64_t>(i) * output_element_bytes);
|
||||
} else {
|
||||
auto output_type = output_tensor->DataType();
|
||||
if ((DataTypeImpl::GetType<int32_t>() == keys_type || DataTypeImpl::GetType<uint32_t>() == keys_type) &&
|
||||
(DataTypeImpl::GetType<int32_t>() == output_type || DataTypeImpl::GetType<uint32_t>() == output_type)) {
|
||||
auto input = keys->DataRaw();
|
||||
auto output = output_tensor->MutableDataRaw();
|
||||
MurmurHash3_x86_32(reinterpret_cast<const uint8_t*>(input) + static_cast<int64_t>(i) * input_element_bytes,
|
||||
input_element_bytes,
|
||||
seed_,
|
||||
reinterpret_cast<uint8_t*>(output) + static_cast<int64_t>(i) * output_element_bytes);
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type not supported.");
|
||||
}
|
||||
output);
|
||||
++input;
|
||||
++output;
|
||||
}
|
||||
} else {
|
||||
auto input = reinterpret_cast<const uint32_t*>(keys->DataRaw());
|
||||
const auto input_end = input + input_count;
|
||||
while (input != input_end) {
|
||||
MurmurHash3_x86_32(input,
|
||||
static_cast<int>(input_element_bytes),
|
||||
seed_,
|
||||
output);
|
||||
++input;
|
||||
++output;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class MurmurHash3 final : public OpKernel {
|
|||
public:
|
||||
MurmurHash3(const OpKernelInfo& info) : OpKernel(info) {
|
||||
seed_ = static_cast<uint32_t>(info.GetAttrOrDefault<int64_t>("seed", 0));
|
||||
is_positive_ = static_cast<int64_t>(info.GetAttrOrDefault<int64_t>("positive", 1));
|
||||
is_positive_ = info.GetAttrOrDefault<int64_t>("positive", 1) == 1;
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
|
@ -23,7 +23,7 @@ private:
|
|||
|
||||
private :
|
||||
uint32_t seed_;
|
||||
int64_t is_positive_{1};
|
||||
bool is_positive_{true};
|
||||
};
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -7,6 +7,15 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(MurmurHash3OpTest, UnsupportedInputType) {
|
||||
OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<double>("X", {1}, {3.});
|
||||
test.AddAttribute<int64_t>("positive", 0);
|
||||
test.AddOutput<int32_t>("Y", {1}, {847579505L});
|
||||
// Unsupported input type
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure);
|
||||
}
|
||||
|
||||
TEST(MurmurHash3OpTest, DefaultSeed) {
|
||||
OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<int32_t>("X", {1}, {3L});
|
||||
|
|
|
|||
Loading…
Reference in a new issue