From 7b61bca6dfd6c29f0969fbab3c16698edf638295 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 8 Oct 2021 11:29:28 -0700 Subject: [PATCH] Fix inclusive sum overlfow when applied on int8_t buffer in Compress (#9295) Use thrust::transform_iterator when feeding input to cub::DeviceScan::InclusiveScan() to make sure the accumulator type is wide enough not to overflow. --- .../core/providers/cuda/tensor/compress.cc | 8 +++--- .../providers/cuda/tensor/compress_impl.cu | 26 +++++++++++++++---- .../providers/cuda/tensor/compress_impl.h | 6 +++-- .../providers/cpu/tensor/compress_op.test.cc | 23 ++++++++++++++++ 4 files changed, 53 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/cuda/tensor/compress.cc b/onnxruntime/core/providers/cuda/tensor/compress.cc index 5c9537fdce..45281fb810 100644 --- a/onnxruntime/core/providers/cuda/tensor/compress.cc +++ b/onnxruntime/core/providers/cuda/tensor/compress.cc @@ -48,14 +48,16 @@ Status Compress::ComputeInternal(OpKernelContext* ctx) const { int64_t compress_input_length = has_axis_ ? input_dimensions[axis] : input_size; int64_t valid_condition_length = compress_input_length < condition_length ? compress_input_length : condition_length; - auto condition_cumulative_sum_buffer = GetScratchBuffer(valid_condition_length); + auto condition_cumulative_sum_buffer = GetScratchBuffer(gsl::narrow(valid_condition_length)); auto condition_cumulative_sum = condition_cumulative_sum_buffer.get(); + size_t temp_storage_bytes = 0; CUDA_RETURN_IF_ERROR(CompressCalcPrefixSumTempStorageBytes(Stream(), reinterpret_cast(condition_data), condition_cumulative_sum, - static_cast(valid_condition_length), + gsl::narrow(valid_condition_length), temp_storage_bytes)); + auto temp_buffer = GetScratchBuffer(temp_storage_bytes); auto d_temp_storage = temp_buffer.get(); CUDA_RETURN_IF_ERROR(CompressInclusivePrefixSum(Stream(), @@ -63,7 +65,7 @@ Status Compress::ComputeInternal(OpKernelContext* ctx) const { temp_storage_bytes, reinterpret_cast(condition_data), condition_cumulative_sum, - static_cast(valid_condition_length))); + gsl::narrow(valid_condition_length))); // cudaMemcpyAsync from device memory to pageable host memory will return only once the copy has completed. int32_t positive_condition_count = 0; diff --git a/onnxruntime/core/providers/cuda/tensor/compress_impl.cu b/onnxruntime/core/providers/cuda/tensor/compress_impl.cu index a3212cdf9e..b2c7b60866 100644 --- a/onnxruntime/core/providers/cuda/tensor/compress_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/compress_impl.cu @@ -13,16 +13,32 @@ #include "core/providers/cuda/tensor/compress_impl.h" +#include +#include + namespace onnxruntime { namespace cuda { -cudaError_t CompressCalcPrefixSumTempStorageBytes(cudaStream_t stream, const int8_t* condition_data, int* condition_cumulative_sum, int length, size_t& temp_storage_bytes) { - return cub::DeviceScan::InclusiveSum( - nullptr, temp_storage_bytes, condition_data, condition_cumulative_sum, length, stream); +// This cast is for transform iterator. This type affects the accumulator type width +// in InclusiveSum(). By default, the accumulator type matches the input, but for int8_t +// the sum overflows quickly, so we want the source type to match the output (int32_t). +// see https://github.com/NVIDIA/cub/issues/384 +struct CastToInt32 : public thrust::unary_function { + __host__ __device__ int32_t operator()(int8_t v) const { + return static_cast(v); + } +}; + +cudaError_t CompressCalcPrefixSumTempStorageBytes(cudaStream_t stream, const int8_t* condition_data, int32_t* condition_cumulative_sum, int length, size_t& temp_storage_bytes) { + auto input_iter = thrust::make_transform_iterator(condition_data, CastToInt32()); + return cub::DeviceScan::InclusiveSum( + nullptr, temp_storage_bytes, input_iter, condition_cumulative_sum, length, stream); } -cudaError_t CompressInclusivePrefixSum(cudaStream_t stream, void* d_temp_storage, size_t temp_storage_bytes, const int8_t* condition_data, int* condition_cumulative_sum, int length) { + +cudaError_t CompressInclusivePrefixSum(cudaStream_t stream, void* d_temp_storage, size_t temp_storage_bytes, const int8_t* condition_data, int32_t* condition_cumulative_sum, int length) { + auto input_iter = thrust::make_transform_iterator(condition_data, CastToInt32()); return cub::DeviceScan::InclusiveSum( - d_temp_storage, temp_storage_bytes, condition_data, condition_cumulative_sum, length, stream); + d_temp_storage, temp_storage_bytes, input_iter, condition_cumulative_sum, length, stream); } template diff --git a/onnxruntime/core/providers/cuda/tensor/compress_impl.h b/onnxruntime/core/providers/cuda/tensor/compress_impl.h index 3397841476..30b6b00447 100644 --- a/onnxruntime/core/providers/cuda/tensor/compress_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/compress_impl.h @@ -9,8 +9,10 @@ namespace onnxruntime { namespace cuda { -cudaError_t CompressCalcPrefixSumTempStorageBytes(cudaStream_t stream, const int8_t* condition_data, int* condition_cumulative_sum, int length, size_t& temp_storage_bytes); -cudaError_t CompressInclusivePrefixSum(cudaStream_t stream, void* d_temp_storage, size_t temp_storage_bytes, const int8_t* condition_data, int* condition_cumulative_sum, int length); +cudaError_t CompressCalcPrefixSumTempStorageBytes(cudaStream_t stream, const int8_t* condition_data, + int32_t* condition_cumulative_sum, int length, size_t& temp_storage_bytes); +cudaError_t CompressInclusivePrefixSum(cudaStream_t stream, void* d_temp_storage, size_t temp_storage_bytes, + const int8_t* condition_data, int32_t* condition_cumulative_sum, int length); Status CompressImpl(cudaStream_t stream, const size_t element_bytes, diff --git a/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc b/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc index 7f2782ca06..173dbcf8ce 100644 --- a/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc @@ -125,6 +125,29 @@ TEST(CompressTest, Compress_default_axis) { test.Run(); } +// Test that we accumulate to a buffer that does not overflow +TEST(CompressTest, Compress_default_axis_issue_9247_cumulative_sum_overflow) { + OpTester test("Compress", 9); + + // Generate input longer than 127 + constexpr size_t elements = 150; + std::vector input; + for (size_t i = 0; i < elements; ++i) { + input.push_back(static_cast(i)); + } + + // Let's select all of the elements + std::unique_ptr all_true = std::make_unique(elements); + std::fill_n(all_true.get(), elements, true); + std::vector output_shape{static_cast(elements)}; + + test.AddInput("input", {2, 75}, input); + test.AddInput("condition", output_shape, all_true.get(), elements); + // Should get all of the input + test.AddOutput("output", output_shape, input); + test.Run(); +} + TEST(CompressTest, Compress0_string) { OpTester test("Compress", 9);