Autoselect default device in FSDP construction. (#127609)

There are still some differences between CUDA and non-CUDA custom devices when
construct FSDP because CUDA is selected as the default device. For example,
when construct FSDP from CPU model and device_id is not passed, device_handle
will choose CUDA as default device. This PR will autoselect the real device
as the default device.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127609
Approved by: https://github.com/awgu
This commit is contained in:
daitian1995 2024-08-08 05:25:15 +00:00 committed by PyTorch MergeBot
parent 4a1edbe475
commit aff48f7378
2 changed files with 28 additions and 22 deletions

View file

@ -149,7 +149,7 @@ alongside a CPU to speed up computation. These device use an asynchronous execut
using :class:`torch.Stream` and :class:`torch.Event` as their main way to perform synchronization.
We also assume that only one such accelerator can be available at once on a given host. This allows
us to use the current accelerator as the default device for relevant concepts such as pinned memory,
Stream device_type, etc.
Stream device_type, FSDP, etc.
As of today, accelerator devices are (in no particular order) :doc:`"CUDA" <cuda>`, :doc:`"MTIA" <mtia>`,
:doc:`"XPU" <xpu>`, and PrivateUse1 (many device not in the PyTorch repo itself).

View file

@ -371,7 +371,9 @@ def _init_device_handle(
If a device is specified by ``device_id``,
then returns device handle corresponds to that device type. Otherwise, If the
module is already on a non-CPU device, then the device type is that non-CPU device type.
If the module is on CPU or meta, then the device type is the current cuda device.
If the module is on CPU or meta, then the device type is the current accelerator device.
See the :ref:`Accelerators<accelerators>` for details.
This method will be called once ignored paramters was determined, as the device handle maybe needed
for other initialization.
@ -395,9 +397,11 @@ def _init_device_handle(
f"FSDP does not support modules with different device types "
f"but got params on {determined_device.type} and {param.device.type}"
)
determined_device = determined_device or torch.device(
"cuda", torch.cuda.current_device()
)
determined_device = determined_device or torch._C._get_accelerator()
if determined_device.type == "cpu":
raise RuntimeError(
"FSDP needs a non-CPU accelerator device, but no accelerator device is detected."
)
state._device_handle = _FSDPDeviceHandle.from_device(determined_device)
return state
@ -555,7 +559,9 @@ def _init_param_handle_from_module(
) -> _FSDPState:
"""Initialize a ``FlatParamHandle`` from a module ``fully_sharded_module``."""
_check_single_device_module(fully_sharded_module, state._ignored_params, device_id)
device_from_device_id = _get_device_from_device_id(device_id, state.rank)
device_from_device_id = _get_device_from_device_id(
device_id, state.rank, state._device_handle
)
is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module(
fully_sharded_module, state._ignored_params, state._ignored_modules
)
@ -566,7 +572,10 @@ def _init_param_handle_from_module(
)
elif is_meta_module:
_materialize_meta_module(
fully_sharded_module, device_id, state._ignored_modules
fully_sharded_module,
device_id,
state._ignored_modules,
state._device_handle,
)
elif is_torchdistX_deferred_init:
deferred_init.materialize_module(
@ -592,6 +601,7 @@ def _init_param_handle_from_module(
state._ignored_params,
device_from_device_id,
state.rank,
state._device_handle,
)
managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params))
@ -798,6 +808,7 @@ def _check_single_device_module(
def _get_device_from_device_id(
device_id: Optional[Union[int, torch.device]],
rank: int,
device_handle: _FSDPDeviceHandle,
) -> Optional[torch.device]:
"""
Return a ``torch.device`` for the specified ``device_id``.
@ -810,18 +821,16 @@ def _get_device_from_device_id(
device = (
device_id if isinstance(device_id, torch.device) else torch.device(device_id)
)
if device in [torch.device("cuda"), torch.device("hpu")]:
backend = getattr(torch, device.type)
device_idx = backend.current_device()
if device.type != "cpu" and device.index is None:
warnings.warn(
f"FSDP got the argument `device_id` {device_id} on rank "
f"{rank}, which does not have an explicit index. "
f"FSDP will use the current device {device_idx}. "
"If this is incorrect, please explicitly call torch.{device_type}.set_device() "
f"FSDP will use the current device {device_handle.current_device()}. "
f"If this is incorrect, please explicitly call `torch.{device.type}.set_device()` "
"before FSDP initialization or pass in the explicit device "
"index as the `device_id` argument."
)
device = torch.device(device.type, device_idx)
device = torch.device(device_handle.current_device())
return device
@ -873,10 +882,11 @@ def _materialize_meta_module(
root_module: nn.Module,
device_from_device_id: Optional[torch.device],
ignored_modules: Set[nn.Module],
device_handle: _FSDPDeviceHandle,
):
# Run default meta device initialization
materialization_device = device_from_device_id or torch.device(
torch.cuda.current_device()
device_handle.current_device()
)
modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
module = None
@ -1027,19 +1037,18 @@ def _get_compute_device(
ignored_params: Set[nn.Parameter],
device_from_device_id: Optional[torch.device],
rank: int,
device_handle: _FSDPDeviceHandle,
) -> torch.device:
"""
Determine and return this FSDP instance's compute device.
If a device is
specified by ``device_id``, then returns that device. Otherwise, If the
module is already on a non-CPU device, then the compute device is that non-CPU
If the module is already on a non-CPU device, then the compute device is that non-CPU
device. If the module is on CPU, then the compute device is the current
device.
Since this method should be called after materializing the module, any
non-CPU device should not be meta device. For now, the compute device is
always a CUDA GPU device with its explicit index.
always a CUDA or CUDA-like device with its explicit index.
Precondition: ``_check_single_device_module()`` and
``_move_module_to_device()``.
@ -1048,10 +1057,7 @@ def _get_compute_device(
if param is not None and param.device.type != "cpu":
compute_device = param.device # Determined by model param placement
else:
if device_from_device_id is not None and device_from_device_id.type != "cuda":
compute_device = device_from_device_id # Determined by custom backend
else:
compute_device = torch.device("cuda", torch.cuda.current_device())
compute_device = torch.device(device_handle.current_device())
if device_from_device_id is not None and compute_device != device_from_device_id:
raise ValueError(
f"Inconsistent compute device and `device_id` on rank {rank}: "