diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 2a6d66f2209..1f95f671b34 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -40,6 +40,7 @@ if dist.is_available(): from torch.distributed.distributed_c10d import ( _get_default_group, _rank_not_in_group, + ProcessGroup, ReduceOp, ) from torch.distributed.utils import ( @@ -642,6 +643,8 @@ class DistributedDataParallel(Module, Joinable): # used to track whether the given thread is inside ddp forward for torchdynamo purposes _active_ddp_module: Optional["DistributedDataParallel"] = None + reducer: dist.Reducer + def __init__( self, module: Module, @@ -650,7 +653,7 @@ class DistributedDataParallel(Module, Joinable): dim: int = 0, broadcast_buffers: bool = True, init_sync=True, - process_group: Optional[Any] = None, + process_group: "Optional[ProcessGroup]" = None, bucket_cap_mb: float = 25, find_unused_parameters: bool = False, check_reduction: bool = False, @@ -680,7 +683,7 @@ class DistributedDataParallel(Module, Joinable): elif process_group is None and device_mesh is None: self.process_group = _get_default_group() elif device_mesh is None: - self.process_group = process_group + self.process_group = process_group # type: ignore[assignment] else: if device_mesh.ndim != 1: raise RuntimeError( @@ -814,7 +817,7 @@ class DistributedDataParallel(Module, Joinable): if bucket_cap_mb is None: # default case (bucket cap is 25 MiB) bucket_cap_mb = 25 - self.bucket_bytes_cap_default = True + self.bucket_bytes_cap_default: bool = True else: self.bucket_bytes_cap_default = False self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) @@ -926,7 +929,7 @@ class DistributedDataParallel(Module, Joinable): ) if self._use_python_reducer: torch._inductor.config._fuse_ddp_communication = True - torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb + torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb # type: ignore[assignment] # Directly adding this to the trace rule will disturb the users # who are using DDPOptimizer. torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST.add( @@ -1095,6 +1098,7 @@ class DistributedDataParallel(Module, Joinable): # Do not cast DDP ignored parameters. if hasattr(param, "_ddp_ignored") and param._ddp_ignored: continue + assert hasattr(param, "_mp_param") _alloc_storage(param._mp_param, param.size()) # copy() implicitly casts to low precision with torch.no_grad():