Introduce register-efficient warp-wise Softmax (#15266)

improve softmax forward when number of elem to do softmax is between
(1024,2048]

several optimizations done in the PR:
1. originally ort will call softmax_block_forward when shape is 1500,
this will cause 5.53ms, however ort has another implementation called
softmax_warp_forward, this function will only need 4.74ms, so i modified
the function selection logic to call the faster version.
2. softmax_warp_forward will use register to cache the input in fp32
mode, this will consume many registers when data number is large and
will make warp occupancy quite low, also compiler can do some of its
optimizations, so the pr implements another version of
softmax_warp_forward, it will use shared memory instead of register to
cache the input; also when the for loop in the function has many
iterations, actually disable loop unrolling will make kernel faster
further.

the perf table between softmax_warp_forward1(the original version) and
softmax_warp_forward2

![image](https://user-images.githubusercontent.com/43435212/228491963-cf87e3b3-e69e-454c-bab6-7e62a25bf76b.png)


in open-ai whisper case, the kernel gain will be 5.53ms/3.03ms = 82%
(softmax_block_forward vs softmax_warp_forward2)
This commit is contained in:
zhijiang 2023-05-22 08:26:03 +08:00 committed by GitHub
parent 0204594f90
commit 5607a7151a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 118 additions and 56 deletions

View file

@ -26,8 +26,9 @@ Status SoftMaxComputeHelper(
int64_t D = input_shape.SizeFromDimension(axis);
auto Y_data = reinterpret_cast<CudaT_OUT*>(Y);
auto X_data = reinterpret_cast<const CudaT_IN*>(X);
if (D <= 1024 && D * sizeof(T) <= 4096) {
// According to nsight compute profiling, softmax_warp_forward_resource_efficient is better than dispatch_blockwise_softmax_forward when 1024 < D <=2048 and N >= 8192.
const bool use_softmax_warp_forward_resource_efficient = 1024 < D && D <= 2048 && D * sizeof(T) <= 4096 && N >= 8192;
if ((D <= 1024 && D * sizeof(T) <= 4096) or use_softmax_warp_forward_resource_efficient) {
return dispatch_warpwise_softmax_forward<
CudaT_IN, CudaT_OUT, AccumulationType_t<CudaT_ACCUM>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N));

View file

@ -39,67 +39,59 @@ Status dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, con
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < GPU_WARP_SIZE_HOST) ? next_power_of_two : GPU_WARP_SIZE_HOST;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int threads_per_block, shared_memory_size;
// there are 2 options to save one row of the input matrix: register or shared memory
// when the number of elements is small, we use register; otherwise, we use shared memory;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
if (log2_elements <= 10){
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
// use 128 threads per block to maximimize gpu utilization
threads_per_block = 128;
shared_memory_size = 0;
} else{
// setting the number of threads per block to 32 will make index offset calculations easier,
// under this setting, the cuda block number will be equal to batch size.
threads_per_block = 32;
// use shared memory to contain one row of elements
// TODO: one more optimization can be done here: we actually not need to save next_power_of_two elements, we can just save the valid elements
shared_memory_size = next_power_of_two * sizeof(input_t);
}
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
softmax_warp_forward<input_t, output_t, acc_t, 0, is_log_softmax>
<<<blocks, threads, 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
softmax_warp_forward<input_t, output_t, acc_t, 1, is_log_softmax>
<<<blocks, threads, 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
softmax_warp_forward<input_t, output_t, acc_t, 2, is_log_softmax>
<<<blocks, threads, 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
softmax_warp_forward<input_t, output_t, acc_t, 3, is_log_softmax>
<<<blocks, threads, 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
softmax_warp_forward<input_t, output_t, acc_t, 4, is_log_softmax>
<<<blocks, threads, 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
softmax_warp_forward<input_t, output_t, acc_t, 5, is_log_softmax>
<<<blocks, threads, 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
softmax_warp_forward<input_t, output_t, acc_t, 6, is_log_softmax>
<<<blocks, threads, 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
softmax_warp_forward<input_t, output_t, acc_t, 7, is_log_softmax>
<<<blocks, threads, 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
softmax_warp_forward<input_t, output_t, acc_t, 8, is_log_softmax>
<<<blocks, threads, 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
softmax_warp_forward<input_t, output_t, acc_t, 9, is_log_softmax>
<<<blocks, threads, 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
softmax_warp_forward<input_t, output_t, acc_t, 10, is_log_softmax>
<<<blocks, threads, 0, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
#define LAUNCH_KERNEL(kernel_name, log2_elements_value) \
kernel_name<input_t, output_t, acc_t, log2_elements_value, is_log_softmax> \
<<<blocks, threads, shared_memory_size, stream>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
#define CASE_LOG2_ELEMENTS(log2_elements_value) \
case log2_elements_value: { \
if constexpr (log2_elements_value <= 10) { \
LAUNCH_KERNEL(softmax_warp_forward, log2_elements_value) \
} else { \
LAUNCH_KERNEL(softmax_warp_forward_resource_efficient, log2_elements_value) \
} \
} break
CASE_LOG2_ELEMENTS(0);
CASE_LOG2_ELEMENTS(1);
CASE_LOG2_ELEMENTS(2);
CASE_LOG2_ELEMENTS(3);
CASE_LOG2_ELEMENTS(4);
CASE_LOG2_ELEMENTS(5);
CASE_LOG2_ELEMENTS(6);
CASE_LOG2_ELEMENTS(7);
CASE_LOG2_ELEMENTS(8);
CASE_LOG2_ELEMENTS(9);
CASE_LOG2_ELEMENTS(10);
CASE_LOG2_ELEMENTS(11); // start to use softmax_warp_forward_resource_efficient instead of softmax_warp_forward for better performance
#undef LAUNCH_KERNEL
#undef CASE_LOG2_ELEMENTS
} // switch
} // else
return CUDA_CALL(cudaGetLastError());
}

View file

@ -163,5 +163,74 @@ __global__ void softmax_warp_forward(output_t* dst, const input_t* src, int batc
}
}
// softmax_warp_forward uses register to store data in fp32 even when data is fp16, which will cause register resource oversubscription when data is large,
// and will lead to low CUDA warp occupancy and thus a poor kernel performance.
// softmax_warp_forward_resource_efficient is implemented to solve the issue, it caches data in original data type, and casts it into fp32 when needed,
// the idea is like we use recomputation to save resource usage.
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
__global__ void softmax_warp_forward_resource_efficient(output_t* dst, const input_t* src, int batch_size, int stride, int element_count) {
// 1 cuda block only processes one row and contains 1 cuda warp only.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
int local_idx = threadIdx.x;
src += blockIdx.x * stride + local_idx;
dst += blockIdx.x * stride + local_idx;
extern __shared__ unsigned char smem[];
input_t (&elements)[WARP_ITERATIONS][WARP_SIZE] = *reinterpret_cast<input_t (*)[WARP_ITERATIONS][WARP_SIZE]>(smem);
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
elements[it][local_idx] = src[it * WARP_SIZE];
} else {
elements[it][local_idx] = -std::numeric_limits<input_t>::infinity();
}
}
// compute max_value
input_t max_value = elements[0][local_idx];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value = (max_value > elements[it][local_idx]) ? max_value : elements[it][local_idx];
}
warp_reduce<input_t, 1, WARP_SIZE, Max>(&max_value);
// compute sum
acc_t sum{0.0f};
// #pragma unroll
// "exp" contains many instructions, if we unroll the loop then cuda warp will be stalled because of icache miss and thus lower the perf.
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index >= element_count)
break;
if (is_log_softmax) {
sum += std::exp((float)(elements[it][local_idx] - max_value));
} else {
acc_t tmp = std::exp((float)(elements[it][local_idx] - max_value));
elements[it][local_idx] = tmp;
sum += tmp;
}
}
warp_reduce<acc_t, 1, WARP_SIZE, Add>(&sum);
// store result
if (is_log_softmax) sum = static_cast<acc_t>(max_value) + std::log((float)(sum));
// do the reciprocal once, so the div operation can be replaced by mul.
acc_t invsum = static_cast<acc_t>(1.0f / sum);
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
if (is_log_softmax) {
dst[it * WARP_SIZE] = (float)elements[it][local_idx] - sum;
} else {
dst[it * WARP_SIZE] = (float)elements[it][local_idx] * invsum;
}
} else {
break;
}
}
}
} // namespace cuda
} // namespace onnxruntime