diff --git a/docs/source/main_classes/trainer.rst b/docs/source/main_classes/trainer.rst index 080fcc222..537c8df84 100644 --- a/docs/source/main_classes/trainer.rst +++ b/docs/source/main_classes/trainer.rst @@ -830,6 +830,28 @@ Here is an example of the ``amp`` configuration: } +Gradient Accumulation +======================================================================================================================= + +While normally DeepSpeed gets gradient accumulation configured with: + +.. code-block:: json + + { + "gradient_accumulation_steps": 3, + } + +in this case, to enable gradient accumulation, pass the command line `--gradient_accumulation_steps` argument as normal +and it will get injected into the DeepSpeed configuration. + +If you try to add it directly to the configuration file, you will receive an error from the Trainer - this is because +this setting is needed by the Trainer too, and so this approach ensures that there is a single way of setting this +value and thus avoid potential subtle errors. + + + + + Gradient Clipping ======================================================================================================================= diff --git a/examples/tests/deepspeed/ds_config.json b/examples/tests/deepspeed/ds_config.json index 9b6f35610..24034d1f1 100644 --- a/examples/tests/deepspeed/ds_config.json +++ b/examples/tests/deepspeed/ds_config.json @@ -3,6 +3,7 @@ "enabled": true, "loss_scale": 0, "loss_scale_window": 1000, + "initial_scale_power": 32, "hysteresis": 2, "min_loss_scale": 1 }, diff --git a/examples/tests/deepspeed/test_deepspeed.py b/examples/tests/deepspeed/test_deepspeed.py index b606376fb..cb8192b75 100644 --- a/examples/tests/deepspeed/test_deepspeed.py +++ b/examples/tests/deepspeed/test_deepspeed.py @@ -23,7 +23,7 @@ from transformers.testing_utils import ( TestCasePlus, execute_subprocess_async, get_gpu_count, - mockenv, + mockenv_context, require_torch_gpu, require_torch_multi_gpu, slow, @@ -31,6 +31,11 @@ from transformers.testing_utils import ( from transformers.trainer_utils import set_seed +bindir = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(f"{bindir}/../../../tests") +from test_trainer import get_regression_trainer # noqa + + set_seed(42) MBART_TINY = "sshleifer/tiny-mbart" @@ -51,32 +56,96 @@ def require_deepspeed(test_case): return test_case +@require_deepspeed +@require_torch_gpu +class TrainerIntegrationDeepSpeed(TestCasePlus): + """ This class is for testing directly via get_regression_trainer """ + + def setUp(self): + super().setUp() + self.dist_env_1_gpu = dict( + MASTER_ADDR="localhost", MASTER_PORT="10999", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1" + ) + self.ds_config_file = f"{self.test_file_dir_str}/ds_config.json" + + def test_fake_notebook_no_launcher(self): + + # this setup emulates a notebook where a launcher needs to be emulated by hand + + with CaptureStd() as cs: + with mockenv_context(**self.dist_env_1_gpu): + trainer = get_regression_trainer(local_rank=0, deepspeed=self.ds_config_file) + trainer.train() + assert "DeepSpeed info" in cs.out, "expected DeepSpeed logger output but got none" + + def test_gradient_accumulation(self): + + # this test measures that we get identical weights and similar loss with: + # 1. per_device_train_batch_size=8, gradient_accumulation_steps=1 + # 2. per_device_train_batch_size=4, gradient_accumulation_steps=2 + # since the 2nd should produce the effective batch of 1st, with the same results + # + # I can get an identical loss for a small train_len=32, plus the power of the initial + # dynamic loss scale value set to: + # "fp16.initial_scale_power": 1 + # plus having the same WarmupLR's warmup_min_lr == warmup_max_lr in the config file + # but for some reason going to train_len=64 the weights, weights start to mismatch with this setup. + # the culprit seems to be `initial_scale_power` - putting it back to its default 32 keeps the weights identical + + train_len = 64 + a = b = 0.0 + + with mockenv_context(**self.dist_env_1_gpu): + no_grad_accum_trainer = get_regression_trainer( + a=a, + b=b, + local_rank=0, + train_len=train_len, + deepspeed=self.ds_config_file, + per_device_train_batch_size=8, + gradient_accumulation_steps=1, + ) + no_grad_accum_result = no_grad_accum_trainer.train() + no_grad_accum_loss = no_grad_accum_result.training_loss + no_grad_accum_a = no_grad_accum_trainer.model.a.item() + no_grad_accum_b = no_grad_accum_trainer.model.b.item() + # make sure the optimizer kicked in - if it hasn't changed from the original value of a then make train_len bigger + self.assertNotEqual(no_grad_accum_a, a) + + with mockenv_context(**self.dist_env_1_gpu): + yes_grad_accum_trainer = get_regression_trainer( + a=a, + b=b, + local_rank=0, + train_len=train_len, + deepspeed=self.ds_config_file, + per_device_train_batch_size=4, + gradient_accumulation_steps=2, + ) + yes_grad_accum_result = yes_grad_accum_trainer.train() + yes_grad_accum_loss = yes_grad_accum_result.training_loss + yes_grad_accum_a = yes_grad_accum_trainer.model.a.item() + yes_grad_accum_b = yes_grad_accum_trainer.model.b.item() + self.assertNotEqual(yes_grad_accum_a, a) + + # training with half the batch size but accumulation steps as 2 should give the same weights + self.assertEqual(no_grad_accum_a, yes_grad_accum_a) + self.assertEqual(no_grad_accum_b, yes_grad_accum_b) + + # see the note above how to get identical loss on a small bs + self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=5) + + @slow @require_deepspeed @require_torch_gpu class TestDeepSpeed(TestCasePlus): - - # this setup emulates a notebook where a launcher needs to be emulated by hand - @mockenv(MASTER_ADDR="localhost", MASTER_PORT="10999", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1") - def test_fake_notebook_no_launcher(self): - sys.path.append(self.tests_dir_str) - from test_trainer import get_regression_trainer - - del sys.path[-1] # restore - ds_config_file = f"{self.test_file_dir_str}/ds_config.json" - with CaptureStd() as cs: - trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_file) - trainer.train() - assert "DeepSpeed info" in cs.out, "expected DeepSpeed logger output but got none" + """ This class is for testing via an external script """ @require_torch_multi_gpu def test_basic_distributed(self): self.run_quick(distributed=True) - @require_torch_multi_gpu - def test_grad_acum(self): - self.run_quick(distributed=True, extra_args_str="--gradient_accumulation_steps 2") - def test_do_eval_no_train(self): # we should not fail if train is skipped output_dir = self.run_trainer( diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 0d16d8c07..b2ed86ce2 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import inspect import logging import os @@ -830,14 +831,49 @@ class TestCasePlus(unittest.TestCase): def mockenv(**kwargs): """ - this is a convenience wrapper, that allows this: + this is a convenience wrapper, that allows this :: + + @mockenv(RUN_SLOW=True, USE_TF=False) + def test_something(): + run_slow = os.getenv("RUN_SLOW", False) + use_tf = os.getenv("USE_TF", False) - @mockenv(RUN_SLOW=True, USE_TF=False) def test_something(): run_slow = os.getenv("RUN_SLOW", False) use_tf = - os.getenv("USE_TF", False) """ return unittest.mock.patch.dict(os.environ, kwargs) +# from https://stackoverflow.com/a/34333710/9201239 +@contextlib.contextmanager +def mockenv_context(*remove, **update): + """ + Temporarily updates the ``os.environ`` dictionary in-place. Similar to mockenv + + The ``os.environ`` dictionary is updated in-place so that the modification is sure to work in all situations. + + Args: + remove: Environment variables to remove. + update: Dictionary of environment variables and values to add/update. + """ + env = os.environ + update = update or {} + remove = remove or [] + + # List of environment variables being updated or removed. + stomped = (set(update.keys()) | set(remove)) & set(env.keys()) + # Environment variables and values to restore on exit. + update_after = {k: env[k] for k in stomped} + # Environment variables and values to remove on exit. + remove_after = frozenset(k for k in update if k not in env) + + try: + env.update(update) + [env.pop(k, None) for k in remove] + yield + finally: + env.update(update_after) + [env.pop(k) for k in remove_after] + + # --- pytest conf functions --- # # to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a24cbfe71..02837d3ee 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -718,7 +718,7 @@ class Trainer: def _wrap_model(self, model, training=True): # already initialized its own DDP and AMP if self.deepspeed: - return model + return self.deepspeed # Mixed precision training with apex (torch < 1.6) if self.use_apex and training: @@ -996,6 +996,10 @@ class Trainer: tr_loss += self.training_step(model, inputs) self._total_flos += float(self.floating_point_ops(inputs)) + # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps + if self.deepspeed: + self.deepspeed.step() + if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps steps_in_epoch <= self.args.gradient_accumulation_steps @@ -1021,7 +1025,7 @@ class Trainer: # Optimizer step if self.deepspeed: - self.deepspeed.step() + pass # called outside the loop elif is_torch_tpu_available(): xm.optimizer_step(self.optimizer) elif self.use_amp: @@ -1030,7 +1034,9 @@ class Trainer: else: self.optimizer.step() - self.lr_scheduler.step() + if not self.deepspeed: + self.lr_scheduler.step() + model.zero_grad() self.state.global_step += 1 self.state.epoch = epoch + (step + 1) / steps_in_epoch @@ -1388,7 +1394,6 @@ class Trainer: Return: :obj:`torch.Tensor`: The tensor with training loss on this batch. """ - model.train() inputs = self._prepare_inputs(inputs) @@ -1401,7 +1406,8 @@ class Trainer: if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training - if self.args.gradient_accumulation_steps > 1: + if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: + # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` loss = loss / self.args.gradient_accumulation_steps if self.use_amp: @@ -1410,7 +1416,8 @@ class Trainer: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() elif self.deepspeed: - self.deepspeed.backward(loss) + # loss gets scaled under gradient_accumulation_steps in deepspeed + loss = self.deepspeed.backward(loss) else: loss.backward()