diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py index 72bd92ca5b..10ca5281fd 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py @@ -29,7 +29,7 @@ def run_ortmodule_deepspeed_zero_stage_1_tests(cwd, log, data_dir): '--deepspeed_config', 'orttraining_test_ortmodule_deepspeed_zero_stage_1_config.json'] if data_dir: - command.extend(['--data_dir', data_dir]) + command.extend(['--data-dir', data_dir]) run_subprocess(command, cwd=cwd, log=log).check_returncode() diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py index d3ec796da6..2b4fb80cc8 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py @@ -46,7 +46,7 @@ def run_ortmodule_poc_net(cwd, log, no_cuda, data_dir): command.extend(['--no-cuda', '--epochs', str(3)]) if data_dir: - command.extend(['--data_dir', data_dir]) + command.extend(['--data-dir', data_dir]) run_subprocess(command, cwd=cwd, log=log).check_returncode() @@ -58,12 +58,12 @@ def run_ortmodule_torch_lightning(cwd, log, data_dir): '--epochs=2', '--batch-size=256'] if data_dir: - command.extend(['--data_dir', data_dir]) + command.extend(['--data-dir', data_dir]) run_subprocess(command, cwd=cwd, log=log).check_returncode() -def run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda, data_dir, transformers_cache): +def run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda, data_dir, transformers_cache): log.debug('Running: ORTModule HuggingFace BERT for sequence classification with --no-cuda arg {}.'.format(no_cuda)) env = get_env_with_transformers_cache(transformers_cache) @@ -73,7 +73,7 @@ def run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, command.extend(['--no-cuda', '--epochs', str(3)]) if data_dir: - command.extend(['--data_dir', data_dir]) + command.extend(['--data-dir', data_dir]) run_subprocess(command, cwd=cwd, log=log, env=env).check_returncode() @@ -90,14 +90,13 @@ def main(): run_ortmodule_poc_net(cwd, log, no_cuda=True, data_dir=args.mnist) - run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=False, + run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=False, data_dir=args.bert_data, transformers_cache=args.transformers_cache) - run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=True, + run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=True, data_dir=args.bert_data, transformers_cache=args.transformers_cache) - # TODO: Re-enable when PyTorch Lightning works with newer torchtext (nightlies after 2021-02-19) - # run_ortmodule_torch_lightning(cwd, log, args.args.mnist) + run_ortmodule_torch_lightning(cwd, log, args.mnist) return 0 diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py index f56385d371..02c7980971 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py @@ -337,7 +337,7 @@ def main(): help='Log level (default: WARNING)') parser.add_argument('--num-hidden-layers', type=int, default=1, metavar='H', help='Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)') - parser.add_argument('--data_dir', type=str, default='./cola_public/raw', + parser.add_argument('--data-dir', type=str, default='./cola_public/raw', help='Path to the bert data directory') args = parser.parse_args() diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py index ed7a6ee630..52f6dda670 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py @@ -146,7 +146,7 @@ def main(): help='number of epochs to train (default: 10)') parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING', help='Log level (default: WARNING)') - parser.add_argument('--data_dir', type=str, default='./mnist', + parser.add_argument('--data-dir', type=str, default='./mnist', help='Path to the mnist data directory') # DeepSpeed-related settings diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py index a3b2ed3491..1c31f16bbd 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py @@ -136,7 +136,7 @@ def main(): help='number of epochs to train (default: 10)') parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING', help='Log level (default: WARNING)') - parser.add_argument('--data_dir', type=str, default='./mnist', + parser.add_argument('--data-dir', type=str, default='./mnist', help='Path to the mnist data directory') args = parser.parse_args() diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py index 7c3b41acc2..5e8c2f747b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py @@ -30,8 +30,7 @@ class LitAutoEncoder(pl.LightningModule): ) if use_ortmodule: self.encoder = ORTModule(self.encoder) - # TODO: Remove this comment below when multiple ORTModule instances is supported - # self.decoder = ORTModule(self.decoder) + self.decoder = ORTModule(self.decoder) def forward(self, x): # in lightning, forward defines the prediction/inference actions diff --git a/tools/ci_build/github/linux/docker/scripts/training/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/requirements.txt index 389c33a74c..e375a567d2 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/requirements.txt @@ -11,3 +11,5 @@ torchtext tensorboard==v2.0.0 h5py wget +# PyTorch Lightning (nightly) is used for CI tests only +https://github.com/PyTorchLightning/pytorch-lightning/archive/master.zip \ No newline at end of file