From e4a42bcf5645a97705a2bb7b28408b6884dcb79d Mon Sep 17 00:00:00 2001 From: shibo19 <18207133434@163.com> Date: Wed, 7 Jun 2023 13:59:20 +0000 Subject: [PATCH] add foreach support for custom device (#102047) Fixes #ISSUE_NUMBER for custom device, we want to support foreach, so I add a func that we could set other device type, and the default value is cuda. Pull Request resolved: https://github.com/pytorch/pytorch/pull/102047 Approved by: https://github.com/janeyx99 --- torch/optim/adam.py | 11 +++++++---- torch/optim/adamw.py | 8 ++++++-- torch/optim/optimizer.py | 12 ++++++++++-- torch/utils/_foreach_utils.py | 13 ++++++++++++- torch/utils/backend_registration.py | 12 +++++++++++- 5 files changed, 46 insertions(+), 10 deletions(-) diff --git a/torch/optim/adam.py b/torch/optim/adam.py index e723a1955c0..6d18b98d717 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -5,6 +5,7 @@ from torch import Tensor from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling, _dispatch_sqrt, _default_to_fused_or_foreach, _capturable_doc, _differentiable_doc, _foreach_doc, _fused_doc, _maximize_doc) +from torch.utils._foreach_utils import _get_fused_kernels_supported_devices __all__ = ['Adam', 'adam'] @@ -36,14 +37,16 @@ class Adam(Optimizer): raise RuntimeError("`fused` does not support `differentiable`") self._step_supports_amp_scaling = True # TODO(crcrpar): [low prec params & their higher prec copy] - # Suppor AMP with FP16/BF16 model params which would need + # Support AMP with FP16/BF16 model params which would need # higher prec copy of params to do update math in higher prec to # alleviate the loss of information. + fused_supported_devices = _get_fused_kernels_supported_devices() if not all( - p.is_cuda and torch.is_floating_point(p) - for pg in self.param_groups for p in pg['params'] + p.device.type in fused_supported_devices and + torch.is_floating_point(p) for pg in self.param_groups for p in pg['params'] ): - raise RuntimeError("`fused=True` requires all the params to be CUDA, floating point Tensor") + raise RuntimeError("`fused=True` requires all the params to be floating point Tensors of " + f"supported devices: {fused_supported_devices}.") if foreach: raise RuntimeError("`fused` and `foreach` cannot be `True` together.") diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index d23961bedf1..27b8101ecd6 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -4,6 +4,7 @@ from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _di _stack_if_compiling, _capturable_doc, _differentiable_doc, _foreach_doc, _fused_doc, _maximize_doc, _default_to_fused_or_foreach) from typing import List, Optional +from torch.utils._foreach_utils import _get_fused_kernels_supported_devices __all__ = ["AdamW", "adamw"] @@ -56,11 +57,14 @@ class AdamW(Optimizer): # Suppor AMP with FP16/BF16 model params which would need # higher prec copy of params to do update math in higher prec to # alleviate the loss of information. + fused_supported_devices = _get_fused_kernels_supported_devices() if not all( - p.is_cuda and torch.is_floating_point(p) + p.device.type in fused_supported_devices and + torch.is_floating_point(p) for pg in self.param_groups for p in pg['params'] ): - raise RuntimeError("`fused=True` requires all the params to be CUDA, floating point Tensor") + raise RuntimeError("`fused=True` requires all the params to be floating point Tensors of " + f"supported devices: {fused_supported_devices}.") if foreach: raise RuntimeError("`fused` and `foreach` cannot be `True` together.") diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 60daab502da..0e48cfa164c 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -10,6 +10,8 @@ from typing import Callable, Dict, List, Tuple import torch.utils.hooks as hooks from torch.utils.hooks import RemovableHandle +from torch.utils._foreach_utils import (_get_fused_kernels_supported_devices, + _get_foreach_kernels_supported_devices) from torch._utils import is_compiling from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype @@ -67,11 +69,17 @@ def _default_to_fused_or_foreach(params: List[torch.Tensor], use_fused: bool = False) -> Tuple[bool, bool]: if torch.jit.is_scripting() or differentiable: return False, False + + fused_supported_devices = _get_fused_kernels_supported_devices() + foreach_supported_devices = _get_foreach_kernels_supported_devices() fused = use_fused and all( - p is None or (type(p) in _foreach_supported_types and p.is_cuda and torch.is_floating_point(p)) for p in params + p is None or (type(p) in _foreach_supported_types and + p.device.type in fused_supported_devices and + torch.is_floating_point(p)) for p in params ) foreach = not fused and all( - p is None or (type(p) in _foreach_supported_types and p.is_cuda) for p in params + p is None or (type(p) in _foreach_supported_types and + p.device.type in foreach_supported_devices) for p in params ) return fused, foreach diff --git a/torch/utils/_foreach_utils.py b/torch/utils/_foreach_utils.py index af8c74e2b0c..400066529ae 100644 --- a/torch/utils/_foreach_utils.py +++ b/torch/utils/_foreach_utils.py @@ -5,6 +5,17 @@ import torch from torch import Tensor from torch.autograd.grad_mode import no_grad +def _get_foreach_kernels_supported_devices() -> List[str]: + r""" + Return the device type list that supports foreach kernels. + """ + return ["cuda", torch.utils.backend_registration._privateuse1_backend_name] + +def _get_fused_kernels_supported_devices() -> List[str]: + r""" + Return the device type list that supports fused kernels in optimizer. + """ + return ["cuda", torch.utils.backend_registration._privateuse1_backend_name] # This util function splits tensors into groups by device and dtype, which is useful before sending # tensors off to a foreach implementation, which requires tensors to be on one device and dtype. @@ -37,6 +48,6 @@ def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]], return per_device_and_dtype_tensors def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool: - if device.type not in ['cpu', 'cuda'] or torch.jit.is_scripting(): + if device.type not in set(_get_foreach_kernels_supported_devices() + ["cpu"]) or torch.jit.is_scripting(): return False return all(t is None or type(t) == torch.Tensor for t in tensors) diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index 7650719024a..716ebd66c27 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -4,6 +4,14 @@ from typing import List, Optional, Union __all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"] +# TODO: Should use `torch._C._get_privateuse1_backend_name()` to get +# renamed-backend name for `privateuse1`, but the func will cause a +# graph break in inductor test, so we use the global variable named +# `_privateuse1_backend_name`. Once we solve the graph break, we need +# to remove the variable and use the func named +# `torch._C._get_privateuse1_backend_name()` instead. +_privateuse1_backend_name = "privateuseone" + def rename_privateuse1_backend(backend_name: str) -> None: r""" rename_privateuse1_backend(backend_name) -> None @@ -77,7 +85,9 @@ def rename_privateuse1_backend(backend_name: str) -> None: # to implement torch.ones. >>> a = torch.ones(2, device="foo") """ - return _rename_privateuse1_backend(backend_name) + _rename_privateuse1_backend(backend_name) + global _privateuse1_backend_name + _privateuse1_backend_name = backend_name def _check_register_once(module, attr):