diff --git a/onnxruntime/core/providers/cuda/math/softmax.cc b/onnxruntime/core/providers/cuda/math/softmax.cc index 56e4e052bb..3e0ebfedc9 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.cc +++ b/onnxruntime/core/providers/cuda/math/softmax.cc @@ -26,8 +26,9 @@ Status SoftMaxComputeHelper( int64_t D = input_shape.SizeFromDimension(axis); auto Y_data = reinterpret_cast(Y); auto X_data = reinterpret_cast(X); - - if (D <= 1024 && D * sizeof(T) <= 4096) { + // According to nsight compute profiling, softmax_warp_forward_resource_efficient is better than dispatch_blockwise_softmax_forward when 1024 < D <=2048 and N >= 8192. + const bool use_softmax_warp_forward_resource_efficient = 1024 < D && D <= 2048 && D * sizeof(T) <= 4096 && N >= 8192; + if ((D <= 1024 && D * sizeof(T) <= 4096) or use_softmax_warp_forward_resource_efficient) { return dispatch_warpwise_softmax_forward< CudaT_IN, CudaT_OUT, AccumulationType_t, is_log_softmax>( stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(N)); diff --git a/onnxruntime/core/providers/cuda/math/softmax_impl.cu b/onnxruntime/core/providers/cuda/math/softmax_impl.cu index ddf07803fc..a9e2b080ef 100644 --- a/onnxruntime/core/providers/cuda/math/softmax_impl.cu +++ b/onnxruntime/core/providers/cuda/math/softmax_impl.cu @@ -39,67 +39,59 @@ Status dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, con // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. int warp_size = (next_power_of_two < GPU_WARP_SIZE_HOST) ? next_power_of_two : GPU_WARP_SIZE_HOST; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int threads_per_block, shared_memory_size; + // there are 2 options to save one row of the input matrix: register or shared memory + // when the number of elements is small, we use register; otherwise, we use shared memory; int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - + if (log2_elements <= 10){ + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + // use 128 threads per block to maximimize gpu utilization + threads_per_block = 128; + shared_memory_size = 0; + } else{ + // setting the number of threads per block to 32 will make index offset calculations easier, + // under this setting, the cuda block number will be equal to batch size. + threads_per_block = 32; + // use shared memory to contain one row of elements + // TODO: one more optimization can be done here: we actually not need to save next_power_of_two elements, we can just save the valid elements + shared_memory_size = next_power_of_two * sizeof(input_t); + } int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; int blocks = (batch_count + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - softmax_warp_forward - <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } + +#define LAUNCH_KERNEL(kernel_name, log2_elements_value) \ + kernel_name \ + <<>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); + +#define CASE_LOG2_ELEMENTS(log2_elements_value) \ + case log2_elements_value: { \ + if constexpr (log2_elements_value <= 10) { \ + LAUNCH_KERNEL(softmax_warp_forward, log2_elements_value) \ + } else { \ + LAUNCH_KERNEL(softmax_warp_forward_resource_efficient, log2_elements_value) \ + } \ + } break + + CASE_LOG2_ELEMENTS(0); + CASE_LOG2_ELEMENTS(1); + CASE_LOG2_ELEMENTS(2); + CASE_LOG2_ELEMENTS(3); + CASE_LOG2_ELEMENTS(4); + CASE_LOG2_ELEMENTS(5); + CASE_LOG2_ELEMENTS(6); + CASE_LOG2_ELEMENTS(7); + CASE_LOG2_ELEMENTS(8); + CASE_LOG2_ELEMENTS(9); + CASE_LOG2_ELEMENTS(10); + CASE_LOG2_ELEMENTS(11); // start to use softmax_warp_forward_resource_efficient instead of softmax_warp_forward for better performance +#undef LAUNCH_KERNEL +#undef CASE_LOG2_ELEMENTS + } // switch + } // else return CUDA_CALL(cudaGetLastError()); } diff --git a/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh b/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh index 448e516f1c..c1b3d6ada8 100644 --- a/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh @@ -163,5 +163,74 @@ __global__ void softmax_warp_forward(output_t* dst, const input_t* src, int batc } } + +// softmax_warp_forward uses register to store data in fp32 even when data is fp16, which will cause register resource oversubscription when data is large, +// and will lead to low CUDA warp occupancy and thus a poor kernel performance. +// softmax_warp_forward_resource_efficient is implemented to solve the issue, it caches data in original data type, and casts it into fp32 when needed, +// the idea is like we use recomputation to save resource usage. +template +__global__ void softmax_warp_forward_resource_efficient(output_t* dst, const input_t* src, int batch_size, int stride, int element_count) { + // 1 cuda block only processes one row and contains 1 cuda warp only. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + + int local_idx = threadIdx.x; + src += blockIdx.x * stride + local_idx; + dst += blockIdx.x * stride + local_idx; + extern __shared__ unsigned char smem[]; + input_t (&elements)[WARP_ITERATIONS][WARP_SIZE] = *reinterpret_cast(smem); +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < element_count) { + elements[it][local_idx] = src[it * WARP_SIZE]; + } else { + elements[it][local_idx] = -std::numeric_limits::infinity(); + } + } + // compute max_value + input_t max_value = elements[0][local_idx]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value = (max_value > elements[it][local_idx]) ? max_value : elements[it][local_idx]; + } + warp_reduce(&max_value); + // compute sum + acc_t sum{0.0f}; + // #pragma unroll + // "exp" contains many instructions, if we unroll the loop then cuda warp will be stalled because of icache miss and thus lower the perf. + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index >= element_count) + break; + if (is_log_softmax) { + sum += std::exp((float)(elements[it][local_idx] - max_value)); + } else { + acc_t tmp = std::exp((float)(elements[it][local_idx] - max_value)); + elements[it][local_idx] = tmp; + sum += tmp; + } + } + warp_reduce(&sum); + // store result + if (is_log_softmax) sum = static_cast(max_value) + std::log((float)(sum)); + // do the reciprocal once, so the div operation can be replaced by mul. + acc_t invsum = static_cast(1.0f / sum); +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * WARP_SIZE; + if (element_index < element_count) { + if (is_log_softmax) { + dst[it * WARP_SIZE] = (float)elements[it][local_idx] - sum; + } else { + dst[it * WARP_SIZE] = (float)elements[it][local_idx] * invsum; + } + } else { + break; + } + } +} + } // namespace cuda } // namespace onnxruntime