mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add again the DistributedDataParallel types that were lost when distributed.pyi was removed.
This commit is contained in:
parent
387c993c3b
commit
53dbe12ea4
1 changed files with 24 additions and 13 deletions
|
|
@ -12,10 +12,21 @@ from collections import defaultdict, deque
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, fields, is_dataclass
|
from dataclasses import dataclass, fields, is_dataclass
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from torch import device
|
||||||
from torch._utils import _get_device_index
|
from torch._utils import _get_device_index
|
||||||
from torch.autograd import Function, Variable
|
from torch.autograd import Function, Variable
|
||||||
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
|
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
|
||||||
|
|
@ -633,23 +644,23 @@ class DistributedDataParallel(Module, Joinable):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
module,
|
module: Module,
|
||||||
device_ids=None,
|
device_ids: Optional[Sequence[Union[int, device]]] = None,
|
||||||
output_device=None,
|
output_device: Optional[Union[int, device]] = None,
|
||||||
dim=0,
|
dim: int = 0,
|
||||||
broadcast_buffers=True,
|
broadcast_buffers: bool = True,
|
||||||
init_sync=True,
|
init_sync=True,
|
||||||
process_group=None,
|
process_group: Optional[Any] = None,
|
||||||
bucket_cap_mb=None,
|
bucket_cap_mb: float = 25,
|
||||||
find_unused_parameters=False,
|
find_unused_parameters: bool = False,
|
||||||
check_reduction=False,
|
check_reduction: bool = False,
|
||||||
gradient_as_bucket_view=False,
|
gradient_as_bucket_view: bool = False,
|
||||||
static_graph=False,
|
static_graph: bool = False,
|
||||||
delay_all_reduce_named_params=None,
|
delay_all_reduce_named_params=None,
|
||||||
param_to_hook_all_reduce=None,
|
param_to_hook_all_reduce=None,
|
||||||
mixed_precision: Optional[_MixedPrecision] = None,
|
mixed_precision: Optional[_MixedPrecision] = None,
|
||||||
device_mesh=None,
|
device_mesh=None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
Joinable.__init__(self)
|
Joinable.__init__(self)
|
||||||
self.logger: Optional[dist.Logger] = None
|
self.logger: Optional[dist.Logger] = None
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue