Fix inclusive sum overlfow when applied on int8_t buffer in Compress (#9295)

Use thrust::transform_iterator when feeding input to cub::DeviceScan::InclusiveScan() to make sure the accumulator type is wide enough not to overflow.
This commit is contained in:
Dmitri Smirnov 2021-10-08 11:29:28 -07:00 committed by GitHub
parent 29379db432
commit 7b61bca6df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 10 deletions

View file

@ -48,14 +48,16 @@ Status Compress::ComputeInternal(OpKernelContext* ctx) const {
int64_t compress_input_length = has_axis_ ? input_dimensions[axis] : input_size;
int64_t valid_condition_length = compress_input_length < condition_length ? compress_input_length : condition_length;
auto condition_cumulative_sum_buffer = GetScratchBuffer<int32_t>(valid_condition_length);
auto condition_cumulative_sum_buffer = GetScratchBuffer<int32_t>(gsl::narrow<size_t>(valid_condition_length));
auto condition_cumulative_sum = condition_cumulative_sum_buffer.get();
size_t temp_storage_bytes = 0;
CUDA_RETURN_IF_ERROR(CompressCalcPrefixSumTempStorageBytes(Stream(),
reinterpret_cast<const int8_t*>(condition_data),
condition_cumulative_sum,
static_cast<int>(valid_condition_length),
gsl::narrow<int>(valid_condition_length),
temp_storage_bytes));
auto temp_buffer = GetScratchBuffer<uint8_t>(temp_storage_bytes);
auto d_temp_storage = temp_buffer.get();
CUDA_RETURN_IF_ERROR(CompressInclusivePrefixSum(Stream(),
@ -63,7 +65,7 @@ Status Compress::ComputeInternal(OpKernelContext* ctx) const {
temp_storage_bytes,
reinterpret_cast<const int8_t*>(condition_data),
condition_cumulative_sum,
static_cast<int>(valid_condition_length)));
gsl::narrow<int>(valid_condition_length)));
// cudaMemcpyAsync from device memory to pageable host memory will return only once the copy has completed.
int32_t positive_condition_count = 0;

View file

@ -13,16 +13,32 @@
#include "core/providers/cuda/tensor/compress_impl.h"
#include <thrust/functional.h>
#include <thrust/iterator/transform_iterator.h>
namespace onnxruntime {
namespace cuda {
cudaError_t CompressCalcPrefixSumTempStorageBytes(cudaStream_t stream, const int8_t* condition_data, int* condition_cumulative_sum, int length, size_t& temp_storage_bytes) {
return cub::DeviceScan::InclusiveSum(
nullptr, temp_storage_bytes, condition_data, condition_cumulative_sum, length, stream);
// This cast is for transform iterator. This type affects the accumulator type width
// in InclusiveSum(). By default, the accumulator type matches the input, but for int8_t
// the sum overflows quickly, so we want the source type to match the output (int32_t).
// see https://github.com/NVIDIA/cub/issues/384
struct CastToInt32 : public thrust::unary_function<int8_t, int32_t> {
__host__ __device__ int32_t operator()(int8_t v) const {
return static_cast<int32_t>(v);
}
};
cudaError_t CompressCalcPrefixSumTempStorageBytes(cudaStream_t stream, const int8_t* condition_data, int32_t* condition_cumulative_sum, int length, size_t& temp_storage_bytes) {
auto input_iter = thrust::make_transform_iterator(condition_data, CastToInt32());
return cub::DeviceScan::InclusiveSum(
nullptr, temp_storage_bytes, input_iter, condition_cumulative_sum, length, stream);
}
cudaError_t CompressInclusivePrefixSum(cudaStream_t stream, void* d_temp_storage, size_t temp_storage_bytes, const int8_t* condition_data, int* condition_cumulative_sum, int length) {
cudaError_t CompressInclusivePrefixSum(cudaStream_t stream, void* d_temp_storage, size_t temp_storage_bytes, const int8_t* condition_data, int32_t* condition_cumulative_sum, int length) {
auto input_iter = thrust::make_transform_iterator(condition_data, CastToInt32());
return cub::DeviceScan::InclusiveSum(
d_temp_storage, temp_storage_bytes, condition_data, condition_cumulative_sum, length, stream);
d_temp_storage, temp_storage_bytes, input_iter, condition_cumulative_sum, length, stream);
}
template <typename T>

View file

@ -9,8 +9,10 @@
namespace onnxruntime {
namespace cuda {
cudaError_t CompressCalcPrefixSumTempStorageBytes(cudaStream_t stream, const int8_t* condition_data, int* condition_cumulative_sum, int length, size_t& temp_storage_bytes);
cudaError_t CompressInclusivePrefixSum(cudaStream_t stream, void* d_temp_storage, size_t temp_storage_bytes, const int8_t* condition_data, int* condition_cumulative_sum, int length);
cudaError_t CompressCalcPrefixSumTempStorageBytes(cudaStream_t stream, const int8_t* condition_data,
int32_t* condition_cumulative_sum, int length, size_t& temp_storage_bytes);
cudaError_t CompressInclusivePrefixSum(cudaStream_t stream, void* d_temp_storage, size_t temp_storage_bytes,
const int8_t* condition_data, int32_t* condition_cumulative_sum, int length);
Status CompressImpl(cudaStream_t stream,
const size_t element_bytes,

View file

@ -125,6 +125,29 @@ TEST(CompressTest, Compress_default_axis) {
test.Run();
}
// Test that we accumulate to a buffer that does not overflow
TEST(CompressTest, Compress_default_axis_issue_9247_cumulative_sum_overflow) {
OpTester test("Compress", 9);
// Generate input longer than 127
constexpr size_t elements = 150;
std::vector<float> input;
for (size_t i = 0; i < elements; ++i) {
input.push_back(static_cast<float>(i));
}
// Let's select all of the elements
std::unique_ptr<bool[]> all_true = std::make_unique<bool[]>(elements);
std::fill_n(all_true.get(), elements, true);
std::vector<int64_t> output_shape{static_cast<int64_t>(elements)};
test.AddInput<float>("input", {2, 75}, input);
test.AddInput<bool>("condition", output_shape, all_true.get(), elements);
// Should get all of the input
test.AddOutput<float>("output", output_shape, input);
test.Run();
}
TEST(CompressTest, Compress0_string) {
OpTester test("Compress", 9);