mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[trainer] apex fixes and tests (#9180)
This commit is contained in:
parent
467e9158b4
commit
f06d0fadc9
2 changed files with 22 additions and 8 deletions
|
|
@ -18,7 +18,7 @@ import unittest
|
|||
from unittest.mock import patch
|
||||
|
||||
from transformers import BertTokenizer, EncoderDecoderModel
|
||||
from transformers.file_utils import is_datasets_available
|
||||
from transformers.file_utils import is_apex_available, is_datasets_available
|
||||
from transformers.integrations import is_fairscale_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
|
|
@ -51,6 +51,17 @@ def require_fairscale(test_case):
|
|||
return test_case
|
||||
|
||||
|
||||
# a candidate for testing_utils
|
||||
def require_apex(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires apex
|
||||
"""
|
||||
if not is_apex_available():
|
||||
return unittest.skip("test requires apex")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
class TestFinetuneTrainer(TestCasePlus):
|
||||
def finetune_trainer_quick(self, distributed=None, extra_args_str=None):
|
||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
|
||||
|
|
@ -72,6 +83,7 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||
def test_finetune_trainer_ddp(self):
|
||||
self.finetune_trainer_quick(distributed=True)
|
||||
|
||||
# it's crucial to test --sharded_ddp w/ and w/o --fp16
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_finetune_trainer_ddp_sharded_ddp(self):
|
||||
|
|
@ -82,6 +94,10 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||
def test_finetune_trainer_ddp_sharded_ddp_fp16(self):
|
||||
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
|
||||
|
||||
@require_apex
|
||||
def test_finetune_trainer_apex(self):
|
||||
self.finetune_trainer_quick(extra_args_str="--fp16 --fp16_backend=apex")
|
||||
|
||||
@slow
|
||||
def test_finetune_trainer_slow(self):
|
||||
# There is a missing call to __init__process_group somewhere
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||
|
||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
||||
from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||
|
|
@ -104,13 +104,10 @@ if is_in_notebook():
|
|||
|
||||
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
|
||||
|
||||
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
|
||||
if version.parse(torch.__version__) < version.parse("1.6"):
|
||||
from .file_utils import is_apex_available
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
else:
|
||||
if version.parse(torch.__version__) >= version.parse("1.6"):
|
||||
_is_native_amp_available = True
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
|
|
@ -309,6 +306,7 @@ class Trainer:
|
|||
backend = "amp" if _is_native_amp_available else "apex"
|
||||
else:
|
||||
backend = args.fp16_backend
|
||||
logger.info(f"Using {backend} fp16 backend")
|
||||
|
||||
if backend == "amp":
|
||||
self.use_amp = True
|
||||
|
|
|
|||
Loading…
Reference in a new issue