diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index a8141d8b2f1..1db89a580c5 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -12,7 +12,7 @@ from collections import defaultdict, deque from contextlib import contextmanager from dataclasses import dataclass, fields, is_dataclass from enum import auto, Enum -from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, Callable, List, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist @@ -639,11 +639,13 @@ class DistributedDataParallel(Module, Joinable): def __init__( self, module: Module, - device_ids: Optional[Union[List[int], List[device], List[Union[int, device]]]] = None, + device_ids: Optional[ + Union[List[int], List[device], List[Union[int, device]]] + ] = None, output_device: Optional[Union[int, device]] = None, dim: int = 0, broadcast_buffers: bool = True, - init_sync=True, + init_sync: bool = True, process_group: "Optional[ProcessGroup]" = None, bucket_cap_mb: float = 25, find_unused_parameters: bool = False, @@ -700,7 +702,7 @@ class DistributedDataParallel(Module, Joinable): self._delay_all_reduce_params = [] if hasattr(module, "_ddp_params_and_buffers_to_ignore"): - self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore) + self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore) # type: ignore[arg-type] else: self.parameters_to_ignore = set() if delay_all_reduce_named_params is not None: