mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
using GPU_WARP_SIZE to make kernel portable between AMD and Nvidia GPU (#5173)
This commit is contained in:
parent
84589c7e05
commit
b49f6a5e2c
1 changed files with 12 additions and 17 deletions
|
|
@ -10,7 +10,6 @@
|
|||
#include "reduction_utils.cuh"
|
||||
|
||||
#define NUM_ELEMENTS_PER_THREAD 4
|
||||
#define NUM_THREADS_PER_WARP 32
|
||||
#define NUM_WARPS_PER_BLOCK 8
|
||||
#define MAX_NUM_BLOCKS 256
|
||||
|
||||
|
|
@ -21,8 +20,8 @@ namespace onnxruntime {
|
|||
namespace cuda {
|
||||
|
||||
std::pair<int, int> compute_block_size(int size) {
|
||||
int x = NUM_THREADS_PER_WARP;
|
||||
int y = std::min(NUM_WARPS_PER_BLOCK, std::max(1, size / (NUM_ELEMENTS_PER_THREAD * NUM_THREADS_PER_WARP)));
|
||||
int x = GPU_WARP_SIZE;
|
||||
int y = std::min(NUM_WARPS_PER_BLOCK, std::max(1, size / (NUM_ELEMENTS_PER_THREAD * GPU_WARP_SIZE)));
|
||||
return std::make_pair(x, y);
|
||||
}
|
||||
|
||||
|
|
@ -48,11 +47,11 @@ __global__ void reduce_all_kernel(const int size, const TIn * data, TOut* output
|
|||
|
||||
// Warp-level indexes:
|
||||
// Warp index of thread.
|
||||
const int wid_in_block = tid_in_block / NUM_THREADS_PER_WARP;
|
||||
const int wid_in_block = tid_in_block / GPU_WARP_SIZE;
|
||||
// Lane index of thread.
|
||||
const int lid_in_block = tid_in_block % NUM_THREADS_PER_WARP;
|
||||
const int lid_in_block = tid_in_block % GPU_WARP_SIZE;
|
||||
// Warp count per block.
|
||||
const int num_warps_in_block = num_threads_in_block / NUM_THREADS_PER_WARP;
|
||||
const int num_warps_in_block = num_threads_in_block / GPU_WARP_SIZE;
|
||||
|
||||
// Grid-level indexes:
|
||||
// Linear index of block in grid.
|
||||
|
|
@ -98,7 +97,7 @@ __global__ void reduce_all_kernel(const int size, const TIn * data, TOut* output
|
|||
// reduction, each block holds num_warps_in_block values in the shared memory.
|
||||
TOut value_ = value;
|
||||
#pragma unroll
|
||||
for (int stride = NUM_THREADS_PER_WARP / 2; stride > 0; stride /= 2) {
|
||||
for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {
|
||||
value_ += WARP_SHFL_DOWN(value_, stride);
|
||||
}
|
||||
|
||||
|
|
@ -211,7 +210,7 @@ void call_reduce_all_kernel(const TIn *data, TOut *output, int size, TOut *buffe
|
|||
cudaMemset(buffer + num_blocks, 0, sizeof(int));
|
||||
}
|
||||
|
||||
const int shared_mem_size = sizeof(TOut) * block_size.first * block_size.second / NUM_THREADS_PER_WARP;
|
||||
const int shared_mem_size = sizeof(TOut) * block_size.first * block_size.second / GPU_WARP_SIZE;
|
||||
reduce_all_kernel<TIn, TOut, TOp, TFinalOp, DivideResultBySize><<<grid, block, shared_mem_size>>>(size, data, output, buffer);
|
||||
}
|
||||
|
||||
|
|
@ -313,12 +312,11 @@ __global__ void reduce_matrix_rows_kernel(const TIn* input, TOut* output, int m,
|
|||
|
||||
for (int col = tid_x_in_grid; col < n; col += x_grid_stride) {
|
||||
shared_memory[tid_in_block] = TBuf(0.0f);
|
||||
|
||||
TBuf sum = TBuf(0.0f);
|
||||
// This loops load multiple blockDim.y-by-blockDim.x sub-tensors from the input.
|
||||
for (int row = tid_y_in_grid; row < m; row += y_grid_stride) {
|
||||
TBuf sum = 0.0f;
|
||||
// Thread-level reduction. Each thread loads y_load_count_per_thread values
|
||||
// and aggregrate them.
|
||||
// Thread-level reduction. Each thread loads y_load_count_per_thread values
|
||||
// and aggregrate them.
|
||||
#pragma unroll(y_load_count_per_thread)
|
||||
for (int row_inner = 0; row_inner < y_load_count_per_thread; ++row_inner) {
|
||||
int row_final = row + row_inner * t_count_y_in_grid;
|
||||
|
|
@ -327,9 +325,9 @@ __global__ void reduce_matrix_rows_kernel(const TIn* input, TOut* output, int m,
|
|||
sum += TBuf(input[row_final * n_int64 + col_final]);
|
||||
}
|
||||
}
|
||||
// Write thread-level reduction result into shared memory.
|
||||
shared_memory[tid_in_block] += sum;
|
||||
}
|
||||
// Write thread-level reduction result into shared memory.
|
||||
shared_memory[tid_in_block] = sum;
|
||||
|
||||
// Wait all threads to finish their thread-level reductions.
|
||||
__syncthreads();
|
||||
|
|
@ -347,9 +345,6 @@ __global__ void reduce_matrix_rows_kernel(const TIn* input, TOut* output, int m,
|
|||
if (threadIdx.y == 0) {
|
||||
atomic_add(output + col, TOut(shared_memory[threadIdx.x]));
|
||||
}
|
||||
|
||||
// Make sure all values in shared memory have been written into the output memory.
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue