diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu index 99f818ff56..66a06a9634 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu @@ -283,7 +283,7 @@ __global__ void cuComputeGradGammaBeta( } } -template +template __global__ void cuComputeGradInput( const T* __restrict__ dout, const T* __restrict__ input, @@ -305,7 +305,7 @@ __global__ void cuComputeGradInput( const T* k_dout = dout + i1 * n2; const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL) { + if (use_gamma) { int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { @@ -398,7 +398,7 @@ __global__ void cuComputeGradInput( U fH = (U)n2; U term1 = (U(1) / fH) * c_invvar; T* k_grad_input = grad_input + i1 * n2; - if (gamma != NULL) { + if (use_gamma) { for (int l = thrx; l < n2; l += numx) { const U c_loss = static_cast(k_dout[l]); U f_grad_input = fH * c_loss * U(gamma[l]); @@ -508,18 +508,8 @@ void HostLayerNormGradient( int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; if (mean == nullptr && !simplified) { - cuComputeGradInput<<>>( - dout, - input, - output, - gamma, - beta, - mean, - invvar, - n1, n2, - grad_input); - } else { - cuComputeGradInput<<>>( + if (gamma == nullptr) { + cuComputeGradInput<<>>( dout, input, output, @@ -529,6 +519,42 @@ void HostLayerNormGradient( invvar, n1, n2, grad_input); + } else { + cuComputeGradInput<<>>( + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + grad_input); + } + } else { + if (gamma == nullptr) { + cuComputeGradInput<<>>( + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + grad_input); + } else { + cuComputeGradInput<<>>( + dout, + input, + output, + gamma, + beta, + mean, + invvar, + n1, n2, + grad_input); + } } }