diff --git a/torch/random.py b/torch/random.py index e4795907a3a..5bd9323995d 100644 --- a/torch/random.py +++ b/torch/random.py @@ -43,6 +43,8 @@ def manual_seed(seed) -> torch._C.Generator: if not torch.mps._is_in_bad_fork(): torch.mps.manual_seed(seed) + _seed_custom_device(seed) + return default_generator.manual_seed(seed) @@ -60,9 +62,34 @@ def seed() -> int: if not torch.mps._is_in_bad_fork(): torch.mps.manual_seed(seed) + _seed_custom_device(seed) + return seed +def _seed_custom_device(seed) -> None: + r"""Sets the seed to generate random numbers for custom device. + + Args: + seed (int): The desired seed. + + See [Note: support the custom device with privateuse1] + """ + seed = int(seed) + custom_backend_name = torch._C._get_privateuse1_backend_name() + if hasattr(torch, custom_backend_name): + custom_device_mod = getattr(torch, custom_backend_name) + _bad_fork_name = "_is_in_bad_fork" + _seed_all_name = "manual_seed_all" + if hasattr(custom_device_mod, _bad_fork_name) and hasattr(custom_device_mod, _seed_all_name): + if not getattr(custom_device_mod, _bad_fork_name)(): + getattr(custom_device_mod, _seed_all_name)(seed) + else: + message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's " + message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module." + warnings.warn(message, UserWarning, stacklevel=3) + + def initial_seed() -> int: r"""Returns the initial seed for generating random numbers as a Python `long`. diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index cbc4f5c9c37..3058e0a6450 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -4,6 +4,7 @@ def rename_privateuse1_backend(backend_name: str) -> None: r""" rename_privateuse1_backend(backend_name) -> None + Note: support the custom device with privateuse1 This is a registration API for external backends that would like to register their own device and C++ kernels out of tree. @@ -17,7 +18,7 @@ def rename_privateuse1_backend(backend_name: str) -> None: Note: this API can only be called once per process. Attempting to change the external backend after it's already been set will result in an error. - Note: and if you want to support AMP on your device, you can register a custom backend module. + Note(AMP): If you want to support AMP on your device, you can register a custom backend module. The backend must register a custom backend module with `torch._register_device_module("foo", BackendModule)`. BackendModule needs to have the following API's: @@ -38,6 +39,14 @@ def rename_privateuse1_backend(backend_name: str) -> None: set the supported dtype on your `foo` device in AMP, and the dtype be contained in the dtypes got from `get_amp_supported_dtype`. + Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's: + + (1) _is_in_bad_fork() -> bool + Return `True` if now it is in bad_fork, else return `False`. + + (2) manual_seed_all(seed: int) -> None + Sets the seed for generating random numbers for your devices. + For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example