mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
add rng_state support for custom device (#98069)
Fixes #ISSUE_NUMBER Extend rng device related func,support custom device extensions,and default device is `cuda`. @bdhirsh @kit1980 would you please take a moment to review my changes? Pull Request resolved: https://github.com/pytorch/pytorch/pull/98069 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
a13a63ae9a
commit
f25f85546f
3 changed files with 82 additions and 37 deletions
|
|
@ -3,6 +3,7 @@
|
|||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from typing import Union
|
||||
import unittest
|
||||
|
||||
import torch.testing._internal.common_utils as common
|
||||
|
|
@ -31,9 +32,27 @@ def remove_build_path():
|
|||
shutil.rmtree(default_build_root, ignore_errors=True)
|
||||
|
||||
|
||||
class DummyModule(object):
|
||||
|
||||
@staticmethod
|
||||
def device_count() -> int:
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def get_rng_state(device: Union[int, str, torch.device] = 'foo') -> torch.Tensor:
|
||||
# create a tensor using our custom device object.
|
||||
return torch.empty(4, 4, device="foo")
|
||||
|
||||
@staticmethod
|
||||
def set_rng_state(new_state: torch.Tensor, device: Union[int, str, torch.device] = 'foo') -> None:
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipIf(IS_ARM64, "Does not work on arm")
|
||||
class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
"""Tests Open Device Registration with C++ extensions.
|
||||
"""
|
||||
module = None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
|
@ -41,6 +60,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
# this file, so we'll change the working directory temporarily
|
||||
self.old_working_dir = os.getcwd()
|
||||
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
||||
assert self.module is not None
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
|
@ -50,14 +70,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
@classmethod
|
||||
def setUpClass(cls):
|
||||
remove_build_path()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
remove_build_path()
|
||||
|
||||
@unittest.skipIf(IS_ARM64, "Does not work on arm")
|
||||
def test_open_device_registration(self):
|
||||
module = torch.utils.cpp_extension.load(
|
||||
cls.module = torch.utils.cpp_extension.load(
|
||||
name="custom_device_extension",
|
||||
sources=[
|
||||
"cpp_extensions/open_registration_extension.cpp",
|
||||
|
|
@ -67,10 +80,15 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
verbose=True,
|
||||
)
|
||||
|
||||
self.assertFalse(module.custom_add_called())
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
remove_build_path()
|
||||
|
||||
def test_open_device_registration(self):
|
||||
self.assertFalse(self.module.custom_add_called())
|
||||
|
||||
# create a tensor using our custom device object.
|
||||
device = module.custom_device()
|
||||
device = self.module.custom_device()
|
||||
|
||||
x = torch.empty(4, 4, device=device)
|
||||
y = torch.empty(4, 4, device=device)
|
||||
|
|
@ -79,13 +97,13 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
self.assertTrue(x.device == device)
|
||||
self.assertFalse(x.is_cpu)
|
||||
|
||||
self.assertFalse(module.custom_add_called())
|
||||
self.assertFalse(self.module.custom_add_called())
|
||||
|
||||
# calls out custom add kernel, registered to the dispatcher
|
||||
z = x + y
|
||||
|
||||
# check that it was called
|
||||
self.assertTrue(module.custom_add_called())
|
||||
self.assertTrue(self.module.custom_add_called())
|
||||
|
||||
z_cpu = z.to(device='cpu')
|
||||
|
||||
|
|
@ -98,14 +116,14 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
z2 = z_cpu + z_cpu
|
||||
|
||||
# None of our CPU operations should call the custom add function.
|
||||
self.assertFalse(module.custom_add_called())
|
||||
self.assertFalse(self.module.custom_add_called())
|
||||
|
||||
# check generator registered befor use
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Please register a generator to the PrivateUse1 dispatch key"):
|
||||
gen_ = torch.Generator(device=device)
|
||||
|
||||
module.register_generator()
|
||||
self.module.register_generator()
|
||||
|
||||
gen = torch.Generator(device=device)
|
||||
self.assertTrue(gen.device == device)
|
||||
|
|
@ -113,7 +131,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
# generator can be registered only once
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Only can register a generator to the PrivateUse1 dispatch key once"):
|
||||
module.register_generator()
|
||||
self.module.register_generator()
|
||||
|
||||
# check whether print tensor.type() meets the expectation
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
|
|
@ -132,5 +150,18 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
test_tensor = torch.empty(4, 4, dtype=tt, device=device)
|
||||
self.assertTrue(test_tensor.type() == dt)
|
||||
|
||||
def test_open_device_random(self):
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
|
||||
torch._register_device_module('xxx', DummyModule)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "torch has no module of"):
|
||||
with torch.random.fork_rng(device_type="foo"):
|
||||
pass
|
||||
torch._register_device_module('foo', DummyModule)
|
||||
|
||||
with torch.random.fork_rng(device_type="foo"):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
|
|
|||
|
|
@ -101,23 +101,29 @@ _fork_rng_warned_already = False
|
|||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices") -> Generator:
|
||||
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices", device_type="cuda") -> Generator:
|
||||
"""
|
||||
Forks the RNG, so that when you return, the RNG is reset
|
||||
to the state that it was previously in.
|
||||
|
||||
Args:
|
||||
devices (iterable of CUDA IDs): CUDA devices for which to fork
|
||||
the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
|
||||
devices (iterable of Device IDs): devices for which to fork
|
||||
the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
|
||||
on all devices, but will emit a warning if your machine has a lot
|
||||
of devices, since this function will run very slowly in that case.
|
||||
If you explicitly specify devices, this warning will be suppressed
|
||||
enabled (bool): if ``False``, the RNG is not forked. This is a convenience
|
||||
argument for easily disabling the context manager without having
|
||||
to delete it and unindent your Python code under it.
|
||||
deivce(str): device type str, default is `cuda`. As for custom device,
|
||||
see details in [Note: support the custom device with privateuse1]
|
||||
"""
|
||||
|
||||
import torch.cuda
|
||||
device_type = torch.device(device_type).type
|
||||
device_mod = getattr(torch, device_type, None)
|
||||
if device_mod is None:
|
||||
raise RuntimeError(f"torch has no module of `{device_type}`, you should register " +
|
||||
"a module by `torch._register_device_module`.")
|
||||
global _fork_rng_warned_already
|
||||
|
||||
# Internal arguments:
|
||||
|
|
@ -129,21 +135,20 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
|
|||
return
|
||||
|
||||
if devices is None:
|
||||
num_devices = torch.cuda.device_count()
|
||||
num_devices = device_mod.device_count()
|
||||
if num_devices > 1 and not _fork_rng_warned_already:
|
||||
warnings.warn(
|
||||
("CUDA reports that you have {num_devices} available devices, and you "
|
||||
"have used {caller} without explicitly specifying which devices are being used. "
|
||||
"For safety, we initialize *every* CUDA device by default, which "
|
||||
"can be quite slow if you have a lot of GPUs. If you know that you are only "
|
||||
"making use of a few CUDA devices, set the environment variable CUDA_VISIBLE_DEVICES "
|
||||
"or the '{devices_kw}' keyword argument of {caller} with the set of devices "
|
||||
"you are actually using. For example, if you are using CPU only, "
|
||||
"set CUDA_VISIBLE_DEVICES= or devices=[]; if you are using "
|
||||
"GPU 0 only, set CUDA_VISIBLE_DEVICES=0 or devices=[0]. To initialize "
|
||||
"all devices and suppress this warning, set the '{devices_kw}' keyword argument "
|
||||
"to `range(torch.cuda.device_count())`."
|
||||
).format(num_devices=num_devices, caller=_caller, devices_kw=_devices_kw))
|
||||
message = (f"{device_type.upper()} reports that you have {num_devices} available devices, and "
|
||||
f"you have used {_caller} without explicitly specifying which devices are being used. "
|
||||
f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
|
||||
"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
|
||||
" making use of a few {device_type.upper()} devices, set the environment variable "
|
||||
f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
|
||||
"with the set of devices you are actually using. For example, if you are using CPU only, "
|
||||
"set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
|
||||
f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "
|
||||
"and suppress this warning, set the '{_devices_kw}' keyword argument to "
|
||||
"`range(torch.{device_type}.device_count())`.")
|
||||
warnings.warn(message)
|
||||
_fork_rng_warned_already = True
|
||||
devices = list(range(num_devices))
|
||||
else:
|
||||
|
|
@ -152,13 +157,13 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device
|
|||
devices = list(devices)
|
||||
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
gpu_rng_states = []
|
||||
device_rng_states = []
|
||||
for device in devices:
|
||||
gpu_rng_states.append(torch.cuda.get_rng_state(device))
|
||||
device_rng_states.append(device_mod.get_rng_state(device))
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.set_rng_state(cpu_rng_state)
|
||||
for device, gpu_rng_state in zip(devices, gpu_rng_states):
|
||||
torch.cuda.set_rng_state(gpu_rng_state, device)
|
||||
for device, device_rng_state in zip(devices, device_rng_states):
|
||||
device_mod.set_rng_state(device_rng_state, device)
|
||||
|
|
|
|||
|
|
@ -47,6 +47,15 @@ def rename_privateuse1_backend(backend_name: str) -> None:
|
|||
(2) manual_seed_all(seed: int) -> None
|
||||
Sets the seed for generating random numbers for your devices.
|
||||
|
||||
(3) device_count() -> int:
|
||||
Returns the number of `foo`s available.
|
||||
|
||||
(4) get_rng_state(device: Union[int, str, torch.device] = 'foo') -> Tensor:
|
||||
Returns a list of ByteTensor representing the random number states of all devices.
|
||||
|
||||
(5) set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> None:
|
||||
Sets the random number generator state of the specified `foo` device.
|
||||
|
||||
For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend
|
||||
For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue