mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Improve softmax's perf in cuda (#144679)
Fixes #144645 Pull Request resolved: https://github.com/pytorch/pytorch/pull/144679 Approved by: https://github.com/eqy
This commit is contained in:
parent
d0a2e11284
commit
1e32842324
2 changed files with 109 additions and 4 deletions
|
|
@ -302,7 +302,7 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad,
|
|||
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked>
|
||||
void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr, int chunk_size = -1, bool is_transformer_mask = false)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
|
||||
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
|
|
@ -342,7 +342,8 @@ void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_ele
|
|||
LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128
|
||||
LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256
|
||||
LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512
|
||||
LAUNCH_SOFTMAX_WARP_FORWARD(10); ; // 1024
|
||||
LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024
|
||||
LAUNCH_SOFTMAX_WARP_FORWARD(11); // 2048
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -493,6 +493,12 @@ ilpReduce(index_t shift,
|
|||
return threadVal;
|
||||
}
|
||||
|
||||
int32_t potential_register_count(int32_t dim_size, int32_t thread_count){
|
||||
// This method calculate the potential register count for ilpReduce method (it's just a rough number).
|
||||
int reg_cnt = (dim_size + thread_count - 1) / thread_count;
|
||||
return reg_cnt;
|
||||
}
|
||||
|
||||
/**
|
||||
* This will apply the Epilogue with vectorized reads & writes when input & output have the same shift
|
||||
*/
|
||||
|
|
@ -694,6 +700,61 @@ cunn_SoftMaxForward(outscalar_t *output, const scalar_t *input, int classes)
|
|||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue, typename index_t, int32_t reg_cnt>
|
||||
__global__ void
|
||||
cunn_SoftMaxForwardReg(outscalar_t *output, const scalar_t *input, index_t classes)
|
||||
{
|
||||
extern __shared__ unsigned char smem[];
|
||||
auto sdata = reinterpret_cast<accscalar_t*>(smem);
|
||||
|
||||
scalar_t reg[reg_cnt];
|
||||
|
||||
input += static_cast<int64_t>(blockIdx.x) * classes;
|
||||
output += static_cast<int64_t>(blockIdx.x) * classes;
|
||||
|
||||
accscalar_t threadMax = -at::numeric_limits<accscalar_t>::max();
|
||||
accscalar_t threadExp = static_cast<accscalar_t>(0);
|
||||
|
||||
// Load the elements from gmem into reg, and get the max for current thread.
|
||||
MaxFloat<scalar_t, accscalar_t> maxFunc;
|
||||
|
||||
#pragma unroll
|
||||
for(int reg_idx = 0; reg_idx < reg_cnt; reg_idx ++){
|
||||
int offset = threadIdx.x + reg_idx * blockDim.x;
|
||||
if(offset < classes) {
|
||||
reg[reg_idx] = input[offset];
|
||||
threadMax = maxFunc(threadMax, reg[reg_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce to the max for block
|
||||
accscalar_t max_k = blockReduceWarp<Max, accscalar_t>(sdata, threadMax,
|
||||
Max<accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
|
||||
|
||||
SumExpFloat<scalar_t, accscalar_t> sumExpFunc(max_k);
|
||||
// reduce all values
|
||||
#pragma unroll
|
||||
for(int reg_idx = 0; reg_idx < reg_cnt; reg_idx ++){
|
||||
int offset = threadIdx.x + reg_idx * blockDim.x;
|
||||
if(offset < classes) {
|
||||
threadExp = sumExpFunc(threadExp, reg[reg_idx]);
|
||||
}
|
||||
}
|
||||
accscalar_t sumAll = blockReduceWarp<Add, accscalar_t>(sdata, threadExp,
|
||||
Add<accscalar_t>(), static_cast<accscalar_t>(0));
|
||||
|
||||
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);
|
||||
|
||||
// Write back the value
|
||||
#pragma unroll
|
||||
for(int reg_idx = 0; reg_idx < reg_cnt; reg_idx ++){
|
||||
int offset = threadIdx.x + reg_idx * blockDim.x;
|
||||
if(offset < classes) {
|
||||
output[offset] = epilogue(reg[reg_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t,
|
||||
template <typename, typename, typename> class Epilogue, typename index_t = int32_t>
|
||||
__global__ void
|
||||
|
|
@ -846,7 +907,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
|
|||
if (!half_to_float) {
|
||||
auto output_ptr = output.mutable_data_ptr<scalar_t>();
|
||||
auto input_ptr = input.const_data_ptr<scalar_t>();
|
||||
if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
|
||||
if (dim_size <= 2048 && dim_size*sizeof(scalar_t) <= 8192) {
|
||||
int64_t remaining = outer_size;
|
||||
int64_t chunk_size = (1L << 30L) / dim_size;
|
||||
while(remaining > 0) {
|
||||
|
|
@ -868,7 +929,50 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
|
|||
can_use_smem &= (!(reinterpret_cast<uintptr_t>(output_ptr) % ALIGN_BYTES));
|
||||
can_use_smem &= !(dim_size % ILP);
|
||||
|
||||
if (can_use_smem) {
|
||||
int32_t potential_reg_cnt = potential_register_count(dim_size, block.x);
|
||||
if(potential_reg_cnt < 10){
|
||||
TORCH_INTERNAL_ASSERT(potential_reg_cnt > 0, "potential_reg_cnt for softmax with register should be greater than 0.");
|
||||
switch (potential_reg_cnt) {
|
||||
// TODO(Wenqin): try to investigate why we couldn't use macro for below code,
|
||||
// because it seems on MSVS, it seems the macro way didn't expand correct.
|
||||
case 1:
|
||||
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 1>
|
||||
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
|
||||
break;
|
||||
case 2:
|
||||
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 2>
|
||||
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
|
||||
break;
|
||||
case 3:
|
||||
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 3>
|
||||
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
|
||||
break;
|
||||
case 4:
|
||||
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 4>
|
||||
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
|
||||
break;
|
||||
case 5:
|
||||
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 5>
|
||||
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
|
||||
break;
|
||||
case 6:
|
||||
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 6>
|
||||
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
|
||||
break;
|
||||
case 7:
|
||||
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 7>
|
||||
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
|
||||
break;
|
||||
case 8:
|
||||
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 8>
|
||||
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
|
||||
break;
|
||||
case 9:
|
||||
cunn_SoftMaxForwardReg<scalar_t, accscalar_t, scalar_t, Epilogue, int64_t, 9>
|
||||
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
|
||||
break;
|
||||
}
|
||||
} else if (can_use_smem) {
|
||||
size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz;
|
||||
cunn_SoftMaxForwardSmem<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
|
||||
<<<grid, block, smem_sz, stream>>>(output_ptr, input_ptr, dim_size);
|
||||
|
|
|
|||
Loading…
Reference in a new issue