From 910c5ab65508c82fdd40d4b6b53ec4df7978e9dc Mon Sep 17 00:00:00 2001 From: baijumeswani Date: Tue, 19 Jan 2021 10:08:13 -0800 Subject: [PATCH] Add ORTModule deepspeed zero stage 1 test to the distributed CI pipeline (#6342) * Add deepspeed zero stage 1 poc test with MNIST data to the ORTModule CI pipeline * Add ORTModule CI pipeline tests for POC and hf BERT classifier with --no-cuda arg --- cmake/onnxruntime_python.cmake | 1 + ...orttraining_ortmodule_distributed_tests.py | 5 ++++- .../python/orttraining_ortmodule_tests.py | 20 +++++++++++++------ ..._test_ortmodule_deepspeed_zero_stage_1.py} | 17 +++++++++++----- ...module_deepspeed_zero_stage_1_config.json} | 0 5 files changed, 31 insertions(+), 12 deletions(-) rename orttraining/orttraining/test/python/{orttraining_test_ortmodule_mnist_deepspeed.py => orttraining_test_ortmodule_deepspeed_zero_stage_1.py} (93%) rename orttraining/orttraining/test/python/{orttraining_test_ortmodule_mnist_deepspeed_config.json => orttraining_test_ortmodule_deepspeed_zero_stage_1_config.json} (100%) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index f6b2e2f2f3..7b93b652a5 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -202,6 +202,7 @@ endif() file(GLOB onnxruntime_python_test_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/test/python/*.py" "${ORTTRAINING_SOURCE_DIR}/test/python/*.py" + "${ORTTRAINING_SOURCE_DIR}/test/python/*.json" ) file(GLOB onnxruntime_python_checkpoint_test_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/test/python/checkpoint/*.py" diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py index 06e3c93753..d14812d0d9 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py @@ -24,7 +24,10 @@ def parse_arguments(): def run_ortmodule_deepspeed_zero_stage_1_tests(cwd, log): log.debug('Running: ORTModule deepspeed zero stage 1 tests') - # TODO: add the actual deepspeed test here + command = ['deepspeed', 'orttraining_test_ortmodule_deepspeed_zero_stage_1.py', + '--deepspeed_config', 'orttraining_test_ortmodule_deepspeed_zero_stage_1_config.json'] + + run_subprocess(command, cwd=cwd, log=log).check_returncode() def main(): diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py index c2696e6747..c147962930 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py @@ -29,18 +29,22 @@ def run_ortmodule_api_tests(cwd, log): run_subprocess(command, cwd=cwd, log=log).check_returncode() -def run_ortmodule_poc_net(cwd, log): - log.debug('Running: ORTModule POCNet for MNIST') +def run_ortmodule_poc_net(cwd, log, no_cuda): + log.debug('Running: ORTModule POCNet for MNIST with --no-cuda arg {}.'.format(no_cuda)) command = [sys.executable, 'orttraining_test_ortmodule_poc.py'] + if no_cuda: + command.extend(['--no-cuda', '--epochs', str(3)]) run_subprocess(command, cwd=cwd, log=log).check_returncode() -def run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log): - log.debug('Running: ORTModule HuggingFace BERT for sequence classification.') +def run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda): + log.debug('Running: ORTModule HuggingFace BERT for sequence classification with --no-cuda arg {}.'.format(no_cuda)) command = [sys.executable, 'orttraining_test_ortmodule_bert_classifier.py'] + if no_cuda: + command.extend(['--no-cuda', '--epochs', str(3)]) run_subprocess(command, cwd=cwd, log=log).check_returncode() @@ -53,9 +57,13 @@ def main(): run_ortmodule_api_tests(cwd, log) - run_ortmodule_poc_net(cwd, log) + run_ortmodule_poc_net(cwd, log, no_cuda=False) - run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log) + run_ortmodule_poc_net(cwd, log, no_cuda=True) + + run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=False) + + run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=True) return 0 diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist_deepspeed.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py similarity index 93% rename from orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist_deepspeed.py rename to orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py index 22e9624786..dc58a0fe05 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist_deepspeed.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py @@ -4,8 +4,8 @@ To run on the local GPU(s): ``` $ pip install deepspeed -$ deepspeed orttraining_test_ortmodule_mnist_deepspeed.py \ - --deepspeed_config=orttraining_test_ortmodule_mnist_deepspeed_config.json +$ deepspeed orttraining_test_ortmodule_deepspeed_zero_stage_1.py \ + --deepspeed_config=orttraining_test_ortmodule_deepspeed_zero_stage_1_config.json ``` """ import argparse @@ -13,6 +13,7 @@ import logging import torch import time from torchvision import datasets, transforms +import torch.distributed as dist import onnxruntime from onnxruntime.training import ORTModule @@ -166,9 +167,15 @@ def main(): device = "cpu" ## Data loader - train_set = datasets.MNIST('./data', train=True, download=True, - transform=transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))])) + dist.init_process_group(backend='nccl') + if args.local_rank == 0: + # download only once on rank 0 + datasets.MNIST('./data', download=True) + dist.barrier() + train_set = datasets.MNIST('./data', train=True, + transform=transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,))])) + test_loader = None if args.test_batch_size > 0: test_loader = torch.utils.data.DataLoader( diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist_deepspeed_config.json b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1_config.json similarity index 100% rename from orttraining/orttraining/test/python/orttraining_test_ortmodule_mnist_deepspeed_config.json rename to orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1_config.json