mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
support random for custom device (#97420)
Fixes #ISSUE_NUMBER set seed for custom device, @bdhirsh , please review my changes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97420 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
3eecca764a
commit
7fc100a290
2 changed files with 37 additions and 1 deletions
|
|
@ -43,6 +43,8 @@ def manual_seed(seed) -> torch._C.Generator:
|
|||
if not torch.mps._is_in_bad_fork():
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
_seed_custom_device(seed)
|
||||
|
||||
return default_generator.manual_seed(seed)
|
||||
|
||||
|
||||
|
|
@ -60,9 +62,34 @@ def seed() -> int:
|
|||
if not torch.mps._is_in_bad_fork():
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
_seed_custom_device(seed)
|
||||
|
||||
return seed
|
||||
|
||||
|
||||
def _seed_custom_device(seed) -> None:
|
||||
r"""Sets the seed to generate random numbers for custom device.
|
||||
|
||||
Args:
|
||||
seed (int): The desired seed.
|
||||
|
||||
See [Note: support the custom device with privateuse1]
|
||||
"""
|
||||
seed = int(seed)
|
||||
custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||
if hasattr(torch, custom_backend_name):
|
||||
custom_device_mod = getattr(torch, custom_backend_name)
|
||||
_bad_fork_name = "_is_in_bad_fork"
|
||||
_seed_all_name = "manual_seed_all"
|
||||
if hasattr(custom_device_mod, _bad_fork_name) and hasattr(custom_device_mod, _seed_all_name):
|
||||
if not getattr(custom_device_mod, _bad_fork_name)():
|
||||
getattr(custom_device_mod, _seed_all_name)(seed)
|
||||
else:
|
||||
message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's "
|
||||
message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module."
|
||||
warnings.warn(message, UserWarning, stacklevel=3)
|
||||
|
||||
|
||||
def initial_seed() -> int:
|
||||
r"""Returns the initial seed for generating random numbers as a
|
||||
Python `long`.
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ def rename_privateuse1_backend(backend_name: str) -> None:
|
|||
r"""
|
||||
rename_privateuse1_backend(backend_name) -> None
|
||||
|
||||
Note: support the custom device with privateuse1
|
||||
This is a registration API for external backends that would like to register their
|
||||
own device and C++ kernels out of tree.
|
||||
|
||||
|
|
@ -17,7 +18,7 @@ def rename_privateuse1_backend(backend_name: str) -> None:
|
|||
Note: this API can only be called once per process. Attempting to change
|
||||
the external backend after it's already been set will result in an error.
|
||||
|
||||
Note: and if you want to support AMP on your device, you can register a custom backend module.
|
||||
Note(AMP): If you want to support AMP on your device, you can register a custom backend module.
|
||||
The backend must register a custom backend module with `torch._register_device_module("foo", BackendModule)`.
|
||||
BackendModule needs to have the following API's:
|
||||
|
||||
|
|
@ -38,6 +39,14 @@ def rename_privateuse1_backend(backend_name: str) -> None:
|
|||
set the supported dtype on your `foo` device in AMP, and the dtype be contained in the dtypes got
|
||||
from `get_amp_supported_dtype`.
|
||||
|
||||
Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's:
|
||||
|
||||
(1) _is_in_bad_fork() -> bool
|
||||
Return `True` if now it is in bad_fork, else return `False`.
|
||||
|
||||
(2) manual_seed_all(seed: int) -> None
|
||||
Sets the seed for generating random numbers for your devices.
|
||||
|
||||
For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend
|
||||
For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue