Fix buffer overrun in 4b dequant cuda (#18780)

### Description
Bugfix: Dequantize4BitsKernel buffer overrun when the input matrix has
less than the number of blocks that a single thread block can handle.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Chen Fu 2023-12-11 15:05:41 -08:00 committed by GitHub
parent ce1fed6ddf
commit 68c832d53b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -64,8 +64,12 @@ __global__ void Dequantize4BitsKernel(
int block_size,
int blocks_per_K,
int blocks_per_threadblock,
int total_blks,
int shift) {
int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift);
if (block_id >= total_blks) {
return;
}
int n_idx = block_id / blocks_per_K;
int kb_idx = block_id % blocks_per_K;
int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1));
@ -96,6 +100,7 @@ Status Dequantize4Bits(
constexpr int element_per_thread = 8;
int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size;
int blocks_per_K = k / block_size;
int total_blks = n * blocks_per_K;
int blocks_per_grid = static_cast<int>(CeilDiv(n * blocks_per_K, blocks_per_threadblock));
int shift = static_cast<int>(log2f(float(block_size)));
@ -107,6 +112,7 @@ Status Dequantize4Bits(
block_size,
blocks_per_K,
blocks_per_threadblock,
total_blks,
shift);
return Status::OK();