From d3785ef8f623d06ee110a7457b925156bd551be6 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 23 Feb 2023 21:58:28 -0800 Subject: [PATCH] Fix decoder_attention scratch buffer size and prefast warnings (#14808) (1) Change `GetScratchBuffer(element_count * element_size)` to `GetScratchBuffer(element_count)` to avoid allocating more memory than needed. (2) Fix prefast:Warning C26451: Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. --- .../contrib_ops/cuda/bert/decoder_attention.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index d48716109e..ed0c8f8115 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -259,7 +259,8 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { has_key_padding_mask_)); // calculate q - gemm_query_buffer_p = GetScratchBuffer(batch_size * sequence_length * hidden_size * element_size, context->GetComputeStream()); + gemm_query_buffer_p = GetScratchBuffer(static_cast(batch_size) * sequence_length * hidden_size, + context->GetComputeStream()); m = sequence_length * batch_size; n = hidden_size; k = hidden_size; @@ -284,7 +285,8 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { k = hidden_size; if (!has_layer_state_ || !use_past_) { if (!static_kv_) { - gemm_kv_buffer_p = GetScratchBuffer(batch_size * 2 * sequence_length * hidden_size * element_size, context->GetComputeStream()); + gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * sequence_length * hidden_size, + context->GetComputeStream()); m = sequence_length * batch_size; n = 2 * hidden_size; k = hidden_size; @@ -303,7 +305,8 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } else { - gemm_kv_buffer_p = GetScratchBuffer(batch_size * 2 * key_sequence_length * hidden_size * element_size, context->GetComputeStream()); + gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * key_sequence_length * hidden_size, + context->GetComputeStream()); m = key_sequence_length * batch_size; n = 2 * hidden_size; k = hidden_size; @@ -328,7 +331,8 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { // key and value cache have identical shape int cache_sequence_length = static_cast(cache_shape[2]); if (!static_kv_) { - gemm_kv_buffer_p = GetScratchBuffer(batch_size * 2 * sequence_length * hidden_size * element_size, context->GetComputeStream()); + gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * sequence_length * hidden_size, + context->GetComputeStream()); m = sequence_length * batch_size; kv_sequence_length = cache_sequence_length + sequence_length; // broadcast bias for key and value: (2*h2, T_S*B)