Add type for device_mesh

This commit is contained in:
Mauricio Villegas 2024-09-30 10:03:26 +02:00
parent 45cd195f24
commit 7a79dbdc05

View file

@ -56,6 +56,7 @@ if dist.rpc.is_available():
from torch.distributed.rpc import RRef
if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
from torch.utils.hooks import RemovableHandle
@ -662,7 +663,7 @@ class DistributedDataParallel(Module, Joinable):
delay_all_reduce_named_params=None,
param_to_hook_all_reduce=None,
mixed_precision: Optional[_MixedPrecision] = None,
device_mesh=None,
device_mesh: "Optional[DeviceMesh]" = None,
) -> None:
super().__init__()
Joinable.__init__(self)