mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[Deepspeed] add support for bf16 mode (#14569)
* [WIP] add support for bf16 mode * prep for bf16 * prep for bf16 * fix; zero2/bf16 is ok * check bf16 is available * test fixes * enable zero3_bf16 * config files * docs * split stage_dtype; merge back to non-dtype-specific config file * fix doc * cleanup * cleanup * bfloat16 => bf16 to match the PR changes * s/zero_gather_fp16_weights_on_model_save/zero_gather_16bit_weights_on_model_save/; s/save_fp16_model/save_16bit_model/ * test fixes/skipping * move * fix * Update docs/source/main_classes/deepspeed.mdx Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * backticks * cleanup * cleanup * cleanup * new version * add note about grad accum in bf16 Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
c1f209dadd
commit
580dd87c55
10 changed files with 214 additions and 113 deletions
|
|
@ -367,7 +367,7 @@ cat <<'EOT' > ds_config_zero3.json
|
|||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_fp16_weights_on_model_save": true
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
|
||||
"gradient_accumulation_steps": "auto",
|
||||
|
|
@ -652,7 +652,7 @@ The following is an example of configuration for ZeRO stage 3:
|
|||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_fp16_weights_on_model_save": true
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
@ -691,7 +691,7 @@ The following configuration values depend on the model's hidden size:
|
|||
therefore set these values to `auto` and the [`Trainer`] will automatically assign the recommended
|
||||
values. But, of course, feel free to set these explicitly as well.
|
||||
|
||||
`stage3_gather_fp16_weights_on_model_save` enables model fp16 weights consolidation when model gets saved. With large
|
||||
`stage3_gather_16bit_weights_on_model_save` enables model fp16 weights consolidation when model gets saved. With large
|
||||
models and multiple GPUs this is an expensive operation both in terms of memory and speed. It's currently required if
|
||||
you plan to resume the training. Watch out for future updates that will remove this limitation and make things more
|
||||
flexible.
|
||||
|
|
@ -760,8 +760,8 @@ The following configuration example enables NVMe to offload both optimizer state
|
|||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_fp16_weights_on_model_save": true
|
||||
}
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -966,7 +966,7 @@ Here is a full ZeRO-3 auto-configuration file `ds_config_zero3.json`:
|
|||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_fp16_weights_on_model_save": true
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
|
||||
"gradient_accumulation_steps": "auto",
|
||||
|
|
@ -1029,7 +1029,7 @@ values look like, but we highly recommend using the one with multiple `auto` set
|
|||
"stage3_param_persistence_threshold": 1e4,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_fp16_weights_on_model_save": true
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
|
||||
"steps_per_print": 2000,
|
||||
|
|
@ -1232,6 +1232,7 @@ the much more efficient tf32 format for some operations, but the results will st
|
|||
benchmarks, please, see [TensorFloat-32(TF32) on Ampere devices](https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices). The document includes
|
||||
instructions on how to disable this automatic conversion if for some reason you prefer not to use it.
|
||||
|
||||
With the 🤗 Trainer you can use `--tf32` to enable it, or disable it with `--tf32 0` or `--no_tf32`. By default the PyTorch default is used.
|
||||
|
||||
|
||||
|
||||
|
|
@ -1241,7 +1242,9 @@ instructions on how to disable this automatic conversion if for some reason you
|
|||
|
||||
You can use automatic mixed precision with either a pytorch-like AMP way or the apex-like way:
|
||||
|
||||
To configure pytorch AMP-like mode set:
|
||||
### fp16
|
||||
|
||||
To configure pytorch AMP-like mode with fp16 (float16) set:
|
||||
|
||||
```json
|
||||
{
|
||||
|
|
@ -1259,7 +1262,7 @@ To configure pytorch AMP-like mode set:
|
|||
and the [`Trainer`] will automatically enable or disable it based on the value of
|
||||
`args.fp16_backend`. The rest of config values are up to you.
|
||||
|
||||
This mode gets enabled when `--fp16 --fp16_backend amp` command line args are passed.
|
||||
This mode gets enabled when `--fp16 --fp16_backend amp` or `--fp16_full_eval` command line args are passed.
|
||||
|
||||
You can also enable/disable this mode explicitly:
|
||||
|
||||
|
|
@ -1281,6 +1284,43 @@ configuration.
|
|||
|
||||
Here is the [documentation](https://www.deepspeed.ai/docs/config-json/#fp16-training-options).
|
||||
|
||||
### bf16
|
||||
|
||||
If bf16 (bfloat16) is desired instead of fp16 then the following configuration section is to be used:
|
||||
|
||||
```json
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
bf16 has the same dynamic range as fp32 and thus doesn't require loss scaling.
|
||||
|
||||
This mode gets enabled when `--bf16` or `--bf16_full_eval` command line args are passed.
|
||||
|
||||
You can also enable/disable this mode explicitly:
|
||||
|
||||
```json
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
As of `deepspeed==0.6.0` the bf16 support is new and experimental.
|
||||
|
||||
If you use [gradient accumulation](#gradient-accumulation) with bf16-enabled, you need to be aware that it'll accumulate gradients in bf16, which may not be what you want due to this format's low precision, as it may lead to a lossy accumulation.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
### apex
|
||||
|
||||
To configure apex AMP-like mode set:
|
||||
|
||||
```json
|
||||
|
|
@ -1411,15 +1451,14 @@ When a model is saved under ZeRO-2, you end up having the normal `pytorch_model.
|
|||
they are only the fp16 version of the weights.
|
||||
|
||||
Under ZeRO-3, things are much more complicated, since the model weights are partitioned out over multiple GPUs,
|
||||
therefore `"stage3_gather_fp16_weights_on_model_save": true` is required to get the `Trainer` to save the fp16
|
||||
version of the weights. If this setting is `False` ``pytorch_model.bin` won't be created. This is because by default DeepSpeed's `state_dict` contains a placeholder and not the real weights. If we were to save this `state_dict`` it
|
||||
won't be possible to load it back.
|
||||
therefore `"stage3_gather_16bit_weights_on_model_save": true` is required to get the `Trainer` to save the fp16
|
||||
version of the weights. If this setting is `False` `pytorch_model.bin` won't be created. This is because by default DeepSpeed's `state_dict` contains a placeholder and not the real weights. If we were to save this `state_dict` it won't be possible to load it back.
|
||||
|
||||
|
||||
```json
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage3_gather_fp16_weights_on_model_save": true
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@
|
|||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_fp16_weights_on_model_save": true
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
|
||||
"gradient_accumulation_steps": "auto",
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -98,7 +98,7 @@ _deps = [
|
|||
"cookiecutter==1.7.2",
|
||||
"dataclasses",
|
||||
"datasets",
|
||||
"deepspeed>=0.5.9",
|
||||
"deepspeed>=0.6.0",
|
||||
"fairscale>0.3",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class HfDeepSpeedConfig:
|
|||
|
||||
# zero stage - this is done as early as possible, before model is created, to allow
|
||||
# ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object
|
||||
# during ``zero.Init()`` which needs whether fp16 is enabled, dtype, etc.
|
||||
# during ``zero.Init()`` which needs to know the dtype, and some other hparams.
|
||||
self._stage = self.get_value("zero_optimization.stage", -1)
|
||||
|
||||
# offload
|
||||
|
|
@ -169,10 +169,12 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
|
|||
|
||||
def __init__(self, config_file_or_dict):
|
||||
super().__init__(config_file_or_dict)
|
||||
self._dtype = torch.float16
|
||||
self._dtype = None
|
||||
self.mismatches = []
|
||||
|
||||
def dtype(self):
|
||||
if self._dtype is None:
|
||||
raise ValueError("trainer_config_process() wasn't called yet to tell dtype")
|
||||
return self._dtype
|
||||
|
||||
def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):
|
||||
|
|
@ -228,26 +230,33 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
|
|||
# total_num_steps - will get set in trainer_config_finalize
|
||||
|
||||
# fp16
|
||||
if args.fp16:
|
||||
if args.fp16 or args.fp16_full_eval:
|
||||
fp16_backend = "apex" if args.fp16_backend == "apex" else "amp"
|
||||
else:
|
||||
fp16_backend = None
|
||||
|
||||
# amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set
|
||||
# any here unless the user did the work
|
||||
self.fill_match("fp16.enabled", fp16_backend == "amp", "fp16+fp16_backend(amp)")
|
||||
self.fill_match(
|
||||
"fp16.enabled",
|
||||
((args.fp16 or args.fp16_full_eval) and fp16_backend == "amp"),
|
||||
"fp16|fp16_full_eval+fp16_backend(amp)",
|
||||
)
|
||||
|
||||
# apex: delegates amp work to apex (which needs to be available), but it cannot be used with any
|
||||
# ZeRO features
|
||||
self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)")
|
||||
self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level")
|
||||
|
||||
# only if we have an explicit fp16.enabled = False then it's fp32, if it's True or this
|
||||
# whole config section is missing then the fallback is fp16
|
||||
if self.is_false("fp16.enabled"):
|
||||
self.fill_match("bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval")
|
||||
|
||||
# deepspeed's default mode is fp16 unless there is a config that says differently
|
||||
if self.is_true("bfoat16.enabled"):
|
||||
self._dtype = torch.bfloat16
|
||||
elif self.is_false("fp16.enabled"):
|
||||
self._dtype = torch.float32
|
||||
# later there will be other dtypes besides just fp16 and fp32
|
||||
# also not quite sure what dtype should be under apex, defaulting to fp16 for now
|
||||
else:
|
||||
self._dtype = torch.float16
|
||||
|
||||
def trainer_config_finalize(self, args, model, num_training_steps):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ deps = {
|
|||
"cookiecutter": "cookiecutter==1.7.2",
|
||||
"dataclasses": "dataclasses",
|
||||
"datasets": "datasets",
|
||||
"deepspeed": "deepspeed>=0.5.9",
|
||||
"deepspeed": "deepspeed>=0.6.0",
|
||||
"fairscale": "fairscale>0.3",
|
||||
"faiss-cpu": "faiss-cpu",
|
||||
"fastapi": "fastapi",
|
||||
|
|
|
|||
|
|
@ -1687,7 +1687,7 @@ class Trainer:
|
|||
self.save_model(output_dir, _internal_call=True)
|
||||
if self.deepspeed:
|
||||
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
|
||||
# config `stage3_gather_fp16_weights_on_model_save` is True
|
||||
# config `stage3_gather_16bit_weights_on_model_save` is True
|
||||
self.deepspeed.save_checkpoint(output_dir)
|
||||
|
||||
# Save optimizer and scheduler
|
||||
|
|
@ -2101,12 +2101,12 @@ class Trainer:
|
|||
# logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
|
||||
os.remove(file)
|
||||
|
||||
# now save the real model if stage3_gather_fp16_weights_on_model_save=True
|
||||
# now save the real model if stage3_gather_16bit_weights_on_model_save=True
|
||||
# if false it will not be saved.
|
||||
# This must be called on all ranks
|
||||
if not self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME):
|
||||
if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME):
|
||||
logger.warning(
|
||||
"deepspeed.save_fp16_model didn't save the model, since stage3_gather_fp16_weights_on_model_save=false. "
|
||||
"deepspeed.save_16bit_model didn't save the model, since stage3_gather_16bit_weights_on_model_save=false. "
|
||||
"Saving the full checkpoint instead, use zero_to_fp32.py to recover weights"
|
||||
)
|
||||
self.deepspeed.save_checkpoint(output_dir)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@
|
|||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@
|
|||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
|
|
@ -45,7 +49,7 @@
|
|||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_fp16_weights_on_model_save": true
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
|
||||
"gradient_accumulation_steps": "auto",
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
import dataclasses
|
||||
import io
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
|
|
@ -23,7 +24,7 @@ from parameterized import parameterized
|
|||
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
|
||||
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.file_utils import WEIGHTS_NAME, is_torch_bf16_available
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
CaptureStd,
|
||||
|
|
@ -120,7 +121,26 @@ def get_launcher(distributed=False):
|
|||
|
||||
ZERO2 = "zero2"
|
||||
ZERO3 = "zero3"
|
||||
|
||||
FP16 = "fp16"
|
||||
BF16 = "bf16"
|
||||
|
||||
stages = [ZERO2, ZERO3]
|
||||
if is_torch_bf16_available():
|
||||
dtypes = [FP16, BF16]
|
||||
else:
|
||||
dtypes = [FP16]
|
||||
|
||||
|
||||
def parameterized_custom_name_func(func, param_num, param):
|
||||
# customize the test name generator function as we want both params to appear in the sub-test
|
||||
# name, as by default it shows only the first param
|
||||
param_based_name = parameterized.to_safe_name("_".join(str(x) for x in param.args))
|
||||
return f"{func.__name__}_{param_based_name}"
|
||||
|
||||
|
||||
# Cartesian-product of zero stages with models to test
|
||||
params = list(itertools.product(stages, dtypes))
|
||||
|
||||
|
||||
@require_deepspeed
|
||||
|
|
@ -138,8 +158,8 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||
MASTER_ADDR="localhost", MASTER_PORT=master_port, RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
|
||||
)
|
||||
|
||||
def test_init_zero3(self):
|
||||
# test that zero.Init() works correctly under zero3
|
||||
def test_init_zero3_fp16(self):
|
||||
# test that zero.Init() works correctly under zero3/fp16
|
||||
ds_config = {
|
||||
"train_batch_size": 1,
|
||||
"zero_optimization": {
|
||||
|
|
@ -216,15 +236,12 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||
# use self.get_config_dict(stage) to use these to ensure the original is not modified
|
||||
with io.open(self.ds_config_file[ZERO2], "r", encoding="utf-8") as f:
|
||||
config_zero2 = json.load(f)
|
||||
# by default use fp16
|
||||
config_zero2["fp16"]["enabled"] = True
|
||||
with io.open(self.ds_config_file[ZERO3], "r", encoding="utf-8") as f:
|
||||
config_zero3 = json.load(f)
|
||||
# by default use fp16
|
||||
config_zero3["fp16"]["enabled"] = True
|
||||
# This setting slows things down, so don't enable it by default unless needed by a test.
|
||||
# The following setting slows things down, so don't enable it by default unless needed by a test.
|
||||
# It's in the file as a demo for users since we want everything to work out of the box even if slower.
|
||||
config_zero3["zero_optimization"]["stage3_gather_fp16_weights_on_model_save"] = False
|
||||
config_zero3["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = False
|
||||
|
||||
self.ds_config_dict = dict(
|
||||
zero2=config_zero2,
|
||||
zero3=config_zero3,
|
||||
|
|
@ -348,21 +365,23 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||
|
||||
# --- These tests need to run on both zero stages --- #
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_hf_optimizer_with_offload(self, stage):
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_hf_optimizer_with_offload(self, stage, dtype):
|
||||
# non-DS optimizers can be used with ZERO-offload (as long as they have both CPU and GPU implementation (except LAMB))
|
||||
ds_config_dict = self.get_config_dict(stage)
|
||||
del ds_config_dict["optimizer"] # force default HF Trainer optimizer
|
||||
# force cpu offload
|
||||
ds_config_dict["zero_optimization"]["offload_optimizer"]["device"] = "cpu"
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_dict)
|
||||
kwargs = dict(local_rank=0, deepspeed=ds_config_dict)
|
||||
kwargs[dtype] = True
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
with CaptureLogger(deepspeed_logger) as cl:
|
||||
trainer.train()
|
||||
self.assertIn("DeepSpeed info", cl.out, "expected DeepSpeed logger output but got none")
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_fake_notebook_no_launcher(self, stage):
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_fake_notebook_no_launcher(self, stage, dtype):
|
||||
# this setup emulates a notebook where a launcher needs to be emulated by hand
|
||||
|
||||
# note that unittest resets sys.stdout each test, so `CaptureStd` will work here to capture
|
||||
|
|
@ -370,13 +389,16 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||
# it's run not as a first test as `sys.stdout` will no longer be the same. So we either have
|
||||
# to reset `deepspeed_logger.handlers[0].setStream(sys.stdout)` or directly capture from the deepspeed_logger.
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=self.get_config_dict(stage))
|
||||
kwargs = dict(local_rank=0, deepspeed=self.get_config_dict(stage))
|
||||
kwargs[dtype] = True
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
|
||||
with CaptureLogger(deepspeed_logger) as cl:
|
||||
trainer.train()
|
||||
self.assertIn("DeepSpeed info", cl.out, "expected DeepSpeed logger output but got none")
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_early_get_last_lr(self, stage):
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_early_get_last_lr(self, stage, dtype):
|
||||
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
|
||||
# not run for the first few dozen steps while loss scale is too large, and thus during
|
||||
# that time `get_last_lr` will fail if called during that warm up stage,
|
||||
|
|
@ -385,34 +407,36 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||
# `self.lr_scheduler.get_last_lr()` and originally it'd fail on the very first step.
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
a = b = 0.0
|
||||
trainer = get_regression_trainer(
|
||||
kwargs = dict(
|
||||
a=a,
|
||||
b=b,
|
||||
local_rank=0,
|
||||
train_len=8,
|
||||
fp16=True,
|
||||
deepspeed=self.get_config_dict(stage),
|
||||
per_device_train_batch_size=8,
|
||||
logging_steps=1,
|
||||
)
|
||||
kwargs[dtype] = True
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
|
||||
trainer.train()
|
||||
post_train_a = trainer.model.a.item()
|
||||
|
||||
# XXX: for some reason the following check fails with zero3 - not a broken but a
|
||||
# different qualitative outcome - as if optimizer did run
|
||||
# XXX: for some reason the following check fails with zero3/fp16 and any/bf16 - not a
|
||||
# broken but a different qualitative outcome - as if optimizer did run
|
||||
# oddly getting 1.0 for both a and b from 0.0 - there is a bug somewhere
|
||||
# print(trainer.model.a.item())
|
||||
# print(trainer.model.b.item())
|
||||
# need to investigate at some point
|
||||
if stage == ZERO3:
|
||||
if (stage == ZERO3 and dtype == FP16) or (dtype == BF16):
|
||||
return
|
||||
|
||||
# it's enough that train didn't fail for this test, but we must check that
|
||||
# optimizer/scheduler didn't run (since if it did this test isn't testing the right thing)
|
||||
self.assertEqual(post_train_a, a)
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_gradient_accumulation(self, stage):
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_gradient_accumulation(self, stage, dtype):
|
||||
# this test measures that we get identical weights and similar loss with:
|
||||
# 1. per_device_train_batch_size=8, gradient_accumulation_steps=1
|
||||
# 2. per_device_train_batch_size=4, gradient_accumulation_steps=2
|
||||
|
|
@ -433,9 +457,9 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||
b=b,
|
||||
local_rank=0,
|
||||
train_len=train_len,
|
||||
fp16=True,
|
||||
deepspeed=self.get_config_dict(stage),
|
||||
)
|
||||
kwargs[dtype] = True
|
||||
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
no_grad_accum_trainer = get_regression_trainer(
|
||||
|
|
@ -482,15 +506,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||
else:
|
||||
raise ValueError(f"unknown stage {stage}")
|
||||
|
||||
# XXX: this can be recoded and then removed once we require deepspeed>0.3.13
|
||||
from packaging import version
|
||||
|
||||
import deepspeed
|
||||
|
||||
if version.parse(deepspeed.__version__) > version.parse("0.3.13"):
|
||||
ds_file_list.append("zero_pp_rank_0_mp_rank_00_optim_states.pt")
|
||||
else:
|
||||
ds_file_list.append("zero_pp_rank_0_mp_rank_00optim_states.pt")
|
||||
ds_file_list.append("zero_pp_rank_0_mp_rank_00_optim_states.pt")
|
||||
|
||||
for step in range(freq, total, freq):
|
||||
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
|
||||
|
|
@ -509,37 +525,42 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||
path = os.path.join(ds_path, filename)
|
||||
self.assertTrue(os.path.isfile(path), f"[{stage}] {path} is not found")
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_save_checkpoints(self, stage):
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_save_checkpoints(self, stage, dtype):
|
||||
# adapted from TrainerIntegrationTest.test_save_checkpoints
|
||||
|
||||
freq = 5
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
ds_config_dict = self.get_config_dict(stage)
|
||||
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
if dtype == FP16:
|
||||
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
# XXX:
|
||||
if stage == ZERO3:
|
||||
ds_config_dict["zero_optimization"]["stage3_gather_fp16_weights_on_model_save"] = True
|
||||
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
|
||||
|
||||
# save checkpoints
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
trainer = get_regression_trainer(
|
||||
kwargs = dict(
|
||||
output_dir=output_dir,
|
||||
save_steps=freq,
|
||||
fp16=True,
|
||||
deepspeed=ds_config_dict,
|
||||
)
|
||||
kwargs[dtype] = True
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
trainer.train()
|
||||
|
||||
total = int(self.n_epochs * 64 / self.batch_size)
|
||||
self.check_saved_checkpoints_deepspeed(output_dir, freq, total, stage)
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_can_resume_training_errors(self, stage):
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_can_resume_training_errors(self, stage, dtype):
|
||||
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
ds_config_dict = self.get_config_dict(stage)
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
trainer = get_regression_trainer(output_dir=output_dir, fp16=True, deepspeed=ds_config_dict)
|
||||
kwargs = dict(output_dir=output_dir, deepspeed=ds_config_dict)
|
||||
kwargs[dtype] = True
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
|
||||
# 1. fail to find any checkpoint - due a fresh output_dir
|
||||
with self.assertRaises(Exception) as context:
|
||||
|
|
@ -557,19 +578,20 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||
"Can't find a valid checkpoint at" in str(context.exception), f"got exception: {context.exception}"
|
||||
)
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_can_resume_training_normal(self, stage):
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_can_resume_training_normal(self, stage, dtype):
|
||||
# adapted from TrainerIntegrationTest.test_can_resume_training
|
||||
# test normal resume for each stage separately, error-handling is tested in a different test
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
|
||||
ds_config_dict = self.get_config_dict(stage)
|
||||
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
if dtype == FP16:
|
||||
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
# XXX:
|
||||
if stage == ZERO3:
|
||||
ds_config_dict["zero_optimization"]["stage3_gather_fp16_weights_on_model_save"] = True
|
||||
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
|
||||
|
||||
kwargs = dict(
|
||||
output_dir=output_dir, train_len=128, save_steps=5, learning_rate=0.1, fp16=True, deepspeed=ds_config_dict
|
||||
)
|
||||
kwargs = dict(output_dir=output_dir, train_len=128, save_steps=5, learning_rate=0.1, deepspeed=ds_config_dict)
|
||||
kwargs[dtype] = True
|
||||
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
|
|
@ -607,8 +629,8 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||
# trainer.train(resume_from_checkpoint=checkpoint)
|
||||
# a workaround needs to be used that re-creates the deepspeed engine
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_load_state_dict_from_zero_checkpoint(self, stage):
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_load_state_dict_from_zero_checkpoint(self, stage, dtype):
|
||||
# test that we can load fp32 weights directly from the zero checkpoint into the current model
|
||||
|
||||
output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False, before=False)
|
||||
|
|
@ -623,9 +645,9 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||
save_strategy="steps",
|
||||
save_steps=1,
|
||||
learning_rate=0.1,
|
||||
fp16=True,
|
||||
deepspeed=ds_config_dict,
|
||||
)
|
||||
kwargs[dtype] = True
|
||||
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
|
|
@ -648,8 +670,8 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
kwargs = dict(output_dir=output_dir, train_len=8, fp16=True)
|
||||
|
||||
ds_config_zero3_dict = self.get_config_dict("zero3")
|
||||
ds_config_zero2_dict = self.get_config_dict("zero2")
|
||||
ds_config_zero3_dict = self.get_config_dict(ZERO3)
|
||||
ds_config_zero2_dict = self.get_config_dict(ZERO2)
|
||||
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
trainer = get_regression_trainer(deepspeed=ds_config_zero3_dict, **kwargs)
|
||||
|
|
@ -698,57 +720,60 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||
#
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@parameterized.expand(stages)
|
||||
def test_basic_distributed(self, stage):
|
||||
self.run_and_check(stage=stage, distributed=True)
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_basic_distributed(self, stage, dtype):
|
||||
self.run_and_check(stage=stage, dtype=dtype, distributed=True)
|
||||
|
||||
def test_do_eval_no_train(self):
|
||||
# testing only zero3 since zero2 makes no sense with inference
|
||||
self.run_and_check(
|
||||
stage=ZERO3,
|
||||
dtype=FP16,
|
||||
eval_steps=1,
|
||||
distributed=False,
|
||||
do_train=False,
|
||||
do_eval=True,
|
||||
)
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_fp32_non_distributed(self, stage):
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_fp32_non_distributed(self, stage, dtype):
|
||||
# real model needs too much GPU memory under stage2+fp32, so using tiny random model here -
|
||||
# therefore no quality checks, just basic completion checks are done
|
||||
self.run_and_check(
|
||||
stage=stage,
|
||||
dtype=dtype,
|
||||
model_name=T5_TINY,
|
||||
distributed=False,
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
quality_checks=False,
|
||||
fp16=False,
|
||||
fp32=True,
|
||||
)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@parameterized.expand(stages)
|
||||
def test_fp32_distributed(self, stage):
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_fp32_distributed(self, stage, dtype):
|
||||
# real model needs too much GPU memory under stage2+fp32, so using tiny random model here -
|
||||
# therefore no quality checks, just basic completion checks are done
|
||||
self.run_and_check(
|
||||
stage=stage,
|
||||
dtype=dtype,
|
||||
model_name=T5_TINY,
|
||||
distributed=True,
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
quality_checks=False,
|
||||
fp16=False,
|
||||
fp32=True,
|
||||
)
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_resume_train_not_from_ds_checkpoint(self, stage):
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_resume_train_not_from_ds_checkpoint(self, stage, dtype):
|
||||
# do normal training and then resume not from the deepspeed checkpoint but explicitly from
|
||||
# the saved model dir
|
||||
|
||||
do_train = True
|
||||
do_eval = False
|
||||
kwargs = dict(stage=stage, eval_steps=1, distributed=True, do_train=do_train, do_eval=do_eval)
|
||||
kwargs = dict(stage=stage, dtype=dtype, eval_steps=1, distributed=True, do_train=do_train, do_eval=do_eval)
|
||||
|
||||
# 1. normal training
|
||||
output_dir = self.run_and_check(**kwargs)
|
||||
|
|
@ -760,19 +785,23 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||
self.do_checks(output_dir, do_train=do_train, do_eval=do_eval)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@parameterized.expand(["fp16", "fp32"])
|
||||
@parameterized.expand(["bf16", "fp16", "fp32"])
|
||||
def test_inference(self, dtype):
|
||||
if dtype == "bf16" and not is_torch_bf16_available():
|
||||
self.skipTest("test requires bfloat16 hardware support")
|
||||
|
||||
# this is just inference, so no optimizer should be loaded
|
||||
# it only works for z3 (makes no sense with z1-z2)
|
||||
fp16 = True if dtype == "fp16" else False
|
||||
fp32 = True if dtype == "fp32" else False
|
||||
self.run_and_check(
|
||||
stage=ZERO3,
|
||||
dtype=FP16,
|
||||
model_name=T5_TINY,
|
||||
distributed=True,
|
||||
do_train=False,
|
||||
do_eval=True,
|
||||
quality_checks=False,
|
||||
fp16=fp16,
|
||||
fp32=fp32,
|
||||
)
|
||||
|
||||
def do_checks(self, output_dir, do_train=True, do_eval=True, quality_checks=True):
|
||||
|
|
@ -793,13 +822,14 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||
def run_and_check(
|
||||
self,
|
||||
stage,
|
||||
dtype,
|
||||
model_name: str = T5_SMALL,
|
||||
eval_steps: int = 10,
|
||||
distributed: bool = True,
|
||||
do_train: bool = True,
|
||||
do_eval: bool = True,
|
||||
quality_checks: bool = True,
|
||||
fp16: bool = True,
|
||||
fp32: bool = False,
|
||||
extra_args_str: str = None,
|
||||
remove_args_str: str = None,
|
||||
):
|
||||
|
|
@ -807,13 +837,14 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||
# we are doing quality testing so using a small real model
|
||||
output_dir = self.run_trainer(
|
||||
stage=stage,
|
||||
dtype=dtype,
|
||||
model_name=model_name,
|
||||
eval_steps=eval_steps,
|
||||
num_train_epochs=1,
|
||||
do_train=do_train,
|
||||
do_eval=do_eval,
|
||||
distributed=distributed,
|
||||
fp16=fp16,
|
||||
fp32=fp32,
|
||||
extra_args_str=extra_args_str,
|
||||
remove_args_str=remove_args_str,
|
||||
)
|
||||
|
|
@ -825,13 +856,14 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||
def run_trainer(
|
||||
self,
|
||||
stage: str,
|
||||
dtype: str,
|
||||
model_name: str,
|
||||
eval_steps: int = 10,
|
||||
num_train_epochs: int = 1,
|
||||
do_train: bool = False,
|
||||
do_eval: bool = True,
|
||||
distributed: bool = True,
|
||||
fp16: bool = True,
|
||||
fp32: bool = False,
|
||||
extra_args_str: str = None,
|
||||
remove_args_str: str = None,
|
||||
):
|
||||
|
|
@ -859,8 +891,8 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||
""".split()
|
||||
args.extend(["--source_prefix", '"translate English to Romanian: "'])
|
||||
|
||||
if fp16:
|
||||
args.extend(["--fp16"])
|
||||
if not fp32:
|
||||
args.extend([f"--{dtype}"])
|
||||
|
||||
actions = 0
|
||||
if do_train:
|
||||
|
|
@ -906,8 +938,8 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||
|
||||
return output_dir
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_clm(self, stage):
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_clm(self, stage, dtype):
|
||||
# this test exercises model.resize_token_embeddings() which requires param gathering outside
|
||||
# of forward - it's not used by `run_translation.py`, but it is in `run_clm.py`
|
||||
|
||||
|
|
@ -928,10 +960,11 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||
--num_train_epochs 1
|
||||
--warmup_steps 8
|
||||
--block_size 64
|
||||
--fp16
|
||||
--report_to none
|
||||
""".split()
|
||||
|
||||
args.extend([f"--{dtype}"])
|
||||
|
||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
|
||||
script = [f"{self.examples_dir_str}/pytorch/language-modeling/run_clm.py"]
|
||||
launcher = get_launcher(distributed=True)
|
||||
|
|
@ -941,7 +974,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
def test_clm_from_config_zero3(self):
|
||||
def test_clm_from_config_zero3_fp16(self):
|
||||
# this test exercises AutoModel.from_config(config) - to ensure zero.Init is called
|
||||
|
||||
data_dir = self.tests_dir / "fixtures"
|
||||
|
|
@ -974,8 +1007,8 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
self.assertIn("Detected DeepSpeed ZeRO-3", cs.err)
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_load_best_model(self, stage):
|
||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||
def test_load_best_model(self, stage, dtype):
|
||||
# this test exercises --load_best_model_at_end - the key is being able to resume after some training
|
||||
|
||||
data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro"
|
||||
|
|
@ -1003,11 +1036,12 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||
--per_device_train_batch_size 1
|
||||
--per_device_eval_batch_size 1
|
||||
--num_train_epochs 1
|
||||
--fp16
|
||||
--report_to none
|
||||
""".split()
|
||||
args.extend(["--source_prefix", "translate English to Romanian: "])
|
||||
|
||||
args.extend([f"--{dtype}"])
|
||||
|
||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
|
||||
script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
|
||||
launcher = get_launcher(distributed=False)
|
||||
|
|
|
|||
|
|
@ -205,8 +205,19 @@ task_cmds = make_task_cmds()
|
|||
|
||||
ZERO2 = "zero2"
|
||||
ZERO3 = "zero3"
|
||||
|
||||
stages = [ZERO2, ZERO3]
|
||||
|
||||
# future preparation:
|
||||
# for now test just fp16, as these tests are quite slow
|
||||
# FP16 = "fp16"
|
||||
# BF16 = "bf16"
|
||||
#
|
||||
# dtypes = [FP16]
|
||||
# so just hardcoding --fp16 for now
|
||||
# if is_torch_bf16_available():
|
||||
# dtypes += [BF16]
|
||||
|
||||
|
||||
def parameterized_custom_name_func(func, param_num, param):
|
||||
# customize the test name generator function as we want both params to appear in the sub-test
|
||||
|
|
|
|||
Loading…
Reference in a new issue