mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
Liqun/bert pretrain tb (#5377)
* add tensor board, remove torch.distributed.lanuch because ort nccl depends on MPI. Use MPI to launch parallel training. Co-authored-by: liqun <liqun@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
b5caa7cb12
commit
773992c7d4
4 changed files with 145 additions and 35 deletions
|
|
@ -8,15 +8,19 @@ import logging
|
|||
import random
|
||||
import h5py
|
||||
from tqdm import tqdm
|
||||
import datetime
|
||||
import numpy as np
|
||||
import dataclasses
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import Optional, Any, Dict
|
||||
import json
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler, Dataset
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from transformers import BertForPreTraining, BertConfig, HfArgumentParser
|
||||
|
||||
|
|
@ -24,7 +28,7 @@ from concurrent.futures import ProcessPoolExecutor
|
|||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.training import amp, optim, orttrainer
|
||||
from onnxruntime.training.optim import _LRScheduler, PolyWarmupLRScheduler
|
||||
from onnxruntime.training.optim import PolyWarmupLRScheduler, LinearWarmupLRScheduler
|
||||
|
||||
# need to override torch.onnx.symbolic_opset12.nll_loss to handle ignore_index == -100 cases.
|
||||
# the fix for ignore_index == -100 cases is already in pytorch master.
|
||||
|
|
@ -33,6 +37,11 @@ from onnxruntime.training.optim import _LRScheduler, PolyWarmupLRScheduler
|
|||
# issues are understood and solved.
|
||||
import onnxruntime.capi.pt_patch
|
||||
|
||||
# we cannot make full convergence run in nightly pipeling because of its timeout limit,
|
||||
# max_steps is still needed to calculate learning rate. force_to_stop_max_steps is used to
|
||||
# terminate the training before the pipeline run hit its timeout.
|
||||
force_to_stop_max_steps = 2500
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO)
|
||||
|
|
@ -266,6 +275,21 @@ class PretrainArguments:
|
|||
metadata={"help": "Whether to use fp16 gradient accumulators."}
|
||||
)
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Serializes this instance to a JSON string.
|
||||
"""
|
||||
return json.dumps(dataclasses.asdict(self), indent=2)
|
||||
|
||||
def to_sanitized_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Sanitized serialization to use with TensorBoard’s hparams
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
valid_types = [bool, int, float, str, torch.Tensor]
|
||||
return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}
|
||||
|
||||
|
||||
def setup_training(args):
|
||||
|
||||
assert (torch.cuda.is_available())
|
||||
|
|
@ -296,6 +320,8 @@ def setup_training(args):
|
|||
|
||||
def prepare_model(args, device):
|
||||
config = BertConfig.from_pretrained('bert-base-uncased', cache_dir=args.cache_dir)
|
||||
|
||||
# config.num_hidden_layers = 12
|
||||
if args.force_num_hidden_layers:
|
||||
logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers)
|
||||
config.num_hidden_layers = args.force_num_hidden_layers
|
||||
|
|
@ -303,7 +329,7 @@ def prepare_model(args, device):
|
|||
model = BertForPreTraining(config)
|
||||
model_desc = bert_model_description(config)
|
||||
|
||||
lr_scheduler = PolyWarmupLRScheduler(total_steps=int(args.max_steps))
|
||||
lr_scheduler = LinearWarmupLRScheduler(total_steps=int(args.max_steps), warmup=args.warmup_proportion)
|
||||
|
||||
loss_scaler = amp.DynamicLossScaler() if args.fp16 else None
|
||||
|
||||
|
|
@ -317,7 +343,10 @@ def prepare_model(args, device):
|
|||
'utils': {
|
||||
'grad_norm_clip': True},
|
||||
'distributed': {
|
||||
'allreduce_post_accumulation': True},
|
||||
'world_rank': max(0, args.local_rank),
|
||||
'world_size': args.world_size,
|
||||
'local_rank': max(0, args.local_rank),
|
||||
'allreduce_post_accumulation': args.allreduce_post_accumulation},
|
||||
'lr_scheduler': lr_scheduler
|
||||
})
|
||||
|
||||
|
|
@ -352,6 +381,13 @@ def main():
|
|||
|
||||
|
||||
def do_pretrain(args):
|
||||
if is_main_process(args) and args.tensorboard_dir:
|
||||
tb_writer = SummaryWriter(log_dir=args.tensorboard_dir)
|
||||
tb_writer.add_text("args", args.to_json_string())
|
||||
tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
|
@ -398,18 +434,27 @@ def do_pretrain(args):
|
|||
input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
|
||||
|
||||
loss, _, _ = model.train_step(input_ids, input_mask, segment_ids, masked_lm_labels, next_sentence_labels)
|
||||
|
||||
# This is an approximation which misses gradient overflow.
|
||||
# TODO: ORTTrainer to expose global_step.
|
||||
global_step = training_steps / args.gradient_accumulation_steps
|
||||
|
||||
average_loss += loss.item()
|
||||
|
||||
global_step = model._train_step_info.optimization_step
|
||||
if training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0:
|
||||
if is_main_process(args):
|
||||
print("Step:{} Average Loss = {}".format(global_step, average_loss / (args.log_freq * args.gradient_accumulation_steps)))
|
||||
divisor = args.log_freq * args.gradient_accumulation_steps
|
||||
if tb_writer:
|
||||
lr = model.options.lr_scheduler.get_last_lr()[0]
|
||||
tb_writer.add_scalar('train/summary/scalar/Learning_Rate', lr, global_step)
|
||||
if args.fp16:
|
||||
tb_writer.add_scalar('train/summary/scalar/loss_scale_25', loss, global_step)
|
||||
# TODO: ORTTrainer to expose all_finite
|
||||
# tb_writer.add_scalar('train/summary/scalar/all_fp16_gradients_finite_859', all_finite, global_step)
|
||||
tb_writer.add_scalar('train/summary/total_loss', average_loss / divisor, global_step)
|
||||
|
||||
print("Step:{} Average Loss = {}".format(global_step, average_loss / divisor))
|
||||
|
||||
if global_step >= args.max_steps or global_step >= force_to_stop_max_steps:
|
||||
if tb_writer:
|
||||
tb_writer.close()
|
||||
|
||||
if global_step >= args.max_steps:
|
||||
final_loss = average_loss / (args.log_freq * args.gradient_accumulation_steps)
|
||||
return final_loss
|
||||
|
||||
|
|
@ -422,6 +467,13 @@ def do_pretrain(args):
|
|||
epoch += 1
|
||||
|
||||
|
||||
def generate_tensorboard_logdir(root_dir):
|
||||
current_date_time = datetime.datetime.today()
|
||||
|
||||
dt_string = current_date_time.strftime('BERT_pretrain_%y_%m_%d_%I_%M_%S')
|
||||
return os.path.join(root_dir, dt_string)
|
||||
|
||||
|
||||
class ORTBertPretrainTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.output_dir = '/bert_data/hf_data/test_out/bert_pretrain_results'
|
||||
|
|
@ -431,13 +483,14 @@ class ORTBertPretrainTest(unittest.TestCase):
|
|||
self.world_size = 1
|
||||
self.max_steps = 300000
|
||||
self.learning_rate = 5e-4
|
||||
self.max_seq_length = 128
|
||||
self.max_seq_length = 512
|
||||
self.max_predictions_per_seq = 20
|
||||
self.input_dir = '/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train'
|
||||
self.train_batch_size = 4096
|
||||
self.gradient_accumulation_steps = 64
|
||||
self.fp16 = True
|
||||
self.allreduce_post_accumulation = True
|
||||
self.tensorboard_dir = '/bert_data/hf_data/test_out'
|
||||
|
||||
def test_pretrain_throughput(self):
|
||||
# setting train_batch_size and gradient_accumulation_steps to maximize per gpu memory usage under 16GB
|
||||
|
|
@ -477,8 +530,6 @@ class ORTBertPretrainTest(unittest.TestCase):
|
|||
do_pretrain(args)
|
||||
|
||||
def test_pretrain_convergence(self):
|
||||
self.max_steps = 200
|
||||
self.force_num_hidden_layers = 8
|
||||
args = PretrainArguments(
|
||||
output_dir=self.output_dir,
|
||||
bert_model=self.bert_model,
|
||||
|
|
@ -494,7 +545,8 @@ class ORTBertPretrainTest(unittest.TestCase):
|
|||
input_dir=self.input_dir,
|
||||
fp16=self.fp16,
|
||||
allreduce_post_accumulation=self.allreduce_post_accumulation,
|
||||
force_num_hidden_layers=self.force_num_hidden_layers)
|
||||
force_num_hidden_layers=self.force_num_hidden_layers,
|
||||
tensorboard_dir=generate_tensorboard_logdir('/bert_data/hf_data/test_out/'))
|
||||
final_loss = do_pretrain(args)
|
||||
return final_loss
|
||||
|
||||
|
|
@ -504,11 +556,21 @@ class ORTBertPretrainTest(unittest.TestCase):
|
|||
if __name__ == "__main__":
|
||||
import sys
|
||||
logger.warning("sys.argv: %s", sys.argv)
|
||||
if len(sys.argv) >= 2 and sys.argv[1].startswith('--local_rank='):
|
||||
# torch.parallel.launch
|
||||
local_rank = int(sys.argv[1][len('--local_rank='):])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
print("torch.parallel.launch, local_rank/world_size: ", local_rank, '/', world_size)
|
||||
# usage:
|
||||
# mpirun -n 4 python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_throughput
|
||||
# mpirun -n 4 python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_convergence
|
||||
# mpirun -n 4 python orttraining_run_bert_pretrain.py # to run real BERT convergence test
|
||||
# pytorch.distributed.launch will not work because ort backend requires MPI to broadcast ncclUniqueId
|
||||
#
|
||||
# calling unpublished get_mpi_context_xxx to get rank/size numbers.
|
||||
from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_local_size, get_mpi_context_world_rank, get_mpi_context_world_size
|
||||
world_size = get_mpi_context_world_size()
|
||||
if world_size > 1:
|
||||
print ('get_mpi_context_world_size(): ', world_size)
|
||||
local_rank = get_mpi_context_local_rank()
|
||||
|
||||
if local_rank == 0:
|
||||
print('================================================================> os.getpid() = ', os.getpid())
|
||||
|
||||
test = ORTBertPretrainTest()
|
||||
test.setUp()
|
||||
|
|
@ -516,17 +578,58 @@ if __name__ == "__main__":
|
|||
test.world_rank = local_rank
|
||||
test.world_size = world_size
|
||||
|
||||
if sys.argv[2] == 'ORTBertPretrainTest.test_pretrain_throughput':
|
||||
if len(sys.argv) >= 2 and sys.argv[1] == 'ORTBertPretrainTest.test_pretrain_throughput':
|
||||
logger.info("running ORTBertPretrainTest.test_pretrain_throughput()...")
|
||||
test.test_pretrain_throughput()
|
||||
logger.info("ORTBertPretrainTest.test_pretrain_throughput() passed")
|
||||
else:
|
||||
elif len(sys.argv) >= 2 and sys.argv[1] == 'ORTBertPretrainTest.test_pretrain_convergence':
|
||||
logger.info("running ORTBertPretrainTest.test_pretrain_convergence()...")
|
||||
test.max_steps = 200
|
||||
test.force_num_hidden_layers = 8
|
||||
final_loss = test.test_pretrain_convergence()
|
||||
logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss)
|
||||
test.assertLess(final_loss, 8.5)
|
||||
logger.info("ORTBertPretrainTest.test_pretrain_convergence() passed")
|
||||
else:
|
||||
# https://microsoft.sharepoint.com/teams/ONNX2/_layouts/15/Doc.aspx?sourcedoc={170774be-e1c6-4f8b-a3ae-984f211fe410}&action=edit&wd=target%28ONNX%20Training.one%7C8176133b-c7cb-4ef2-aa9d-3fdad5344c40%2FGitHub%20Master%20Merge%20Schedule%7Cb67f0db1-e3a0-4add-80a6-621d67fd8107%2F%29
|
||||
# to make equivalent args for cpp convergence test
|
||||
|
||||
# ngpu=4
|
||||
# seq_len=128
|
||||
# max_predictions_per_seq=20
|
||||
# batch_size=64
|
||||
# grad_acc=16
|
||||
# num_train_steps=1000000
|
||||
# optimizer=adam
|
||||
# lr=5e-4
|
||||
# warmup_ratio=0.1
|
||||
# warmup_mode=Linear
|
||||
# effective_batch_size=$((ngpu * batch_size * grad_acc))
|
||||
# commit=$(git rev-parse HEAD | cut -c1-8)
|
||||
# time_now=$(date +%m%d%H%M)
|
||||
# run_name=ort_${commit}_nvbertbase_bookwiki128_fp16_${optimizer}_lr${lr}_${warmup_mode}${warmup_ratio}_g${ngpu}_bs${batch_size}_acc${grad_acc}_efbs${effective_batch_size}_steps${num_train_steps}_fp16allreduce_${time_now}
|
||||
|
||||
# mixed precision
|
||||
# mpirun -n ${ngpu} ./onnxruntime_training_bert --model_name /bert_ort/bert_models/nv/bert-base/bert-base-uncased_L_12_H_768_A_12_V_30528_S_512_Dp_0.1_optimized_layer_norm
|
||||
# --train_data_dir /bert_data/128/books_wiki_en_corpus/train --test_data_dir /bert_data/128/books_wiki_en_corpus/test
|
||||
# --train_batch_size ${batch_size} --mode train --num_train_steps ${num_train_steps} --display_loss_steps 5
|
||||
# --log_dir ./logs/bert_base/${run_name} --optimizer ${optimizer} --learning_rate ${lr} --warmup_ratio ${warmup_ratio} --warmup_mode ${warmup_mode}
|
||||
# --gradient_accumulation_steps ${grad_acc} --max_predictions_per_seq=${max_predictions_per_seq} --use_mixed_precision --allreduce_in_fp16 --lambda 0
|
||||
# --use_nccl | tee ${run_name}.log
|
||||
|
||||
test.max_seq_length = 128
|
||||
test.max_predictions_per_seq = 20
|
||||
test.gradient_accumulation_steps = 16
|
||||
test.train_batch_size = 64 * test.gradient_accumulation_steps * test.world_size # cpp_batch_size (=64) * grad_acc * world_size
|
||||
test.max_steps = 300000
|
||||
|
||||
test.force_num_hidden_layers = None
|
||||
|
||||
# already using Adam (e.g. AdamConfig)
|
||||
test.learning_rate = 5e-4
|
||||
test.warmup_proportion = 0.1
|
||||
|
||||
final_loss = test.test_pretrain_convergence()
|
||||
logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss)
|
||||
else:
|
||||
unittest.main()
|
||||
|
||||
# unittest.main()
|
||||
|
|
@ -1115,19 +1115,25 @@ def run_training_python_frontend_e2e_tests(cwd):
|
|||
import torch
|
||||
ngpus = torch.cuda.device_count()
|
||||
if ngpus > 1:
|
||||
log.debug('RUN: {} -m torch.distributed.launch --nproc_per_node {} \
|
||||
orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_throughput'.format(
|
||||
sys.executable, ngpus))
|
||||
bert_pretrain_script = 'orttraining_run_bert_pretrain.py'
|
||||
log.debug('RUN: mpirun -n {} ''-x' 'NCCL_DEBUG=INFO'' {} {} {}'.format(
|
||||
ngpus, sys.executable, bert_pretrain_script, 'ORTBertPretrainTest.test_pretrain_throughput'))
|
||||
run_subprocess([
|
||||
sys.executable, '-m', 'torch.distributed.launch', '--nproc_per_node', str(ngpus),
|
||||
'orttraining_run_bert_pretrain.py', 'ORTBertPretrainTest.test_pretrain_throughput'], cwd=cwd)
|
||||
'mpirun', '-n', str(ngpus), '-x', 'NCCL_DEBUG=INFO', sys.executable,
|
||||
bert_pretrain_script, 'ORTBertPretrainTest.test_pretrain_throughput'], cwd=cwd)
|
||||
|
||||
log.debug('RUN: {} -m torch.distributed.launch --nproc_per_node {} \
|
||||
orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_convergence'.format(
|
||||
sys.executable, ngpus))
|
||||
log.debug('RUN: mpirun -n {} ''-x' 'NCCL_DEBUG=INFO'' {} {} {}'.format(
|
||||
ngpus, sys.executable, bert_pretrain_script, 'ORTBertPretrainTest.test_pretrain_convergence'))
|
||||
run_subprocess([
|
||||
sys.executable, '-m', 'torch.distributed.launch', '--nproc_per_node', str(ngpus),
|
||||
'orttraining_run_bert_pretrain.py', 'ORTBertPretrainTest.test_pretrain_convergence'], cwd=cwd)
|
||||
'mpirun', '-n', str(ngpus), '-x', 'NCCL_DEBUG=INFO', sys.executable,
|
||||
bert_pretrain_script, 'ORTBertPretrainTest.test_pretrain_convergence'], cwd=cwd)
|
||||
|
||||
# a long run
|
||||
log.debug('RUN: mpirun -n {} ''-x' 'NCCL_DEBUG=INFO'' {} {}'.format(
|
||||
ngpus, sys.executable, bert_pretrain_script))
|
||||
run_subprocess([
|
||||
'mpirun', '-n', str(ngpus), '-x', 'NCCL_DEBUG=INFO', sys.executable,
|
||||
bert_pretrain_script], cwd=cwd)
|
||||
|
||||
log.debug('RUN: mpirun -n {} {} orttraining_run_glue.py'.format(ngpus, sys.executable))
|
||||
run_subprocess([
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ trigger: none
|
|||
jobs:
|
||||
- job: Onnxruntime_Linux_GPU_Training_FrontEnd
|
||||
|
||||
timeoutInMinutes: 240
|
||||
timeoutInMinutes: 300
|
||||
|
||||
steps:
|
||||
- checkout: self
|
||||
|
|
|
|||
|
|
@ -6,4 +6,5 @@ transformers==v2.10.0
|
|||
torch==1.6.0.dev20200610
|
||||
torchvision==0.7.0.dev20200610
|
||||
torchtext==0.6.0.dev20200610
|
||||
tensorboard==v2.0.0
|
||||
h5py
|
||||
|
|
|
|||
Loading…
Reference in a new issue