Change device_ids type from Sequence to List

This commit is contained in:
Mauricio Villegas 2024-10-21 09:48:06 +02:00
parent 6061dd7ad6
commit 2e90a1aaf3

View file

@ -12,17 +12,7 @@ from collections import defaultdict, deque
from contextlib import contextmanager
from dataclasses import dataclass, fields, is_dataclass
from enum import auto, Enum
from typing import (
Any,
Callable,
List,
Optional,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
import torch
import torch.distributed as dist
@ -649,7 +639,7 @@ class DistributedDataParallel(Module, Joinable):
def __init__(
self,
module: Module,
device_ids: Optional[Sequence[Union[int, device]]] = None,
device_ids: Optional[List[Union[int, device]]] = None,
output_device: Optional[Union[int, device]] = None,
dim: int = 0,
broadcast_buffers: bool = True,