improve perf for softmax (#6128)

* improve perf for both gathergrad and softmax

* revert the change in gathergrad and will be done in another PR.

* address comments from code review.
This commit is contained in:
Weixing Zhang 2020-12-21 14:15:54 -08:00 committed by GitHub
parent ea9cfa554a
commit 53307a5f2e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 406 additions and 0 deletions

View file

@ -0,0 +1,210 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Modifications Copyright (c) Microsoft. */
// The code below is mostly copied from Pytorch PersistentSoftmax.cuh
#include "hip/hip_runtime.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/math/softmax_impl.cuh"
#include "core/providers/rocm/math/softmax.h"
#include <limits>
namespace onnxruntime {
namespace rocm {
// The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension.
// Each sample contains element_count scalar elements. element_count can be any integer value <= 1024.
// The template arguments have the following meaning:
// One "WARP" works on one "BATCH". One "BATCH" contains "WARP_BATCH" samples.
// WARP_BATCH is equal to 1 when element_count is large, and > 1 when element_count is small.
// A "WARP" contains "GPU_WARP_SIZE" threads, these treads are guaranteed to belong to the same warp.
// This is important because it means only __shfl_ instructions are required for reductions.
// Note that this means WARP_SIZE must be a power of two and <= architecture warp size.
// ROCM warp size is 32 for all existing GPU architecures, but there is no guarantee this will not change for future arch.
// is_log_softmax is a flag indicating whether SoftMax or LogSoftMax should be computed.
// The template can be instantiated with any floating point type for the type arguments input_t, output_t and acc_t.
// This allows SoftMax to be fused with a cast immediately following the SoftMax.
// For instance:
// input_t=half, acc_t=float, output_t=half => read half tensor, float accumulators, write half tensor.
// input_t=half, acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor.
// input_t_float, acc_t=float, output_t=half => read float tensor, float accumulators, write half tensor.
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
__global__ void softmax_warp_forward(output_t* dst, const input_t* src, int batch_size, int stride, int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
// constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int WARP_BATCH = 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * stride + local_idx;
dst += first_batch * stride + local_idx;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
// the nested loops.
// This should have no impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
elements[i][it] = src[i * element_count + it * WARP_SIZE];
} else {
elements[i][it] = -std::numeric_limits<acc_t>::infinity();
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (is_log_softmax) {
sum[i] += expf((float)(elements[i][it] - max_value[i]));
} else {
elements[i][it] = expf((float)(elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
if (is_log_softmax) sum[i] = max_value[i] + logf((float)(sum[i]));
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
if (is_log_softmax) {
dst[i * element_count + it * WARP_SIZE] = elements[i][it] - sum[i];
} else {
dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
}
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_softmax_forward(output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
// int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
int batches_per_warp = 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 256;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 0, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 1, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 2, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 3, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 4, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 5, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 6, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 7, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 8, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 9, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_forward<input_t, output_t, acc_t, 10, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
#define SPECIALIZED_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template void dispatch_softmax_forward<input_t, output_t, acc_t, false>(output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); \
template void dispatch_softmax_forward<input_t, output_t, acc_t, true>(output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count);
SPECIALIZED_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_SOFTMAX_IMPL(half, half, float)
SPECIALIZED_SOFTMAX_IMPL(double, double, double)
}
}

View file

@ -0,0 +1,194 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Modifications Copyright (c) Microsoft. */
// The code below is mostly copied from Pytorch PersistentSoftmax.cuh
#include "hip/hip_runtime.h"
#include "orttraining/training_ops/rocm/math/softmax_grad.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/math/softmax_impl.cuh"
namespace onnxruntime {
namespace rocm {
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
__global__ void softmax_warp_backward(output_t* gradInput, const input_t* grad, const input_t* output, int batch_size, int stride, int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
// constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int WARP_BATCH = 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x % WARP_SIZE;
// the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
// the nested loops.
// This should have no impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
acc_t grad_output_reg[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] = grad[i * element_count + it * WARP_SIZE];
output_reg[i][it] = output[i * element_count + it * WARP_SIZE];
grad_output_reg[i][it] = grad_reg[i][it] * output_reg[i][it];
} else {
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
grad_output_reg[i][it] = acc_t(0);
}
}
}
acc_t sum[WARP_BATCH];
if (!is_log_softmax) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_output_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_output_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
}
else {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
}
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
if (is_log_softmax) {
gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - expf(output_reg[i][it]) * sum[i]);
} else {
gradInput[i * element_count + it * WARP_SIZE] = (grad_reg[i][it] - sum[i] ) * output_reg[i][it];
}
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_softmax_backward(output_t* grad_input, const input_t* grad, const input_t* output, int softmax_elements, int softmax_elements_stride, int batch_count) {
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
// int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
int batches_per_warp = 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 256;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 0, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 1, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 2, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 3, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 4, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 5, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 6, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 7, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 8, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 9, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
hipLaunchKernelGGL(HIP_KERNEL_NAME(softmax_warp_backward<input_t, output_t, acc_t, 10, is_log_softmax>), dim3(blocks), dim3(threads), 0, 0, grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
#define SPECIALIZED_SOFTMAX_GRAD_IMPL(input_t, output_t, acc_t) \
template void dispatch_softmax_backward<input_t, output_t, acc_t, false>(input_t * grad_input, const output_t* grad, const output_t* output, int softmax_elements, int softmax_elements_stride, int batch_count); \
template void dispatch_softmax_backward<input_t, output_t, acc_t, true>(input_t * grad_input, const output_t* grad, const output_t* output, int softmax_elements, int softmax_elements_stride, int batch_count);
SPECIALIZED_SOFTMAX_GRAD_IMPL(float, float, float)
SPECIALIZED_SOFTMAX_GRAD_IMPL(half, half, float)
SPECIALIZED_SOFTMAX_GRAD_IMPL(double, double, double)
}
}

View file

@ -105,6 +105,7 @@ provider_excluded_files = [
'math/matmul_integer.cu',
'math/matmul_integer.cuh',
'math/matmul_integer.h',
'math/softmax_impl.cu',
'math/softmax.cc',
'math/topk.cc',
'math/topk.h',
@ -260,6 +261,7 @@ training_ops_excluded_files = [
'math/scale.cc',
'math/scale.cu',
'math/scale.h',
'math/softmax_grad_impl.cu',
'math/softmax_grad.cc',
'nn/batch_norm_grad.cc',
'nn/batch_norm_grad.h',