diff --git a/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py index d5b256c97df..8dab7b15aef 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py @@ -17,7 +17,7 @@ class _AllreduceUpcastHookState: """ ddp_weakref: Any - upcast_stream: torch.cuda.Stream + upcast_stream: torch.Stream wait_for_stream_enqueued: bool = False @@ -46,7 +46,7 @@ def _reducer_allreduce_and_upcast_hook( fut = reducer._run_allreduce_hook(bucket) ret_fut = torch.futures.Future() stream = hook_state.upcast_stream - with torch.cuda.stream(stream): + with torch.get_device_module().stream(stream): fut.wait() bucket.buffer().div_(process_group.size()) ret_fut.set_result(bucket.buffer()) @@ -61,7 +61,7 @@ def _reducer_allreduce_and_upcast_hook( # enqueue a callback to wait for this stream at end of backward def wait_for_stream_cb(): - torch.cuda.current_stream().wait_stream(stream) + torch.accelerator.current_stream().wait_stream(stream) # Remove post-backward hooks since they are re-installed in next # iteration, similar to FSDP. # Parameters that don't require grad still needed to be casted since diff --git a/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py index 5ae242b04a9..5de3b25d2ba 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py @@ -41,7 +41,7 @@ class _OptimizerHookState: @dataclass class _OptimInBackwardHookState: - optim_stream: torch.cuda.Stream + optim_stream: torch.Stream wait_for_optim_stream_enqueued: bool @@ -57,7 +57,7 @@ def _apply_optim_in_backward_hook( step for parameters after gradient communication has taken place. """ optim_in_bwd_state = _OptimInBackwardHookState( - optim_stream=torch.cuda.Stream(), + optim_stream=torch.Stream(), wait_for_optim_stream_enqueued=False, ) @@ -72,7 +72,7 @@ def _apply_optim_in_backward_hook( reducer, process_group = ddp_inst.reducer, ddp_inst.process_group fut = reducer._run_allreduce_hook(bucket) optimizer_stream = optim_stream_state.optim_stream - with torch.cuda.stream(optimizer_stream): + with torch.get_device_module().stream(optimizer_stream): fut.wait() # Apply gradient division since C++ side only allreduces and does # not average. TODO: (rohan-varma) the div factor may be different @@ -99,7 +99,9 @@ def _apply_optim_in_backward_hook( # enqueue a callback to wait for this optimizer stream at the end of # backward and set all DDP managed grads to None. def wait_for_optim_stream_callback(): - torch.cuda.current_stream().wait_stream(optim_stream_state.optim_stream) + torch.accelerator.current_stream().wait_stream( + optim_stream_state.optim_stream + ) # Set DDP managed grads to None for param in ddp_inst._get_data_parallel_params(ddp_inst.module): if hasattr(param, "_in_backward_optimizers"): diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py index 20f75152f0b..c03a62f0620 100644 --- a/torch/distributed/algorithms/model_averaging/utils.py +++ b/torch/distributed/algorithms/model_averaging/utils.py @@ -40,8 +40,8 @@ def average_parameters( flat_params = torch.cat([p.data.reshape(-1) for p in params_it1]) flat_params /= dist.get_world_size(group_to_use) # Make sure the allreduce will not conflict with any other ongoing process group. - if torch.cuda.is_available(): - torch.cuda.synchronize() + if torch.accelerator.is_available(): + torch.accelerator.synchronize() dist.all_reduce(flat_params, group=group_to_use) offset = 0 diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index cc3194bb463..22e2659e9d8 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -535,10 +535,11 @@ def _override_module_mixed_precision( def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None: - # FIXME record_stream doesn't work with non-cuda/mtia tensors + # FIXME record_stream doesn't work with non-cuda/mtia/xpu tensors if tensor.device.type not in [ "cuda", "mtia", + "xpu", torch._C._get_privateuse1_backend_name(), ]: return diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index 988ecb3533f..8f62c3e6772 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -22,6 +22,7 @@ def _is_supported_device(tensor: torch.Tensor) -> bool: "cpu", "hpu", "mtia", + "xpu", torch._C._get_privateuse1_backend_name(), )