add rng_state support for custom device (#98069)

Fixes #ISSUE_NUMBER
Extend rng device related func,support custom device extensions,and default device is `cuda`.
@bdhirsh @kit1980 would you please take a moment to review my changes?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98069
Approved by: https://github.com/bdhirsh
This commit is contained in:
shibo 2023-04-10 22:36:50 +00:00 committed by PyTorch MergeBot
parent a13a63ae9a
commit f25f85546f
3 changed files with 82 additions and 37 deletions

View file

@ -3,6 +3,7 @@
import os
import shutil
import sys
from typing import Union
import unittest
import torch.testing._internal.common_utils as common
@ -31,9 +32,27 @@ def remove_build_path():
shutil.rmtree(default_build_root, ignore_errors=True)
class DummyModule(object):
@staticmethod
def device_count() -> int:
return 1
@staticmethod
def get_rng_state(device: Union[int, str, torch.device] = 'foo') -> torch.Tensor:
# create a tensor using our custom device object.
return torch.empty(4, 4, device="foo")
@staticmethod
def set_rng_state(new_state: torch.Tensor, device: Union[int, str, torch.device] = 'foo') -> None:
pass
@unittest.skipIf(IS_ARM64, "Does not work on arm")
class TestCppExtensionOpenRgistration(common.TestCase):
"""Tests Open Device Registration with C++ extensions.
"""
module = None
def setUp(self):
super().setUp()
@ -41,6 +60,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
# this file, so we'll change the working directory temporarily
self.old_working_dir = os.getcwd()
os.chdir(os.path.dirname(os.path.abspath(__file__)))
assert self.module is not None
def tearDown(self):
super().tearDown()
@ -50,14 +70,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
@classmethod
def setUpClass(cls):
remove_build_path()
@classmethod
def tearDownClass(cls):
remove_build_path()
@unittest.skipIf(IS_ARM64, "Does not work on arm")
def test_open_device_registration(self):
module = torch.utils.cpp_extension.load(
cls.module = torch.utils.cpp_extension.load(
name="custom_device_extension",
sources=[
"cpp_extensions/open_registration_extension.cpp",
@ -67,10 +80,15 @@ class TestCppExtensionOpenRgistration(common.TestCase):
verbose=True,
)
self.assertFalse(module.custom_add_called())
@classmethod
def tearDownClass(cls):
remove_build_path()
def test_open_device_registration(self):
self.assertFalse(self.module.custom_add_called())
# create a tensor using our custom device object.
device = module.custom_device()
device = self.module.custom_device()
x = torch.empty(4, 4, device=device)
y = torch.empty(4, 4, device=device)
@ -79,13 +97,13 @@ class TestCppExtensionOpenRgistration(common.TestCase):
self.assertTrue(x.device == device)
self.assertFalse(x.is_cpu)
self.assertFalse(module.custom_add_called())
self.assertFalse(self.module.custom_add_called())
# calls out custom add kernel, registered to the dispatcher
z = x + y
# check that it was called
self.assertTrue(module.custom_add_called())
self.assertTrue(self.module.custom_add_called())
z_cpu = z.to(device='cpu')
@ -98,14 +116,14 @@ class TestCppExtensionOpenRgistration(common.TestCase):
z2 = z_cpu + z_cpu
# None of our CPU operations should call the custom add function.
self.assertFalse(module.custom_add_called())
self.assertFalse(self.module.custom_add_called())
# check generator registered befor use
with self.assertRaisesRegex(RuntimeError,
"Please register a generator to the PrivateUse1 dispatch key"):
gen_ = torch.Generator(device=device)
module.register_generator()
self.module.register_generator()
gen = torch.Generator(device=device)
self.assertTrue(gen.device == device)
@ -113,7 +131,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
# generator can be registered only once
with self.assertRaisesRegex(RuntimeError,
"Only can register a generator to the PrivateUse1 dispatch key once"):
module.register_generator()
self.module.register_generator()
# check whether print tensor.type() meets the expectation
torch.utils.rename_privateuse1_backend('foo')
@ -132,5 +150,18 @@ class TestCppExtensionOpenRgistration(common.TestCase):
test_tensor = torch.empty(4, 4, dtype=tt, device=device)
self.assertTrue(test_tensor.type() == dt)
def test_open_device_random(self):
torch.utils.rename_privateuse1_backend('foo')
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
torch._register_device_module('xxx', DummyModule)
with self.assertRaisesRegex(RuntimeError, "torch has no module of"):
with torch.random.fork_rng(device_type="foo"):
pass
torch._register_device_module('foo', DummyModule)
with torch.random.fork_rng(device_type="foo"):
pass
if __name__ == "__main__":
common.run_tests()

View file

@ -101,23 +101,29 @@ _fork_rng_warned_already = False
@contextlib.contextmanager
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices") -> Generator:
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices", device_type="cuda") -> Generator:
"""
Forks the RNG, so that when you return, the RNG is reset
to the state that it was previously in.
Args:
devices (iterable of CUDA IDs): CUDA devices for which to fork
the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
devices (iterable of Device IDs): devices for which to fork
the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
on all devices, but will emit a warning if your machine has a lot
of devices, since this function will run very slowly in that case.
If you explicitly specify devices, this warning will be suppressed
enabled (bool): if ``False``, the RNG is not forked. This is a convenience
argument for easily disabling the context manager without having
to delete it and unindent your Python code under it.
deivce(str): device type str, default is `cuda`. As for custom device,
see details in [Note: support the custom device with privateuse1]
"""
import torch.cuda
device_type = torch.device(device_type).type
device_mod = getattr(torch, device_type, None)
if device_mod is None:
raise RuntimeError(f"torch has no module of `{device_type}`, you should register " +
"a module by `torch._register_device_module`.")
global _fork_rng_warned_already
# Internal arguments:
@ -129,21 +135,20 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
return
if devices is None:
num_devices = torch.cuda.device_count()
num_devices = device_mod.device_count()
if num_devices > 1 and not _fork_rng_warned_already:
warnings.warn(
("CUDA reports that you have {num_devices} available devices, and you "
"have used {caller} without explicitly specifying which devices are being used. "
"For safety, we initialize *every* CUDA device by default, which "
"can be quite slow if you have a lot of GPUs. If you know that you are only "
"making use of a few CUDA devices, set the environment variable CUDA_VISIBLE_DEVICES "
"or the '{devices_kw}' keyword argument of {caller} with the set of devices "
"you are actually using. For example, if you are using CPU only, "
"set CUDA_VISIBLE_DEVICES= or devices=[]; if you are using "
"GPU 0 only, set CUDA_VISIBLE_DEVICES=0 or devices=[0]. To initialize "
"all devices and suppress this warning, set the '{devices_kw}' keyword argument "
"to `range(torch.cuda.device_count())`."
).format(num_devices=num_devices, caller=_caller, devices_kw=_devices_kw))
message = (f"{device_type.upper()} reports that you have {num_devices} available devices, and "
f"you have used {_caller} without explicitly specifying which devices are being used. "
f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
" making use of a few {device_type.upper()} devices, set the environment variable "
f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
"with the set of devices you are actually using. For example, if you are using CPU only, "
"set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "
"and suppress this warning, set the '{_devices_kw}' keyword argument to "
"`range(torch.{device_type}.device_count())`.")
warnings.warn(message)
_fork_rng_warned_already = True
devices = list(range(num_devices))
else:
@ -152,13 +157,13 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
devices = list(devices)
cpu_rng_state = torch.get_rng_state()
gpu_rng_states = []
device_rng_states = []
for device in devices:
gpu_rng_states.append(torch.cuda.get_rng_state(device))
device_rng_states.append(device_mod.get_rng_state(device))
try:
yield
finally:
torch.set_rng_state(cpu_rng_state)
for device, gpu_rng_state in zip(devices, gpu_rng_states):
torch.cuda.set_rng_state(gpu_rng_state, device)
for device, device_rng_state in zip(devices, device_rng_states):
device_mod.set_rng_state(device_rng_state, device)

View file

@ -47,6 +47,15 @@ def rename_privateuse1_backend(backend_name: str) -> None:
(2) manual_seed_all(seed: int) -> None
Sets the seed for generating random numbers for your devices.
(3) device_count() -> int:
Returns the number of `foo`s available.
(4) get_rng_state(device: Union[int, str, torch.device] = 'foo') -> Tensor:
Returns a list of ByteTensor representing the random number states of all devices.
(5) set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> None:
Sets the random number generator state of the specified `foo` device.
For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend
For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example