diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index d48716109e..ed0c8f8115 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -259,7 +259,8 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { has_key_padding_mask_)); // calculate q - gemm_query_buffer_p = GetScratchBuffer(batch_size * sequence_length * hidden_size * element_size, context->GetComputeStream()); + gemm_query_buffer_p = GetScratchBuffer(static_cast(batch_size) * sequence_length * hidden_size, + context->GetComputeStream()); m = sequence_length * batch_size; n = hidden_size; k = hidden_size; @@ -284,7 +285,8 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { k = hidden_size; if (!has_layer_state_ || !use_past_) { if (!static_kv_) { - gemm_kv_buffer_p = GetScratchBuffer(batch_size * 2 * sequence_length * hidden_size * element_size, context->GetComputeStream()); + gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * sequence_length * hidden_size, + context->GetComputeStream()); m = sequence_length * batch_size; n = 2 * hidden_size; k = hidden_size; @@ -303,7 +305,8 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } else { - gemm_kv_buffer_p = GetScratchBuffer(batch_size * 2 * key_sequence_length * hidden_size * element_size, context->GetComputeStream()); + gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * key_sequence_length * hidden_size, + context->GetComputeStream()); m = key_sequence_length * batch_size; n = 2 * hidden_size; k = hidden_size; @@ -328,7 +331,8 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { // key and value cache have identical shape int cache_sequence_length = static_cast(cache_shape[2]); if (!static_kv_) { - gemm_kv_buffer_p = GetScratchBuffer(batch_size * 2 * sequence_length * hidden_size * element_size, context->GetComputeStream()); + gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * sequence_length * hidden_size, + context->GetComputeStream()); m = sequence_length * batch_size; kv_sequence_length = cache_sequence_length + sequence_length; // broadcast bias for key and value: (2*h2, T_S*B)