mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add again the DistributedDataParallel types that were lost when distributed.pyi was removed.
This commit is contained in:
parent
387c993c3b
commit
53dbe12ea4
1 changed files with 24 additions and 13 deletions
|
|
@ -12,10 +12,21 @@ from collections import defaultdict, deque
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import device
|
||||
from torch._utils import _get_device_index
|
||||
from torch.autograd import Function, Variable
|
||||
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
|
||||
|
|
@ -633,23 +644,23 @@ class DistributedDataParallel(Module, Joinable):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
module,
|
||||
device_ids=None,
|
||||
output_device=None,
|
||||
dim=0,
|
||||
broadcast_buffers=True,
|
||||
module: Module,
|
||||
device_ids: Optional[Sequence[Union[int, device]]] = None,
|
||||
output_device: Optional[Union[int, device]] = None,
|
||||
dim: int = 0,
|
||||
broadcast_buffers: bool = True,
|
||||
init_sync=True,
|
||||
process_group=None,
|
||||
bucket_cap_mb=None,
|
||||
find_unused_parameters=False,
|
||||
check_reduction=False,
|
||||
gradient_as_bucket_view=False,
|
||||
static_graph=False,
|
||||
process_group: Optional[Any] = None,
|
||||
bucket_cap_mb: float = 25,
|
||||
find_unused_parameters: bool = False,
|
||||
check_reduction: bool = False,
|
||||
gradient_as_bucket_view: bool = False,
|
||||
static_graph: bool = False,
|
||||
delay_all_reduce_named_params=None,
|
||||
param_to_hook_all_reduce=None,
|
||||
mixed_precision: Optional[_MixedPrecision] = None,
|
||||
device_mesh=None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
Joinable.__init__(self)
|
||||
self.logger: Optional[dist.Logger] = None
|
||||
|
|
|
|||
Loading…
Reference in a new issue