mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
add foreach support for custom device (#102047)
Fixes #ISSUE_NUMBER for custom device, we want to support foreach, so I add a func that we could set other device type, and the default value is cuda. Pull Request resolved: https://github.com/pytorch/pytorch/pull/102047 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
07104ca99c
commit
e4a42bcf56
5 changed files with 46 additions and 10 deletions
|
|
@ -5,6 +5,7 @@ from torch import Tensor
|
|||
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling,
|
||||
_dispatch_sqrt, _default_to_fused_or_foreach, _capturable_doc,
|
||||
_differentiable_doc, _foreach_doc, _fused_doc, _maximize_doc)
|
||||
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
|
||||
|
||||
__all__ = ['Adam', 'adam']
|
||||
|
||||
|
|
@ -36,14 +37,16 @@ class Adam(Optimizer):
|
|||
raise RuntimeError("`fused` does not support `differentiable`")
|
||||
self._step_supports_amp_scaling = True
|
||||
# TODO(crcrpar): [low prec params & their higher prec copy]
|
||||
# Suppor AMP with FP16/BF16 model params which would need
|
||||
# Support AMP with FP16/BF16 model params which would need
|
||||
# higher prec copy of params to do update math in higher prec to
|
||||
# alleviate the loss of information.
|
||||
fused_supported_devices = _get_fused_kernels_supported_devices()
|
||||
if not all(
|
||||
p.is_cuda and torch.is_floating_point(p)
|
||||
for pg in self.param_groups for p in pg['params']
|
||||
p.device.type in fused_supported_devices and
|
||||
torch.is_floating_point(p) for pg in self.param_groups for p in pg['params']
|
||||
):
|
||||
raise RuntimeError("`fused=True` requires all the params to be CUDA, floating point Tensor")
|
||||
raise RuntimeError("`fused=True` requires all the params to be floating point Tensors of "
|
||||
f"supported devices: {fused_supported_devices}.")
|
||||
if foreach:
|
||||
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _di
|
|||
_stack_if_compiling, _capturable_doc, _differentiable_doc, _foreach_doc,
|
||||
_fused_doc, _maximize_doc, _default_to_fused_or_foreach)
|
||||
from typing import List, Optional
|
||||
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
|
||||
|
||||
__all__ = ["AdamW", "adamw"]
|
||||
|
||||
|
|
@ -56,11 +57,14 @@ class AdamW(Optimizer):
|
|||
# Suppor AMP with FP16/BF16 model params which would need
|
||||
# higher prec copy of params to do update math in higher prec to
|
||||
# alleviate the loss of information.
|
||||
fused_supported_devices = _get_fused_kernels_supported_devices()
|
||||
if not all(
|
||||
p.is_cuda and torch.is_floating_point(p)
|
||||
p.device.type in fused_supported_devices and
|
||||
torch.is_floating_point(p)
|
||||
for pg in self.param_groups for p in pg['params']
|
||||
):
|
||||
raise RuntimeError("`fused=True` requires all the params to be CUDA, floating point Tensor")
|
||||
raise RuntimeError("`fused=True` requires all the params to be floating point Tensors of "
|
||||
f"supported devices: {fused_supported_devices}.")
|
||||
if foreach:
|
||||
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from typing import Callable, Dict, List, Tuple
|
|||
|
||||
import torch.utils.hooks as hooks
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from torch.utils._foreach_utils import (_get_fused_kernels_supported_devices,
|
||||
_get_foreach_kernels_supported_devices)
|
||||
from torch._utils import is_compiling
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
||||
|
||||
|
|
@ -67,11 +69,17 @@ def _default_to_fused_or_foreach(params: List[torch.Tensor],
|
|||
use_fused: bool = False) -> Tuple[bool, bool]:
|
||||
if torch.jit.is_scripting() or differentiable:
|
||||
return False, False
|
||||
|
||||
fused_supported_devices = _get_fused_kernels_supported_devices()
|
||||
foreach_supported_devices = _get_foreach_kernels_supported_devices()
|
||||
fused = use_fused and all(
|
||||
p is None or (type(p) in _foreach_supported_types and p.is_cuda and torch.is_floating_point(p)) for p in params
|
||||
p is None or (type(p) in _foreach_supported_types and
|
||||
p.device.type in fused_supported_devices and
|
||||
torch.is_floating_point(p)) for p in params
|
||||
)
|
||||
foreach = not fused and all(
|
||||
p is None or (type(p) in _foreach_supported_types and p.is_cuda) for p in params
|
||||
p is None or (type(p) in _foreach_supported_types and
|
||||
p.device.type in foreach_supported_devices) for p in params
|
||||
)
|
||||
return fused, foreach
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,17 @@ import torch
|
|||
from torch import Tensor
|
||||
from torch.autograd.grad_mode import no_grad
|
||||
|
||||
def _get_foreach_kernels_supported_devices() -> List[str]:
|
||||
r"""
|
||||
Return the device type list that supports foreach kernels.
|
||||
"""
|
||||
return ["cuda", torch.utils.backend_registration._privateuse1_backend_name]
|
||||
|
||||
def _get_fused_kernels_supported_devices() -> List[str]:
|
||||
r"""
|
||||
Return the device type list that supports fused kernels in optimizer.
|
||||
"""
|
||||
return ["cuda", torch.utils.backend_registration._privateuse1_backend_name]
|
||||
|
||||
# This util function splits tensors into groups by device and dtype, which is useful before sending
|
||||
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
|
||||
|
|
@ -37,6 +48,6 @@ def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
|
|||
return per_device_and_dtype_tensors
|
||||
|
||||
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
|
||||
if device.type not in ['cpu', 'cuda'] or torch.jit.is_scripting():
|
||||
if device.type not in set(_get_foreach_kernels_supported_devices() + ["cpu"]) or torch.jit.is_scripting():
|
||||
return False
|
||||
return all(t is None or type(t) == torch.Tensor for t in tensors)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,14 @@ from typing import List, Optional, Union
|
|||
|
||||
__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"]
|
||||
|
||||
# TODO: Should use `torch._C._get_privateuse1_backend_name()` to get
|
||||
# renamed-backend name for `privateuse1`, but the func will cause a
|
||||
# graph break in inductor test, so we use the global variable named
|
||||
# `_privateuse1_backend_name`. Once we solve the graph break, we need
|
||||
# to remove the variable and use the func named
|
||||
# `torch._C._get_privateuse1_backend_name()` instead.
|
||||
_privateuse1_backend_name = "privateuseone"
|
||||
|
||||
def rename_privateuse1_backend(backend_name: str) -> None:
|
||||
r"""
|
||||
rename_privateuse1_backend(backend_name) -> None
|
||||
|
|
@ -77,7 +85,9 @@ def rename_privateuse1_backend(backend_name: str) -> None:
|
|||
# to implement torch.ones.
|
||||
>>> a = torch.ones(2, device="foo")
|
||||
"""
|
||||
return _rename_privateuse1_backend(backend_name)
|
||||
_rename_privateuse1_backend(backend_name)
|
||||
global _privateuse1_backend_name
|
||||
_privateuse1_backend_name = backend_name
|
||||
|
||||
|
||||
def _check_register_once(module, attr):
|
||||
|
|
|
|||
Loading…
Reference in a new issue