support random for custom device (#97420)

Fixes #ISSUE_NUMBER
set seed for custom device, @bdhirsh , please review my changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97420
Approved by: https://github.com/bdhirsh
This commit is contained in:
shibo 2023-03-30 02:12:52 +00:00 committed by PyTorch MergeBot
parent 3eecca764a
commit 7fc100a290
2 changed files with 37 additions and 1 deletions

View file

@ -43,6 +43,8 @@ def manual_seed(seed) -> torch._C.Generator:
if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed)
_seed_custom_device(seed)
return default_generator.manual_seed(seed)
@ -60,9 +62,34 @@ def seed() -> int:
if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed)
_seed_custom_device(seed)
return seed
def _seed_custom_device(seed) -> None:
r"""Sets the seed to generate random numbers for custom device.
Args:
seed (int): The desired seed.
See [Note: support the custom device with privateuse1]
"""
seed = int(seed)
custom_backend_name = torch._C._get_privateuse1_backend_name()
if hasattr(torch, custom_backend_name):
custom_device_mod = getattr(torch, custom_backend_name)
_bad_fork_name = "_is_in_bad_fork"
_seed_all_name = "manual_seed_all"
if hasattr(custom_device_mod, _bad_fork_name) and hasattr(custom_device_mod, _seed_all_name):
if not getattr(custom_device_mod, _bad_fork_name)():
getattr(custom_device_mod, _seed_all_name)(seed)
else:
message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's "
message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module."
warnings.warn(message, UserWarning, stacklevel=3)
def initial_seed() -> int:
r"""Returns the initial seed for generating random numbers as a
Python `long`.

View file

@ -4,6 +4,7 @@ def rename_privateuse1_backend(backend_name: str) -> None:
r"""
rename_privateuse1_backend(backend_name) -> None
Note: support the custom device with privateuse1
This is a registration API for external backends that would like to register their
own device and C++ kernels out of tree.
@ -17,7 +18,7 @@ def rename_privateuse1_backend(backend_name: str) -> None:
Note: this API can only be called once per process. Attempting to change
the external backend after it's already been set will result in an error.
Note: and if you want to support AMP on your device, you can register a custom backend module.
Note(AMP): If you want to support AMP on your device, you can register a custom backend module.
The backend must register a custom backend module with `torch._register_device_module("foo", BackendModule)`.
BackendModule needs to have the following API's:
@ -38,6 +39,14 @@ def rename_privateuse1_backend(backend_name: str) -> None:
set the supported dtype on your `foo` device in AMP, and the dtype be contained in the dtypes got
from `get_amp_supported_dtype`.
Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's:
(1) _is_in_bad_fork() -> bool
Return `True` if now it is in bad_fork, else return `False`.
(2) manual_seed_all(seed: int) -> None
Sets the seed for generating random numbers for your devices.
For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend
For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example