Fix decoder_attention scratch buffer size and prefast warnings (#14808)

(1) Change `GetScratchBuffer<T>(element_count * element_size)` to
`GetScratchBuffer<T>(element_count)` to avoid allocating more memory
than needed.
(2) Fix prefast:Warning C26451: Arithmetic overflow: Using operator '*'
on a 4 byte value and then casting the result to a 8 byte value.
This commit is contained in:
Tianlei Wu 2023-02-23 21:58:28 -08:00 committed by GitHub
parent 928289c414
commit d3785ef8f6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -259,7 +259,8 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
has_key_padding_mask_));
// calculate q
gemm_query_buffer_p = GetScratchBuffer<T>(batch_size * sequence_length * hidden_size * element_size, context->GetComputeStream());
gemm_query_buffer_p = GetScratchBuffer<T>(static_cast<size_t>(batch_size) * sequence_length * hidden_size,
context->GetComputeStream());
m = sequence_length * batch_size;
n = hidden_size;
k = hidden_size;
@ -284,7 +285,8 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
k = hidden_size;
if (!has_layer_state_ || !use_past_) {
if (!static_kv_) {
gemm_kv_buffer_p = GetScratchBuffer<T>(batch_size * 2 * sequence_length * hidden_size * element_size, context->GetComputeStream());
gemm_kv_buffer_p = GetScratchBuffer<T>(static_cast<size_t>(batch_size) * 2 * sequence_length * hidden_size,
context->GetComputeStream());
m = sequence_length * batch_size;
n = 2 * hidden_size;
k = hidden_size;
@ -303,7 +305,8 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
} else {
gemm_kv_buffer_p = GetScratchBuffer<T>(batch_size * 2 * key_sequence_length * hidden_size * element_size, context->GetComputeStream());
gemm_kv_buffer_p = GetScratchBuffer<T>(static_cast<size_t>(batch_size) * 2 * key_sequence_length * hidden_size,
context->GetComputeStream());
m = key_sequence_length * batch_size;
n = 2 * hidden_size;
k = hidden_size;
@ -328,7 +331,8 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
// key and value cache have identical shape
int cache_sequence_length = static_cast<int>(cache_shape[2]);
if (!static_kv_) {
gemm_kv_buffer_p = GetScratchBuffer<T>(batch_size * 2 * sequence_length * hidden_size * element_size, context->GetComputeStream());
gemm_kv_buffer_p = GetScratchBuffer<T>(static_cast<size_t>(batch_size) * 2 * sequence_length * hidden_size,
context->GetComputeStream());
m = sequence_length * batch_size;
kv_sequence_length = cache_sequence_length + sequence_length;
// broadcast bias for key and value: (2*h2, T_S*B)