From 7a79dbdc05351cdc704e57ae0d8574deb66f911c Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Mon, 30 Sep 2024 10:03:26 +0200 Subject: [PATCH] Add type for device_mesh --- torch/nn/parallel/distributed.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 1f95f671b34..7d6c14ebba8 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -56,6 +56,7 @@ if dist.rpc.is_available(): from torch.distributed.rpc import RRef if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh from torch.utils.hooks import RemovableHandle @@ -662,7 +663,7 @@ class DistributedDataParallel(Module, Joinable): delay_all_reduce_named_params=None, param_to_hook_all_reduce=None, mixed_precision: Optional[_MixedPrecision] = None, - device_mesh=None, + device_mesh: "Optional[DeviceMesh]" = None, ) -> None: super().__init__() Joinable.__init__(self)