diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu index 1045bd92c7..f301c80d5a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu @@ -29,6 +29,7 @@ limitations under the License. #include "longformer_attention_impl.h" #include "attention_impl.h" #include "longformer_attention_softmax.h" +#include "core/common/safeint.h" using namespace onnxruntime::cuda; using namespace cub; @@ -62,7 +63,7 @@ namespace cuda { // [scratch1: BxNxSxS] [scratch2: BxNxSxS] size_t GetScratch1Size(size_t element_size, int batch_size, int num_heads, int sequence_length, int window) { - return (5 * sequence_length - 3 * window) * window * num_heads * batch_size * element_size; + return SafeInt(5 * sequence_length - 3 * window) * window * num_heads * batch_size * element_size; } constexpr size_t GetScratch2Size() { @@ -81,7 +82,7 @@ size_t GetLongformerSoftmaxWorkspaceSize( size_t scratch2_size = 10 * (sizeof(void*) + sizeof(size_t)); return scratch1_size + scratch2_size; } else { - return 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, sequence_length); + return SafeInt(2) * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, sequence_length); } } @@ -100,7 +101,7 @@ size_t GetLongformerAttentionWorkspaceSize( sequence_length, window, disable_compact_memory); - size_t qkv_size = 3 * batch_size * sequence_length * num_heads * head_size * element_size; + size_t qkv_size = SafeInt(3) * batch_size * sequence_length * num_heads * head_size * element_size; size_t global_qkv_size = max_num_global > 0 ? qkv_size : 0; return softmax_size + qkv_size + global_qkv_size; }