diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 109a7f44829..2a6d66f2209 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -12,10 +12,21 @@ from collections import defaultdict, deque from contextlib import contextmanager from dataclasses import dataclass, fields, is_dataclass from enum import auto, Enum -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import ( + Any, + Callable, + List, + Optional, + Sequence, + Tuple, + Type, + TYPE_CHECKING, + Union, +) import torch import torch.distributed as dist +from torch import device from torch._utils import _get_device_index from torch.autograd import Function, Variable from torch.distributed.algorithms.join import Join, Joinable, JoinHook @@ -633,23 +644,23 @@ class DistributedDataParallel(Module, Joinable): def __init__( self, - module, - device_ids=None, - output_device=None, - dim=0, - broadcast_buffers=True, + module: Module, + device_ids: Optional[Sequence[Union[int, device]]] = None, + output_device: Optional[Union[int, device]] = None, + dim: int = 0, + broadcast_buffers: bool = True, init_sync=True, - process_group=None, - bucket_cap_mb=None, - find_unused_parameters=False, - check_reduction=False, - gradient_as_bucket_view=False, - static_graph=False, + process_group: Optional[Any] = None, + bucket_cap_mb: float = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, delay_all_reduce_named_params=None, param_to_hook_all_reduce=None, mixed_precision: Optional[_MixedPrecision] = None, device_mesh=None, - ): + ) -> None: super().__init__() Joinable.__init__(self) self.logger: Optional[dist.Logger] = None