Fix integer overflow in LongformerAttention (#12435)

fix integer overflow
This commit is contained in:
Tianlei Wu 2022-08-03 10:29:07 -07:00 committed by GitHub
parent 44ec2cf088
commit 97a340bf48
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -29,6 +29,7 @@ limitations under the License.
#include "longformer_attention_impl.h"
#include "attention_impl.h"
#include "longformer_attention_softmax.h"
#include "core/common/safeint.h"
using namespace onnxruntime::cuda;
using namespace cub;
@ -62,7 +63,7 @@ namespace cuda {
// [scratch1: BxNxSxS] [scratch2: BxNxSxS]
size_t GetScratch1Size(size_t element_size, int batch_size, int num_heads, int sequence_length, int window) {
return (5 * sequence_length - 3 * window) * window * num_heads * batch_size * element_size;
return SafeInt<size_t>(5 * sequence_length - 3 * window) * window * num_heads * batch_size * element_size;
}
constexpr size_t GetScratch2Size() {
@ -81,7 +82,7 @@ size_t GetLongformerSoftmaxWorkspaceSize(
size_t scratch2_size = 10 * (sizeof(void*) + sizeof(size_t));
return scratch1_size + scratch2_size;
} else {
return 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, sequence_length);
return SafeInt<size_t>(2) * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, sequence_length);
}
}
@ -100,7 +101,7 @@ size_t GetLongformerAttentionWorkspaceSize(
sequence_length,
window,
disable_compact_memory);
size_t qkv_size = 3 * batch_size * sequence_length * num_heads * head_size * element_size;
size_t qkv_size = SafeInt<size_t>(3) * batch_size * sequence_length * num_heads * head_size * element_size;
size_t global_qkv_size = max_num_global > 0 ? qkv_size : 0;
return softmax_size + qkv_size + global_qkv_size;
}