mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Fix decoder_attention scratch buffer size and prefast warnings (#14808)
(1) Change `GetScratchBuffer<T>(element_count * element_size)` to `GetScratchBuffer<T>(element_count)` to avoid allocating more memory than needed. (2) Fix prefast:Warning C26451: Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value.
This commit is contained in:
parent
928289c414
commit
d3785ef8f6
1 changed files with 8 additions and 4 deletions
|
|
@ -259,7 +259,8 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
has_key_padding_mask_));
|
||||
|
||||
// calculate q
|
||||
gemm_query_buffer_p = GetScratchBuffer<T>(batch_size * sequence_length * hidden_size * element_size, context->GetComputeStream());
|
||||
gemm_query_buffer_p = GetScratchBuffer<T>(static_cast<size_t>(batch_size) * sequence_length * hidden_size,
|
||||
context->GetComputeStream());
|
||||
m = sequence_length * batch_size;
|
||||
n = hidden_size;
|
||||
k = hidden_size;
|
||||
|
|
@ -284,7 +285,8 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
k = hidden_size;
|
||||
if (!has_layer_state_ || !use_past_) {
|
||||
if (!static_kv_) {
|
||||
gemm_kv_buffer_p = GetScratchBuffer<T>(batch_size * 2 * sequence_length * hidden_size * element_size, context->GetComputeStream());
|
||||
gemm_kv_buffer_p = GetScratchBuffer<T>(static_cast<size_t>(batch_size) * 2 * sequence_length * hidden_size,
|
||||
context->GetComputeStream());
|
||||
m = sequence_length * batch_size;
|
||||
n = 2 * hidden_size;
|
||||
k = hidden_size;
|
||||
|
|
@ -303,7 +305,8 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
|
||||
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
|
||||
} else {
|
||||
gemm_kv_buffer_p = GetScratchBuffer<T>(batch_size * 2 * key_sequence_length * hidden_size * element_size, context->GetComputeStream());
|
||||
gemm_kv_buffer_p = GetScratchBuffer<T>(static_cast<size_t>(batch_size) * 2 * key_sequence_length * hidden_size,
|
||||
context->GetComputeStream());
|
||||
m = key_sequence_length * batch_size;
|
||||
n = 2 * hidden_size;
|
||||
k = hidden_size;
|
||||
|
|
@ -328,7 +331,8 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
// key and value cache have identical shape
|
||||
int cache_sequence_length = static_cast<int>(cache_shape[2]);
|
||||
if (!static_kv_) {
|
||||
gemm_kv_buffer_p = GetScratchBuffer<T>(batch_size * 2 * sequence_length * hidden_size * element_size, context->GetComputeStream());
|
||||
gemm_kv_buffer_p = GetScratchBuffer<T>(static_cast<size_t>(batch_size) * 2 * sequence_length * hidden_size,
|
||||
context->GetComputeStream());
|
||||
m = sequence_length * batch_size;
|
||||
kv_sequence_length = cache_sequence_length + sequence_length;
|
||||
// broadcast bias for key and value: (2*h2, T_S*B)
|
||||
|
|
|
|||
Loading…
Reference in a new issue