Add ORTModule deepspeed zero stage 1 test to the distributed CI pipeline (#6342)

* Add deepspeed zero stage 1 poc test with MNIST data to the ORTModule CI pipeline

* Add ORTModule CI pipeline tests for POC and hf BERT classifier with --no-cuda arg
This commit is contained in:
baijumeswani 2021-01-19 10:08:13 -08:00 committed by GitHub
parent 0586c610b2
commit 910c5ab655
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 12 deletions

View file

@ -202,6 +202,7 @@ endif()
file(GLOB onnxruntime_python_test_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/python/*.py"
"${ORTTRAINING_SOURCE_DIR}/test/python/*.py"
"${ORTTRAINING_SOURCE_DIR}/test/python/*.json"
)
file(GLOB onnxruntime_python_checkpoint_test_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/test/python/checkpoint/*.py"

View file

@ -24,7 +24,10 @@ def parse_arguments():
def run_ortmodule_deepspeed_zero_stage_1_tests(cwd, log):
log.debug('Running: ORTModule deepspeed zero stage 1 tests')
# TODO: add the actual deepspeed test here
command = ['deepspeed', 'orttraining_test_ortmodule_deepspeed_zero_stage_1.py',
'--deepspeed_config', 'orttraining_test_ortmodule_deepspeed_zero_stage_1_config.json']
run_subprocess(command, cwd=cwd, log=log).check_returncode()
def main():

View file

@ -29,18 +29,22 @@ def run_ortmodule_api_tests(cwd, log):
run_subprocess(command, cwd=cwd, log=log).check_returncode()
def run_ortmodule_poc_net(cwd, log):
log.debug('Running: ORTModule POCNet for MNIST')
def run_ortmodule_poc_net(cwd, log, no_cuda):
log.debug('Running: ORTModule POCNet for MNIST with --no-cuda arg {}.'.format(no_cuda))
command = [sys.executable, 'orttraining_test_ortmodule_poc.py']
if no_cuda:
command.extend(['--no-cuda', '--epochs', str(3)])
run_subprocess(command, cwd=cwd, log=log).check_returncode()
def run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log):
log.debug('Running: ORTModule HuggingFace BERT for sequence classification.')
def run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda):
log.debug('Running: ORTModule HuggingFace BERT for sequence classification with --no-cuda arg {}.'.format(no_cuda))
command = [sys.executable, 'orttraining_test_ortmodule_bert_classifier.py']
if no_cuda:
command.extend(['--no-cuda', '--epochs', str(3)])
run_subprocess(command, cwd=cwd, log=log).check_returncode()
@ -53,9 +57,13 @@ def main():
run_ortmodule_api_tests(cwd, log)
run_ortmodule_poc_net(cwd, log)
run_ortmodule_poc_net(cwd, log, no_cuda=False)
run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log)
run_ortmodule_poc_net(cwd, log, no_cuda=True)
run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=False)
run_ort_module_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=True)
return 0

View file

@ -4,8 +4,8 @@ To run on the local GPU(s):
```
$ pip install deepspeed
$ deepspeed orttraining_test_ortmodule_mnist_deepspeed.py \
--deepspeed_config=orttraining_test_ortmodule_mnist_deepspeed_config.json
$ deepspeed orttraining_test_ortmodule_deepspeed_zero_stage_1.py \
--deepspeed_config=orttraining_test_ortmodule_deepspeed_zero_stage_1_config.json
```
"""
import argparse
@ -13,6 +13,7 @@ import logging
import torch
import time
from torchvision import datasets, transforms
import torch.distributed as dist
import onnxruntime
from onnxruntime.training import ORTModule
@ -166,9 +167,15 @@ def main():
device = "cpu"
## Data loader
train_set = datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
dist.init_process_group(backend='nccl')
if args.local_rank == 0:
# download only once on rank 0
datasets.MNIST('./data', download=True)
dist.barrier()
train_set = datasets.MNIST('./data', train=True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
test_loader = None
if args.test_batch_size > 0:
test_loader = torch.utils.data.DataLoader(