mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add type for device_mesh
This commit is contained in:
parent
45cd195f24
commit
7a79dbdc05
1 changed files with 2 additions and 1 deletions
|
|
@ -56,6 +56,7 @@ if dist.rpc.is_available():
|
|||
from torch.distributed.rpc import RRef
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
|
||||
|
|
@ -662,7 +663,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
delay_all_reduce_named_params=None,
|
||||
param_to_hook_all_reduce=None,
|
||||
mixed_precision: Optional[_MixedPrecision] = None,
|
||||
device_mesh=None,
|
||||
device_mesh: "Optional[DeviceMesh]" = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
Joinable.__init__(self)
|
||||
|
|
|
|||
Loading…
Reference in a new issue