Optimize cuComputeGradInput performance. (#7479)

Move the checking of gamma to host and specialize both case through template.
This commit is contained in:
sabreshao 2021-04-29 08:08:31 +08:00 committed by GitHub
parent 6773b4f5dd
commit e6a3308db7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -283,7 +283,7 @@ __global__ void cuComputeGradGammaBeta(
}
}
template <typename T, typename U, bool use_mean, bool simplified>
template <typename T, typename U, bool use_mean, bool use_gamma, bool simplified>
__global__ void cuComputeGradInput(
const T* __restrict__ dout,
const T* __restrict__ input,
@ -305,7 +305,7 @@ __global__ void cuComputeGradInput(
const T* k_dout = dout + i1 * n2;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
if (use_gamma) {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
@ -398,7 +398,7 @@ __global__ void cuComputeGradInput(
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T* k_grad_input = grad_input + i1 * n2;
if (gamma != NULL) {
if (use_gamma) {
for (int l = thrx; l < n2; l += numx) {
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * U(gamma[l]);
@ -508,18 +508,8 @@ void HostLayerNormGradient(
int nshared =
threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
if (mean == nullptr && !simplified) {
cuComputeGradInput<T, U, false, false><<<blocks1, threads1, nshared, stream>>>(
dout,
input,
output,
gamma,
beta,
mean,
invvar,
n1, n2,
grad_input);
} else {
cuComputeGradInput<T, U, true, simplified><<<blocks1, threads1, nshared, stream>>>(
if (gamma == nullptr) {
cuComputeGradInput<T, U, false, false, false><<<blocks1, threads1, nshared, stream>>>(
dout,
input,
output,
@ -529,6 +519,42 @@ void HostLayerNormGradient(
invvar,
n1, n2,
grad_input);
} else {
cuComputeGradInput<T, U, false, true, false><<<blocks1, threads1, nshared, stream>>>(
dout,
input,
output,
gamma,
beta,
mean,
invvar,
n1, n2,
grad_input);
}
} else {
if (gamma == nullptr) {
cuComputeGradInput<T, U, true, false, simplified><<<blocks1, threads1, nshared, stream>>>(
dout,
input,
output,
gamma,
beta,
mean,
invvar,
n1, n2,
grad_input);
} else {
cuComputeGradInput<T, U, true, true, simplified><<<blocks1, threads1, nshared, stream>>>(
dout,
input,
output,
gamma,
beta,
mean,
invvar,
n1, n2,
grad_input);
}
}
}