diff --git a/aten/src/ATen/core/DistributionsHelper.h b/aten/src/ATen/core/DistributionsHelper.h index 576b4a39b45..2bf98cf95f8 100644 --- a/aten/src/ATen/core/DistributionsHelper.h +++ b/aten/src/ATen/core/DistributionsHelper.h @@ -40,11 +40,7 @@ struct uniform_int_from_to_distribution { template C10_HOST_DEVICE inline T operator()(RNG generator) { - if (( - std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v) && range_ >= 1ULL << 32) + if (range_ >= 1ULL << 25) // allow approx 1% skew in uniform int generation using % { return transformation::uniform_int_from_to(generator->random64(), range_, base_); } else { diff --git a/aten/src/ATen/native/cuda/DistributionTemplates.h b/aten/src/ATen/native/cuda/DistributionTemplates.h index b685a67ae5d..58c9c28f6f4 100644 --- a/aten/src/ATen/native/cuda/DistributionTemplates.h +++ b/aten/src/ATen/native/cuda/DistributionTemplates.h @@ -280,11 +280,7 @@ namespace cuda { template void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) { AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] { - if (( - std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v) && range >= 1ULL << 32) + if (range >= 1ULL << 25) // allow approx 1% skew in uniform int generation using % { // define lambda to mod with range and add base auto random_func = [range, base] __device__ (uint64_t rand) { diff --git a/aten/src/ATen/test/rng_test.h b/aten/src/ATen/test/rng_test.h index 250d54f20b2..c4a8953fad3 100644 --- a/aten/src/ATen/test/rng_test.h +++ b/aten/src/ATen/test/rng_test.h @@ -137,7 +137,9 @@ void test_random_from_to(const at::Device& device) { range = static_cast(max_to) - static_cast(from) + 1; from_case_covered = true; } - if (range < (1ULL << 32)) { + // this is leaking details of implementation into test + // we are starting to use random64() at 2^25 to minimize skew due to % + if (range < (1ULL << 25)) { exp = static_cast(static_cast((static_cast(val) % range + from))); } else { exp = static_cast(static_cast((val % range + from))); diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index bbc5710c8f3..3a7770bef89 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -8551,6 +8551,26 @@ class CommonTemplate: self.assertGreater(c0.max(), 2**40) self.assertLess(c0.max(), 2**50) + def test_randint_distribution(self): + @torch.compile(fullgraph=True) + def fn(n_argsmax, size): + return torch.randint(n_max, (size,), device=self.device) + + def bin(index, max_size): + return index // (max_size // n_bins) + + size = 1_000_000 + n_max = int(0.75 * 2**32) + n_bins = 8 + + res = fn(n_max, size) + bins = bin(res, n_max).float().cpu() + hist, _ = bins.histogram(8, range=(0, n_bins)) + expected_bin = res.shape[0] / 8 + expected_error = math.sqrt(expected_bin) / expected_bin * 3 + error = (hist - expected_bin).abs().max() / expected_bin + self.assertTrue(error < expected_error) + @config.patch(fallback_random=True) def test_like_rands(self): def fn(x): diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index f04452fa832..0173fe9c4de 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -231,6 +231,7 @@ test_failures = { "test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda", "xpu")), "test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda", "xpu")), "test_polar_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True), + "test_randint_distribution_dynamic_shapes": TestFailure(("cuda",)), "test_randn_generator_dynamic_shapes": TestFailure(("cpu",)), "test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_single_elem_dynamic_shapes": TestFailure(("cpu",)), diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 68d0e07c143..ba6b9ab4711 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -59,6 +59,7 @@ test_failures = { ("cpu", "cuda", "xpu") ), "test_conv_inference_heuristics_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_randint_distribution_dynamic_shapes": TestFailure(("cuda",)), } if TEST_WITH_ROCM: diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index c6e375471f1..5d46f6e1d61 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -3495,6 +3495,24 @@ class TestRandomTensorCreation(TestCase): self.assertTrue((res1 < 6).all().item()) self.assertTrue((res1 >= 0).all().item()) + + def test_randint_distribution(self, device): + size = 1_000_000 + n_max = int(0.75 * 2 ** 32) + n_bins = 8 + + def bin(index, max_size): + return index // (max_size // n_bins) + res = torch.randint(n_max, (size,), device=device) + # histogram implemented for float only + bins = bin(res, n_max).float().cpu() + hist, _ = bins.histogram(8, range=(0, n_bins)) + expected_bin = res.shape[0] / 8 + expected_error = math.sqrt(expected_bin) / expected_bin * 3 + error = (hist - expected_bin).abs().max() / expected_bin + self.assertTrue(error < expected_error) + + @dtypes(torch.half, torch.float, torch.bfloat16, torch.double, torch.complex32, torch.complex64, torch.complex128) def test_randn(self, device, dtype):