mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Autoselect default device in FSDP construction. (#127609)
There are still some differences between CUDA and non-CUDA custom devices when construct FSDP because CUDA is selected as the default device. For example, when construct FSDP from CPU model and device_id is not passed, device_handle will choose CUDA as default device. This PR will autoselect the real device as the default device. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127609 Approved by: https://github.com/awgu
This commit is contained in:
parent
4a1edbe475
commit
aff48f7378
2 changed files with 28 additions and 22 deletions
|
|
@ -149,7 +149,7 @@ alongside a CPU to speed up computation. These device use an asynchronous execut
|
|||
using :class:`torch.Stream` and :class:`torch.Event` as their main way to perform synchronization.
|
||||
We also assume that only one such accelerator can be available at once on a given host. This allows
|
||||
us to use the current accelerator as the default device for relevant concepts such as pinned memory,
|
||||
Stream device_type, etc.
|
||||
Stream device_type, FSDP, etc.
|
||||
|
||||
As of today, accelerator devices are (in no particular order) :doc:`"CUDA" <cuda>`, :doc:`"MTIA" <mtia>`,
|
||||
:doc:`"XPU" <xpu>`, and PrivateUse1 (many device not in the PyTorch repo itself).
|
||||
|
|
|
|||
|
|
@ -371,7 +371,9 @@ def _init_device_handle(
|
|||
If a device is specified by ``device_id``,
|
||||
then returns device handle corresponds to that device type. Otherwise, If the
|
||||
module is already on a non-CPU device, then the device type is that non-CPU device type.
|
||||
If the module is on CPU or meta, then the device type is the current cuda device.
|
||||
If the module is on CPU or meta, then the device type is the current accelerator device.
|
||||
See the :ref:`Accelerators<accelerators>` for details.
|
||||
|
||||
|
||||
This method will be called once ignored paramters was determined, as the device handle maybe needed
|
||||
for other initialization.
|
||||
|
|
@ -395,9 +397,11 @@ def _init_device_handle(
|
|||
f"FSDP does not support modules with different device types "
|
||||
f"but got params on {determined_device.type} and {param.device.type}"
|
||||
)
|
||||
determined_device = determined_device or torch.device(
|
||||
"cuda", torch.cuda.current_device()
|
||||
)
|
||||
determined_device = determined_device or torch._C._get_accelerator()
|
||||
if determined_device.type == "cpu":
|
||||
raise RuntimeError(
|
||||
"FSDP needs a non-CPU accelerator device, but no accelerator device is detected."
|
||||
)
|
||||
|
||||
state._device_handle = _FSDPDeviceHandle.from_device(determined_device)
|
||||
return state
|
||||
|
|
@ -555,7 +559,9 @@ def _init_param_handle_from_module(
|
|||
) -> _FSDPState:
|
||||
"""Initialize a ``FlatParamHandle`` from a module ``fully_sharded_module``."""
|
||||
_check_single_device_module(fully_sharded_module, state._ignored_params, device_id)
|
||||
device_from_device_id = _get_device_from_device_id(device_id, state.rank)
|
||||
device_from_device_id = _get_device_from_device_id(
|
||||
device_id, state.rank, state._device_handle
|
||||
)
|
||||
is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module(
|
||||
fully_sharded_module, state._ignored_params, state._ignored_modules
|
||||
)
|
||||
|
|
@ -566,7 +572,10 @@ def _init_param_handle_from_module(
|
|||
)
|
||||
elif is_meta_module:
|
||||
_materialize_meta_module(
|
||||
fully_sharded_module, device_id, state._ignored_modules
|
||||
fully_sharded_module,
|
||||
device_id,
|
||||
state._ignored_modules,
|
||||
state._device_handle,
|
||||
)
|
||||
elif is_torchdistX_deferred_init:
|
||||
deferred_init.materialize_module(
|
||||
|
|
@ -592,6 +601,7 @@ def _init_param_handle_from_module(
|
|||
state._ignored_params,
|
||||
device_from_device_id,
|
||||
state.rank,
|
||||
state._device_handle,
|
||||
)
|
||||
|
||||
managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params))
|
||||
|
|
@ -798,6 +808,7 @@ def _check_single_device_module(
|
|||
def _get_device_from_device_id(
|
||||
device_id: Optional[Union[int, torch.device]],
|
||||
rank: int,
|
||||
device_handle: _FSDPDeviceHandle,
|
||||
) -> Optional[torch.device]:
|
||||
"""
|
||||
Return a ``torch.device`` for the specified ``device_id``.
|
||||
|
|
@ -810,18 +821,16 @@ def _get_device_from_device_id(
|
|||
device = (
|
||||
device_id if isinstance(device_id, torch.device) else torch.device(device_id)
|
||||
)
|
||||
if device in [torch.device("cuda"), torch.device("hpu")]:
|
||||
backend = getattr(torch, device.type)
|
||||
device_idx = backend.current_device()
|
||||
if device.type != "cpu" and device.index is None:
|
||||
warnings.warn(
|
||||
f"FSDP got the argument `device_id` {device_id} on rank "
|
||||
f"{rank}, which does not have an explicit index. "
|
||||
f"FSDP will use the current device {device_idx}. "
|
||||
"If this is incorrect, please explicitly call torch.{device_type}.set_device() "
|
||||
f"FSDP will use the current device {device_handle.current_device()}. "
|
||||
f"If this is incorrect, please explicitly call `torch.{device.type}.set_device()` "
|
||||
"before FSDP initialization or pass in the explicit device "
|
||||
"index as the `device_id` argument."
|
||||
)
|
||||
device = torch.device(device.type, device_idx)
|
||||
device = torch.device(device_handle.current_device())
|
||||
return device
|
||||
|
||||
|
||||
|
|
@ -873,10 +882,11 @@ def _materialize_meta_module(
|
|||
root_module: nn.Module,
|
||||
device_from_device_id: Optional[torch.device],
|
||||
ignored_modules: Set[nn.Module],
|
||||
device_handle: _FSDPDeviceHandle,
|
||||
):
|
||||
# Run default meta device initialization
|
||||
materialization_device = device_from_device_id or torch.device(
|
||||
torch.cuda.current_device()
|
||||
device_handle.current_device()
|
||||
)
|
||||
modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
|
||||
module = None
|
||||
|
|
@ -1027,19 +1037,18 @@ def _get_compute_device(
|
|||
ignored_params: Set[nn.Parameter],
|
||||
device_from_device_id: Optional[torch.device],
|
||||
rank: int,
|
||||
device_handle: _FSDPDeviceHandle,
|
||||
) -> torch.device:
|
||||
"""
|
||||
Determine and return this FSDP instance's compute device.
|
||||
|
||||
If a device is
|
||||
specified by ``device_id``, then returns that device. Otherwise, If the
|
||||
module is already on a non-CPU device, then the compute device is that non-CPU
|
||||
If the module is already on a non-CPU device, then the compute device is that non-CPU
|
||||
device. If the module is on CPU, then the compute device is the current
|
||||
device.
|
||||
|
||||
Since this method should be called after materializing the module, any
|
||||
non-CPU device should not be meta device. For now, the compute device is
|
||||
always a CUDA GPU device with its explicit index.
|
||||
always a CUDA or CUDA-like device with its explicit index.
|
||||
|
||||
Precondition: ``_check_single_device_module()`` and
|
||||
``_move_module_to_device()``.
|
||||
|
|
@ -1048,10 +1057,7 @@ def _get_compute_device(
|
|||
if param is not None and param.device.type != "cpu":
|
||||
compute_device = param.device # Determined by model param placement
|
||||
else:
|
||||
if device_from_device_id is not None and device_from_device_id.type != "cuda":
|
||||
compute_device = device_from_device_id # Determined by custom backend
|
||||
else:
|
||||
compute_device = torch.device("cuda", torch.cuda.current_device())
|
||||
compute_device = torch.device(device_handle.current_device())
|
||||
if device_from_device_id is not None and compute_device != device_from_device_id:
|
||||
raise ValueError(
|
||||
f"Inconsistent compute device and `device_id` on rank {rank}: "
|
||||
|
|
|
|||
Loading…
Reference in a new issue