Enable PyTorch Lightning basic test on CI (#6809)

This commit is contained in:
Thiago Crepaldi 2021-02-27 09:35:42 -08:00 committed by GitHub
parent 059ed1c241
commit f71d93ea2b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 14 additions and 14 deletions

View file

@ -29,7 +29,7 @@ def run_ortmodule_deepspeed_zero_stage_1_tests(cwd, log, data_dir):
'--deepspeed_config', 'orttraining_test_ortmodule_deepspeed_zero_stage_1_config.json']
if data_dir:
command.extend(['--data_dir', data_dir])
command.extend(['--data-dir', data_dir])
run_subprocess(command, cwd=cwd, log=log).check_returncode()

View file

@ -46,7 +46,7 @@ def run_ortmodule_poc_net(cwd, log, no_cuda, data_dir):
command.extend(['--no-cuda', '--epochs', str(3)])
if data_dir:
command.extend(['--data_dir', data_dir])
command.extend(['--data-dir', data_dir])
run_subprocess(command, cwd=cwd, log=log).check_returncode()
@ -58,12 +58,12 @@ def run_ortmodule_torch_lightning(cwd, log, data_dir):
'--epochs=2', '--batch-size=256']
if data_dir:
command.extend(['--data_dir', data_dir])
command.extend(['--data-dir', data_dir])
run_subprocess(command, cwd=cwd, log=log).check_returncode()
def run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda, data_dir, transformers_cache):
def run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda, data_dir, transformers_cache):
log.debug('Running: ORTModule HuggingFace BERT for sequence classification with --no-cuda arg {}.'.format(no_cuda))
env = get_env_with_transformers_cache(transformers_cache)
@ -73,7 +73,7 @@ def run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log,
command.extend(['--no-cuda', '--epochs', str(3)])
if data_dir:
command.extend(['--data_dir', data_dir])
command.extend(['--data-dir', data_dir])
run_subprocess(command, cwd=cwd, log=log, env=env).check_returncode()
@ -90,14 +90,13 @@ def main():
run_ortmodule_poc_net(cwd, log, no_cuda=True, data_dir=args.mnist)
run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=False,
run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=False,
data_dir=args.bert_data, transformers_cache=args.transformers_cache)
run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=True,
run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=True,
data_dir=args.bert_data, transformers_cache=args.transformers_cache)
# TODO: Re-enable when PyTorch Lightning works with newer torchtext (nightlies after 2021-02-19)
# run_ortmodule_torch_lightning(cwd, log, args.args.mnist)
run_ortmodule_torch_lightning(cwd, log, args.mnist)
return 0

View file

@ -337,7 +337,7 @@ def main():
help='Log level (default: WARNING)')
parser.add_argument('--num-hidden-layers', type=int, default=1, metavar='H',
help='Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)')
parser.add_argument('--data_dir', type=str, default='./cola_public/raw',
parser.add_argument('--data-dir', type=str, default='./cola_public/raw',
help='Path to the bert data directory')
args = parser.parse_args()

View file

@ -146,7 +146,7 @@ def main():
help='number of epochs to train (default: 10)')
parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING',
help='Log level (default: WARNING)')
parser.add_argument('--data_dir', type=str, default='./mnist',
parser.add_argument('--data-dir', type=str, default='./mnist',
help='Path to the mnist data directory')
# DeepSpeed-related settings

View file

@ -136,7 +136,7 @@ def main():
help='number of epochs to train (default: 10)')
parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING',
help='Log level (default: WARNING)')
parser.add_argument('--data_dir', type=str, default='./mnist',
parser.add_argument('--data-dir', type=str, default='./mnist',
help='Path to the mnist data directory')
args = parser.parse_args()

View file

@ -30,8 +30,7 @@ class LitAutoEncoder(pl.LightningModule):
)
if use_ortmodule:
self.encoder = ORTModule(self.encoder)
# TODO: Remove this comment below when multiple ORTModule instances is supported
# self.decoder = ORTModule(self.decoder)
self.decoder = ORTModule(self.decoder)
def forward(self, x):
# in lightning, forward defines the prediction/inference actions

View file

@ -11,3 +11,5 @@ torchtext
tensorboard==v2.0.0
h5py
wget
# PyTorch Lightning (nightly) is used for CI tests only
https://github.com/PyTorchLightning/pytorch-lightning/archive/master.zip