mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Fix inclusive sum overlfow when applied on int8_t buffer in Compress (#9295)
Use thrust::transform_iterator when feeding input to cub::DeviceScan::InclusiveScan() to make sure the accumulator type is wide enough not to overflow.
This commit is contained in:
parent
29379db432
commit
7b61bca6df
4 changed files with 53 additions and 10 deletions
|
|
@ -48,14 +48,16 @@ Status Compress::ComputeInternal(OpKernelContext* ctx) const {
|
|||
int64_t compress_input_length = has_axis_ ? input_dimensions[axis] : input_size;
|
||||
int64_t valid_condition_length = compress_input_length < condition_length ? compress_input_length : condition_length;
|
||||
|
||||
auto condition_cumulative_sum_buffer = GetScratchBuffer<int32_t>(valid_condition_length);
|
||||
auto condition_cumulative_sum_buffer = GetScratchBuffer<int32_t>(gsl::narrow<size_t>(valid_condition_length));
|
||||
auto condition_cumulative_sum = condition_cumulative_sum_buffer.get();
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
CUDA_RETURN_IF_ERROR(CompressCalcPrefixSumTempStorageBytes(Stream(),
|
||||
reinterpret_cast<const int8_t*>(condition_data),
|
||||
condition_cumulative_sum,
|
||||
static_cast<int>(valid_condition_length),
|
||||
gsl::narrow<int>(valid_condition_length),
|
||||
temp_storage_bytes));
|
||||
|
||||
auto temp_buffer = GetScratchBuffer<uint8_t>(temp_storage_bytes);
|
||||
auto d_temp_storage = temp_buffer.get();
|
||||
CUDA_RETURN_IF_ERROR(CompressInclusivePrefixSum(Stream(),
|
||||
|
|
@ -63,7 +65,7 @@ Status Compress::ComputeInternal(OpKernelContext* ctx) const {
|
|||
temp_storage_bytes,
|
||||
reinterpret_cast<const int8_t*>(condition_data),
|
||||
condition_cumulative_sum,
|
||||
static_cast<int>(valid_condition_length)));
|
||||
gsl::narrow<int>(valid_condition_length)));
|
||||
|
||||
// cudaMemcpyAsync from device memory to pageable host memory will return only once the copy has completed.
|
||||
int32_t positive_condition_count = 0;
|
||||
|
|
|
|||
|
|
@ -13,16 +13,32 @@
|
|||
|
||||
#include "core/providers/cuda/tensor/compress_impl.h"
|
||||
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
cudaError_t CompressCalcPrefixSumTempStorageBytes(cudaStream_t stream, const int8_t* condition_data, int* condition_cumulative_sum, int length, size_t& temp_storage_bytes) {
|
||||
return cub::DeviceScan::InclusiveSum(
|
||||
nullptr, temp_storage_bytes, condition_data, condition_cumulative_sum, length, stream);
|
||||
// This cast is for transform iterator. This type affects the accumulator type width
|
||||
// in InclusiveSum(). By default, the accumulator type matches the input, but for int8_t
|
||||
// the sum overflows quickly, so we want the source type to match the output (int32_t).
|
||||
// see https://github.com/NVIDIA/cub/issues/384
|
||||
struct CastToInt32 : public thrust::unary_function<int8_t, int32_t> {
|
||||
__host__ __device__ int32_t operator()(int8_t v) const {
|
||||
return static_cast<int32_t>(v);
|
||||
}
|
||||
};
|
||||
|
||||
cudaError_t CompressCalcPrefixSumTempStorageBytes(cudaStream_t stream, const int8_t* condition_data, int32_t* condition_cumulative_sum, int length, size_t& temp_storage_bytes) {
|
||||
auto input_iter = thrust::make_transform_iterator(condition_data, CastToInt32());
|
||||
return cub::DeviceScan::InclusiveSum(
|
||||
nullptr, temp_storage_bytes, input_iter, condition_cumulative_sum, length, stream);
|
||||
}
|
||||
cudaError_t CompressInclusivePrefixSum(cudaStream_t stream, void* d_temp_storage, size_t temp_storage_bytes, const int8_t* condition_data, int* condition_cumulative_sum, int length) {
|
||||
|
||||
cudaError_t CompressInclusivePrefixSum(cudaStream_t stream, void* d_temp_storage, size_t temp_storage_bytes, const int8_t* condition_data, int32_t* condition_cumulative_sum, int length) {
|
||||
auto input_iter = thrust::make_transform_iterator(condition_data, CastToInt32());
|
||||
return cub::DeviceScan::InclusiveSum(
|
||||
d_temp_storage, temp_storage_bytes, condition_data, condition_cumulative_sum, length, stream);
|
||||
d_temp_storage, temp_storage_bytes, input_iter, condition_cumulative_sum, length, stream);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -9,8 +9,10 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
cudaError_t CompressCalcPrefixSumTempStorageBytes(cudaStream_t stream, const int8_t* condition_data, int* condition_cumulative_sum, int length, size_t& temp_storage_bytes);
|
||||
cudaError_t CompressInclusivePrefixSum(cudaStream_t stream, void* d_temp_storage, size_t temp_storage_bytes, const int8_t* condition_data, int* condition_cumulative_sum, int length);
|
||||
cudaError_t CompressCalcPrefixSumTempStorageBytes(cudaStream_t stream, const int8_t* condition_data,
|
||||
int32_t* condition_cumulative_sum, int length, size_t& temp_storage_bytes);
|
||||
cudaError_t CompressInclusivePrefixSum(cudaStream_t stream, void* d_temp_storage, size_t temp_storage_bytes,
|
||||
const int8_t* condition_data, int32_t* condition_cumulative_sum, int length);
|
||||
|
||||
Status CompressImpl(cudaStream_t stream,
|
||||
const size_t element_bytes,
|
||||
|
|
|
|||
|
|
@ -125,6 +125,29 @@ TEST(CompressTest, Compress_default_axis) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
// Test that we accumulate to a buffer that does not overflow
|
||||
TEST(CompressTest, Compress_default_axis_issue_9247_cumulative_sum_overflow) {
|
||||
OpTester test("Compress", 9);
|
||||
|
||||
// Generate input longer than 127
|
||||
constexpr size_t elements = 150;
|
||||
std::vector<float> input;
|
||||
for (size_t i = 0; i < elements; ++i) {
|
||||
input.push_back(static_cast<float>(i));
|
||||
}
|
||||
|
||||
// Let's select all of the elements
|
||||
std::unique_ptr<bool[]> all_true = std::make_unique<bool[]>(elements);
|
||||
std::fill_n(all_true.get(), elements, true);
|
||||
std::vector<int64_t> output_shape{static_cast<int64_t>(elements)};
|
||||
|
||||
test.AddInput<float>("input", {2, 75}, input);
|
||||
test.AddInput<bool>("condition", output_shape, all_true.get(), elements);
|
||||
// Should get all of the input
|
||||
test.AddOutput<float>("output", output_shape, input);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(CompressTest, Compress0_string) {
|
||||
OpTester test("Compress", 9);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue