diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 1f95f671b34..7d6c14ebba8 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -56,6 +56,7 @@ if dist.rpc.is_available(): from torch.distributed.rpc import RRef if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh from torch.utils.hooks import RemovableHandle @@ -662,7 +663,7 @@ class DistributedDataParallel(Module, Joinable): delay_all_reduce_named_params=None, param_to_hook_all_reduce=None, mixed_precision: Optional[_MixedPrecision] = None, - device_mesh=None, + device_mesh: "Optional[DeviceMesh]" = None, ) -> None: super().__init__() Joinable.__init__(self)