From e6a3308db7c03a13e0f08b221b6770e17fc3a4ef Mon Sep 17 00:00:00 2001 From: sabreshao Date: Thu, 29 Apr 2021 08:08:31 +0800 Subject: [PATCH] Optimize cuComputeGradInput performance. (#7479) Move the checking of gamma to host and specialize both case through template. --- .../training_ops/cuda/nn/layer_norm_impl.cu | 56 ++++++++++++++----- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu index 99f818ff56..66a06a9634 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu @@ -283,7 +283,7 @@ __global__ void cuComputeGradGammaBeta( } } -template +template __global__ void cuComputeGradInput( const T* __restrict__ dout, const T* __restrict__ input, @@ -305,7 +305,7 @@ __global__ void cuComputeGradInput( const T* k_dout = dout + i1 * n2; const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL) { + if (use_gamma) { int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { @@ -398,7 +398,7 @@ __global__ void cuComputeGradInput( U fH = (U)n2; U term1 = (U(1) / fH) * c_invvar; T* k_grad_input = grad_input + i1 * n2; - if (gamma != NULL) { + if (use_gamma) { for (int l = thrx; l < n2; l += numx) { const U c_loss = static_cast(k_dout[l]); U f_grad_input = fH * c_loss * U(gamma[l]); @@ -508,18 +508,8 @@ void HostLayerNormGradient( int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; if (mean == nullptr && !simplified) { - cuComputeGradInput<<>>( - dout, - input, - output, - gamma, - beta, - mean, - invvar, - n1, n2, - grad_input); - } else { - cuComputeGradInput<<>>( + if (gamma == nullptr) { + cuComputeGradInput<<>>( dout, input, output, @@ -529,6 +519,42 @@ void HostLayerNormGradient( invvar, n1, n2, grad_input); + } else { + cuComputeGradInput<<>>( + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + grad_input); + } + } else { + if (gamma == nullptr) { + cuComputeGradInput<<>>( + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + grad_input); + } else { + cuComputeGradInput<<>>( + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + grad_input); + } } }