Add again the DistributedDataParallel types that were lost when distributed.pyi was removed.

This commit is contained in:
Mauricio Villegas 2024-09-27 06:03:51 +02:00
parent 387c993c3b
commit 53dbe12ea4

View file

@ -12,10 +12,21 @@ from collections import defaultdict, deque
from contextlib import contextmanager
from dataclasses import dataclass, fields, is_dataclass
from enum import auto, Enum
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import (
Any,
Callable,
List,
Optional,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
import torch
import torch.distributed as dist
from torch import device
from torch._utils import _get_device_index
from torch.autograd import Function, Variable
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
@ -633,23 +644,23 @@ class DistributedDataParallel(Module, Joinable):
def __init__(
self,
module,
device_ids=None,
output_device=None,
dim=0,
broadcast_buffers=True,
module: Module,
device_ids: Optional[Sequence[Union[int, device]]] = None,
output_device: Optional[Union[int, device]] = None,
dim: int = 0,
broadcast_buffers: bool = True,
init_sync=True,
process_group=None,
bucket_cap_mb=None,
find_unused_parameters=False,
check_reduction=False,
gradient_as_bucket_view=False,
static_graph=False,
process_group: Optional[Any] = None,
bucket_cap_mb: float = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
delay_all_reduce_named_params=None,
param_to_hook_all_reduce=None,
mixed_precision: Optional[_MixedPrecision] = None,
device_mesh=None,
):
) -> None:
super().__init__()
Joinable.__init__(self)
self.logger: Optional[dist.Logger] = None