Add type to init_sync and address linter errors

This commit is contained in:
Mauricio Villegas 2025-02-10 06:20:23 +01:00
parent 0673a3d876
commit a9e65284d4

View file

@ -12,7 +12,7 @@ from collections import defaultdict, deque
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, fields, is_dataclass from dataclasses import dataclass, fields, is_dataclass
from enum import auto, Enum from enum import auto, Enum
from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING, Union from typing import Any, Callable, List, Optional, TYPE_CHECKING, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -639,11 +639,13 @@ class DistributedDataParallel(Module, Joinable):
def __init__( def __init__(
self, self,
module: Module, module: Module,
device_ids: Optional[Union[List[int], List[device], List[Union[int, device]]]] = None, device_ids: Optional[
Union[List[int], List[device], List[Union[int, device]]]
] = None,
output_device: Optional[Union[int, device]] = None, output_device: Optional[Union[int, device]] = None,
dim: int = 0, dim: int = 0,
broadcast_buffers: bool = True, broadcast_buffers: bool = True,
init_sync=True, init_sync: bool = True,
process_group: "Optional[ProcessGroup]" = None, process_group: "Optional[ProcessGroup]" = None,
bucket_cap_mb: float = 25, bucket_cap_mb: float = 25,
find_unused_parameters: bool = False, find_unused_parameters: bool = False,
@ -700,7 +702,7 @@ class DistributedDataParallel(Module, Joinable):
self._delay_all_reduce_params = [] self._delay_all_reduce_params = []
if hasattr(module, "_ddp_params_and_buffers_to_ignore"): if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore) self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore) # type: ignore[arg-type]
else: else:
self.parameters_to_ignore = set() self.parameters_to_ignore = set()
if delay_all_reduce_named_params is not None: if delay_all_reduce_named_params is not None: