mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Fix integer overflow in LongformerAttention (#12435)
fix integer overflow
This commit is contained in:
parent
44ec2cf088
commit
97a340bf48
1 changed files with 4 additions and 3 deletions
|
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||
#include "longformer_attention_impl.h"
|
||||
#include "attention_impl.h"
|
||||
#include "longformer_attention_softmax.h"
|
||||
#include "core/common/safeint.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
using namespace cub;
|
||||
|
|
@ -62,7 +63,7 @@ namespace cuda {
|
|||
// [scratch1: BxNxSxS] [scratch2: BxNxSxS]
|
||||
|
||||
size_t GetScratch1Size(size_t element_size, int batch_size, int num_heads, int sequence_length, int window) {
|
||||
return (5 * sequence_length - 3 * window) * window * num_heads * batch_size * element_size;
|
||||
return SafeInt<size_t>(5 * sequence_length - 3 * window) * window * num_heads * batch_size * element_size;
|
||||
}
|
||||
|
||||
constexpr size_t GetScratch2Size() {
|
||||
|
|
@ -81,7 +82,7 @@ size_t GetLongformerSoftmaxWorkspaceSize(
|
|||
size_t scratch2_size = 10 * (sizeof(void*) + sizeof(size_t));
|
||||
return scratch1_size + scratch2_size;
|
||||
} else {
|
||||
return 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, sequence_length);
|
||||
return SafeInt<size_t>(2) * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, sequence_length);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -100,7 +101,7 @@ size_t GetLongformerAttentionWorkspaceSize(
|
|||
sequence_length,
|
||||
window,
|
||||
disable_compact_memory);
|
||||
size_t qkv_size = 3 * batch_size * sequence_length * num_heads * head_size * element_size;
|
||||
size_t qkv_size = SafeInt<size_t>(3) * batch_size * sequence_length * num_heads * head_size * element_size;
|
||||
size_t global_qkv_size = max_num_global > 0 ? qkv_size : 0;
|
||||
return softmax_size + qkv_size + global_qkv_size;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue