diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index f365e469c..fd22962f0 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -297,6 +297,20 @@ def is_pandas_available(): return importlib.util.find_spec("pandas") is not None +def is_sagemaker_distributed_available(): + # Get the sagemaker specific env variable. + sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}") + try: + # Parse it and check the field "sagemaker_distributed_dataparallel_enabled". + sagemaker_params = json.loads(sagemaker_params) + if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False): + return False + except json.JSONDecodeError: + return False + # Lastly, check if the `smdistributed` module is present. + return importlib.util.find_spec("smdistributed") is not None + + def torch_only_method(fn): def wrapper(*args, **kwargs): if not _torch_available: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 856b9e1af..e64a1e3fb 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -51,7 +51,14 @@ from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, SequentialSampler from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator -from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available +from .file_utils import ( + WEIGHTS_NAME, + is_apex_available, + is_datasets_available, + is_in_notebook, + is_sagemaker_distributed_available, + is_torch_tpu_available, +) from .modeling_utils import PreTrainedModel from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING from .optimization import Adafactor, AdamW, get_scheduler @@ -125,6 +132,11 @@ if is_fairscale_available(): from fairscale.optim import OSS from fairscale.optim.grad_scaler import ShardedGradScaler +if is_sagemaker_distributed_available(): + import smdistributed.dataparallel.torch.distributed as dist + from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP +else: + import torch.distributed as dist if TYPE_CHECKING: import optuna @@ -428,9 +440,12 @@ class Trainer: if self.args.parallel_mode == ParallelMode.TPU: num_processes = xm.xrt_world_size() process_index = xm.get_ordinal() - elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: - num_processes = torch.distributed.get_world_size() - process_index = torch.distributed.get_rank() + elif ( + self.args.parallel_mode == ParallelMode.DISTRIBUTED + or self.args.parallel_mode == ParallelMode.SAGEMAKER_DISTRIBUTED + ): + num_processes = dist.get_world_size() + process_index = dist.get_rank() else: num_processes = 1 process_index = 0 @@ -743,6 +758,8 @@ class Trainer: # Distributed training (should be after apex fp16 initialization) if self.sharded_dpp: model = ShardedDDP(model, self.optimizer) + elif is_sagemaker_distributed_available(): + model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False) elif self.args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, @@ -767,14 +784,13 @@ class Trainer: # Train! if is_torch_tpu_available(): - total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() + world_size = xm.xrt_world_size() + elif self.args.local_rank != -1: + world_size = dist.get_world_size() else: - total_train_batch_size = ( - self.args.train_batch_size - * self.args.gradient_accumulation_steps - * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1) - ) + world_size = 1 + total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps * world_size num_examples = ( self.num_examples(train_dataloader) if train_dataset_is_sized @@ -1302,7 +1318,7 @@ class Trainer: if is_torch_tpu_available(): return xm.is_master_ordinal(local=False) else: - return self.args.local_rank == -1 or torch.distributed.get_rank() == 0 + return self.args.local_rank == -1 or dist.get_rank() == 0 def save_model(self, output_dir: Optional[str] = None): """ @@ -1542,7 +1558,7 @@ class Trainer: if is_torch_tpu_available(): world_size = xm.xrt_world_size() elif self.args.local_rank != -1: - world_size = torch.distributed.get_world_size() + world_size = dist.get_world_size() world_size = max(1, world_size) eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index db7a08008..1a406eb00 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -28,10 +28,16 @@ from torch.utils.data.dataset import Dataset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, Sampler -from .file_utils import is_torch_tpu_available +from .file_utils import is_sagemaker_distributed_available, is_torch_tpu_available from .utils import logging +if is_sagemaker_distributed_available(): + import smdistributed.dataparallel.torch.distributed as dist +else: + import torch.distributed as dist + + if is_torch_tpu_available(): import torch_xla.core.xla_model as xm @@ -121,8 +127,8 @@ def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] try: if isinstance(tensor, (tuple, list)): return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor) - output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] - torch.distributed.all_gather(output_tensors, tensor) + output_tensors = [tensor.clone() for _ in range(dist.get_world_size())] + dist.all_gather(output_tensors, tensor) concat = torch.cat(output_tensors, dim=0) # truncate the dummy elements added by SequentialDistributedSampler @@ -138,8 +144,8 @@ def distributed_broadcast_scalars( ) -> torch.Tensor: try: tensorized_scalar = torch.tensor(scalars).cuda() - output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())] - torch.distributed.all_gather(output_tensors, tensorized_scalar) + output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())] + dist.all_gather(output_tensors, tensorized_scalar) concat = torch.cat(output_tensors, dim=0) # truncate the dummy elements added by SequentialDistributedSampler @@ -167,10 +173,10 @@ def torch_distributed_zero_first(local_rank: int): local_rank (:obj:`int`): The rank of the local process. """ if local_rank not in [-1, 0]: - torch.distributed.barrier() + dist.barrier() yield if local_rank == 0: - torch.distributed.barrier() + dist.barrier() class SequentialDistributedSampler(Sampler): @@ -185,13 +191,13 @@ class SequentialDistributedSampler(Sampler): def __init__(self, dataset, num_replicas=None, rank=None): if num_replicas is None: - if not torch.distributed.is_available(): + if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") - num_replicas = torch.distributed.get_world_size() + num_replicas = dist.get_world_size() if rank is None: - if not torch.distributed.is_available(): + if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") - rank = torch.distributed.get_rank() + rank = dist.get_rank() self.dataset = dataset self.num_replicas = num_replicas self.rank = rank @@ -480,13 +486,13 @@ class DistributedLengthGroupedSampler(DistributedSampler): lengths: Optional[List[int]] = None, ): if num_replicas is None: - if not torch.distributed.is_available(): + if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") - num_replicas = torch.distributed.get_world_size() + num_replicas = dist.get_world_size() if rank is None: - if not torch.distributed.is_available(): + if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") - rank = torch.distributed.get_rank() + rank = dist.get_rank() self.dataset = dataset self.batch_size = batch_size self.num_replicas = num_replicas diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index e032ede0e..2f11cda19 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -25,7 +25,7 @@ from typing import Any, Dict, NamedTuple, Optional, Tuple, Union import numpy as np -from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available +from .file_utils import is_sagemaker_distributed_available, is_tf_available, is_torch_available, is_torch_tpu_available from .tokenization_utils_base import ExplicitEnum @@ -187,6 +187,10 @@ def total_processes_number(local_rank): import torch_xla.core.xla_model as xm return xm.xrt_world_size() + elif is_sagemaker_distributed_available(): + import smdistributed.dataparallel.torch.distributed as dist + + return dist.get_world_size() elif local_rank != -1 and is_torch_available(): import torch diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5a7aa99bc..50e7a00ca 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -18,7 +18,13 @@ from dataclasses import asdict, dataclass, field from enum import Enum from typing import Any, Dict, List, Optional -from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required +from .file_utils import ( + cached_property, + is_sagemaker_distributed_available, + is_torch_available, + is_torch_tpu_available, + torch_required, +) from .trainer_utils import EvaluationStrategy, SchedulerType from .utils import logging @@ -493,6 +499,13 @@ class TrainingArguments: elif is_torch_tpu_available(): device = xm.xla_device() self._n_gpu = 0 + elif is_sagemaker_distributed_available(): + import smdistributed.dataparallel.torch.distributed as dist + + dist.init_process_group() + self.local_rank = dist.get_local_rank() + device = torch.device("cuda", self.local_rank) + self._n_gpu = 1 elif self.local_rank == -1: # if n_gpu is > 1 we'll use nn.DataParallel. # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` @@ -566,6 +579,8 @@ class TrainingArguments: """ if is_torch_tpu_available(): return ParallelMode.TPU + elif is_sagemaker_distributed_available(): + return ParallelMode.SAGEMAKER_DISTRIBUTED elif self.local_rank != -1: return ParallelMode.DISTRIBUTED elif self.n_gpu > 1: @@ -607,4 +622,5 @@ class ParallelMode(Enum): NOT_PARALLEL = "not_parallel" NOT_DISTRIBUTED = "not_distributed" DISTRIBUTED = "distributed" + SAGEMAKER_DISTRIBUTED = "sm_distributed" TPU = "tpu"