mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Use device-agnostic runtime API in distributed DDP/FSDP instead of cuda device specific. (#137678)
# Motivation This PR targets to use device-agnostic runtime API in distributed DDP/FSDP instead of `cuda` device specific. cc cc [@jgong5](https://github.com/jgong5) [@gujinghui](https://github.com/gujinghui) [@EikanWang](https://github.com/EikanWang) [@fengyuan14](https://github.com/fengyuan14) [@guangyey](https://github.com/guangyey) Pull Request resolved: https://github.com/pytorch/pytorch/pull/137678 Approved by: https://github.com/kwen2501, https://github.com/guangyey, https://github.com/jgong5
This commit is contained in:
parent
4c6eebf4e2
commit
1886e33f60
5 changed files with 14 additions and 10 deletions
|
|
@ -17,7 +17,7 @@ class _AllreduceUpcastHookState:
|
|||
"""
|
||||
|
||||
ddp_weakref: Any
|
||||
upcast_stream: torch.cuda.Stream
|
||||
upcast_stream: torch.Stream
|
||||
wait_for_stream_enqueued: bool = False
|
||||
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ def _reducer_allreduce_and_upcast_hook(
|
|||
fut = reducer._run_allreduce_hook(bucket)
|
||||
ret_fut = torch.futures.Future()
|
||||
stream = hook_state.upcast_stream
|
||||
with torch.cuda.stream(stream):
|
||||
with torch.get_device_module().stream(stream):
|
||||
fut.wait()
|
||||
bucket.buffer().div_(process_group.size())
|
||||
ret_fut.set_result(bucket.buffer())
|
||||
|
|
@ -61,7 +61,7 @@ def _reducer_allreduce_and_upcast_hook(
|
|||
|
||||
# enqueue a callback to wait for this stream at end of backward
|
||||
def wait_for_stream_cb():
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
torch.accelerator.current_stream().wait_stream(stream)
|
||||
# Remove post-backward hooks since they are re-installed in next
|
||||
# iteration, similar to FSDP.
|
||||
# Parameters that don't require grad still needed to be casted since
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class _OptimizerHookState:
|
|||
|
||||
@dataclass
|
||||
class _OptimInBackwardHookState:
|
||||
optim_stream: torch.cuda.Stream
|
||||
optim_stream: torch.Stream
|
||||
wait_for_optim_stream_enqueued: bool
|
||||
|
||||
|
||||
|
|
@ -57,7 +57,7 @@ def _apply_optim_in_backward_hook(
|
|||
step for parameters after gradient communication has taken place.
|
||||
"""
|
||||
optim_in_bwd_state = _OptimInBackwardHookState(
|
||||
optim_stream=torch.cuda.Stream(),
|
||||
optim_stream=torch.Stream(),
|
||||
wait_for_optim_stream_enqueued=False,
|
||||
)
|
||||
|
||||
|
|
@ -72,7 +72,7 @@ def _apply_optim_in_backward_hook(
|
|||
reducer, process_group = ddp_inst.reducer, ddp_inst.process_group
|
||||
fut = reducer._run_allreduce_hook(bucket)
|
||||
optimizer_stream = optim_stream_state.optim_stream
|
||||
with torch.cuda.stream(optimizer_stream):
|
||||
with torch.get_device_module().stream(optimizer_stream):
|
||||
fut.wait()
|
||||
# Apply gradient division since C++ side only allreduces and does
|
||||
# not average. TODO: (rohan-varma) the div factor may be different
|
||||
|
|
@ -99,7 +99,9 @@ def _apply_optim_in_backward_hook(
|
|||
# enqueue a callback to wait for this optimizer stream at the end of
|
||||
# backward and set all DDP managed grads to None.
|
||||
def wait_for_optim_stream_callback():
|
||||
torch.cuda.current_stream().wait_stream(optim_stream_state.optim_stream)
|
||||
torch.accelerator.current_stream().wait_stream(
|
||||
optim_stream_state.optim_stream
|
||||
)
|
||||
# Set DDP managed grads to None
|
||||
for param in ddp_inst._get_data_parallel_params(ddp_inst.module):
|
||||
if hasattr(param, "_in_backward_optimizers"):
|
||||
|
|
|
|||
|
|
@ -40,8 +40,8 @@ def average_parameters(
|
|||
flat_params = torch.cat([p.data.reshape(-1) for p in params_it1])
|
||||
flat_params /= dist.get_world_size(group_to_use)
|
||||
# Make sure the allreduce will not conflict with any other ongoing process group.
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
if torch.accelerator.is_available():
|
||||
torch.accelerator.synchronize()
|
||||
dist.all_reduce(flat_params, group=group_to_use)
|
||||
|
||||
offset = 0
|
||||
|
|
|
|||
|
|
@ -535,10 +535,11 @@ def _override_module_mixed_precision(
|
|||
|
||||
|
||||
def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None:
|
||||
# FIXME record_stream doesn't work with non-cuda/mtia tensors
|
||||
# FIXME record_stream doesn't work with non-cuda/mtia/xpu tensors
|
||||
if tensor.device.type not in [
|
||||
"cuda",
|
||||
"mtia",
|
||||
"xpu",
|
||||
torch._C._get_privateuse1_backend_name(),
|
||||
]:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ def _is_supported_device(tensor: torch.Tensor) -> bool:
|
|||
"cpu",
|
||||
"hpu",
|
||||
"mtia",
|
||||
"xpu",
|
||||
torch._C._get_privateuse1_backend_name(),
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue