diff --git a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh index 4553276bab6..f0871fa0ead 100644 --- a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh +++ b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh @@ -302,7 +302,7 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, template void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr, int chunk_size = -1, bool is_transformer_mask = false) { - TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 ); + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); if (softmax_elements == 0) { return; } else { @@ -342,7 +342,8 @@ void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_ele LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128 LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256 LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512 - LAUNCH_SOFTMAX_WARP_FORWARD(10); ; // 1024 + LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024 + LAUNCH_SOFTMAX_WARP_FORWARD(11); // 2048 default: break; } diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index 066a8772beb..bc0c86f1c52 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -493,6 +493,12 @@ ilpReduce(index_t shift, return threadVal; } +int32_t potential_register_count(int32_t dim_size, int32_t thread_count){ + // This method calculate the potential register count for ilpReduce method (it's just a rough number). + int reg_cnt = (dim_size + thread_count - 1) / thread_count; + return reg_cnt; +} + /** * This will apply the Epilogue with vectorized reads & writes when input & output have the same shift */ @@ -694,6 +700,61 @@ cunn_SoftMaxForward(outscalar_t *output, const scalar_t *input, int classes) } } +template class Epilogue, typename index_t, int32_t reg_cnt> +__global__ void +cunn_SoftMaxForwardReg(outscalar_t *output, const scalar_t *input, index_t classes) +{ + extern __shared__ unsigned char smem[]; + auto sdata = reinterpret_cast(smem); + + scalar_t reg[reg_cnt]; + + input += static_cast(blockIdx.x) * classes; + output += static_cast(blockIdx.x) * classes; + + accscalar_t threadMax = -at::numeric_limits::max(); + accscalar_t threadExp = static_cast(0); + + // Load the elements from gmem into reg, and get the max for current thread. + MaxFloat maxFunc; + + #pragma unroll + for(int reg_idx = 0; reg_idx < reg_cnt; reg_idx ++){ + int offset = threadIdx.x + reg_idx * blockDim.x; + if(offset < classes) { + reg[reg_idx] = input[offset]; + threadMax = maxFunc(threadMax, reg[reg_idx]); + } + } + + // Reduce to the max for block + accscalar_t max_k = blockReduceWarp(sdata, threadMax, + Max(), -at::numeric_limits::max()); + + SumExpFloat sumExpFunc(max_k); + // reduce all values + #pragma unroll + for(int reg_idx = 0; reg_idx < reg_cnt; reg_idx ++){ + int offset = threadIdx.x + reg_idx * blockDim.x; + if(offset < classes) { + threadExp = sumExpFunc(threadExp, reg[reg_idx]); + } + } + accscalar_t sumAll = blockReduceWarp(sdata, threadExp, + Add(), static_cast(0)); + + Epilogue epilogue(max_k, sumAll); + + // Write back the value + #pragma unroll + for(int reg_idx = 0; reg_idx < reg_cnt; reg_idx ++){ + int offset = threadIdx.x + reg_idx * blockDim.x; + if(offset < classes) { + output[offset] = epilogue(reg[reg_idx]); + } + } +} + template class Epilogue, typename index_t = int32_t> __global__ void @@ -846,7 +907,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t if (!half_to_float) { auto output_ptr = output.mutable_data_ptr(); auto input_ptr = input.const_data_ptr(); - if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { + if (dim_size <= 2048 && dim_size*sizeof(scalar_t) <= 8192) { int64_t remaining = outer_size; int64_t chunk_size = (1L << 30L) / dim_size; while(remaining > 0) { @@ -868,7 +929,50 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); can_use_smem &= !(dim_size % ILP); - if (can_use_smem) { + int32_t potential_reg_cnt = potential_register_count(dim_size, block.x); + if(potential_reg_cnt < 10){ + TORCH_INTERNAL_ASSERT(potential_reg_cnt > 0, "potential_reg_cnt for softmax with register should be greater than 0."); + switch (potential_reg_cnt) { + // TODO(Wenqin): try to investigate why we couldn't use macro for below code, + // because it seems on MSVS, it seems the macro way didn't expand correct. + case 1: + cunn_SoftMaxForwardReg + <<>>(output_ptr, input_ptr, dim_size); + break; + case 2: + cunn_SoftMaxForwardReg + <<>>(output_ptr, input_ptr, dim_size); + break; + case 3: + cunn_SoftMaxForwardReg + <<>>(output_ptr, input_ptr, dim_size); + break; + case 4: + cunn_SoftMaxForwardReg + <<>>(output_ptr, input_ptr, dim_size); + break; + case 5: + cunn_SoftMaxForwardReg + <<>>(output_ptr, input_ptr, dim_size); + break; + case 6: + cunn_SoftMaxForwardReg + <<>>(output_ptr, input_ptr, dim_size); + break; + case 7: + cunn_SoftMaxForwardReg + <<>>(output_ptr, input_ptr, dim_size); + break; + case 8: + cunn_SoftMaxForwardReg + <<>>(output_ptr, input_ptr, dim_size); + break; + case 9: + cunn_SoftMaxForwardReg + <<>>(output_ptr, input_ptr, dim_size); + break; + } + } else if (can_use_smem) { size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; cunn_SoftMaxForwardSmem <<>>(output_ptr, input_ptr, dim_size);