From 65968ab817db323a532f50a2f2ea131ae27dada5 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Thu, 29 Apr 2021 17:51:06 -0700 Subject: [PATCH] Revert "Remove sync for randperm on small tensors. (#54113)" (#57299) Summary: This reverts commit e8c268746b297efa988e03abc61ff22203bf3980. It occasionally produces wrong results. Pull Request resolved: https://github.com/pytorch/pytorch/pull/57299 Reviewed By: wat3rBro Differential Revision: D28102706 Pulled By: ngimel fbshipit-source-id: d7618e104d854c3b96aa502fb4e30041b9aab5df --- aten/src/ATen/native/cuda/Randperm.cu | 182 +++++------------- aten/src/ATen/native/cuda/Randperm.cuh | 56 ------ aten/src/ATen/test/cuda_distributions_test.cu | 57 ------ test/test_tensor_creation_ops.py | 6 +- 4 files changed, 54 insertions(+), 247 deletions(-) delete mode 100644 aten/src/ATen/native/cuda/Randperm.cuh diff --git a/aten/src/ATen/native/cuda/Randperm.cu b/aten/src/ATen/native/cuda/Randperm.cu index f6e3ac25d57..1702be2fd28 100644 --- a/aten/src/ATen/native/cuda/Randperm.cu +++ b/aten/src/ATen/native/cuda/Randperm.cu @@ -3,161 +3,81 @@ #include #include #include -#include #include namespace at { namespace native { -// [Algorithm of randperm] -// -// randperm is implemented by sorting an arange tensor of size n with randomly -// generated keys. When random keys are different from each other, all different -// permutations have the same probability. -// -// However, there is a pitfall here: -// For better performance, these N random keys are generated independently, -// and there is no effort to make sure they are different at the time of generation. -// When two keys are identical, stable sorting algorithms will not permute these two keys. -// As a result, (0, 1) will appear more often than (1, 0). -// -// To overcome this pitfall we first carefully choose the number of bits in these keys, -// so that the probability of having duplicate keys is under a threshold. Let q be the -// threshold probability for having non-duplicate keys, then it can be proved that[1] -// the number of bits required is: ceil(log2(n - (6 n^2 + 1) / (12 log(q)))) -// -// Then after sort, we lauch a separate kernel that additionally shuffles any islands -// of values whose keys matched. The algorithm of this kernel is as follows: -// Each thread reads its key and the keys of its neighbors to tell if it's part of an island. -// For each island, the first thread in the island sees a key match at index i+1 but not index i-1. -// This thread considers itself the "island leader". The island leader then reads more indices to -// the right to figure out how big the island is. Most likely, the island will be very small, -// just a few values. The island leader then rolls that many RNG, uses them to additionally -// shuffle values within the island using serial Fisher-Yates, and writes them out. -// -// Reference -// [1] https://osf.io/af2hy/ - Tensor& randperm_out_cuda(int64_t n, c10::optional generator, Tensor& result) { TORCH_CHECK(n >= 0, "n must be non-negative, got", n); TORCH_CHECK(!generator.has_value() || (generator.has_value() && result.device() == generator->device()), "Expected a '", result.device(), "' generator device but found '", generator->device(), "'"); - TORCH_CHECK(n <= std::numeric_limits::max(), - "randperm of tensors larger than INT_MAX is not supported yet in pytorch"); check_supported_max_int_with_precision(n, result); result.resize_({n}); - auto range = at::arange(n, result.options()); - - // shuffled_data points to the underlying data of the output tensor if the tensor is contiguous; otherwise it - // points to a new tensor. - Tensor shuffled; - void *shuffled_data; - if (result.is_contiguous()) { - shuffled_data = result.data_ptr(); - } else { - shuffled = at::empty(n, result.options()); - shuffled_data = shuffled.data_ptr(); + if (n < 30000) { // For small inputs, we offload it to CPU instead. + auto result_cpu = at::empty({n}, result.options().device(kCPU)); + randperm_out(result_cpu, n, generator); + return result.copy_(result_cpu); } - auto opt = TensorOptions().device(result.device()); - - // See note [Algorithm of randperm] - const double log_threshold_12 = std::log(0.9) * 12; - double nd = static_cast(n); - -#if !defined(_MSC_VER) - constexpr bool is_reduced_bits = true; - int bits = std::min(64, - static_cast(std::ceil(std::log2(nd - (6 * nd * nd + 1) / log_threshold_12)))); -#else - // For some unknown reason, randperm_handle_duplicate_keys is causing test failures. - // Without this additional permutation kernel, we should sort on as much bits as we can. - constexpr bool is_reduced_bits = false; - int bits = 64; +#if 0 + // This if condition should never be true because if n >= 30000 and the tensor has a Half type, + // check_supported_max_int_with_precision should have reported an error. This snippet is commented out but left here + // for the sake of clarity, because Half in thrust is spotty, and we do not want future change unaware of this. + if (result.scalar_type() == at::ScalarType::Half) { // Half in thrust is spotty. Avoid. + auto result_float = at::empty({n}, initialTensorOptions().device(Device(DeviceType::CUDA))); + return result.copy_(randperm_out_cuda(result_float, n, generator)); + } #endif - if (n == 0) { - return result; - } else if (bits <= 8) { - auto keys = at::empty(result.sizes(), opt.dtype(kByte)).random_(generator); - auto keys_tmp = at::empty_like(keys); - auto keys_out = keys_tmp.data_ptr(); - AT_DISPATCH_ALL_TYPES_AND(kHalf, result.scalar_type(), "randperm_out_cuda", [&] { - auto shuffled_data_ = reinterpret_cast(shuffled_data); - at::cuda::cub::sort_pairs( - keys.data_ptr(), keys_out, - range.data_ptr(), shuffled_data_, - n, false, 0, bits); + // Generate random values for the keys array + AT_DISPATCH_ALL_TYPES( + result.scalar_type(), "randperm_out_cuda", [&] { + TORCH_CHECK(n <= std::numeric_limits::max(), + "randperm of tensors larger than INT_MAX is not supported yet in pytorch"); - if (is_reduced_bits) { - // This causes failing tests on MSVC for unknown reason. - randperm_handle_duplicate_keys(keys_out, shuffled_data_, bits, n, generator); + auto keys = at::empty(result.sizes(), result.options()).random_(generator); + auto range = at::arange(n, result.options()); + auto keys_tmp = at::empty_like(keys); + + // shuffled_data points to the underlying data of the output tensor if the tensor is contiguous; otherwise it + // points to a new tensor. + Tensor shuffled; + scalar_t *shuffled_data; + if (result.is_contiguous()) { + shuffled_data = result.data_ptr(); + } else { + shuffled = at::empty(n, result.options()); + shuffled_data = shuffled.data_ptr(); } - }); - } else if (bits <= 16) { - auto keys = at::empty(result.sizes(), opt.dtype(kShort)).random_( - std::numeric_limits::min(), std::numeric_limits::max(), generator); - auto keys_tmp = at::empty_like(keys); - auto keys_out = keys_tmp.data_ptr(); - AT_DISPATCH_ALL_TYPES_AND(kHalf, result.scalar_type(), "randperm_out_cuda", [&] { - auto shuffled_data_ = reinterpret_cast(shuffled_data); - at::cuda::cub::sort_pairs( - keys.data_ptr(), keys_out, - range.data_ptr(), shuffled_data_, - n, false, 0, bits); + // Use the sorted order of keys to rearrange the result array + size_t temp_storage_bytes = 0; - if (is_reduced_bits) { - // This causes failing tests on MSVC for unknown reason. - randperm_handle_duplicate_keys(keys_out, shuffled_data_, bits, n, generator); + cub::DeviceRadixSort::SortPairs( + nullptr, temp_storage_bytes, + keys.data_ptr(), keys_tmp.data_ptr(), + range.data_ptr(), shuffled_data, n, + 0, sizeof(scalar_t) * 8, at::cuda::getCurrentCUDAStream()); + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto dataPtr = allocator.allocate(temp_storage_bytes); + cub::DeviceRadixSort::SortPairs( + dataPtr.get(), temp_storage_bytes, + keys.data_ptr(), keys_tmp.data_ptr(), + range.data_ptr(), shuffled_data, n, + 0, sizeof(scalar_t) * 8, at::cuda::getCurrentCUDAStream()); + + if (!result.is_contiguous()) { + result.copy_(shuffled); } - - }); - } else if (bits <= 32) { - auto keys = at::empty(result.sizes(), opt.dtype(kInt)).random_( - std::numeric_limits::min(), std::numeric_limits::max(), generator); - auto keys_tmp = at::empty_like(keys); - auto keys_out = keys_tmp.data_ptr(); - AT_DISPATCH_ALL_TYPES_AND(kHalf, result.scalar_type(), "randperm_out_cuda", [&] { - auto shuffled_data_ = reinterpret_cast(shuffled_data); - at::cuda::cub::sort_pairs( - keys.data_ptr(), keys_out, - range.data_ptr(), shuffled_data_, - n, false, 0, bits); - - if (is_reduced_bits) { - // This causes failing tests on MSVC for unknown reason. - randperm_handle_duplicate_keys(keys_out, shuffled_data_, bits, n, generator); - } - - }); - } else { - auto keys = at::empty(result.sizes(), opt.dtype(kLong)).random_( - std::numeric_limits::min(), std::numeric_limits::max(), generator); - auto keys_tmp = at::empty_like(keys); - auto keys_out = keys_tmp.data_ptr(); - AT_DISPATCH_ALL_TYPES_AND(kHalf, result.scalar_type(), "randperm_out_cuda", [&] { - auto shuffled_data_ = reinterpret_cast(shuffled_data); - at::cuda::cub::sort_pairs( - keys.data_ptr(), keys_out, - range.data_ptr(), shuffled_data_, - n, false, 0, bits); - - if (is_reduced_bits) { - // This causes failing tests on MSVC for unknown reason. - randperm_handle_duplicate_keys(keys_out, shuffled_data_, bits, n, generator); - } - - }); - } - - if (!result.is_contiguous()) { - result.copy_(shuffled); - } + } + ); return result; } + + }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/Randperm.cuh b/aten/src/ATen/native/cuda/Randperm.cuh deleted file mode 100644 index 6f20d949cfa..00000000000 --- a/aten/src/ATen/native/cuda/Randperm.cuh +++ /dev/null @@ -1,56 +0,0 @@ -#include -#include -#include - -#include -#include -#include - -namespace { - -// See note [Algorithm of randperm] -template -__global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T mask, int n, at::PhiloxCudaState philox_args) { - int tid = threadIdx.x + blockDim.x * blockIdx.x; - - // find the beginning of islands - if (tid >= n - 1) return; // out of range - if ((keys[tid] & mask) != (keys[tid + 1] & mask)) return; // not in an island - if (tid != 0 && (keys[tid] & mask) == (keys[tid - 1] & mask)) return; // not the beginning of an island - - // find the size of islands - int island_size = 0; - while ((keys[tid + ++island_size] & mask) == (keys[tid] & mask)); - - // do random permutation inside each island. - data += tid; - auto seeds = at::cuda::philox::unpack(philox_args); - curandStatePhilox4_32_10_t state; - curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state); - for (int i = island_size - 1; i > 0; i--) { - unsigned int r = curand(&state) % (i + 1); - if (i != r) { - scalar_t tmp = data[i]; - data[i] = data[r]; - data[r] = tmp; - } - } -} - -// See note [Algorithm of randperm] -template -void randperm_handle_duplicate_keys(T *keys, scalar_t *data, int bits, int64_t n, c10::optional &gen_) { - auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); - int64_t counter_offset = n; - at::PhiloxCudaState rng_engine_inputs; - { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_cuda_state(counter_offset); - } - T mask = static_cast((1UL << bits) - 1); - randperm_handle_duplicate_keys_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>( - keys, data, mask, n, rng_engine_inputs); -} - -} diff --git a/aten/src/ATen/test/cuda_distributions_test.cu b/aten/src/ATen/test/cuda_distributions_test.cu index 70a35d18425..56cc824faea 100644 --- a/aten/src/ATen/test/cuda_distributions_test.cu +++ b/aten/src/ATen/test/cuda_distributions_test.cu @@ -2,11 +2,8 @@ #include #include -#include - #include #include -#include #include #include @@ -151,57 +148,3 @@ TEST(DistributionsTest, TestPhiloxIncrementSmallMultinomialTensor) { // expected uniforms will start from counter offset of 4 assert_with_expected_uniforms(4); } - -__managed__ int keys[] = { - 1, (1 << 15) + 1, (1 << 16) + 1, - 2, (1 << 14) + 2 -}; - -__managed__ int values[] = { 1, 2, 3, 4, 5 }; - -std::vector> valid_perms1 = { - {1, 2, 3}, {1, 3, 2}, {2, 1, 3}, {2, 3, 1}, {3, 1, 2}, {3, 2, 1} -}; -std::vector> valid_perms2 = { - {4, 5}, {5, 4} -}; - -TEST(RandomPermutationTest, TestIslandShuffle) { - if (!at::cuda::is_available()) return; - at::manual_seed(123); - - bool shuffled1 = false; - bool shuffled2 = false; - for (int i = 0; i < 100; i++) { - cudaDeviceSynchronize(); - c10::optional gen = c10::nullopt; - randperm_handle_duplicate_keys(keys, values, 8, 5, gen); - cudaDeviceSynchronize(); - std::vector slice1 = {values[0], values[1], values[2]}; - std::vector slice2 = {values[3], values[4]}; - if (slice1 != valid_perms1[0]) { - shuffled1 = true; - } - if (slice2 != valid_perms2[0]) { - shuffled2 = true; - } - bool passed1 = false; - bool passed2 = false; - for (auto &i : valid_perms1) { - if (i == slice1) { - passed1 = true; - break; - } - } - for (auto &i : valid_perms2) { - if (i == slice2) { - passed2 = true; - break; - } - } - ASSERT_TRUE(passed1); - ASSERT_TRUE(passed2); - } - ASSERT_TRUE(shuffled1); - ASSERT_TRUE(shuffled2); -} diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 2776f5833f5..b989228e4b4 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -3262,9 +3262,9 @@ class TestRandomTensorCreation(TestCase): # see https://github.com/pytorch/pytorch/issues/54282 rng_device = [device] - # Test core functionality. On CUDA, different value of n has different - # code path - for n in (5, 100, 50000, 100000): + # Test core functionality. On CUDA, for small n, randperm is offloaded to CPU instead. For large n, randperm is + # executed on GPU. + for n in (100, 50000, 100000): # Ensure both integer and floating-point numbers are tested. Half follows an execution path that is # different from others on CUDA. for dtype in (torch.long, torch.half, torch.float):