mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Enable PyTorch Lightning basic test on CI (#6809)
This commit is contained in:
parent
059ed1c241
commit
f71d93ea2b
7 changed files with 14 additions and 14 deletions
|
|
@ -29,7 +29,7 @@ def run_ortmodule_deepspeed_zero_stage_1_tests(cwd, log, data_dir):
|
|||
'--deepspeed_config', 'orttraining_test_ortmodule_deepspeed_zero_stage_1_config.json']
|
||||
|
||||
if data_dir:
|
||||
command.extend(['--data_dir', data_dir])
|
||||
command.extend(['--data-dir', data_dir])
|
||||
|
||||
run_subprocess(command, cwd=cwd, log=log).check_returncode()
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ def run_ortmodule_poc_net(cwd, log, no_cuda, data_dir):
|
|||
command.extend(['--no-cuda', '--epochs', str(3)])
|
||||
|
||||
if data_dir:
|
||||
command.extend(['--data_dir', data_dir])
|
||||
command.extend(['--data-dir', data_dir])
|
||||
|
||||
run_subprocess(command, cwd=cwd, log=log).check_returncode()
|
||||
|
||||
|
|
@ -58,12 +58,12 @@ def run_ortmodule_torch_lightning(cwd, log, data_dir):
|
|||
'--epochs=2', '--batch-size=256']
|
||||
|
||||
if data_dir:
|
||||
command.extend(['--data_dir', data_dir])
|
||||
command.extend(['--data-dir', data_dir])
|
||||
|
||||
run_subprocess(command, cwd=cwd, log=log).check_returncode()
|
||||
|
||||
|
||||
def run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda, data_dir, transformers_cache):
|
||||
def run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda, data_dir, transformers_cache):
|
||||
log.debug('Running: ORTModule HuggingFace BERT for sequence classification with --no-cuda arg {}.'.format(no_cuda))
|
||||
|
||||
env = get_env_with_transformers_cache(transformers_cache)
|
||||
|
|
@ -73,7 +73,7 @@ def run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log,
|
|||
command.extend(['--no-cuda', '--epochs', str(3)])
|
||||
|
||||
if data_dir:
|
||||
command.extend(['--data_dir', data_dir])
|
||||
command.extend(['--data-dir', data_dir])
|
||||
|
||||
run_subprocess(command, cwd=cwd, log=log, env=env).check_returncode()
|
||||
|
||||
|
|
@ -90,14 +90,13 @@ def main():
|
|||
|
||||
run_ortmodule_poc_net(cwd, log, no_cuda=True, data_dir=args.mnist)
|
||||
|
||||
run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=False,
|
||||
run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=False,
|
||||
data_dir=args.bert_data, transformers_cache=args.transformers_cache)
|
||||
|
||||
run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=True,
|
||||
run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=True,
|
||||
data_dir=args.bert_data, transformers_cache=args.transformers_cache)
|
||||
|
||||
# TODO: Re-enable when PyTorch Lightning works with newer torchtext (nightlies after 2021-02-19)
|
||||
# run_ortmodule_torch_lightning(cwd, log, args.args.mnist)
|
||||
run_ortmodule_torch_lightning(cwd, log, args.mnist)
|
||||
|
||||
return 0
|
||||
|
||||
|
|
|
|||
|
|
@ -337,7 +337,7 @@ def main():
|
|||
help='Log level (default: WARNING)')
|
||||
parser.add_argument('--num-hidden-layers', type=int, default=1, metavar='H',
|
||||
help='Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)')
|
||||
parser.add_argument('--data_dir', type=str, default='./cola_public/raw',
|
||||
parser.add_argument('--data-dir', type=str, default='./cola_public/raw',
|
||||
help='Path to the bert data directory')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ def main():
|
|||
help='number of epochs to train (default: 10)')
|
||||
parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING',
|
||||
help='Log level (default: WARNING)')
|
||||
parser.add_argument('--data_dir', type=str, default='./mnist',
|
||||
parser.add_argument('--data-dir', type=str, default='./mnist',
|
||||
help='Path to the mnist data directory')
|
||||
|
||||
# DeepSpeed-related settings
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ def main():
|
|||
help='number of epochs to train (default: 10)')
|
||||
parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING',
|
||||
help='Log level (default: WARNING)')
|
||||
parser.add_argument('--data_dir', type=str, default='./mnist',
|
||||
parser.add_argument('--data-dir', type=str, default='./mnist',
|
||||
help='Path to the mnist data directory')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -30,8 +30,7 @@ class LitAutoEncoder(pl.LightningModule):
|
|||
)
|
||||
if use_ortmodule:
|
||||
self.encoder = ORTModule(self.encoder)
|
||||
# TODO: Remove this comment below when multiple ORTModule instances is supported
|
||||
# self.decoder = ORTModule(self.decoder)
|
||||
self.decoder = ORTModule(self.decoder)
|
||||
|
||||
def forward(self, x):
|
||||
# in lightning, forward defines the prediction/inference actions
|
||||
|
|
|
|||
|
|
@ -11,3 +11,5 @@ torchtext
|
|||
tensorboard==v2.0.0
|
||||
h5py
|
||||
wget
|
||||
# PyTorch Lightning (nightly) is used for CI tests only
|
||||
https://github.com/PyTorchLightning/pytorch-lightning/archive/master.zip
|
||||
Loading…
Reference in a new issue