Revert "fix randint distribution for large max (#143787)"

This reverts commit 8059d56ec3.

Reverted https://github.com/pytorch/pytorch/pull/143787 on behalf of https://github.com/wdvr due to failing internal tests, to be fixed first ([comment](https://github.com/pytorch/pytorch/pull/143787#issuecomment-2563493323))
This commit is contained in:
PyTorch MergeBot 2024-12-27 09:16:36 +00:00
parent f6801ba4b3
commit 3571476739
7 changed files with 11 additions and 45 deletions

View file

@ -40,7 +40,11 @@ struct uniform_int_from_to_distribution {
template <typename RNG>
C10_HOST_DEVICE inline T operator()(RNG generator) {
if (range_ >= 1ULL << 25) // allow approx 1% skew in uniform int generation using %
if ((
std::is_same_v<T, int64_t> ||
std::is_same_v<T, double> ||
std::is_same_v<T, float> ||
std::is_same_v<T, at::BFloat16>) && range_ >= 1ULL << 32)
{
return transformation::uniform_int_from_to<T>(generator->random64(), range_, base_);
} else {

View file

@ -280,7 +280,11 @@ namespace cuda {
template<typename RNG>
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
if (range >= 1ULL << 25) // allow approx 1% skew in uniform int generation using %
if ((
std::is_same_v<scalar_t, int64_t> ||
std::is_same_v<scalar_t, double> ||
std::is_same_v<scalar_t, float> ||
std::is_same_v<scalar_t, at::BFloat16>) && range >= 1ULL << 32)
{
// define lambda to mod with range and add base
auto random_func = [range, base] __device__ (uint64_t rand) {

View file

@ -137,9 +137,7 @@ void test_random_from_to(const at::Device& device) {
range = static_cast<uint64_t>(max_to) - static_cast<uint64_t>(from) + 1;
from_case_covered = true;
}
// this is leaking details of implementation into test
// we are starting to use random64() at 2^25 to minimize skew due to %
if (range < (1ULL << 25)) {
if (range < (1ULL << 32)) {
exp = static_cast<T>(static_cast<int64_t>((static_cast<uint32_t>(val) % range + from)));
} else {
exp = static_cast<T>(static_cast<int64_t>((val % range + from)));

View file

@ -8551,26 +8551,6 @@ class CommonTemplate:
self.assertGreater(c0.max(), 2**40)
self.assertLess(c0.max(), 2**50)
def test_randint_distribution(self):
@torch.compile(fullgraph=True)
def fn(n_argsmax, size):
return torch.randint(n_max, (size,), device=self.device)
def bin(index, max_size):
return index // (max_size // n_bins)
size = 1_000_000
n_max = int(0.75 * 2**32)
n_bins = 8
res = fn(n_max, size)
bins = bin(res, n_max).float().cpu()
hist, _ = bins.histogram(8, range=(0, n_bins))
expected_bin = res.shape[0] / 8
expected_error = math.sqrt(expected_bin) / expected_bin * 3
error = (hist - expected_bin).abs().max() / expected_bin
self.assertTrue(error < expected_error)
@config.patch(fallback_random=True)
def test_like_rands(self):
def fn(x):

View file

@ -231,7 +231,6 @@ test_failures = {
"test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda", "xpu")),
"test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda", "xpu")),
"test_polar_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True),
"test_randint_distribution_dynamic_shapes": TestFailure(("cuda",)),
"test_randn_generator_dynamic_shapes": TestFailure(("cpu",)),
"test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_single_elem_dynamic_shapes": TestFailure(("cpu",)),

View file

@ -59,7 +59,6 @@ test_failures = {
("cpu", "cuda", "xpu")
),
"test_conv_inference_heuristics_dynamic_shapes": TestFailure(("cuda", "xpu")),
"test_randint_distribution_dynamic_shapes": TestFailure(("cuda",)),
}
if TEST_WITH_ROCM:

View file

@ -3495,24 +3495,6 @@ class TestRandomTensorCreation(TestCase):
self.assertTrue((res1 < 6).all().item())
self.assertTrue((res1 >= 0).all().item())
def test_randint_distribution(self, device):
size = 1_000_000
n_max = int(0.75 * 2 ** 32)
n_bins = 8
def bin(index, max_size):
return index // (max_size // n_bins)
res = torch.randint(n_max, (size,), device=device)
# histogram implemented for float only
bins = bin(res, n_max).float().cpu()
hist, _ = bins.histogram(8, range=(0, n_bins))
expected_bin = res.shape[0] / 8
expected_error = math.sqrt(expected_bin) / expected_bin * 3
error = (hist - expected_bin).abs().max() / expected_bin
self.assertTrue(error < expected_error)
@dtypes(torch.half, torch.float, torch.bfloat16, torch.double,
torch.complex32, torch.complex64, torch.complex128)
def test_randn(self, device, dtype):