mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Optimize cuComputeGradInput performance. (#7479)
Move the checking of gamma to host and specialize both case through template.
This commit is contained in:
parent
6773b4f5dd
commit
e6a3308db7
1 changed files with 41 additions and 15 deletions
|
|
@ -283,7 +283,7 @@ __global__ void cuComputeGradGammaBeta(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, bool use_mean, bool simplified>
|
||||
template <typename T, typename U, bool use_mean, bool use_gamma, bool simplified>
|
||||
__global__ void cuComputeGradInput(
|
||||
const T* __restrict__ dout,
|
||||
const T* __restrict__ input,
|
||||
|
|
@ -305,7 +305,7 @@ __global__ void cuComputeGradInput(
|
|||
const T* k_dout = dout + i1 * n2;
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
if (gamma != NULL) {
|
||||
if (use_gamma) {
|
||||
int l = 4 * thrx;
|
||||
for (; l + 3 < n2; l += 4 * numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
|
|
@ -398,7 +398,7 @@ __global__ void cuComputeGradInput(
|
|||
U fH = (U)n2;
|
||||
U term1 = (U(1) / fH) * c_invvar;
|
||||
T* k_grad_input = grad_input + i1 * n2;
|
||||
if (gamma != NULL) {
|
||||
if (use_gamma) {
|
||||
for (int l = thrx; l < n2; l += numx) {
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
U f_grad_input = fH * c_loss * U(gamma[l]);
|
||||
|
|
@ -508,18 +508,8 @@ void HostLayerNormGradient(
|
|||
int nshared =
|
||||
threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
|
||||
if (mean == nullptr && !simplified) {
|
||||
cuComputeGradInput<T, U, false, false><<<blocks1, threads1, nshared, stream>>>(
|
||||
dout,
|
||||
input,
|
||||
output,
|
||||
gamma,
|
||||
beta,
|
||||
mean,
|
||||
invvar,
|
||||
n1, n2,
|
||||
grad_input);
|
||||
} else {
|
||||
cuComputeGradInput<T, U, true, simplified><<<blocks1, threads1, nshared, stream>>>(
|
||||
if (gamma == nullptr) {
|
||||
cuComputeGradInput<T, U, false, false, false><<<blocks1, threads1, nshared, stream>>>(
|
||||
dout,
|
||||
input,
|
||||
output,
|
||||
|
|
@ -529,6 +519,42 @@ void HostLayerNormGradient(
|
|||
invvar,
|
||||
n1, n2,
|
||||
grad_input);
|
||||
} else {
|
||||
cuComputeGradInput<T, U, false, true, false><<<blocks1, threads1, nshared, stream>>>(
|
||||
dout,
|
||||
input,
|
||||
output,
|
||||
gamma,
|
||||
beta,
|
||||
mean,
|
||||
invvar,
|
||||
n1, n2,
|
||||
grad_input);
|
||||
}
|
||||
} else {
|
||||
if (gamma == nullptr) {
|
||||
cuComputeGradInput<T, U, true, false, simplified><<<blocks1, threads1, nshared, stream>>>(
|
||||
dout,
|
||||
input,
|
||||
output,
|
||||
gamma,
|
||||
beta,
|
||||
mean,
|
||||
invvar,
|
||||
n1, n2,
|
||||
grad_input);
|
||||
} else {
|
||||
cuComputeGradInput<T, U, true, true, simplified><<<blocks1, threads1, nshared, stream>>>(
|
||||
dout,
|
||||
input,
|
||||
output,
|
||||
gamma,
|
||||
beta,
|
||||
mean,
|
||||
invvar,
|
||||
n1, n2,
|
||||
grad_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue