mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-15 21:01:19 +00:00
Smdistributed trainer (#9798)
* Add a debug print * Adapt Trainer to use smdistributed if available * Forgotten parenthesis * Real check for sagemaker * Donforget to define device... * Woopsie, local)rank is defined differently * Update since local_rank has the proper value * Remove debug statement * More robust check for smdistributed * Quality * Deal with key not present error
This commit is contained in:
parent
897a24c869
commit
0d0efd3a0e
5 changed files with 85 additions and 29 deletions
|
|
@ -297,6 +297,20 @@ def is_pandas_available():
|
|||
return importlib.util.find_spec("pandas") is not None
|
||||
|
||||
|
||||
def is_sagemaker_distributed_available():
|
||||
# Get the sagemaker specific env variable.
|
||||
sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
|
||||
try:
|
||||
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
|
||||
sagemaker_params = json.loads(sagemaker_params)
|
||||
if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
|
||||
return False
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
# Lastly, check if the `smdistributed` module is present.
|
||||
return importlib.util.find_spec("smdistributed") is not None
|
||||
|
||||
|
||||
def torch_only_method(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
if not _torch_available:
|
||||
|
|
|
|||
|
|
@ -51,7 +51,14 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||
|
||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||
from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
||||
from .file_utils import (
|
||||
WEIGHTS_NAME,
|
||||
is_apex_available,
|
||||
is_datasets_available,
|
||||
is_in_notebook,
|
||||
is_sagemaker_distributed_available,
|
||||
is_torch_tpu_available,
|
||||
)
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
from .optimization import Adafactor, AdamW, get_scheduler
|
||||
|
|
@ -125,6 +132,11 @@ if is_fairscale_available():
|
|||
from fairscale.optim import OSS
|
||||
from fairscale.optim.grad_scaler import ShardedGradScaler
|
||||
|
||||
if is_sagemaker_distributed_available():
|
||||
import smdistributed.dataparallel.torch.distributed as dist
|
||||
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
|
||||
else:
|
||||
import torch.distributed as dist
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import optuna
|
||||
|
|
@ -428,9 +440,12 @@ class Trainer:
|
|||
if self.args.parallel_mode == ParallelMode.TPU:
|
||||
num_processes = xm.xrt_world_size()
|
||||
process_index = xm.get_ordinal()
|
||||
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
num_processes = torch.distributed.get_world_size()
|
||||
process_index = torch.distributed.get_rank()
|
||||
elif (
|
||||
self.args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
or self.args.parallel_mode == ParallelMode.SAGEMAKER_DISTRIBUTED
|
||||
):
|
||||
num_processes = dist.get_world_size()
|
||||
process_index = dist.get_rank()
|
||||
else:
|
||||
num_processes = 1
|
||||
process_index = 0
|
||||
|
|
@ -743,6 +758,8 @@ class Trainer:
|
|||
# Distributed training (should be after apex fp16 initialization)
|
||||
if self.sharded_dpp:
|
||||
model = ShardedDDP(model, self.optimizer)
|
||||
elif is_sagemaker_distributed_available():
|
||||
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
||||
elif self.args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
|
|
@ -767,14 +784,13 @@ class Trainer:
|
|||
|
||||
# Train!
|
||||
if is_torch_tpu_available():
|
||||
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
|
||||
world_size = xm.xrt_world_size()
|
||||
elif self.args.local_rank != -1:
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
total_train_batch_size = (
|
||||
self.args.train_batch_size
|
||||
* self.args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
|
||||
)
|
||||
world_size = 1
|
||||
|
||||
total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps * world_size
|
||||
num_examples = (
|
||||
self.num_examples(train_dataloader)
|
||||
if train_dataset_is_sized
|
||||
|
|
@ -1302,7 +1318,7 @@ class Trainer:
|
|||
if is_torch_tpu_available():
|
||||
return xm.is_master_ordinal(local=False)
|
||||
else:
|
||||
return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
|
||||
return self.args.local_rank == -1 or dist.get_rank() == 0
|
||||
|
||||
def save_model(self, output_dir: Optional[str] = None):
|
||||
"""
|
||||
|
|
@ -1542,7 +1558,7 @@ class Trainer:
|
|||
if is_torch_tpu_available():
|
||||
world_size = xm.xrt_world_size()
|
||||
elif self.args.local_rank != -1:
|
||||
world_size = torch.distributed.get_world_size()
|
||||
world_size = dist.get_world_size()
|
||||
world_size = max(1, world_size)
|
||||
|
||||
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
|
||||
|
|
|
|||
|
|
@ -28,10 +28,16 @@ from torch.utils.data.dataset import Dataset
|
|||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||
|
||||
from .file_utils import is_torch_tpu_available
|
||||
from .file_utils import is_sagemaker_distributed_available, is_torch_tpu_available
|
||||
from .utils import logging
|
||||
|
||||
|
||||
if is_sagemaker_distributed_available():
|
||||
import smdistributed.dataparallel.torch.distributed as dist
|
||||
else:
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
|
@ -121,8 +127,8 @@ def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int]
|
|||
try:
|
||||
if isinstance(tensor, (tuple, list)):
|
||||
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
||||
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensor)
|
||||
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output_tensors, tensor)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
|
|
@ -138,8 +144,8 @@ def distributed_broadcast_scalars(
|
|||
) -> torch.Tensor:
|
||||
try:
|
||||
tensorized_scalar = torch.tensor(scalars).cuda()
|
||||
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensorized_scalar)
|
||||
output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output_tensors, tensorized_scalar)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
|
|
@ -167,10 +173,10 @@ def torch_distributed_zero_first(local_rank: int):
|
|||
local_rank (:obj:`int`): The rank of the local process.
|
||||
"""
|
||||
if local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier()
|
||||
dist.barrier()
|
||||
yield
|
||||
if local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
dist.barrier()
|
||||
|
||||
|
||||
class SequentialDistributedSampler(Sampler):
|
||||
|
|
@ -185,13 +191,13 @@ class SequentialDistributedSampler(Sampler):
|
|||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None):
|
||||
if num_replicas is None:
|
||||
if not torch.distributed.is_available():
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = torch.distributed.get_world_size()
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not torch.distributed.is_available():
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = torch.distributed.get_rank()
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
|
|
@ -480,13 +486,13 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
|||
lengths: Optional[List[int]] = None,
|
||||
):
|
||||
if num_replicas is None:
|
||||
if not torch.distributed.is_available():
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = torch.distributed.get_world_size()
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not torch.distributed.is_available():
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = torch.distributed.get_rank()
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.num_replicas = num_replicas
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
|
|||
|
||||
import numpy as np
|
||||
|
||||
from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available
|
||||
from .file_utils import is_sagemaker_distributed_available, is_tf_available, is_torch_available, is_torch_tpu_available
|
||||
from .tokenization_utils_base import ExplicitEnum
|
||||
|
||||
|
||||
|
|
@ -187,6 +187,10 @@ def total_processes_number(local_rank):
|
|||
import torch_xla.core.xla_model as xm
|
||||
|
||||
return xm.xrt_world_size()
|
||||
elif is_sagemaker_distributed_available():
|
||||
import smdistributed.dataparallel.torch.distributed as dist
|
||||
|
||||
return dist.get_world_size()
|
||||
elif local_rank != -1 and is_torch_available():
|
||||
import torch
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,13 @@ from dataclasses import asdict, dataclass, field
|
|||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
||||
from .file_utils import (
|
||||
cached_property,
|
||||
is_sagemaker_distributed_available,
|
||||
is_torch_available,
|
||||
is_torch_tpu_available,
|
||||
torch_required,
|
||||
)
|
||||
from .trainer_utils import EvaluationStrategy, SchedulerType
|
||||
from .utils import logging
|
||||
|
||||
|
|
@ -493,6 +499,13 @@ class TrainingArguments:
|
|||
elif is_torch_tpu_available():
|
||||
device = xm.xla_device()
|
||||
self._n_gpu = 0
|
||||
elif is_sagemaker_distributed_available():
|
||||
import smdistributed.dataparallel.torch.distributed as dist
|
||||
|
||||
dist.init_process_group()
|
||||
self.local_rank = dist.get_local_rank()
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
self._n_gpu = 1
|
||||
elif self.local_rank == -1:
|
||||
# if n_gpu is > 1 we'll use nn.DataParallel.
|
||||
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||
|
|
@ -566,6 +579,8 @@ class TrainingArguments:
|
|||
"""
|
||||
if is_torch_tpu_available():
|
||||
return ParallelMode.TPU
|
||||
elif is_sagemaker_distributed_available():
|
||||
return ParallelMode.SAGEMAKER_DISTRIBUTED
|
||||
elif self.local_rank != -1:
|
||||
return ParallelMode.DISTRIBUTED
|
||||
elif self.n_gpu > 1:
|
||||
|
|
@ -607,4 +622,5 @@ class ParallelMode(Enum):
|
|||
NOT_PARALLEL = "not_parallel"
|
||||
NOT_DISTRIBUTED = "not_distributed"
|
||||
DISTRIBUTED = "distributed"
|
||||
SAGEMAKER_DISTRIBUTED = "sm_distributed"
|
||||
TPU = "tpu"
|
||||
|
|
|
|||
Loading…
Reference in a new issue