diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index f79d4c5d6e..aaa764cf2e 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -4,6 +4,7 @@ #include "attention_quantization.h" #include "attention_quantization_impl.cuh" #include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" #include "core/providers/cuda/shared_inc/integer_gemm.h" @@ -112,18 +113,18 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { past_tensor)); const auto& shape = input->Shape(); - int batch_size = static_cast(shape[0]); - int sequence_length = static_cast(shape[1]); - int input_hidden_size = static_cast(shape[2]); + int batch_size = SafeInt(shape[0]); + int sequence_length = SafeInt(shape[1]); + int input_hidden_size = SafeInt(shape[2]); const auto& bias_shape = bias->Shape(); - const int hidden_size = static_cast(bias_shape.GetDims()[0]) / 3; + const int hidden_size = SafeInt(bias_shape.GetDims()[0]) / 3; const int head_size = hidden_size / num_heads_; TensorShapeVector output_shape(3); output_shape[0] = shape[0]; output_shape[1] = shape[1]; - output_shape[2] = static_cast(hidden_size); + output_shape[2] = SafeInt(hidden_size); Tensor* output = context->Output(0, output_shape); cublasHandle_t cublas = CublasHandle(); @@ -133,8 +134,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { int m = batch_size * sequence_length; int n = 3 * hidden_size; int k = input_hidden_size; - auto gemm_buffer = GetScratchBuffer(batch_size * sequence_length * 3 * hidden_size * element_size); - auto gemm_buffer_quantized = GetScratchBuffer(batch_size * sequence_length * 3 * hidden_size); + auto gemm_buffer = GetScratchBuffer(SafeInt(batch_size) * sequence_length * 3 * hidden_size * element_size); + auto gemm_buffer_quantized = GetScratchBuffer(SafeInt(batch_size) * sequence_length * 3 * hidden_size); typedef typename ToCudaType::MappedType CudaT; diff --git a/onnxruntime/core/providers/cuda/integer_gemm.cc b/onnxruntime/core/providers/cuda/integer_gemm.cc index 09058edca7..7529735abf 100644 --- a/onnxruntime/core/providers/cuda/integer_gemm.cc +++ b/onnxruntime/core/providers/cuda/integer_gemm.cc @@ -3,13 +3,14 @@ #include "core/providers/cuda/shared_inc/integer_gemm.h" +#include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/cuda_call.h" namespace onnxruntime { namespace cuda { -inline int roundoff(int v, int d) { +constexpr int roundoff(int v, int d) { return (v + d - 1) / d * d; } @@ -27,12 +28,12 @@ Status GemmInt8(int m, int n, int k, // 1. leading dimension is multiples of 4 // 2. A, B is 32-bit aligned - const int mask = 0x1F; + constexpr int mask = 0x1F; int lda_aligned = lda; IAllocatorUniquePtr a_padded; if ((mask & lda_aligned) != 0) { lda_aligned = roundoff(lda, 32); - a_padded = cuda_kernel->GetScratchBuffer(m * lda_aligned); + a_padded = cuda_kernel->GetScratchBuffer(SafeInt(m) * lda_aligned); cudaMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, cudaMemcpyDeviceToDevice, stream); } @@ -40,7 +41,7 @@ Status GemmInt8(int m, int n, int k, IAllocatorUniquePtr b_padded; if ((mask & ldb_aligned) != 0) { ldb_aligned = roundoff(ldb, 32); - b_padded = cuda_kernel->GetScratchBuffer(k * ldb_aligned); + b_padded = cuda_kernel->GetScratchBuffer(SafeInt(k) * ldb_aligned); cudaMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, cudaMemcpyDeviceToDevice, stream); }