mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add type to init_sync and address linter errors
This commit is contained in:
parent
0673a3d876
commit
a9e65284d4
1 changed files with 6 additions and 4 deletions
|
|
@ -12,7 +12,7 @@ from collections import defaultdict, deque
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -639,11 +639,13 @@ class DistributedDataParallel(Module, Joinable):
|
|||
def __init__(
|
||||
self,
|
||||
module: Module,
|
||||
device_ids: Optional[Union[List[int], List[device], List[Union[int, device]]]] = None,
|
||||
device_ids: Optional[
|
||||
Union[List[int], List[device], List[Union[int, device]]]
|
||||
] = None,
|
||||
output_device: Optional[Union[int, device]] = None,
|
||||
dim: int = 0,
|
||||
broadcast_buffers: bool = True,
|
||||
init_sync=True,
|
||||
init_sync: bool = True,
|
||||
process_group: "Optional[ProcessGroup]" = None,
|
||||
bucket_cap_mb: float = 25,
|
||||
find_unused_parameters: bool = False,
|
||||
|
|
@ -700,7 +702,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
|
||||
self._delay_all_reduce_params = []
|
||||
if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
|
||||
self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore)
|
||||
self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore) # type: ignore[arg-type]
|
||||
else:
|
||||
self.parameters_to_ignore = set()
|
||||
if delay_all_reduce_named_params is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue