diff --git a/orttraining/orttraining/core/framework/communication/mpi/mpi_context.cc b/orttraining/orttraining/core/framework/communication/mpi/mpi_context.cc index 87da7477c8..f7e7fdcaf9 100644 --- a/orttraining/orttraining/core/framework/communication/mpi/mpi_context.cc +++ b/orttraining/orttraining/core/framework/communication/mpi/mpi_context.cc @@ -4,6 +4,7 @@ #define SHARED_PROVIDER_TODO 0 #include "orttraining/core/framework/communication/mpi/mpi_context.h" +#include "core/common/safeint.h" #ifndef _WIN32 #include #include @@ -72,7 +73,9 @@ MPIContext::MPIContext() : world_rank_(0), int world_rank; MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &world_rank)); - int* ranks = (int*)malloc(sizeof(int) * world_size); + SafeInt alloc_size = world_size; + alloc_size *= sizeof(int); + int* ranks = (int*)malloc(alloc_size); MPI_Allgather(&world_rank, 1, MPI_INT, ranks, 1, MPI_INT, MPI_COMM_WORLD);