Improve fast_divmod (#4224)

* improve fast_divmod

BERT-L throughput is improved about ~1.8%

* fix Win build.

Co-authored-by: Weixing Zhang <wezhan@microsoft.com>
This commit is contained in:
Weixing Zhang 2020-06-16 03:03:58 -07:00 committed by GitHub
parent 825392c25b
commit 7ccce4379e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -5,126 +5,56 @@
#pragma once
#include <iostream>
#include <limits>
#include <cuda_runtime.h>
#include <cmath>
#include "core/common/common.h"
namespace onnxruntime {
namespace cuda {
__host__ __device__ __inline__ int mulhi(const int M, const int n) {
// The code below is based on section 4 Unsigned division of paper https://gmplib.org/~tege/divcnst-pldi94.pdf
// In current ORT, fast_divmod is used for calculating the position of a element in tensor,
// so unsigned integer division from the paper is good enough for ORT. The advantage is that div is very simple,
// then GPU compiler can do loop unroll easilly when divmod is called in a loop.
struct fast_divmod {
fast_divmod(int d = 1) {
d_ = d == 0 ? 1 : d;
ORT_ENFORCE(d_ >= 1 && d_ <= static_cast<uint32_t>(std::numeric_limits<int>::max()));
for (l_ = 0; l_ < 32; l_++) if ((1U << l_) >= d_) break;
uint64_t one = 1;
uint64_t m = ((one << 32) * ((one << l_) - d_)) / d_ + 1;
M_ = static_cast<uint32_t>(m);
// according to paper, the value of m' should fit in a unsigned integer.
ORT_ENFORCE(M_ > 0 && M_ == m);
}
__host__ __device__ inline int div(int n) const {
#ifdef __CUDA_ARCH__
return __mulhi(M, n);
uint32_t t = __umulhi(M_, n);
return (t + n) >> l_;
#else
return (((unsigned long long)((long long)M * (long long)n)) >> 32);
// Using uint64_t for t, then t + n won't overflow.
uint64_t t = ((uint64_t) M_ * n) >> 32;
return static_cast<int>((t + n) >> l_);
#endif
}
// Based on code from Chapter 10 of "Hacker's Delight, 2nd ed."
class fast_divmod {
public:
fast_divmod(int d = 1) : d_(d), a_(0) { find_magic_numbers(); }
fast_divmod(const fast_divmod& other) : d_(other.d_), M_(other.M_), s_(other.s_), a_(other.a_){};
__host__ __device__ __inline__ int div(int n) const {
// get high 32 bits of M * n
int q = mulhi(M_, n);
// deal with add / subs if needed
q += a_ * n;
// shift if necessary
if (s_ >= 0) {
q >>= s_;
q += ((unsigned int)q >> 31);
}
return q;
}
__host__ __device__ __inline__ void divmod(int n, int& q, int& r) const {
// handle special cases
if (d_ == 1) {
q = n;
r = 0;
} else if (d_ == -1) {
q = -n;
r = 0;
} else {
// general case
q = div(n);
r = n - q * d_;
}
__host__ __device__ inline int mod(int n) const {
return n - div(n) * d_;
}
public:
int d_, M_, s_, a_;
private:
// Based on code from Hacker's delight 2.ed
// Chapter 10, figure 10-1
void find_magic_numbers() {
// special case for d = 1, -1
if (d_ == 1) {
M_ = 0;
s_ = 0;
a_ = 1;
return;
} else if (d_ == -1) {
M_ = 0;
s_ = -1;
a_ = -1;
return;
}
// general case
const unsigned two31 = 0x80000000;
unsigned abs_d = (d_ == 0) ? 1 : abs(d_);
unsigned t = two31 + ((unsigned)d_ >> 31); // t = 2^31 + (d < 0) ? 1 : 0
unsigned abs_nc = t - 1 - (t % abs_d); // |n_c| = t - 1 - rem(t, |d|)
int p = 31;
unsigned q1 = two31 / abs_nc; // Init q_1 = 2^31 / |n_c|
unsigned r1 = two31 - q1 * abs_nc; // Init r_1 = rem(q_1, |n_c|)
unsigned q2 = two31 / abs_d; // Init q_2 = 2^31 / |d|
unsigned r2 = two31 - q2 * abs_d; // Init r_2 = rem(q_2, |d|)
unsigned delta;
// iterate p until
// 2^p < n_c * (d - rem(2^p, d)) is satisfied
do {
++p;
q1 *= 2;
r1 *= 2;
if (r1 >= abs_nc) {
q1 += 1;
r1 -= abs_nc;
}
q2 *= 2;
r2 *= 2;
if (r2 >= abs_d) {
q2 += 1;
r2 -= abs_d;
}
delta = abs_d - r2;
} while (q1 < delta ||
(q1 == delta && r1 == 0));
// store magic numbers
M_ = q2 + 1;
if (d_ < 0) M_ = -M_;
s_ = p - 32;
// generate sentinel for correct adds / subs
// "generate the add if d > 0 and M < 0"
if ((d_ > 0) && (M_ < 0)) a_ = 1;
// "generate the sub if d < 0 and M > 0"
else if ((d_ < 0) && (M_ > 0))
a_ = -1;
// Otherwise no add / sub needed
else
a_ = 0;
__host__ __device__ inline void divmod(int n, int& q, int& r) const {
q = div(n);
r = n - q * d_;
}
uint32_t d_; // divisor
uint32_t M_; // m' in the paper.
uint32_t l_; // l_ = ceil(log2(d_))
};
} // namespace cuda