mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Fix buffer overrun in 4b dequant cuda (#18780)
### Description Bugfix: Dequantize4BitsKernel buffer overrun when the input matrix has less than the number of blocks that a single thread block can handle. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
ce1fed6ddf
commit
68c832d53b
1 changed files with 6 additions and 0 deletions
|
|
@ -64,8 +64,12 @@ __global__ void Dequantize4BitsKernel(
|
|||
int block_size,
|
||||
int blocks_per_K,
|
||||
int blocks_per_threadblock,
|
||||
int total_blks,
|
||||
int shift) {
|
||||
int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift);
|
||||
if (block_id >= total_blks) {
|
||||
return;
|
||||
}
|
||||
int n_idx = block_id / blocks_per_K;
|
||||
int kb_idx = block_id % blocks_per_K;
|
||||
int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1));
|
||||
|
|
@ -96,6 +100,7 @@ Status Dequantize4Bits(
|
|||
constexpr int element_per_thread = 8;
|
||||
int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size;
|
||||
int blocks_per_K = k / block_size;
|
||||
int total_blks = n * blocks_per_K;
|
||||
int blocks_per_grid = static_cast<int>(CeilDiv(n * blocks_per_K, blocks_per_threadblock));
|
||||
int shift = static_cast<int>(log2f(float(block_size)));
|
||||
|
||||
|
|
@ -107,6 +112,7 @@ Status Dequantize4Bits(
|
|||
block_size,
|
||||
blocks_per_K,
|
||||
blocks_per_threadblock,
|
||||
total_blks,
|
||||
shift);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
Loading…
Reference in a new issue