From af3988198cf173ad47b540ad26a045db7e38cdd1 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Thu, 30 Apr 2020 12:26:38 -0700 Subject: [PATCH] Liqun/e2e transformer test (#3540) * initial change to transformer.py * prepare e2e transformer tests * refactor transformer tests * put test python files in a flat folder * fix typo pip install transform(s) * python 3.6 * python version to 3.6 in install_ubuntu.sh * remove argparser * to use opset ver 12 * workaround loss_scale naming patch in case of loss_fn_ * assign self.loss_fn_ so it can be checked * skip a few un-needed post-process steps * fix loss_scale_input_name, clean up post process steps * skip non-frontend tests * move cpu/cuda related files to coresponding cpu/cuda folder (#3668) Co-authored-by: Weixing Zhang * type cast for ratio is not necessary for dropout (#3682) Co-authored-by: Weixing Zhang * thrustallocator is not needed since cub is used directly for gather now. (#3683) Co-authored-by: Weixing Zhang * GatherND-12 Implementation (#3645) * Renamed, UT passing * Move GatherND CUDA Kerenl into onnxruntime * Merge GatherNDOpTest * Refactor Test code * Merge CPU Kernel Impl * Handle Negative Indice, Fix UT * Improve CUDA kernel to handle negative index * Minor Fixes * Preserve GatherND-1 Cuda kernel * Fix Mac build * fix UT * Fix Build * fix GatherNDOpTest.double > CUDA error cudaErrorInvalidDeviceFunction:invalid device function Co-authored-by: Sherlock Huang Co-authored-by: Peng Wang (pengwa) * update with reviewers' comments * testBertTrainingGradientAccumulation was not using rtol and may fail occasionally with small (e-06) difference * fix merge mistakes Co-authored-by: liqun Co-authored-by: Weixing Zhang Co-authored-by: Weixing Zhang Co-authored-by: Sherlock Co-authored-by: Sherlock Huang Co-authored-by: Peng Wang (pengwa) --- cmake/onnxruntime_python.cmake | 1 + .../python/onnxruntime_test_transformers.py | 593 ------------------ orttraining/orttraining/python/ort_trainer.py | 16 +- .../orttraining_test_bert_postprocess.py | 12 + .../python/orttraining_test_data_loader.py | 85 +++ .../orttraining_test_layer_norm_transform.py | 177 ++++++ .../orttraining_test_model_transform.py | 106 ++++ .../python/orttraining_test_transformers.py | 198 ++++++ .../test/python/orttraining_test_utils.py | 118 ++++ .../mnist_training.py | 8 - tools/ci_build/build.py | 21 +- .../linux/docker/scripts/install_deps.sh | 3 + 12 files changed, 728 insertions(+), 610 deletions(-) delete mode 100644 onnxruntime/test/python/onnxruntime_test_transformers.py create mode 100644 orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py create mode 100644 orttraining/orttraining/test/python/orttraining_test_data_loader.py create mode 100644 orttraining/orttraining/test/python/orttraining_test_layer_norm_transform.py create mode 100644 orttraining/orttraining/test/python/orttraining_test_model_transform.py create mode 100644 orttraining/orttraining/test/python/orttraining_test_transformers.py create mode 100644 orttraining/orttraining/test/python/orttraining_test_utils.py diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index a9a0f6b68d..9701a2e1d4 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -173,6 +173,7 @@ endif() file(GLOB onnxruntime_python_test_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/test/python/*.py" + "${ORTTRAINING_SOURCE_DIR}/test/python/*.py" ) file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/*.py" diff --git a/onnxruntime/test/python/onnxruntime_test_transformers.py b/onnxruntime/test/python/onnxruntime_test_transformers.py deleted file mode 100644 index ba0c0a29ca..0000000000 --- a/onnxruntime/test/python/onnxruntime_test_transformers.py +++ /dev/null @@ -1,593 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import unittest -import shutil -import pytest -import sys -import os -from enum import Enum - -# set transformer_repo to the huggingface repo location -transformer_repo = '' -nvidia_deep_learning_examples_repo = '' - -sys.path.append(transformer_repo) - -from transformers import is_torch_available - -sys.path.append(os.path.join(transformer_repo, "transformers/tests")) -from modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor) -from configuration_common_test import ConfigTester - -if is_torch_available(): - from transformers import (BertConfig, BertModel, BertForMaskedLM, - BertForNextSentencePrediction, BertForPreTraining, - BertForQuestionAnswering, BertForSequenceClassification, - BertForTokenClassification, BertForMultipleChoice) - from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP -else: - pytestmark = pytest.mark.skip("Require Torch") - -import onnxruntime as ort -from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler, generate_sample -import torch - -sys.path.append(os.path.join(nvidia_deep_learning_examples_repo, 'PyTorch/LanguageModeling/BERT')) -from run_pretraining import postprocess_model - -def map_optimizer_attributes(name): - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys) - if no_decay: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} - else: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6} - -from torch.utils.data import Dataset, DataLoader -class OrtTestDataset(Dataset): - def __init__(self, input_desc, seq_len, device): - import copy - self.input_desc_ = copy.deepcopy(input_desc) - for input_desc in self.input_desc_: - shape_ = [] - for i, axis in enumerate(input_desc.shape_): - if axis == 'max_seq_len_in_batch': - shape_ = shape_ + [seq_len, ] - elif axis != 'batch': - shape_ = input_desc.shape_[i] - input_desc.shape_ = shape_ - self.device_ = device - - def __len__(self): - return 100 - - def __getitem__(self, item): - input_batch = [] - for input_desc in self.input_desc_: - input_sample = generate_sample(input_desc, self.device_) - input_batch.append(input_sample) - return input_batch - -def create_ort_test_dataloader(input_desc, batch_size, seq_len, device): - dataset = OrtTestDataset(input_desc, seq_len, device) - return DataLoader(dataset, batch_size=batch_size) - -class BatchArgsOption(Enum): - List = 1 - Dict = 2 - ListAndDict = 3 - -def split_batch(batch, input_desc, args_count): - total_argument_count = len(input_desc) - # batch=[input_ids[batch, seglen], attention_mask[batch, seglen], token_type_ids[batch,seglen], token_type_ids[batch, seglen]] - args = [] # (input_ids[batch, seglen], attention_mask[batch, seglen]) - kwargs = {} # {'token_type_ids': token_type_ids[batch,seglen], 'position_ids': token_type_ids[batch, seglen]} - for i in range(args_count): - args = args + [batch[i]] - - for i in range(args_count, total_argument_count): - kwargs[input_desc[i].name_] = batch[i] - - return args, kwargs - -def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16, - allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler, - batch_args_option): - dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, device) - - model = ORTTrainer(model, None, model_desc, "LambOptimizer", - map_optimizer_attributes=map_optimizer_attributes, - learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32), - device=device, postprocess_model=postprocess_model, - gradient_accumulation_steps=gradient_accumulation_steps, - # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 - world_rank=args.local_rank, world_size=args.world_size, - use_mixed_precision=fp16, - allreduce_post_accumulation=allreduce_post_accumulation, - get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None, - loss_scaler=loss_scaler if use_internal_loss_scaler else None) - - # trainig loop - eval_batch = None - model.train() - for step, batch in enumerate(dataloader): - if eval_batch is None: - eval_batch = batch - - if not use_internal_get_lr_this_step: - lr = get_lr_this_step(step) - learning_rate = torch.tensor([lr]) - - if not use_internal_loss_scaler and fp16: - loss_scale = torch.tensor(loss_scaler.loss_scale_) - - if batch_args_option == BatchArgsOption.List: - if not use_internal_get_lr_this_step: - batch = batch + [learning_rate, ] - if not use_internal_loss_scaler and fp16: - batch = batch + [loss_scale, ] - outputs = model(*batch) - elif batch_args_option == BatchArgsOption.Dict: - args, kwargs = split_batch(batch, model_desc.inputs_, 0) - if not use_internal_get_lr_this_step: - kwargs['Learning_Rate'] = learning_rate - if not use_internal_loss_scaler and fp16: - kwargs[model.loss_scale_input_name] = loss_scale - outputs = model(*args, **kwargs) - else: - args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs - args, kwargs = split_batch(batch, model_desc.inputs_, args_count) - if not use_internal_get_lr_this_step: - kwargs['Learning_Rate'] = learning_rate - if not use_internal_loss_scaler and fp16: - kwargs[model.loss_scale_input_name] = loss_scale - outputs = model(*args, **kwargs) - - # eval - model.eval() - if batch_args_option == BatchArgsOption.List: - outputs = model(*batch) - elif batch_args_option == BatchArgsOption.Dict: - args, kwargs = split_batch(batch, model_desc.inputs_, 0) - outputs = model(*args, **kwargs) - else: - args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs - args, kwargs = split_batch(batch, model_desc.inputs_, args_count) - outputs = model(*args, **kwargs) - - return (output.cpu().numpy() for output in outputs) - - -class BertModelTest(CommonTestCases.CommonModelTester): - - all_model_classes = (BertModel, BertForMaskedLM, BertForNextSentencePrediction, - BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, - BertForTokenClassification) if is_torch_available() else () - - class BertModelTester(object): - - def __init__(self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=True, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - scope=None, - device='cpu', - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.scope = scope - self.device = device - - # 1. superset of bert input/output descs - # see BertPreTrainedModel doc - self.input_ids_desc = IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size) - self.attention_mask_desc = IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2) - self.token_type_ids_desc = IODescription('token_type_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2) - self.position_ids_desc = IODescription('position_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.max_position_embeddings) - self.head_mask_desc = IODescription('head_mask', [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2) - self.inputs_embeds_desc = IODescription('inputs_embeds', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) - - self.encoder_hidden_states_desc = IODescription('encoder_hidden_states', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) - self.encoder_attention_mask_desc = IODescription('encoder_attention_mask', ['batch', 'max_seq_len_in_batch'], torch.float32) - - # see BertForPreTraining doc - self.masked_lm_labels_desc = IODescription('masked_lm_labels', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size) - self.next_sentence_label_desc = IODescription('next_sentence_label', ['batch',], torch.int64, num_classes=2) - - # outputs - self.loss_desc = IODescription('loss', [1,], torch.float32) - self.prediction_scores_desc = IODescription('prediction_scores', ['batch', 'max_seq_len_in_batch', self.vocab_size], torch.float32) - - self.seq_relationship_scores_desc = IODescription('seq_relationship_scores', ['batch', 2], torch.float32) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32) - self.hidden_states_desc = IODescription('hidden_states', [self.num_hidden_layers, 'batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) - self.attentions_desc = IODescription('attentions', [self.num_hidden_layers, 'batch', self.num_attention_heads, 'max_seq_len_in_batch', 'max_seq_len_in_batch'], torch.float32) - self.last_hidden_state_desc = IODescription('last_hidden_state', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) - self.pooler_output_desc = IODescription('pooler_output', ['batch', self.hidden_size], torch.float32) - - # BertForPreTraining forward: - # def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, - # position_ids??=None, head_mask??=None, inputs_embeds??=None, - # masked_lm_labels=None, next_sentence_label=None): - # - # create_and_check_bert_for_pretraining calls BertForPreTraining: - # model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, - # masked_lm_labels=token_labels, next_sentence_label=sequence_labels) - - def BertForPreTraining_descs(self): - return ModelDescription( - [self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, self.masked_lm_labels_desc, self.next_sentence_label_desc], - # returns loss_desc if both masked_lm_labels_desc, next_sentence_label are provided - # hidden_states_desc, attentions_desc shall be included according to config.output_attentions, config.output_hidden_states - [self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc, - #hidden_states_desc, attentions_desc - ]) - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).to(self.device) - - input_mask = None - if self.use_input_mask: - input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2).to(self.device) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size).to(self.device) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size).to(self.device) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels).to(self.device) - choice_labels = ids_tensor([self.batch_size], self.num_choices).to(self.device) - - config = BertConfig( - vocab_size_or_config_json_file=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range) - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def prepare_config_and_inputs_for_decoder(self): - config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels = self.prepare_config_and_inputs() - - config.is_decoder = True - encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) - encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask - - def check_loss_output(self, result): - self.parent.assertListEqual( - list(result["loss"].size()), - []) - - def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): - model = BertModel(config=config) - model.to(input_ids.device) - model.eval() - - sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) - - # failed because there is not loss output - model_desc = ModelDescription([self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc], - [self.last_hidden_state_desc, self.pooler_output_desc]) - args_gradient_accumulation_steps = 8 - args_local_rank = 0 - args_world_size = 1 - args_fp16 = True - args_allreduce_post_accumulation = True - - model = ORTTrainer(model, None, model_desc, "LambOptimizer", - map_optimizer_attributes=map_optimizer_attributes, - learning_rate_description=IODescription('Learning_Rate', [1, ], torch.float32), - device=self.device, postprocess_model=postprocess_model, - gradient_accumulation_steps=args_gradient_accumulation_steps, - world_rank=args_local_rank, world_size=args_world_size, - use_mixed_precision=True if args_fp16 else False, - allreduce_post_accumulation=True if args_allreduce_post_accumulation else False) - - sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids) - sequence_output, pooled_output = model(input_ids) - - result = { - "sequence_output": sequence_output, - "pooled_output": pooled_output, - } - self.parent.assertListEqual( - list(result["sequence_output"].size()), - [self.batch_size, self.seq_length, self.hidden_size]) - self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) - - def create_and_check_bert_model_as_decoder(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask): - model = BertModel(config) - model.eval() - sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask) - sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states) - sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) - - result = { - "sequence_output": sequence_output, - "pooled_output": pooled_output, - } - self.parent.assertListEqual( - list(result["sequence_output"].size()), - [self.batch_size, self.seq_length, self.hidden_size]) - self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) - - def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): - model = BertForMaskedLM(config=config) - model.eval() - loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels) - - ##### - model_desc = ModelDescription([self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, self.masked_lm_labels_desc], - [self.loss_desc, self.prediction_scores_desc]) - args_gradient_accumulation_steps = 8 - args_local_rank = 0 - args_world_size = 1 - args_fp16 = True - args_allreduce_post_accumulation = True - - model = ORTTrainer(model, None, model_desc, "LambOptimizer", - map_optimizer_attributes=map_optimizer_attributes, - learning_rate_description=IODescription('Learning_Rate', [1, ], torch.float32), - device=self.device, postprocess_model=postprocess_model, - gradient_accumulation_steps=args_gradient_accumulation_steps, - world_rank=args_local_rank, world_size=args_world_size, - use_mixed_precision=True if args_fp16 else False, - allreduce_post_accumulation=True if args_allreduce_post_accumulation else False) - model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels) - - def create_and_check_bert_model_for_masked_lm_as_decoder(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask): - model = BertForMaskedLM(config=config) - model.eval() - loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask) - loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels, encoder_hidden_states=encoder_hidden_states) - result = { - "loss": loss, - "prediction_scores": prediction_scores, - } - self.parent.assertListEqual( - list(result["prediction_scores"].size()), - [self.batch_size, self.seq_length, self.vocab_size]) - self.check_loss_output(result) - - def create_and_check_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): - model = BertForNextSentencePrediction(config=config) - model.eval() - loss, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels) - result = { - "loss": loss, - "seq_relationship_score": seq_relationship_score, - } - self.parent.assertListEqual( - list(result["seq_relationship_score"].size()), - [self.batch_size, 2]) - self.check_loss_output(result) - - def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): - model = BertForPreTraining(config=config) - model.eval() - loss, prediction_scores, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, - masked_lm_labels=token_labels, next_sentence_label=sequence_labels) - model_desc = ModelDescription([self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, - self.masked_lm_labels_desc, self.next_sentence_label_desc], - [self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc]) - - import argparse - args_ = argparse.Namespace(fp16=True, amp_opt_level='O1') - - from collections import namedtuple - MyArgs = namedtuple("MyArgs", - "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len") - args = MyArgs(local_rank=0, world_size=1, max_steps=100, learning_rate=0.00001, warmup_proportion=0.01, batch_size=13, seq_len=7) - - from train_with_ort_trainer import get_lr - def get_lr_this_step(global_step): - return get_lr(args, global_step) - loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000) - - option_gradient_accumulation_steps = [8] - option_fp16 = [True, False] - option_allreduce_post_accumulation = True - option_use_internal_get_lr_this_step = False - option_use_internal_loss_scaler = False - # TODO: with with fetches - - for gradient_accumulation_steps in option_gradient_accumulation_steps: - for fp16 in option_fp16: - for option_split_batch in BatchArgsOption: - loss_ort, prediction_scores_ort, seq_relationship_score_ort =\ - run_test(model, model_desc, self.device, args, gradient_accumulation_steps, fp16, - option_allreduce_post_accumulation, - get_lr_this_step, option_use_internal_get_lr_this_step, - loss_scaler, option_use_internal_loss_scaler, - option_split_batch) - - print(loss_ort) - print(prediction_scores_ort) - print(seq_relationship_score_ort) - - def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): - model = BertForQuestionAnswering(config=config) - model.eval() - loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, - start_positions=sequence_labels, end_positions=sequence_labels) - result = { - "loss": loss, - "start_logits": start_logits, - "end_logits": end_logits, - } - self.parent.assertListEqual( - list(result["start_logits"].size()), - [self.batch_size, self.seq_length]) - self.parent.assertListEqual( - list(result["end_logits"].size()), - [self.batch_size, self.seq_length]) - self.check_loss_output(result) - - def create_and_check_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): - config.num_labels = self.num_labels - model = BertForSequenceClassification(config) - model.eval() - loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) - result = { - "loss": loss, - "logits": logits, - } - self.parent.assertListEqual( - list(result["logits"].size()), - [self.batch_size, self.num_labels]) - self.check_loss_output(result) - - def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): - config.num_labels = self.num_labels - model = BertForTokenClassification(config=config) - model.eval() - loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) - result = { - "loss": loss, - "logits": logits, - } - self.parent.assertListEqual( - list(result["logits"].size()), - [self.batch_size, self.seq_length, self.num_labels]) - self.check_loss_output(result) - - def create_and_check_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): - config.num_choices = self.num_choices - model = BertForMultipleChoice(config=config) - model.eval() - multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() - multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() - multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() - loss, logits = model(multiple_choice_inputs_ids, - attention_mask=multiple_choice_input_mask, - token_type_ids=multiple_choice_token_type_ids, - labels=choice_labels) - result = { - "loss": loss, - "logits": logits, - } - self.parent.assertListEqual( - list(result["logits"].size()), - [self.batch_size, self.num_choices]) - self.check_loss_output(result) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - (config, input_ids, token_type_ids, input_mask, - sequence_labels, token_labels, choice_labels) = config_and_inputs - inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask} - return config, inputs_dict - - def setUp(self): - self.model_tester = BertModelTest.BertModelTester(self) - self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37) - - # def test_config(self): - # self.config_tester.run_common_tests() - - # def test_bert_model(self, use_cuda=False): - # # ^^ This could be a real fixture - # if use_cuda: - # self.model_tester.device = "cuda" - # config_and_inputs = self.model_tester.prepare_config_and_inputs() - # self.model_tester.create_and_check_bert_model(*config_and_inputs) - - # def test_bert_model_as_decoder(self): - # config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() - # self.model_tester.create_and_check_bert_model_as_decoder(*config_and_inputs) - - def test_for_masked_lm(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs) - - # def test_for_masked_lm_decoder(self): - # config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() - # self.model_tester.create_and_check_bert_model_for_masked_lm_as_decoder(*config_and_inputs) - - # def test_for_multiple_choice(self): - # config_and_inputs = self.model_tester.prepare_config_and_inputs() - # self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs) - - # def test_for_next_sequence_prediction(self): - # config_and_inputs = self.model_tester.prepare_config_and_inputs() - # self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs) - - def test_for_pretraining(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs) - - # def test_for_question_answering(self): - # config_and_inputs = self.model_tester.prepare_config_and_inputs() - # self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs) - - # def test_for_sequence_classification(self): - # config_and_inputs = self.model_tester.prepare_config_and_inputs() - # self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs) - - # def test_for_token_classification(self): - # config_and_inputs = self.model_tester.prepare_config_and_inputs() - # self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs) - - # @pytest.mark.slow - # def test_model_from_pretrained(self): - # cache_dir = "/tmp/transformers_test/" - # for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - # model = BertModel.from_pretrained(model_name, cache_dir=cache_dir) - # shutil.rmtree(cache_dir) - # self.assertIsNotNone(model) - - -if __name__ == "__main__": - unittest.main() diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index 7c16a2e716..71d583c08d 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -309,6 +309,7 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op # Other export options to use(this is for backward compatibility). other_export_options = {} + # This option was added after 1.4 release. if LooseVersion(torch.__version__) > LooseVersion('1.4.0'): other_export_options['enable_onnx_checker'] = False @@ -630,6 +631,7 @@ class ORTTrainer(): if loss_fn is not None: print("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.") # TODO: accept loss_fn as an onnx model. build self.onnx_model_ with model and loss_fn + self.loss_fn_ = None self.model_desc_ = model_desc self.input_desc_with_lr = [*self.model_desc_.inputs_, learning_rate_description] @@ -664,9 +666,12 @@ class ORTTrainer(): self.enable_grad_norm_clip_ = enable_grad_norm_clip self.frozen_weights_ = frozen_weights self.opset_version_ = _opset_version - self.loss_scale_input_name = '' self.state_dict_ = None + # use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs. + # see prepare_input_and_fetches for more details. + self.loss_scale_input_name = 'default_loss_scale_input_name' + self._init_session() def _init_session(self): @@ -790,6 +795,15 @@ class ORTTrainer(): input = input + (internal_learning_rate,) if internal_loss_scale is not None: input = input + (internal_loss_scale,) + elif self.use_mixed_precision: + # loss_scale input name is needed to call train_step, for example: + # kwargs[model.loss_scale_input_name] = loss_scale + # outputs = model.train_step(*args, **kwargs) + # However, when first time train_step is called model.loss_scale_input_name is not set. + # To workaround this problem, we use the special name 'default_loss_scale_input_name' to indicate + # the loss_scale. + if 'default_loss_scale_input_name' in kwargs.keys(): + input = input + (kwargs['default_loss_scale_input_name'],) fetches = None if 'fetches' in kwargs: diff --git a/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py b/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py new file mode 100644 index 0000000000..b91122a45b --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py @@ -0,0 +1,12 @@ +from orttraining_test_model_transform import add_name, fix_transpose, add_expand_shape +from orttraining_test_layer_norm_transform import layer_norm_transform + +def postprocess_model(model): + add_name(model) + + # remove transpose node if its input is a 2d weight which only feeds to the node + fix_transpose(model) + + add_expand_shape(model) + + layer_norm_transform(model) diff --git a/orttraining/orttraining/test/python/orttraining_test_data_loader.py b/orttraining/orttraining/test/python/orttraining_test_data_loader.py new file mode 100644 index 0000000000..09fd3e2fb5 --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_data_loader.py @@ -0,0 +1,85 @@ +from enum import Enum +import random +import torch +from torch.utils.data import Dataset, DataLoader +from onnxruntime.capi.ort_trainer import generate_sample + +global_rng = random.Random() + +def ids_tensor(shape, vocab_size, rng=None, name=None): + """Creates a random int32 tensor of the shape within the vocab size.""" + if rng is None: + rng = global_rng + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.randint(0, vocab_size - 1)) + + return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() + + +def floats_tensor(shape, scale=1.0, rng=None, name=None): + """Creates a random float32 tensor of the shape within the vocab size.""" + if rng is None: + rng = global_rng + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.random() * scale) + + return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() + + +class OrtTestDataset(Dataset): + def __init__(self, input_desc, seq_len, device): + import copy + self.input_desc_ = copy.deepcopy(input_desc) + for input_desc in self.input_desc_: + shape_ = [] + for i, axis in enumerate(input_desc.shape_): + if axis == 'max_seq_len_in_batch': + shape_ = shape_ + [seq_len, ] + elif axis != 'batch': + shape_ = input_desc.shape_[i] + input_desc.shape_ = shape_ + self.device_ = device + + def __len__(self): + return 100 + + def __getitem__(self, item): + input_batch = [] + for input_desc in self.input_desc_: + input_sample = generate_sample(input_desc, self.device_) + input_batch.append(input_sample) + return input_batch + +def create_ort_test_dataloader(input_desc, batch_size, seq_len, device): + dataset = OrtTestDataset(input_desc, seq_len, device) + return DataLoader(dataset, batch_size=batch_size) + +class BatchArgsOption(Enum): + List = 1 + Dict = 2 + ListAndDict = 3 + +def split_batch(batch, input_desc, args_count): + total_argument_count = len(input_desc) + # batch=[input_ids[batch, seglen], attention_mask[batch, seglen], token_type_ids[batch,seglen], token_type_ids[batch, seglen]] + args = [] # (input_ids[batch, seglen], attention_mask[batch, seglen]) + kwargs = {} # {'token_type_ids': token_type_ids[batch,seglen], 'position_ids': token_type_ids[batch, seglen]} + for i in range(args_count): + args = args + [batch[i]] + + for i in range(args_count, total_argument_count): + kwargs[input_desc[i].name_] = batch[i] + + return args, kwargs diff --git a/orttraining/orttraining/test/python/orttraining_test_layer_norm_transform.py b/orttraining/orttraining/test/python/orttraining_test_layer_norm_transform.py new file mode 100644 index 0000000000..883d7386e7 --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_layer_norm_transform.py @@ -0,0 +1,177 @@ +import onnx + +def find_node(graph_proto, op_type): + nodes = [] + map_input_node = {} + for node in graph_proto.node: + if node.op_type == op_type: + map_input_node[node.input[0]] = node + if op_type == 'Div' or op_type == 'Mul': + map_input_node[node.input[1]] = node + nodes.append(node) + return nodes, map_input_node + +def gen_attribute(key, value): + attr = AttributeProto() + attr.name = key + attr.ints.extend(int(v) for v in value) + attr.type = AttributeProto.INTS + return attr + +def layer_norm_transform(model_proto): + # a layer norm subgraph + # input + # | + # ReduceMean + # __|____ + # | | + # Sub Sub + # | | + # | Pow + # | | + # | ReduceMean + # | | + # | Add + # | | + # |__ __Sqrt + # | | + # Div + # | + # Mul + # | + # Add + # | + # output + + graph_proto = model_proto.graph + + _, map_input_Div = find_node(graph_proto, 'Div') + + _, map_input_Sqrt = find_node(graph_proto, 'Sqrt') + + _, map_input_Add = find_node(graph_proto, 'Add') + + nodes_ReduceMean, map_input_ReduceMean = find_node(graph_proto, 'ReduceMean') + + _, map_input_Pow = find_node(graph_proto, 'Pow') + + _, map_input_Mul = find_node(graph_proto, 'Mul') + + # find right side Sub (see the layer norm subgrapg) + nodes_Sub = [] + map_input_Sub = {} + for node in graph_proto.node: + if node.op_type == 'Sub': + if node.output[0] in map_input_Pow: + nodes_Sub.append(node) + map_input_Sub[node.input[1]] = node + + # find first ReduceMean + first_ReduceMean = [] + first_ReduceMean_outputs = [] + for node in nodes_ReduceMean: + if node.output[0] in map_input_Sub: + first_ReduceMean.append(node) + first_ReduceMean_outputs.append(node.output[0]) + + # find constant node + nodes_Constant = [] + map_output_Constant = {} + for node in graph_proto.node: + if node.op_type == 'Constant': + nodes_Constant.append(node) + map_output_Constant[node.output[0]] = node + + id = 0 + removed_nodes = [] + layer_norm_nodes = [] + # Replace with layer norm + for node in first_ReduceMean: + layer_norm_input = [] + layer_norm_output = [] + layer_norm_input.append(node.input[0]) + + # collect nodes within a layer norm subgraph. + # skip building layer norm node if there is a pattern miss-match. + if node.output[0] not in map_input_Sub: + continue + + node_sub = map_input_Sub[node.output[0]] + if node_sub.output[0] not in map_input_Pow: + continue + + node_pow = map_input_Pow[node_sub.output[0]] + if node_pow.output[0] not in map_input_ReduceMean: + continue + + node_reduce = map_input_ReduceMean[node_pow.output[0]] + if node_reduce.output[0] not in map_input_Add: + continue + + node_Add = map_input_Add[node_reduce.output[0]] + if node_Add.output[0] not in map_input_Sqrt: + continue + + node_Sqrt = map_input_Sqrt[node_Add.output[0]] + if node_Sqrt.output[0] not in map_input_Div: + continue + + node_Div = map_input_Div[node_Sqrt.output[0]] + if node_Div.output[0] not in map_input_Mul: + continue + + node_Mul = map_input_Mul[node_Div.output[0]] + + if node_Mul.input[0] != node_Div.output[0]: + layer_norm_input.append(node_Mul.input[0]) + else: + layer_norm_input.append(node_Mul.input[1]) + + if node_Mul.output[0] not in map_input_Add: + continue + + node_Add1 = map_input_Add[node_Mul.output[0]] + layer_norm_input.append(node_Add1.input[1]) + + removed_nodes.append(node) + removed_nodes.append(node_sub) + removed_nodes.append(node_pow) + removed_nodes.append(node_reduce) + removed_nodes.append(node_Add) + removed_nodes.append(node_Sqrt) + removed_nodes.append(node_Div) + removed_nodes.append(node_Mul) + removed_nodes.append(node_Add1) + removed_nodes.append(map_output_Constant[node_pow.input[1]]) + + removed_nodes.append(map_output_Constant[node_Add.input[1]]) + layer_norm_output.append(node_Add1.output[0]) + id = id + 1 + layer_norm_output.append('saved_mean_' + str(id)) + id = id + 1 + layer_norm_output.append('saved_inv_std_var_' + str(id)) + layer_norm = onnx.helper.make_node("LayerNormalization", + layer_norm_input, + layer_norm_output, + "LayerNormalization_" + str(id), + None, + axis = node_reduce.attribute[0].ints[0], + epsilon = 9.999999960041972e-13) + layer_norm_nodes.append(layer_norm) + + # remove left side Subs + for node in graph_proto.node: + if node.op_type == 'Sub': + if node.input[1] in first_ReduceMean_outputs: + removed_nodes.append(node) + + all_nodes = [] + for node in graph_proto.node: + if node not in removed_nodes: + all_nodes.append(node) + + for node in layer_norm_nodes: + all_nodes.append(node) + + graph_proto.ClearField("node") + graph_proto.node.extend(all_nodes) diff --git a/orttraining/orttraining/test/python/orttraining_test_model_transform.py b/orttraining/orttraining/test/python/orttraining_test_model_transform.py new file mode 100644 index 0000000000..9ef92aabcf --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_model_transform.py @@ -0,0 +1,106 @@ +from onnx import numpy_helper + +def add_name(model): + i = 0 + for node in model.graph.node: + node.name = '%s_%d' %(node.op_type, i) + i += 1 + +def find_single_output_node(model, arg): + result = [] + for node in model.graph.node: + for input in node.input: + if input == arg: + result.append(node) + return result[0] if len(result) == 1 else None + +def find_input_as_initializer(model, arg): + for initializer in model.graph.initializer: + if initializer.name == arg: + return initializer + return None + +def get_node_index(model, node): + for i, n in enumerate(model.graph.node): + if n == node: + return i + return None + +def replace_input_arg(model, arg, new_arg): + for node in model.graph.node: + for i in range(len(node.input)): + if node.input[i] == arg: + node.input[i] = new_arg + +def find_weight_index(model, name): + for index, w in enumerate(model.graph.initializer): + if w.name == name: + return index + index += 1 + return None + +def fix_transpose(model): + """ + remove transpose node if its input is a 2d weight which only feeds to the node. + """ + + # Find transpose nodes with initializer weight as input. + # The input weight needs to be only feeded into the transpose node. + # Collect these nodes and weights. + transpose = [] + for node in model.graph.node: + if node.op_type == 'Transpose': + weight = find_input_as_initializer(model, node.input[0]) + if weight is not None: + result = [] + for n in model.graph.node: + for input in n.input: + if input == weight.name: + result.append(n) + if len(result) > 1: + continue + perm = node.attribute[0] + assert perm.name == 'perm' + perm = perm.ints + assert len(perm) == 2 and perm[0] == 1 and perm[1] == 0 + transpose.append((get_node_index(model, node), weight)) + + # Transpose collected weights and add it to the model initializers. + # The transposed weight initializers become inputs to the transpose nodes' recipient nodes. + for t in transpose: + node = model.graph.node[t[0]] + weight = numpy_helper.to_array(t[1]) + assert len(weight.shape) == 2 + weight = weight.transpose(perm) + new_weight = numpy_helper.from_array(weight, "%s_transposed" % t[1].name) + model.graph.initializer.extend([new_weight]) + replace_input_arg(model, node.output[0], new_weight.name) + + # collected transpose nodes can be removed. + transpose.sort(reverse=True) + for t in transpose: + del model.graph.node[t[0]] + + # the original weight initializer can be removed. + # (remember that a wight needs only to be feeded into the transpose node when collecting wights) + old_ws = [] + for t in transpose: + if find_single_output_node(model, t[1].name) is None: + old_ws.append(find_weight_index(model, t[1].name)) + old_ws.sort(reverse=True) + for w_i in old_ws: + del model.graph.initializer[w_i] + +def add_expand_shape(model): + """ + this method is very specific to the Bert model where there is a solo Expand op. + training backend requires the op's output shape. it is the same as the shape of the model (single) input. + """ + + expand_node = [n for n in model.graph.node if n.op_type == 'Expand'] + if len(expand_node) != 1: + raise "cannot find the single expand node in the BERT model." + return + expand_out = model.graph.value_info.add() + expand_out.name = expand_node[0].output[0] # base: '421' # tiny: '85' + expand_out.type.CopyFrom(model.graph.input[0].type) \ No newline at end of file diff --git a/orttraining/orttraining/test/python/orttraining_test_transformers.py b/orttraining/orttraining/test/python/orttraining_test_transformers.py new file mode 100644 index 0000000000..d61b43c877 --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_transformers.py @@ -0,0 +1,198 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +import shutil +import pytest +import os + +from transformers import (BertConfig, BertForPreTraining, BertModel) + +from orttraining_test_data_loader import ids_tensor, BatchArgsOption +from orttraining_test_utils import run_test, get_lr + +from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler + +import torch + +class BertModelTest(unittest.TestCase): + + class BertModelTester(object): + + def __init__(self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + device='cpu', + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.device = device + + # 1. superset of bert input/output descs + # see BertPreTrainedModel doc + self.input_ids_desc = IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size) + self.attention_mask_desc = IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2) + self.token_type_ids_desc = IODescription('token_type_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2) + self.position_ids_desc = IODescription('position_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.max_position_embeddings) + self.head_mask_desc = IODescription('head_mask', [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2) + self.inputs_embeds_desc = IODescription('inputs_embeds', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) + + self.encoder_hidden_states_desc = IODescription('encoder_hidden_states', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) + self.encoder_attention_mask_desc = IODescription('encoder_attention_mask', ['batch', 'max_seq_len_in_batch'], torch.float32) + + # see BertForPreTraining doc + self.masked_lm_labels_desc = IODescription('masked_lm_labels', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size) + self.next_sentence_label_desc = IODescription('next_sentence_label', ['batch',], torch.int64, num_classes=2) + + # outputs + self.loss_desc = IODescription('loss', [1,], torch.float32) + self.prediction_scores_desc = IODescription('prediction_scores', ['batch', 'max_seq_len_in_batch', self.vocab_size], torch.float32) + + self.seq_relationship_scores_desc = IODescription('seq_relationship_scores', ['batch', 2], torch.float32) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32) + self.hidden_states_desc = IODescription('hidden_states', [self.num_hidden_layers, 'batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) + self.attentions_desc = IODescription('attentions', [self.num_hidden_layers, 'batch', self.num_attention_heads, 'max_seq_len_in_batch', 'max_seq_len_in_batch'], torch.float32) + self.last_hidden_state_desc = IODescription('last_hidden_state', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) + self.pooler_output_desc = IODescription('pooler_output', ['batch', self.hidden_size], torch.float32) + + def BertForPreTraining_descs(self): + return ModelDescription( + [self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, self.masked_lm_labels_desc, self.next_sentence_label_desc], + # returns loss_desc if both masked_lm_labels_desc, next_sentence_label are provided + # hidden_states_desc, attentions_desc shall be included according to config.output_attentions, config.output_hidden_states + [self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc, + #hidden_states_desc, attentions_desc + ]) + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).to(self.device) + + input_mask = None + if self.use_input_mask: + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2).to(self.device) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size).to(self.device) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size).to(self.device) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels).to(self.device) + choice_labels = ids_tensor([self.batch_size], self.num_choices).to(self.device) + + config = BertConfig( + vocab_size=self.vocab_size, + vocab_size_or_config_json_file=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range) + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + model = BertForPreTraining(config=config) + model.eval() + loss, prediction_scores, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, + masked_lm_labels=token_labels, next_sentence_label=sequence_labels) + model_desc = ModelDescription([self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, + self.masked_lm_labels_desc, self.next_sentence_label_desc], + [self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc]) + + from collections import namedtuple + MyArgs = namedtuple("MyArgs", + "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len") + args = MyArgs(local_rank=0, world_size=1, max_steps=100, learning_rate=0.00001, warmup_proportion=0.01, batch_size=13, seq_len=7) + + def get_lr_this_step(global_step): + return get_lr(args, global_step) + loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000) + + # It would be better to test both with/without mixed precision and allreduce_post_accumulation. + # However, stress test of all the 4 cases is not stable at lease on the test machine. + # There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases. + option_fp16 = [True] + option_allreduce_post_accumulation = [True] + option_gradient_accumulation_steps = [1, 8] + option_use_internal_get_lr_this_step = [True, False] + option_use_internal_loss_scaler = [True, False] + option_split_batch = [BatchArgsOption.ListAndDict] + + for fp16 in option_fp16: + for allreduce_post_accumulation in option_allreduce_post_accumulation: + for gradient_accumulation_steps in option_gradient_accumulation_steps: + for use_internal_get_lr_this_step in option_use_internal_get_lr_this_step: + for use_internal_loss_scaler in option_use_internal_loss_scaler: + for split_batch in option_split_batch: + print("gradient_accumulation_steps:", gradient_accumulation_steps) + print("use_internal_loss_scaler:", use_internal_loss_scaler) + loss_ort, prediction_scores_ort, seq_relationship_score_ort =\ + run_test(model, model_desc, self.device, args, gradient_accumulation_steps, fp16, + allreduce_post_accumulation, + get_lr_this_step, use_internal_get_lr_this_step, + loss_scaler, use_internal_loss_scaler, + split_batch) + + print(loss_ort) + print(prediction_scores_ort) + print(seq_relationship_score_ort) + + def setUp(self): + self.model_tester = BertModelTest.BertModelTester(self) + + def test_for_pretraining(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs) + +if __name__ == "__main__": + unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_test_utils.py b/orttraining/orttraining/test/python/orttraining_test_utils.py new file mode 100644 index 0000000000..9e8152a83a --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_utils.py @@ -0,0 +1,118 @@ +import torch + +from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription + +from orttraining_test_data_loader import create_ort_test_dataloader, BatchArgsOption, split_batch +from orttraining_test_bert_postprocess import postprocess_model + +def warmup_cosine(x, warmup=0.002): + if x < warmup: + return x/warmup + return 0.5 * (1.0 + torch.cos(math.pi * x)) + +def warmup_constant(x, warmup=0.002): + if x < warmup: + return x/warmup + return 1.0 + +def warmup_linear(x, warmup=0.002): + if x < warmup: + return x/warmup + return max((x - 1. )/ (warmup - 1.), 0.) + +def warmup_poly(x, warmup=0.002, degree=0.5): + if x < warmup: + return x/warmup + return (1.0 - x)**degree + + +SCHEDULES = { + 'warmup_cosine':warmup_cosine, + 'warmup_constant':warmup_constant, + 'warmup_linear':warmup_linear, + 'warmup_poly':warmup_poly, +} + +def get_lr(args, training_steps, schedule='warmup_poly'): + if args.max_steps == -1: + return args.learning_rate + + schedule_fct = SCHEDULES[schedule] + return args.learning_rate * schedule_fct(training_steps / args.max_steps, args.warmup_proportion) + +def map_optimizer_attributes(name): + no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] + no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys) + if no_decay: + return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} + else: + return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6} + +def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16, + allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler, + batch_args_option): + dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, device) + + model = ORTTrainer(model, None, model_desc, "LambOptimizer", + map_optimizer_attributes=map_optimizer_attributes, + learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32), + device=device, postprocess_model=postprocess_model, + gradient_accumulation_steps=gradient_accumulation_steps, + # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 + world_rank=args.local_rank, world_size=args.world_size, + use_mixed_precision=fp16, + allreduce_post_accumulation=allreduce_post_accumulation, + get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None, + loss_scaler=loss_scaler if use_internal_loss_scaler else None, + _opset_version=12) + + # trainig loop + eval_batch = None + model.train() + for step, batch in enumerate(dataloader): + if eval_batch is None: + eval_batch = batch + + if not use_internal_get_lr_this_step: + lr = get_lr_this_step(step) + learning_rate = torch.tensor([lr]) + + if not use_internal_loss_scaler and fp16: + loss_scale = torch.tensor([loss_scaler.loss_scale_]) + + if batch_args_option == BatchArgsOption.List: + if not use_internal_get_lr_this_step: + batch = batch + [learning_rate, ] + if not use_internal_loss_scaler and fp16: + batch = batch + [loss_scale, ] + outputs = model(*batch) + elif batch_args_option == BatchArgsOption.Dict: + args, kwargs = split_batch(batch, model_desc.inputs_, 0) + if not use_internal_get_lr_this_step: + kwargs['Learning_Rate'] = learning_rate + if not use_internal_loss_scaler and fp16: + kwargs[model.loss_scale_input_name] = loss_scale + outputs = model(*args, **kwargs) + else: + args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs + args, kwargs = split_batch(batch, model_desc.inputs_, args_count) + if not use_internal_get_lr_this_step: + kwargs['Learning_Rate'] = learning_rate + if not use_internal_loss_scaler and fp16: + kwargs[model.loss_scale_input_name] = loss_scale + outputs = model(*args, **kwargs) + + # eval + model.eval() + if batch_args_option == BatchArgsOption.List: + outputs = model(*batch) + elif batch_args_option == BatchArgsOption.Dict: + args, kwargs = split_batch(batch, model_desc.inputs_, 0) + outputs = model(*args, **kwargs) + else: + args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs + args, kwargs = split_batch(batch, model_desc.inputs_, args_count) + outputs = model(*args, **kwargs) + + return (output.cpu().numpy() for output in outputs) + diff --git a/orttraining/pytorch_frontend_examples/mnist_training.py b/orttraining/pytorch_frontend_examples/mnist_training.py index ed73a132e4..d30c135c12 100644 --- a/orttraining/pytorch_frontend_examples/mnist_training.py +++ b/orttraining/pytorch_frontend_examples/mnist_training.py @@ -19,14 +19,6 @@ from torchvision import datasets, transforms import numpy as np import os -# TODO: remove after ready for CV -# import sys -# sys.path.insert(0, '/bert_ort/liqun/onnxruntime/build/Linux/Debug/') -# import onnxruntime as ort - -# sys.path.insert(0, '/bert_ort/liqun/onnxruntime/onnxruntime/python/') -# from ort_trainer import IODescription, ModelDescription, ORTTrainer, ORTModel - from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer, ORTModel from mpi4py import MPI from onnxruntime.capi._pybind_state import set_cuda_device_id diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 63cc9170eb..d362f4085f 100755 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1006,18 +1006,27 @@ def adb_push(source_dir, src, dest, **kwargs): def adb_shell(*args, **kwargs): return run_subprocess(['adb', 'shell', *args], **kwargs) -def run_training_python_frontend_e2e_tests(args, cwd, dll_path): +def run_training_python_frontend_e2e_tests(args, cwd): # frontend tests are to be added here: log.info("Running python frontend e2e tests.") - run_subprocess( - [sys.executable, 'onnxruntime_test_ort_trainer_with_mixed_precision.py'], - cwd=cwd, dll_path=dll_path) + run_subprocess([sys.executable, 'orttraining_test_transformers.py'], cwd=cwd) + + run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer.py'], cwd=cwd) + + run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer_with_mixed_precision.py'], cwd=cwd) def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs, enable_tvm=False, enable_tensorrt=False): for config in configs: log.info("Running tests for %s configuration", config) cwd = get_config_build_dir(build_dir, config) + + if args.enable_training and args.use_cuda and args.enable_training_python_frontend_e2e_tests: + # run frontend tests for orttraining-linux-gpu-frontend_test-ci-pipeline. + # this is not a PR merge test so skip other tests. + run_training_python_frontend_e2e_tests(args, cwd=cwd) + continue + android_x86_64 = args.android_abi == 'x86_64' if android_x86_64: run_subprocess(os.path.join( @@ -1099,10 +1108,6 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs, [sys.executable, 'onnxruntime_test_training_unit_tests.py'], cwd=cwd, dll_path=dll_path) - # run additional frontend tests for orttraining-linux-gpu-frontend_test_ci-pipeline - if args.enable_training_python_frontend_e2e_tests: - run_training_python_frontend_e2e_tests(args, cwd=cwd, dll_path=dll_path) - try: import onnx # noqa # gen_test_models.py used by onnx_test requires scipy. diff --git a/tools/ci_build/github/linux/docker/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/scripts/install_deps.sh index 23b97752b7..0bb2e0df9b 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_deps.sh @@ -112,6 +112,9 @@ elif [ $DEVICE_TYPE = "gpu" ]; then if [[ $BUILD_EXTR_PAR = *--enable_training* ]]; then ${PYTHON_EXE} -m pip install --upgrade --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html fi + if [[ $BUILD_EXTR_PAR = *--enable_training_python_frontend_e2e_tests* ]]; then + ${PYTHON_EXE} -m pip install transformers + fi fi