From a9e65284d4dfceda8814276d796dbffb032574d4 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Mon, 10 Feb 2025 06:20:23 +0100 Subject: [PATCH] Add type to init_sync and address linter errors --- torch/nn/parallel/distributed.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index a8141d8b2f1..1db89a580c5 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -12,7 +12,7 @@ from collections import defaultdict, deque from contextlib import contextmanager from dataclasses import dataclass, fields, is_dataclass from enum import auto, Enum -from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, Callable, List, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist @@ -639,11 +639,13 @@ class DistributedDataParallel(Module, Joinable): def __init__( self, module: Module, - device_ids: Optional[Union[List[int], List[device], List[Union[int, device]]]] = None, + device_ids: Optional[ + Union[List[int], List[device], List[Union[int, device]]] + ] = None, output_device: Optional[Union[int, device]] = None, dim: int = 0, broadcast_buffers: bool = True, - init_sync=True, + init_sync: bool = True, process_group: "Optional[ProcessGroup]" = None, bucket_cap_mb: float = 25, find_unused_parameters: bool = False, @@ -700,7 +702,7 @@ class DistributedDataParallel(Module, Joinable): self._delay_all_reduce_params = [] if hasattr(module, "_ddp_params_and_buffers_to_ignore"): - self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore) + self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore) # type: ignore[arg-type] else: self.parameters_to_ignore = set() if delay_all_reduce_named_params is not None: