batch size tests (#5508)

* batch size tests

Co-authored-by: liqun <liqun@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
liqunfu 2020-10-28 15:55:40 -07:00 committed by GitHub
parent 50582abe93
commit 5129b4d5bc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 213 additions and 93 deletions

View file

@ -70,11 +70,13 @@ def bert_model_description(config):
('attention_mask', ['batch', 'max_seq_len_in_batch'],),
('token_type_ids', ['batch', 'max_seq_len_in_batch'],),
('masked_lm_labels', ['batch', 'max_seq_len_in_batch'],),
('next_sentence_label', ['batch', ],)],
('next_sentence_label', ['batch', ],)
],
'outputs': [
('loss', [], True),
('prediction_scores', ['batch', 'max_seq_len_in_batch', vocab_size],),
('seq_relationship_scores', ['batch', 2],)]}
('seq_relationship_scores', ['batch', 2],)
]}
return new_model_desc
@ -119,6 +121,47 @@ class pretraining_dataset(Dataset):
return [input_ids, segment_ids, input_mask,
masked_lm_labels, next_sentence_labels]
import argparse
def parse_arguments():
parser = argparse.ArgumentParser()
# batch size test config parameters
parser.add_argument("--enable_mixed_precision",
default=False,
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument("--sequence_length",
default=512,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--max_predictions_per_seq",
default=80,
type=int,
help="The maximum total of masked tokens in input sequence")
parser.add_argument("--max_batch_size",
default=32,
type=int,
help="Total batch size for training.")
parser.add_argument("--gelu_recompute",
default=False,
action='store_true')
parser.add_argument("--attn_dropout_recompute",
default=False,
action='store_true')
parser.add_argument("--transformer_layer_recompute",
default=False,
action='store_true')
args = parser.parse_args()
return args
@dataclass
class PretrainArguments:
"""
@ -207,6 +250,19 @@ class PretrainArguments:
metadata={"help": "Whether to use 16-bit float precision instead of 32-bit."}
)
gelu_recompute: bool = field(
default=False,
metadata={"help": "Whether to enable recomputing Gelu activation output to save memory."}
)
attn_dropout_recompute: bool = field(
default=False,
metadata={"help": "Whether to enable recomputing attention dropout to save memory."}
)
transformer_layer_recompute: bool = field(
default=False,
metadata={"help": "Whether to enable recomputing transformer layerwise to save memory."}
)
loss_scale: Optional[float] = field(
default=0.0,
metadata={"help": "Loss scaling, positive power of 2 values can improve fp16 convergence."}
@ -345,8 +401,8 @@ def setup_torch_distributed(world_rank, world_size):
return
def prepare_model(args, device):
config = BertConfig.from_pretrained('bert-base-uncased', cache_dir=args.cache_dir)
config = BertConfig.from_pretrained(args.bert_model, cache_dir=args.cache_dir)
# config.num_hidden_layers = 12
if args.force_num_hidden_layers:
logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers)
@ -367,6 +423,11 @@ def prepare_model(args, device):
'mixed_precision': {
'enabled': args.fp16,
'loss_scaler': loss_scaler},
'graph_transformer': {
'attn_dropout_recompute': args.attn_dropout_recompute,
'gelu_recompute': args.gelu_recompute,
'transformer_layer_recompute': args.transformer_layer_recompute,
},
'debug': {'deterministic_compute': True, },
'utils': {
'grad_norm_clip': True},
@ -524,41 +585,41 @@ class ORTBertPretrainTest(unittest.TestCase):
self.allreduce_post_accumulation = True
self.tensorboard_dir = '/bert_data/hf_data/test_out'
def test_pretrain_throughput(self):
# setting train_batch_size and gradient_accumulation_steps to maximize per gpu memory usage under 16GB
# to validate throughput regression.
# train_batch_size is initially configured as per optimization batch size,
# taking into consideration of world_size and gradient_accumulation_steps:
# train_batch_size = world_size * gradient_accumulation_steps * batch_size_per_gpu
# in the code later train_batch_size is translated to per gpu batch size:
# args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps // args.world_size
def test_pretrain_throughput(self, process_args=None):
if process_args.sequence_length == 128:
input_dir = '/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train'
else:
input_dir = '/bert_data/hdf5_lower_case_1_seq_len_512_max_pred_80_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train'
# the LAMB batch size of 64k
optimization_batch_size = 64 * 1024
per_gpu_batch_size = 32
print("process_args.enable_mixed_precision: ", process_args.enable_mixed_precision)
print("process_args.sequence_length: ", process_args.sequence_length)
print("process_args.max_batch_size: ", process_args.max_batch_size)
print("process_args.max_predictions_per_seq: ", process_args.max_predictions_per_seq)
print("process_args.gelu_recompute: ", process_args.gelu_recompute)
print("process_args.attn_dropout_recompute: ", process_args.attn_dropout_recompute)
print("process_args.transformer_layer_recompute: ", process_args.transformer_layer_recompute)
self.train_batch_size = optimization_batch_size
self.gradient_accumulation_steps = optimization_batch_size // per_gpu_batch_size // self.world_size
logger.info("self.gradient_accumulation_steps = %d", self.gradient_accumulation_steps)
# only to run on optimization step because we only want to make sure there is no throughput regression
self.max_steps = 1
args = PretrainArguments(
output_dir=self.output_dir,
bert_model=self.bert_model,
input_dir=input_dir,
output_dir='/bert_data/hf_data/test_out/bert_pretrain_results',
bert_model='bert-large-uncased',
local_rank=self.local_rank,
world_rank=self.world_rank,
world_size=self.world_size,
max_steps=self.max_steps,
learning_rate=self.learning_rate,
max_seq_length=self.max_seq_length,
max_predictions_per_seq=self.max_predictions_per_seq,
train_batch_size=self.train_batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
input_dir=self.input_dir,
fp16=self.fp16,
allreduce_post_accumulation=self.allreduce_post_accumulation)
max_steps=10,
learning_rate=5e-4,
max_seq_length=process_args.sequence_length,
max_predictions_per_seq=process_args.max_predictions_per_seq,
train_batch_size=process_args.max_batch_size,
gradient_accumulation_steps=1,
fp16=process_args.enable_mixed_precision,
gelu_recompute=process_args.gelu_recompute,
attn_dropout_recompute=process_args.attn_dropout_recompute,
transformer_layer_recompute=process_args.transformer_layer_recompute,
allreduce_post_accumulation=True,
# TODO: remove
force_num_hidden_layers=2,
)
do_pretrain(args)
def test_pretrain_convergence(self):
@ -621,8 +682,8 @@ class ORTBertPretrainTest(unittest.TestCase):
fp16=self.fp16,
allreduce_post_accumulation=self.allreduce_post_accumulation,
force_num_hidden_layers=self.force_num_hidden_layers,
deepspeed_zero_stage = self.deepspeed_zero_stage,
save_checkpoint = True)
deepspeed_zero_stage=self.deepspeed_zero_stage,
save_checkpoint=True)
train_loss = do_pretrain(args)
# ensure all workers reach this point before loading the checkpointed state
@ -633,7 +694,7 @@ class ORTBertPretrainTest(unittest.TestCase):
checkpoint_files = _list_checkpoint_files(self.output_dir, "ORT_checkpoint")
ckpt_agg = _CombineZeroCheckpoint(checkpoint_files)
final_state_dict = ckpt_agg.aggregate_checkpoints()
args.init_state_dict = final_state_dict
torch.distributed.barrier()
@ -646,22 +707,31 @@ class ORTBertPretrainTest(unittest.TestCase):
return final_loss
# to do parallel training:
# python -m torch.distributed.launch --nproc_per_node 4 orttraining_run_bert_pretrain.py
if __name__ == "__main__":
import sys
logger.warning("sys.argv: %s", sys.argv)
# usage:
# mpirun -n 4 python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_throughput
# mpirun -n 4 python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_convergence
# mpirun -n 4 python orttraining_run_bert_pretrain.py # to run real BERT convergence test
# pytorch.distributed.launch will not work because ort backend requires MPI to broadcast ncclUniqueId
# data parallel training
# mpirun -n 4 python orttraining_run_bert_pretrain.py
#
# single gpu:
# python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_throughput
# [batch size test arguments]
# python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_convergence
#
# pytorch.distributed.launch will not work because ort backend requires MPI to broadcast ncclUniqueId
# calling unpublished get_mpi_context_xxx to get rank/size numbers.
from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_local_size, get_mpi_context_world_rank, get_mpi_context_world_size
world_size = get_mpi_context_world_size()
if world_size > 1:
print ('get_mpi_context_world_size(): ', world_size)
try:
# In case ORT is not built with MPI/NCCL, there are no get_mpi_context_xxx internal apis.
from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_local_size,\
get_mpi_context_world_rank, get_mpi_context_world_size
has_get_mpi_context_internal_api = True
except ImportError:
has_get_mpi_context_internal_api = False
pass
if has_get_mpi_context_internal_api and get_mpi_context_world_size() > 1:
world_size = get_mpi_context_world_size()
print('get_mpi_context_world_size(): ', world_size)
local_rank = get_mpi_context_local_rank()
if local_rank == 0:
@ -673,19 +743,7 @@ if __name__ == "__main__":
test.world_rank = local_rank
test.world_size = world_size
if len(sys.argv) >= 2 and sys.argv[1] == 'ORTBertPretrainTest.test_pretrain_throughput':
logger.info("running ORTBertPretrainTest.test_pretrain_throughput()...")
test.test_pretrain_throughput()
logger.info("ORTBertPretrainTest.test_pretrain_throughput() passed")
elif len(sys.argv) >= 2 and sys.argv[1] == 'ORTBertPretrainTest.test_pretrain_convergence':
logger.info("running ORTBertPretrainTest.test_pretrain_convergence()...")
test.max_steps = 200
test.force_num_hidden_layers = 8
final_loss = test.test_pretrain_convergence()
logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss)
test.assertLess(final_loss, 8.5)
logger.info("ORTBertPretrainTest.test_pretrain_convergence() passed")
elif len(sys.argv) >= 2 and sys.argv[1] == 'ORTBertPretrainTest.test_pretrain_zero':
if len(sys.argv) >= 2 and sys.argv[1] == 'ORTBertPretrainTest.test_pretrain_zero':
logger.info("running ORTBertPretrainTest.test_pretrain_zero()...")
final_loss = test.test_pretrain_zero()
logger.info("ORTBertPretrainTest.test_pretrain_zero() rank = %i final loss = %f", local_rank, final_loss)
@ -694,37 +752,23 @@ if __name__ == "__main__":
else:
test.assertGreater(final_loss, 11.0)
logger.info("ORTBertPretrainTest.test_pretrain_zero() passed")
elif len(sys.argv) >= 2 and sys.argv[1] == 'ORTBertPretrainTest.test_pretrain_convergence':
logger.info("running ORTBertPretrainTest.test_pretrain_convergence()...")
test.max_steps = 200
test.force_num_hidden_layers = 8
final_loss = test.test_pretrain_convergence()
logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss)
test.assertLess(final_loss, 8.5)
logger.info("ORTBertPretrainTest.test_pretrain_convergence() passed")
else:
# https://microsoft.sharepoint.com/teams/ONNX2/_layouts/15/Doc.aspx?sourcedoc={170774be-e1c6-4f8b-a3ae-984f211fe410}&action=edit&wd=target%28ONNX%20Training.one%7C8176133b-c7cb-4ef2-aa9d-3fdad5344c40%2FGitHub%20Master%20Merge%20Schedule%7Cb67f0db1-e3a0-4add-80a6-621d67fd8107%2F%29
# to make equivalent args for cpp convergence test
# ngpu=4
# seq_len=128
# max_predictions_per_seq=20
# batch_size=64
# grad_acc=16
# num_train_steps=1000000
# optimizer=adam
# lr=5e-4
# warmup_ratio=0.1
# warmup_mode=Linear
# effective_batch_size=$((ngpu * batch_size * grad_acc))
# commit=$(git rev-parse HEAD | cut -c1-8)
# time_now=$(date +%m%d%H%M)
# run_name=ort_${commit}_nvbertbase_bookwiki128_fp16_${optimizer}_lr${lr}_${warmup_mode}${warmup_ratio}_g${ngpu}_bs${batch_size}_acc${grad_acc}_efbs${effective_batch_size}_steps${num_train_steps}_fp16allreduce_${time_now}
# mixed precision
# mpirun -n ${ngpu} ./onnxruntime_training_bert --model_name /bert_ort/bert_models/nv/bert-base/bert-base-uncased_L_12_H_768_A_12_V_30528_S_512_Dp_0.1_optimized_layer_norm
# --train_data_dir /bert_data/128/books_wiki_en_corpus/train --test_data_dir /bert_data/128/books_wiki_en_corpus/test
# --train_batch_size ${batch_size} --mode train --num_train_steps ${num_train_steps} --display_loss_steps 5
# --log_dir ./logs/bert_base/${run_name} --optimizer ${optimizer} --learning_rate ${lr} --warmup_ratio ${warmup_ratio} --warmup_mode ${warmup_mode}
# --gradient_accumulation_steps ${grad_acc} --max_predictions_per_seq=${max_predictions_per_seq} --use_mixed_precision --allreduce_in_fp16 --lambda 0
# --use_nccl | tee ${run_name}.log
test.max_seq_length = 128
test.max_predictions_per_seq = 20
test.gradient_accumulation_steps = 16
test.train_batch_size = 64 * test.gradient_accumulation_steps * test.world_size # cpp_batch_size (=64) * grad_acc * world_size
# cpp_batch_size (=64) * grad_acc * world_size
test.train_batch_size = 64 * test.gradient_accumulation_steps * test.world_size
test.max_steps = 300000
test.force_num_hidden_layers = None
@ -736,4 +780,23 @@ if __name__ == "__main__":
final_loss = test.test_pretrain_convergence()
logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss)
else:
unittest.main()
# unittest does not accept user defined arguments
# we need to run this script with user defined arguments
if len(sys.argv) >= 2 and sys.argv[1] == 'ORTBertPretrainTest.test_pretrain_throughput':
run_test_pretrain_throughput, run_test_pretrain_convergence = True, False
sys.argv.remove('ORTBertPretrainTest.test_pretrain_throughput')
elif len(sys.argv) >= 2 and sys.argv[1] == 'ORTBertPretrainTest.test_pretrain_convergence':
run_test_pretrain_throughput, run_test_pretrain_convergence = False, True
sys.argv.remove('ORTBertPretrainTest.test_pretrain_convergence')
else:
run_test_pretrain_throughput, run_test_pretrain_convergence = True, True
process_args = parse_arguments()
test = ORTBertPretrainTest()
test.setUp()
if run_test_pretrain_throughput:
logger.info("running single GPU ORTBertPretrainTest.test_pretrain_throughput()...")
test.test_pretrain_throughput(process_args)
logger.info("single GPU ORTBertPretrainTest.test_pretrain_throughput() passed")
# unittest.main()

View file

@ -0,0 +1,56 @@
import sys
import collections
import subprocess
Config = collections.namedtuple(
"Config",
[
"enable_mixed_precision",
"sequence_length",
"max_batch_size",
"max_predictions_per_seq",
"gelu_recompute",
"attn_dropout_recompute",
"transformer_layer_recompute"])
configs = [
Config(True, 128, 46, 20, False, False, False),
Config(True, 512, 8, 80, False, False, False),
Config(False, 128, 26, 20, False, False, False),
Config(False, 512, 4, 80, False, False, False),
Config(True, 128, 50, 20, True, False, False),
Config(True, 128, 50, 20, False, True, False),
Config(True, 128, 76, 20, False, False, True),
Config(True, 512, 8, 80, True, False, False),
Config(True, 512, 9, 80, False, True, False),
Config(True, 512, 15, 80, False, False, True),
]
def run_with_config(config):
print("##### testing name - {}-{} #####".format("fp16" if config.enable_mixed_precision else "fp32",
config.sequence_length))
print("gelu_recompute: ", config.gelu_recompute)
print("attn_dropout_recompute: ", config.attn_dropout_recompute)
print("transformer_layer_recompute: ", config.transformer_layer_recompute)
cmds = [
sys.executable,
'orttraining_run_bert_pretrain.py',
"ORTBertPretrainTest.test_pretrain_throughput",
"--sequence_length", str(config.sequence_length),
"--max_batch_size", str(config.max_batch_size),
"--max_predictions_per_seq", str(config.max_predictions_per_seq)]
if config.enable_mixed_precision:
cmds.append("--enable_mixed_precision")
if config.gelu_recompute:
cmds.append("--gelu_recompute")
if config.attn_dropout_recompute:
cmds.append("--attn_dropout_recompute")
if config.transformer_layer_recompute:
cmds.append("--transformer_layer_recompute")
subprocess.run(cmds, timeout=1200).check_returncode()
for config in configs:
run_with_config(config)

View file

@ -124,7 +124,7 @@ class ORTGlueTest(unittest.TestCase):
def test_bert_fp16_with_mrpc(self):
expected_acc = 0.84
expected_f1 = 0.88
expected_loss = 0.40
expected_loss = 0.44
results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=True)

View file

@ -90,6 +90,7 @@ class ORTMultipleChoiceTest(unittest.TestCase):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "multiple_choice_test_output/")
self.cache_dir = '/tmp/multiple_choice/'
self.logging_steps = 10
self.rtol = 2e-01
def test_bert_with_swag(self):
expected_acc = 0.75

View file

@ -1180,16 +1180,15 @@ def run_training_python_frontend_e2e_tests(cwd):
# frontend tests are to be added here:
log.info("Running python frontend e2e tests.")
run_subprocess(
[sys.executable, 'orttraining_run_frontend_batch_size_test.py', '-v'],
cwd=cwd, env={'CUDA_VISIBLE_DEVICES': '0'})
import torch
ngpus = torch.cuda.device_count()
if ngpus > 1:
bert_pretrain_script = 'orttraining_run_bert_pretrain.py'
log.debug('RUN: mpirun -n {} ''-x' 'NCCL_DEBUG=INFO'' {} {} {}'.format(
ngpus, sys.executable, bert_pretrain_script, 'ORTBertPretrainTest.test_pretrain_throughput'))
run_subprocess([
'mpirun', '-n', str(ngpus), '-x', 'NCCL_DEBUG=INFO', sys.executable,
bert_pretrain_script, 'ORTBertPretrainTest.test_pretrain_throughput'], cwd=cwd)
# TODO: this test will be replaced with convergence test ported from backend
log.debug('RUN: mpirun -n {} ''-x' 'NCCL_DEBUG=INFO'' {} {} {}'.format(
ngpus, sys.executable, bert_pretrain_script, 'ORTBertPretrainTest.test_pretrain_convergence'))
run_subprocess([
@ -1231,7 +1230,8 @@ def run_training_python_frontend_e2e_tests(cwd):
sys.executable, 'orttraining_test_transformers.py',
'BertModelTest.test_for_pretraining_mixed_precision'], cwd=cwd)
# this test is not stable. need to skip to unblock release
# this test is not stable. it occasionally causes segfault due to its session creation/release pattern.
# need to skip to unblock release
# run_subprocess([
# sys.executable, 'orttraining_test_transformers.py',
# 'BertModelTest.test_for_pretraining_mixed_precision_with_gradient_accumulation'], cwd=cwd)