Use device-agnostic runtime API in distributed DDP/FSDP instead of cuda device specific. (#137678)

# Motivation
This PR targets to use device-agnostic runtime API in distributed DDP/FSDP instead of `cuda` device specific.

cc cc [@jgong5](https://github.com/jgong5) [@gujinghui](https://github.com/gujinghui) [@EikanWang](https://github.com/EikanWang) [@fengyuan14](https://github.com/fengyuan14) [@guangyey](https://github.com/guangyey)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137678
Approved by: https://github.com/kwen2501, https://github.com/guangyey, https://github.com/jgong5
This commit is contained in:
lzhang2 2024-11-13 05:32:19 +00:00 committed by PyTorch MergeBot
parent 4c6eebf4e2
commit 1886e33f60
5 changed files with 14 additions and 10 deletions

View file

@ -17,7 +17,7 @@ class _AllreduceUpcastHookState:
"""
ddp_weakref: Any
upcast_stream: torch.cuda.Stream
upcast_stream: torch.Stream
wait_for_stream_enqueued: bool = False
@ -46,7 +46,7 @@ def _reducer_allreduce_and_upcast_hook(
fut = reducer._run_allreduce_hook(bucket)
ret_fut = torch.futures.Future()
stream = hook_state.upcast_stream
with torch.cuda.stream(stream):
with torch.get_device_module().stream(stream):
fut.wait()
bucket.buffer().div_(process_group.size())
ret_fut.set_result(bucket.buffer())
@ -61,7 +61,7 @@ def _reducer_allreduce_and_upcast_hook(
# enqueue a callback to wait for this stream at end of backward
def wait_for_stream_cb():
torch.cuda.current_stream().wait_stream(stream)
torch.accelerator.current_stream().wait_stream(stream)
# Remove post-backward hooks since they are re-installed in next
# iteration, similar to FSDP.
# Parameters that don't require grad still needed to be casted since

View file

@ -41,7 +41,7 @@ class _OptimizerHookState:
@dataclass
class _OptimInBackwardHookState:
optim_stream: torch.cuda.Stream
optim_stream: torch.Stream
wait_for_optim_stream_enqueued: bool
@ -57,7 +57,7 @@ def _apply_optim_in_backward_hook(
step for parameters after gradient communication has taken place.
"""
optim_in_bwd_state = _OptimInBackwardHookState(
optim_stream=torch.cuda.Stream(),
optim_stream=torch.Stream(),
wait_for_optim_stream_enqueued=False,
)
@ -72,7 +72,7 @@ def _apply_optim_in_backward_hook(
reducer, process_group = ddp_inst.reducer, ddp_inst.process_group
fut = reducer._run_allreduce_hook(bucket)
optimizer_stream = optim_stream_state.optim_stream
with torch.cuda.stream(optimizer_stream):
with torch.get_device_module().stream(optimizer_stream):
fut.wait()
# Apply gradient division since C++ side only allreduces and does
# not average. TODO: (rohan-varma) the div factor may be different
@ -99,7 +99,9 @@ def _apply_optim_in_backward_hook(
# enqueue a callback to wait for this optimizer stream at the end of
# backward and set all DDP managed grads to None.
def wait_for_optim_stream_callback():
torch.cuda.current_stream().wait_stream(optim_stream_state.optim_stream)
torch.accelerator.current_stream().wait_stream(
optim_stream_state.optim_stream
)
# Set DDP managed grads to None
for param in ddp_inst._get_data_parallel_params(ddp_inst.module):
if hasattr(param, "_in_backward_optimizers"):

View file

@ -40,8 +40,8 @@ def average_parameters(
flat_params = torch.cat([p.data.reshape(-1) for p in params_it1])
flat_params /= dist.get_world_size(group_to_use)
# Make sure the allreduce will not conflict with any other ongoing process group.
if torch.cuda.is_available():
torch.cuda.synchronize()
if torch.accelerator.is_available():
torch.accelerator.synchronize()
dist.all_reduce(flat_params, group=group_to_use)
offset = 0

View file

@ -535,10 +535,11 @@ def _override_module_mixed_precision(
def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None:
# FIXME record_stream doesn't work with non-cuda/mtia tensors
# FIXME record_stream doesn't work with non-cuda/mtia/xpu tensors
if tensor.device.type not in [
"cuda",
"mtia",
"xpu",
torch._C._get_privateuse1_backend_name(),
]:
return

View file

@ -22,6 +22,7 @@ def _is_supported_device(tensor: torch.Tensor) -> bool:
"cpu",
"hpu",
"mtia",
"xpu",
torch._C._get_privateuse1_backend_name(),
)