mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Enable bfloat16 random kernels on Windows (#44918)
Summary: Fixes https://github.com/pytorch/pytorch/issues/33793 Pull Request resolved: https://github.com/pytorch/pytorch/pull/44918 Reviewed By: pbelevich Differential Revision: D23777548 Pulled By: ngimel fbshipit-source-id: 9cf13166d7deba17bc72e402b82ed0afe347cb9b
This commit is contained in:
parent
06389406bb
commit
e255a4e1fd
3 changed files with 3 additions and 57 deletions
|
|
@ -273,12 +273,6 @@ namespace cuda {
|
|||
|
||||
template<typename RNG>
|
||||
void random_from_to_kernel(TensorIterator& iter, uint64_t range, int64_t base, RNG gen) {
|
||||
#ifdef _WIN32
|
||||
// TODO: https://github.com/pytorch/pytorch/issues/33793
|
||||
if (iter.dtype() == ScalarType::BFloat16) {
|
||||
TORCH_CHECK(false, "random_() is not supported for bfloat16 CUDA tensors on Windows. Please see https://github.com/pytorch/pytorch/issues/33793");
|
||||
}
|
||||
#endif
|
||||
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_from_to_kernel_cuda", [&] {
|
||||
if ((
|
||||
std::is_same<scalar_t, int64_t>::value ||
|
||||
|
|
@ -319,12 +313,6 @@ void random_from_to_kernel(TensorIterator& iter, uint64_t range, int64_t base, R
|
|||
// to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
|
||||
template<typename RNG>
|
||||
void random_full_64_bits_range_kernel(TensorIterator& iter, RNG gen) {
|
||||
#ifdef _WIN32
|
||||
// TODO: https://github.com/pytorch/pytorch/issues/33793
|
||||
if (iter.dtype() == ScalarType::BFloat16) {
|
||||
TORCH_CHECK(false, "random_() is not supported for bfloat16 CUDA tensors on Windows. Please see https://github.com/pytorch/pytorch/issues/33793");
|
||||
}
|
||||
#endif
|
||||
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cuda", [&] {
|
||||
if (std::is_same<scalar_t, int64_t>::value ||
|
||||
std::is_same<scalar_t, double>::value ||
|
||||
|
|
@ -361,12 +349,6 @@ struct RandomFromToKernel {
|
|||
|
||||
template<typename RNG>
|
||||
void random_kernel(TensorIterator& iter, RNG gen) {
|
||||
#ifdef _WIN32
|
||||
// TODO: https://github.com/pytorch/pytorch/issues/33793
|
||||
if (iter.dtype() == ScalarType::BFloat16) {
|
||||
TORCH_CHECK(false, "random_() is not supported for bfloat16 CUDA tensors on Windows. Please see https://github.com/pytorch/pytorch/issues/33793");
|
||||
}
|
||||
#endif
|
||||
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] {
|
||||
if (std::is_same<scalar_t, double>::value || std::is_same<scalar_t, int64_t>::value) {
|
||||
auto random_func = [] __device__ (uint64_t rand) {
|
||||
|
|
@ -462,12 +444,6 @@ struct NormalKernel {
|
|||
|
||||
template<typename RNG>
|
||||
void uniform_kernel(TensorIterator& iter, double from_, double to_, RNG gen) {
|
||||
#ifdef _WIN32
|
||||
// TODO: https://github.com/pytorch/pytorch/issues/33793
|
||||
if (iter.dtype() == ScalarType::BFloat16) {
|
||||
TORCH_CHECK(false, "uniform_() is not supported for bfloat16 CUDA tensors on Windows. Please see https://github.com/pytorch/pytorch/issues/33793");
|
||||
}
|
||||
#endif
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] {
|
||||
auto from = static_cast<scalar_t>(from_);
|
||||
auto to = static_cast<scalar_t>(to_);
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import torch
|
|||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TestCase, run_tests, do_test_empty_full, TEST_NUMPY, suppress_warnings,
|
||||
IS_WINDOWS, torch_to_numpy_dtype_dict, slowTest)
|
||||
torch_to_numpy_dtype_dict, slowTest)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA,
|
||||
onlyCPU, skipCUDAIfNotRocm, largeCUDATensorTest, precisionOverride, dtypes,
|
||||
|
|
@ -822,10 +822,7 @@ class TestTensorCreation(TestCase):
|
|||
self.assertEqual(shape, torch.empty_like(torch.zeros(shape, device=device, dtype=dt)).shape)
|
||||
self.assertEqual(shape, torch.empty_strided(shape, (0,) * len(shape), device=device, dtype=dt).shape)
|
||||
|
||||
if dt == torch.bfloat16 and device.startswith('cuda') and IS_WINDOWS:
|
||||
# TODO: https://github.com/pytorch/pytorch/issues/33793
|
||||
self.assertRaises(RuntimeError, lambda: torch.randint(6, shape, device=device, dtype=dt).shape)
|
||||
elif dt == torch.bool:
|
||||
if dt == torch.bool:
|
||||
self.assertEqual(shape, torch.randint(2, shape, device=device, dtype=dt).shape)
|
||||
self.assertEqual(shape, torch.randint_like(torch.zeros(shape, device=device, dtype=dt), 2).shape)
|
||||
elif dt.is_complex:
|
||||
|
|
|
|||
|
|
@ -10897,10 +10897,6 @@ class TestTorchDeviceType(TestCase):
|
|||
@dtypes(torch.float, torch.double, torch.half)
|
||||
@dtypesIfCUDA(torch.float, torch.double, torch.half, torch.bfloat16)
|
||||
def test_uniform_from_to(self, device, dtype):
|
||||
# TODO: https://github.com/pytorch/pytorch/issues/33793
|
||||
if IS_WINDOWS and device.startswith('cuda') and dtype == torch.bfloat16:
|
||||
raise unittest.SkipTest("Crashes with CUDA error: unspecified launch failure")
|
||||
|
||||
size = 2000
|
||||
alpha = 0.1
|
||||
|
||||
|
|
@ -11119,10 +11115,6 @@ class TestTorchDeviceType(TestCase):
|
|||
@skipIfNoSciPy
|
||||
@dtypes(*torch.testing.get_all_fp_dtypes())
|
||||
def test_uniform_kstest(self, device, dtype):
|
||||
# TODO: https://github.com/pytorch/pytorch/issues/33793
|
||||
if IS_WINDOWS and device.startswith('cuda') and dtype == torch.bfloat16:
|
||||
raise unittest.SkipTest("Crashes with CUDA error: unspecified launch failure")
|
||||
|
||||
from scipy import stats
|
||||
size = 1000
|
||||
for from_ in [-42, 0, 4.2]:
|
||||
|
|
@ -12244,10 +12236,7 @@ class TestTorchDeviceType(TestCase):
|
|||
def test_unfold_all_devices_and_dtypes(self, device):
|
||||
for dt in torch.testing.get_all_dtypes():
|
||||
|
||||
if dt == torch.bfloat16 and device.startswith('cuda') and IS_WINDOWS:
|
||||
# TODO: https://github.com/pytorch/pytorch/issues/33793
|
||||
self.assertRaises(RuntimeError, lambda: torch.randint(5, (0, 1, 3, 0), dtype=dt, device=device))
|
||||
elif dt == torch.bool:
|
||||
if dt == torch.bool:
|
||||
x = torch.empty((0, 1, 3, 0), dtype=dt, device=device)
|
||||
self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
|
||||
else:
|
||||
|
|
@ -17629,10 +17618,6 @@ else:
|
|||
|
||||
@dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes()))
|
||||
def test_random_full_range(self, device, dtype):
|
||||
# TODO: https://github.com/pytorch/pytorch/issues/33793
|
||||
if IS_WINDOWS and device.startswith('cuda') and dtype == torch.bfloat16:
|
||||
raise unittest.SkipTest("Crashes with CUDA error: unspecified launch failure")
|
||||
|
||||
size = 2000
|
||||
alpha = 0.1
|
||||
|
||||
|
|
@ -17667,10 +17652,6 @@ else:
|
|||
|
||||
@dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes()))
|
||||
def test_random_from_to(self, device, dtype):
|
||||
# TODO: https://github.com/pytorch/pytorch/issues/33793
|
||||
if IS_WINDOWS and device.startswith('cuda') and dtype == torch.bfloat16:
|
||||
raise unittest.SkipTest("Crashes with CUDA error: unspecified launch failure")
|
||||
|
||||
size = 2000
|
||||
alpha = 0.1
|
||||
|
||||
|
|
@ -17760,10 +17741,6 @@ else:
|
|||
|
||||
@dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes()))
|
||||
def test_random_to(self, device, dtype):
|
||||
# TODO: https://github.com/pytorch/pytorch/issues/33793
|
||||
if IS_WINDOWS and device.startswith('cuda') and dtype == torch.bfloat16:
|
||||
raise unittest.SkipTest("Crashes with CUDA error: unspecified launch failure")
|
||||
|
||||
size = 2000
|
||||
alpha = 0.1
|
||||
|
||||
|
|
@ -17822,10 +17799,6 @@ else:
|
|||
|
||||
@dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes()))
|
||||
def test_random_default(self, device, dtype):
|
||||
# TODO: https://github.com/pytorch/pytorch/issues/33793
|
||||
if IS_WINDOWS and device.startswith('cuda') and dtype == torch.bfloat16:
|
||||
raise unittest.SkipTest("Crashes with CUDA error: unspecified launch failure")
|
||||
|
||||
size = 2000
|
||||
alpha = 0.1
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue