From f25f85546f156caa4097e9e1458fbc7d213e4065 Mon Sep 17 00:00:00 2001 From: shibo <18207133434@163.com> Date: Mon, 10 Apr 2023 22:36:50 +0000 Subject: [PATCH] add rng_state support for custom device (#98069) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #ISSUE_NUMBER Extend rng device related func,support custom device extensions,and default device is `cuda`. @bdhirsh @kit1980 would you please take a moment to review my changes? Pull Request resolved: https://github.com/pytorch/pytorch/pull/98069 Approved by: https://github.com/bdhirsh --- ...cpp_extensions_open_device_registration.py | 61 ++++++++++++++----- torch/random.py | 49 ++++++++------- torch/utils/backend_registration.py | 9 +++ 3 files changed, 82 insertions(+), 37 deletions(-) diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 96219991ee2..c1a54d2dffe 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -3,6 +3,7 @@ import os import shutil import sys +from typing import Union import unittest import torch.testing._internal.common_utils as common @@ -31,9 +32,27 @@ def remove_build_path(): shutil.rmtree(default_build_root, ignore_errors=True) +class DummyModule(object): + + @staticmethod + def device_count() -> int: + return 1 + + @staticmethod + def get_rng_state(device: Union[int, str, torch.device] = 'foo') -> torch.Tensor: + # create a tensor using our custom device object. + return torch.empty(4, 4, device="foo") + + @staticmethod + def set_rng_state(new_state: torch.Tensor, device: Union[int, str, torch.device] = 'foo') -> None: + pass + + +@unittest.skipIf(IS_ARM64, "Does not work on arm") class TestCppExtensionOpenRgistration(common.TestCase): """Tests Open Device Registration with C++ extensions. """ + module = None def setUp(self): super().setUp() @@ -41,6 +60,7 @@ class TestCppExtensionOpenRgistration(common.TestCase): # this file, so we'll change the working directory temporarily self.old_working_dir = os.getcwd() os.chdir(os.path.dirname(os.path.abspath(__file__))) + assert self.module is not None def tearDown(self): super().tearDown() @@ -50,14 +70,7 @@ class TestCppExtensionOpenRgistration(common.TestCase): @classmethod def setUpClass(cls): remove_build_path() - - @classmethod - def tearDownClass(cls): - remove_build_path() - - @unittest.skipIf(IS_ARM64, "Does not work on arm") - def test_open_device_registration(self): - module = torch.utils.cpp_extension.load( + cls.module = torch.utils.cpp_extension.load( name="custom_device_extension", sources=[ "cpp_extensions/open_registration_extension.cpp", @@ -67,10 +80,15 @@ class TestCppExtensionOpenRgistration(common.TestCase): verbose=True, ) - self.assertFalse(module.custom_add_called()) + @classmethod + def tearDownClass(cls): + remove_build_path() + + def test_open_device_registration(self): + self.assertFalse(self.module.custom_add_called()) # create a tensor using our custom device object. - device = module.custom_device() + device = self.module.custom_device() x = torch.empty(4, 4, device=device) y = torch.empty(4, 4, device=device) @@ -79,13 +97,13 @@ class TestCppExtensionOpenRgistration(common.TestCase): self.assertTrue(x.device == device) self.assertFalse(x.is_cpu) - self.assertFalse(module.custom_add_called()) + self.assertFalse(self.module.custom_add_called()) # calls out custom add kernel, registered to the dispatcher z = x + y # check that it was called - self.assertTrue(module.custom_add_called()) + self.assertTrue(self.module.custom_add_called()) z_cpu = z.to(device='cpu') @@ -98,14 +116,14 @@ class TestCppExtensionOpenRgistration(common.TestCase): z2 = z_cpu + z_cpu # None of our CPU operations should call the custom add function. - self.assertFalse(module.custom_add_called()) + self.assertFalse(self.module.custom_add_called()) # check generator registered befor use with self.assertRaisesRegex(RuntimeError, "Please register a generator to the PrivateUse1 dispatch key"): gen_ = torch.Generator(device=device) - module.register_generator() + self.module.register_generator() gen = torch.Generator(device=device) self.assertTrue(gen.device == device) @@ -113,7 +131,7 @@ class TestCppExtensionOpenRgistration(common.TestCase): # generator can be registered only once with self.assertRaisesRegex(RuntimeError, "Only can register a generator to the PrivateUse1 dispatch key once"): - module.register_generator() + self.module.register_generator() # check whether print tensor.type() meets the expectation torch.utils.rename_privateuse1_backend('foo') @@ -132,5 +150,18 @@ class TestCppExtensionOpenRgistration(common.TestCase): test_tensor = torch.empty(4, 4, dtype=tt, device=device) self.assertTrue(test_tensor.type() == dt) + def test_open_device_random(self): + torch.utils.rename_privateuse1_backend('foo') + with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"): + torch._register_device_module('xxx', DummyModule) + + with self.assertRaisesRegex(RuntimeError, "torch has no module of"): + with torch.random.fork_rng(device_type="foo"): + pass + torch._register_device_module('foo', DummyModule) + + with torch.random.fork_rng(device_type="foo"): + pass + if __name__ == "__main__": common.run_tests() diff --git a/torch/random.py b/torch/random.py index 5bd9323995d..fdea0d3f963 100644 --- a/torch/random.py +++ b/torch/random.py @@ -101,23 +101,29 @@ _fork_rng_warned_already = False @contextlib.contextmanager -def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices") -> Generator: +def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices", device_type="cuda") -> Generator: """ Forks the RNG, so that when you return, the RNG is reset to the state that it was previously in. Args: - devices (iterable of CUDA IDs): CUDA devices for which to fork - the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates + devices (iterable of Device IDs): devices for which to fork + the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates on all devices, but will emit a warning if your machine has a lot of devices, since this function will run very slowly in that case. If you explicitly specify devices, this warning will be suppressed enabled (bool): if ``False``, the RNG is not forked. This is a convenience argument for easily disabling the context manager without having to delete it and unindent your Python code under it. + deivce(str): device type str, default is `cuda`. As for custom device, + see details in [Note: support the custom device with privateuse1] """ - import torch.cuda + device_type = torch.device(device_type).type + device_mod = getattr(torch, device_type, None) + if device_mod is None: + raise RuntimeError(f"torch has no module of `{device_type}`, you should register " + + "a module by `torch._register_device_module`.") global _fork_rng_warned_already # Internal arguments: @@ -129,21 +135,20 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device return if devices is None: - num_devices = torch.cuda.device_count() + num_devices = device_mod.device_count() if num_devices > 1 and not _fork_rng_warned_already: - warnings.warn( - ("CUDA reports that you have {num_devices} available devices, and you " - "have used {caller} without explicitly specifying which devices are being used. " - "For safety, we initialize *every* CUDA device by default, which " - "can be quite slow if you have a lot of GPUs. If you know that you are only " - "making use of a few CUDA devices, set the environment variable CUDA_VISIBLE_DEVICES " - "or the '{devices_kw}' keyword argument of {caller} with the set of devices " - "you are actually using. For example, if you are using CPU only, " - "set CUDA_VISIBLE_DEVICES= or devices=[]; if you are using " - "GPU 0 only, set CUDA_VISIBLE_DEVICES=0 or devices=[0]. To initialize " - "all devices and suppress this warning, set the '{devices_kw}' keyword argument " - "to `range(torch.cuda.device_count())`." - ).format(num_devices=num_devices, caller=_caller, devices_kw=_devices_kw)) + message = (f"{device_type.upper()} reports that you have {num_devices} available devices, and " + f"you have used {_caller} without explicitly specifying which devices are being used. " + f"For safety, we initialize *every* {device_type.upper()} device by default, which can " + "be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only" + " making use of a few {device_type.upper()} devices, set the environment variable " + f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} " + "with the set of devices you are actually using. For example, if you are using CPU only, " + "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, " + f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices " + "and suppress this warning, set the '{_devices_kw}' keyword argument to " + "`range(torch.{device_type}.device_count())`.") + warnings.warn(message) _fork_rng_warned_already = True devices = list(range(num_devices)) else: @@ -152,13 +157,13 @@ def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="device devices = list(devices) cpu_rng_state = torch.get_rng_state() - gpu_rng_states = [] + device_rng_states = [] for device in devices: - gpu_rng_states.append(torch.cuda.get_rng_state(device)) + device_rng_states.append(device_mod.get_rng_state(device)) try: yield finally: torch.set_rng_state(cpu_rng_state) - for device, gpu_rng_state in zip(devices, gpu_rng_states): - torch.cuda.set_rng_state(gpu_rng_state, device) + for device, device_rng_state in zip(devices, device_rng_states): + device_mod.set_rng_state(device_rng_state, device) diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index 1a619675101..d0c052b7309 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -47,6 +47,15 @@ def rename_privateuse1_backend(backend_name: str) -> None: (2) manual_seed_all(seed: int) -> None Sets the seed for generating random numbers for your devices. + (3) device_count() -> int: + Returns the number of `foo`s available. + + (4) get_rng_state(device: Union[int, str, torch.device] = 'foo') -> Tensor: + Returns a list of ByteTensor representing the random number states of all devices. + + (5) set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> None: + Sets the random number generator state of the specified `foo` device. + For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example