add foreach support for custom device (#102047)

Fixes #ISSUE_NUMBER
for custom device, we want to support foreach, so I add a func that we could set other device type, and the default value is cuda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102047
Approved by: https://github.com/janeyx99
This commit is contained in:
shibo19 2023-06-07 13:59:20 +00:00 committed by PyTorch MergeBot
parent 07104ca99c
commit e4a42bcf56
5 changed files with 46 additions and 10 deletions

View file

@ -5,6 +5,7 @@ from torch import Tensor
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling,
_dispatch_sqrt, _default_to_fused_or_foreach, _capturable_doc,
_differentiable_doc, _foreach_doc, _fused_doc, _maximize_doc)
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
__all__ = ['Adam', 'adam']
@ -36,14 +37,16 @@ class Adam(Optimizer):
raise RuntimeError("`fused` does not support `differentiable`")
self._step_supports_amp_scaling = True
# TODO(crcrpar): [low prec params & their higher prec copy]
# Suppor AMP with FP16/BF16 model params which would need
# Support AMP with FP16/BF16 model params which would need
# higher prec copy of params to do update math in higher prec to
# alleviate the loss of information.
fused_supported_devices = _get_fused_kernels_supported_devices()
if not all(
p.is_cuda and torch.is_floating_point(p)
for pg in self.param_groups for p in pg['params']
p.device.type in fused_supported_devices and
torch.is_floating_point(p) for pg in self.param_groups for p in pg['params']
):
raise RuntimeError("`fused=True` requires all the params to be CUDA, floating point Tensor")
raise RuntimeError("`fused=True` requires all the params to be floating point Tensors of "
f"supported devices: {fused_supported_devices}.")
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")

View file

@ -4,6 +4,7 @@ from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _di
_stack_if_compiling, _capturable_doc, _differentiable_doc, _foreach_doc,
_fused_doc, _maximize_doc, _default_to_fused_or_foreach)
from typing import List, Optional
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
__all__ = ["AdamW", "adamw"]
@ -56,11 +57,14 @@ class AdamW(Optimizer):
# Suppor AMP with FP16/BF16 model params which would need
# higher prec copy of params to do update math in higher prec to
# alleviate the loss of information.
fused_supported_devices = _get_fused_kernels_supported_devices()
if not all(
p.is_cuda and torch.is_floating_point(p)
p.device.type in fused_supported_devices and
torch.is_floating_point(p)
for pg in self.param_groups for p in pg['params']
):
raise RuntimeError("`fused=True` requires all the params to be CUDA, floating point Tensor")
raise RuntimeError("`fused=True` requires all the params to be floating point Tensors of "
f"supported devices: {fused_supported_devices}.")
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")

View file

@ -10,6 +10,8 @@ from typing import Callable, Dict, List, Tuple
import torch.utils.hooks as hooks
from torch.utils.hooks import RemovableHandle
from torch.utils._foreach_utils import (_get_fused_kernels_supported_devices,
_get_foreach_kernels_supported_devices)
from torch._utils import is_compiling
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
@ -67,11 +69,17 @@ def _default_to_fused_or_foreach(params: List[torch.Tensor],
use_fused: bool = False) -> Tuple[bool, bool]:
if torch.jit.is_scripting() or differentiable:
return False, False
fused_supported_devices = _get_fused_kernels_supported_devices()
foreach_supported_devices = _get_foreach_kernels_supported_devices()
fused = use_fused and all(
p is None or (type(p) in _foreach_supported_types and p.is_cuda and torch.is_floating_point(p)) for p in params
p is None or (type(p) in _foreach_supported_types and
p.device.type in fused_supported_devices and
torch.is_floating_point(p)) for p in params
)
foreach = not fused and all(
p is None or (type(p) in _foreach_supported_types and p.is_cuda) for p in params
p is None or (type(p) in _foreach_supported_types and
p.device.type in foreach_supported_devices) for p in params
)
return fused, foreach

View file

@ -5,6 +5,17 @@ import torch
from torch import Tensor
from torch.autograd.grad_mode import no_grad
def _get_foreach_kernels_supported_devices() -> List[str]:
r"""
Return the device type list that supports foreach kernels.
"""
return ["cuda", torch.utils.backend_registration._privateuse1_backend_name]
def _get_fused_kernels_supported_devices() -> List[str]:
r"""
Return the device type list that supports fused kernels in optimizer.
"""
return ["cuda", torch.utils.backend_registration._privateuse1_backend_name]
# This util function splits tensors into groups by device and dtype, which is useful before sending
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
@ -37,6 +48,6 @@ def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
return per_device_and_dtype_tensors
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
if device.type not in ['cpu', 'cuda'] or torch.jit.is_scripting():
if device.type not in set(_get_foreach_kernels_supported_devices() + ["cpu"]) or torch.jit.is_scripting():
return False
return all(t is None or type(t) == torch.Tensor for t in tensors)

View file

@ -4,6 +4,14 @@ from typing import List, Optional, Union
__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"]
# TODO: Should use `torch._C._get_privateuse1_backend_name()` to get
# renamed-backend name for `privateuse1`, but the func will cause a
# graph break in inductor test, so we use the global variable named
# `_privateuse1_backend_name`. Once we solve the graph break, we need
# to remove the variable and use the func named
# `torch._C._get_privateuse1_backend_name()` instead.
_privateuse1_backend_name = "privateuseone"
def rename_privateuse1_backend(backend_name: str) -> None:
r"""
rename_privateuse1_backend(backend_name) -> None
@ -77,7 +85,9 @@ def rename_privateuse1_backend(backend_name: str) -> None:
# to implement torch.ones.
>>> a = torch.ones(2, device="foo")
"""
return _rename_privateuse1_backend(backend_name)
_rename_privateuse1_backend(backend_name)
global _privateuse1_backend_name
_privateuse1_backend_name = backend_name
def _check_register_once(module, attr):