mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add type to init_sync and address linter errors
This commit is contained in:
parent
0673a3d876
commit
a9e65284d4
1 changed files with 6 additions and 4 deletions
|
|
@ -12,7 +12,7 @@ from collections import defaultdict, deque
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, fields, is_dataclass
|
from dataclasses import dataclass, fields, is_dataclass
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
@ -639,11 +639,13 @@ class DistributedDataParallel(Module, Joinable):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
module: Module,
|
module: Module,
|
||||||
device_ids: Optional[Union[List[int], List[device], List[Union[int, device]]]] = None,
|
device_ids: Optional[
|
||||||
|
Union[List[int], List[device], List[Union[int, device]]]
|
||||||
|
] = None,
|
||||||
output_device: Optional[Union[int, device]] = None,
|
output_device: Optional[Union[int, device]] = None,
|
||||||
dim: int = 0,
|
dim: int = 0,
|
||||||
broadcast_buffers: bool = True,
|
broadcast_buffers: bool = True,
|
||||||
init_sync=True,
|
init_sync: bool = True,
|
||||||
process_group: "Optional[ProcessGroup]" = None,
|
process_group: "Optional[ProcessGroup]" = None,
|
||||||
bucket_cap_mb: float = 25,
|
bucket_cap_mb: float = 25,
|
||||||
find_unused_parameters: bool = False,
|
find_unused_parameters: bool = False,
|
||||||
|
|
@ -700,7 +702,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||||
|
|
||||||
self._delay_all_reduce_params = []
|
self._delay_all_reduce_params = []
|
||||||
if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
|
if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
|
||||||
self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore)
|
self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
self.parameters_to_ignore = set()
|
self.parameters_to_ignore = set()
|
||||||
if delay_all_reduce_named_params is not None:
|
if delay_all_reduce_named_params is not None:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue