diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index cf16e69a1..24e56f752 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -18,7 +18,7 @@ import unittest from unittest.mock import patch from transformers import BertTokenizer, EncoderDecoderModel -from transformers.file_utils import is_datasets_available +from transformers.file_utils import is_apex_available, is_datasets_available from transformers.integrations import is_fairscale_available from transformers.testing_utils import ( TestCasePlus, @@ -51,6 +51,17 @@ def require_fairscale(test_case): return test_case +# a candidate for testing_utils +def require_apex(test_case): + """ + Decorator marking a test that requires apex + """ + if not is_apex_available(): + return unittest.skip("test requires apex")(test_case) + else: + return test_case + + class TestFinetuneTrainer(TestCasePlus): def finetune_trainer_quick(self, distributed=None, extra_args_str=None): output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str) @@ -72,6 +83,7 @@ class TestFinetuneTrainer(TestCasePlus): def test_finetune_trainer_ddp(self): self.finetune_trainer_quick(distributed=True) + # it's crucial to test --sharded_ddp w/ and w/o --fp16 @require_torch_multi_gpu @require_fairscale def test_finetune_trainer_ddp_sharded_ddp(self): @@ -82,6 +94,10 @@ class TestFinetuneTrainer(TestCasePlus): def test_finetune_trainer_ddp_sharded_ddp_fp16(self): self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp --fp16") + @require_apex + def test_finetune_trainer_apex(self): + self.finetune_trainer_quick(extra_args_str="--fp16 --fp16_backend=apex") + @slow def test_finetune_trainer_slow(self): # There is a missing call to __init__process_group somewhere diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0bad52217..ee918e267 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -53,7 +53,7 @@ from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, SequentialSampler from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator -from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available +from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available from .modeling_utils import PreTrainedModel from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING from .optimization import AdamW, get_linear_schedule_with_warmup @@ -104,13 +104,10 @@ if is_in_notebook(): DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback -# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex -if version.parse(torch.__version__) < version.parse("1.6"): - from .file_utils import is_apex_available +if is_apex_available(): + from apex import amp - if is_apex_available(): - from apex import amp -else: +if version.parse(torch.__version__) >= version.parse("1.6"): _is_native_amp_available = True from torch.cuda.amp import autocast @@ -309,6 +306,7 @@ class Trainer: backend = "amp" if _is_native_amp_available else "apex" else: backend = args.fp16_backend + logger.info(f"Using {backend} fp16 backend") if backend == "amp": self.use_amp = True