Improve softmax's perf in cuda (#144679)

Fixes #144645

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144679
Approved by: https://github.com/eqy
This commit is contained in:
Wenqin Yang 2025-01-23 00:02:57 +00:00 committed by PyTorch MergeBot
parent d0a2e11284
commit 1e32842324
2 changed files with 109 additions and 4 deletions

View file

@ -302,7 +302,7 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked>
void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr, int chunk_size = -1, bool is_transformer_mask = false)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
@ -342,7 +342,8 @@ void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_ele
LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128
LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256
LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512
LAUNCH_SOFTMAX_WARP_FORWARD(10); ; // 1024
LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024
LAUNCH_SOFTMAX_WARP_FORWARD(11); // 2048
default:
break;
}

View file

@ -493,6 +493,12 @@ ilpReduce(index_t shift,
return threadVal;
}
int32_t potential_register_count(int32_t dim_size, int32_t thread_count){
// This method calculate the potential register count for ilpReduce method (it's just a rough number).
int reg_cnt = (dim_size + thread_count - 1) / thread_count;
return reg_cnt;
}
/**
* This will apply the Epilogue with vectorized reads & writes when input & output have the same shift
*/
@ -694,6 +700,61 @@ cunn_SoftMaxForward(outscalar_t *output, const scalar_t *input, int classes)
}
}
template <typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue, typename index_t, int32_t reg_cnt>
__global__ void
cunn_SoftMaxForwardReg(outscalar_t *output, const scalar_t *input, index_t classes)
{
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accscalar_t*>(smem);
scalar_t reg[reg_cnt];
input += static_cast<int64_t>(blockIdx.x) * classes;
output += static_cast<int64_t>(blockIdx.x) * classes;
accscalar_t threadMax = -at::numeric_limits<accscalar_t>::max();
accscalar_t threadExp = static_cast<accscalar_t>(0);
// Load the elements from gmem into reg, and get the max for current thread.
MaxFloat<scalar_t, accscalar_t> maxFunc;
#pragma unroll
for(int reg_idx = 0; reg_idx < reg_cnt; reg_idx ++){
int offset = threadIdx.x + reg_idx * blockDim.x;
if(offset < classes) {
reg[reg_idx] = input[offset];
threadMax = maxFunc(threadMax, reg[reg_idx]);
}
}
// Reduce to the max for block
accscalar_t max_k = blockReduceWarp<Max, accscalar_t>(sdata, threadMax,
Max<accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
SumExpFloat<scalar_t, accscalar_t> sumExpFunc(max_k);
// reduce all values
#pragma unroll
for(int reg_idx = 0; reg_idx < reg_cnt; reg_idx ++){
int offset = threadIdx.x + reg_idx * blockDim.x;
if(offset < classes) {
threadExp = sumExpFunc(threadExp, reg[reg_idx]);
}
}
accscalar_t sumAll = blockReduceWarp<Add, accscalar_t>(sdata, threadExp,
Add<accscalar_t>(), static_cast<accscalar_t>(0));
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);
// Write back the value
#pragma unroll
for(int reg_idx = 0; reg_idx < reg_cnt; reg_idx ++){
int offset = threadIdx.x + reg_idx * blockDim.x;
if(offset < classes) {
output[offset] = epilogue(reg[reg_idx]);
}
}
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t,
template <typename, typename, typename> class Epilogue, typename index_t = int32_t>
__global__ void
@ -846,7 +907,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
if (!half_to_float) {
auto output_ptr = output.mutable_data_ptr<scalar_t>();
auto input_ptr = input.const_data_ptr<scalar_t>();
if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
if (dim_size <= 2048 && dim_size*sizeof(scalar_t) <= 8192) {
int64_t remaining = outer_size;
int64_t chunk_size = (1L << 30L) / dim_size;
while(remaining > 0) {
@ -868,7 +929,50 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
can_use_smem &= (!(reinterpret_cast<uintptr_t>(output_ptr) % ALIGN_BYTES));
can_use_smem &= !(dim_size % ILP);
if (can_use_smem) {
int32_t potential_reg_cnt = potential_register_count(dim_size, block.x);
if(potential_reg_cnt < 10){
TORCH_INTERNAL_ASSERT(potential_reg_cnt > 0, "potential_reg_cnt for softmax with register should be greater than 0.");
switch (potential_reg_cnt) {
// TODO(Wenqin): try to investigate why we couldn't use macro for below code,
// because it seems on MSVS, it seems the macro way didn't expand correct.
case 1:
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 1>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
break;
case 2:
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 2>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
break;
case 3:
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 3>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
break;
case 4:
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 4>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
break;
case 5:
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 5>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
break;
case 6:
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 6>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
break;
case 7:
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 7>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
break;
case 8:
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 8>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
break;
case 9:
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 9>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
break;
}
} else if (can_use_smem) {
size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz;
cunn_SoftMaxForwardSmem<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
<<<grid, block, smem_sz, stream>>>(output_ptr, input_ptr, dim_size);