diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index 7921315ab5..6b66f1d84e 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -64,8 +64,12 @@ __global__ void Dequantize4BitsKernel( int block_size, int blocks_per_K, int blocks_per_threadblock, + int total_blks, int shift) { int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); + if (block_id >= total_blks) { + return; + } int n_idx = block_id / blocks_per_K; int kb_idx = block_id % blocks_per_K; int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); @@ -96,6 +100,7 @@ Status Dequantize4Bits( constexpr int element_per_thread = 8; int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; int blocks_per_K = k / block_size; + int total_blks = n * blocks_per_K; int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock)); int shift = static_cast(log2f(float(block_size))); @@ -107,6 +112,7 @@ Status Dequantize4Bits( block_size, blocks_per_K, blocks_per_threadblock, + total_blks, shift); return Status::OK();