Liqun/e2e transformer test (#3540)

* initial change to transformer.py

* prepare e2e transformer tests

* refactor transformer tests

* put test python files in a flat folder

* fix typo pip install transform(s)

* python 3.6

* python version to 3.6 in install_ubuntu.sh

* remove argparser

* to use opset ver 12

* workaround loss_scale naming patch in case of loss_fn_

* assign self.loss_fn_ so it can be checked

* skip a few un-needed post-process steps

* fix loss_scale_input_name, clean up post process steps

* skip non-frontend tests

* move cpu/cuda related files to coresponding cpu/cuda folder (#3668)

Co-authored-by: Weixing Zhang <wezhan@microsoft.com>

* type cast for ratio is not necessary for dropout (#3682)

Co-authored-by: Weixing Zhang <wezhan@microsoft.com>

* thrustallocator is not needed since cub is used directly for gather now. (#3683)

Co-authored-by: Weixing Zhang <wezhan@microsoft.com>

* GatherND-12 Implementation (#3645)

* Renamed, UT passing

* Move GatherND CUDA Kerenl into onnxruntime

* Merge GatherNDOpTest

* Refactor Test code

* Merge CPU Kernel Impl

* Handle Negative Indice, Fix UT

* Improve CUDA kernel to handle negative index

* Minor Fixes

* Preserve GatherND-1 Cuda kernel

* Fix Mac build

* fix UT

* Fix Build

* fix GatherNDOpTest.double > CUDA error cudaErrorInvalidDeviceFunction:invalid device function

Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
Co-authored-by: Peng Wang (pengwa) <pengwa@microsoft.com>

* update with reviewers' comments

* testBertTrainingGradientAccumulation was not using rtol and may fail occasionally with small (e-06) difference

* fix merge mistakes

Co-authored-by: liqun <liqun@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
Co-authored-by: Weixing Zhang <weixingzhang@users.noreply.github.com>
Co-authored-by: Weixing Zhang <wezhan@microsoft.com>
Co-authored-by: Sherlock <baihan.huang@gmail.com>
Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
Co-authored-by: Peng Wang (pengwa) <pengwa@microsoft.com>
This commit is contained in:
liqunfu 2020-04-30 12:26:38 -07:00 committed by GitHub
parent 177c1357f4
commit af3988198c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 728 additions and 610 deletions

View file

@ -173,6 +173,7 @@ endif()
file(GLOB onnxruntime_python_test_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/python/*.py"
"${ORTTRAINING_SOURCE_DIR}/test/python/*.py"
)
file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/*.py"

View file

@ -1,593 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import shutil
import pytest
import sys
import os
from enum import Enum
# set transformer_repo to the huggingface repo location
transformer_repo = ''
nvidia_deep_learning_examples_repo = ''
sys.path.append(transformer_repo)
from transformers import is_torch_available
sys.path.append(os.path.join(transformer_repo, "transformers/tests"))
from modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
from configuration_common_test import ConfigTester
if is_torch_available():
from transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertForMultipleChoice)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
else:
pytestmark = pytest.mark.skip("Require Torch")
import onnxruntime as ort
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler, generate_sample
import torch
sys.path.append(os.path.join(nvidia_deep_learning_examples_repo, 'PyTorch/LanguageModeling/BERT'))
from run_pretraining import postprocess_model
def map_optimizer_attributes(name):
no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys)
if no_decay:
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}
else:
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6}
from torch.utils.data import Dataset, DataLoader
class OrtTestDataset(Dataset):
def __init__(self, input_desc, seq_len, device):
import copy
self.input_desc_ = copy.deepcopy(input_desc)
for input_desc in self.input_desc_:
shape_ = []
for i, axis in enumerate(input_desc.shape_):
if axis == 'max_seq_len_in_batch':
shape_ = shape_ + [seq_len, ]
elif axis != 'batch':
shape_ = input_desc.shape_[i]
input_desc.shape_ = shape_
self.device_ = device
def __len__(self):
return 100
def __getitem__(self, item):
input_batch = []
for input_desc in self.input_desc_:
input_sample = generate_sample(input_desc, self.device_)
input_batch.append(input_sample)
return input_batch
def create_ort_test_dataloader(input_desc, batch_size, seq_len, device):
dataset = OrtTestDataset(input_desc, seq_len, device)
return DataLoader(dataset, batch_size=batch_size)
class BatchArgsOption(Enum):
List = 1
Dict = 2
ListAndDict = 3
def split_batch(batch, input_desc, args_count):
total_argument_count = len(input_desc)
# batch=[input_ids[batch, seglen], attention_mask[batch, seglen], token_type_ids[batch,seglen], token_type_ids[batch, seglen]]
args = [] # (input_ids[batch, seglen], attention_mask[batch, seglen])
kwargs = {} # {'token_type_ids': token_type_ids[batch,seglen], 'position_ids': token_type_ids[batch, seglen]}
for i in range(args_count):
args = args + [batch[i]]
for i in range(args_count, total_argument_count):
kwargs[input_desc[i].name_] = batch[i]
return args, kwargs
def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16,
allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler,
batch_args_option):
dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, device)
model = ORTTrainer(model, None, model_desc, "LambOptimizer",
map_optimizer_attributes=map_optimizer_attributes,
learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32),
device=device, postprocess_model=postprocess_model,
gradient_accumulation_steps=gradient_accumulation_steps,
# BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
world_rank=args.local_rank, world_size=args.world_size,
use_mixed_precision=fp16,
allreduce_post_accumulation=allreduce_post_accumulation,
get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None,
loss_scaler=loss_scaler if use_internal_loss_scaler else None)
# trainig loop
eval_batch = None
model.train()
for step, batch in enumerate(dataloader):
if eval_batch is None:
eval_batch = batch
if not use_internal_get_lr_this_step:
lr = get_lr_this_step(step)
learning_rate = torch.tensor([lr])
if not use_internal_loss_scaler and fp16:
loss_scale = torch.tensor(loss_scaler.loss_scale_)
if batch_args_option == BatchArgsOption.List:
if not use_internal_get_lr_this_step:
batch = batch + [learning_rate, ]
if not use_internal_loss_scaler and fp16:
batch = batch + [loss_scale, ]
outputs = model(*batch)
elif batch_args_option == BatchArgsOption.Dict:
args, kwargs = split_batch(batch, model_desc.inputs_, 0)
if not use_internal_get_lr_this_step:
kwargs['Learning_Rate'] = learning_rate
if not use_internal_loss_scaler and fp16:
kwargs[model.loss_scale_input_name] = loss_scale
outputs = model(*args, **kwargs)
else:
args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs
args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
if not use_internal_get_lr_this_step:
kwargs['Learning_Rate'] = learning_rate
if not use_internal_loss_scaler and fp16:
kwargs[model.loss_scale_input_name] = loss_scale
outputs = model(*args, **kwargs)
# eval
model.eval()
if batch_args_option == BatchArgsOption.List:
outputs = model(*batch)
elif batch_args_option == BatchArgsOption.Dict:
args, kwargs = split_batch(batch, model_desc.inputs_, 0)
outputs = model(*args, **kwargs)
else:
args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs
args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
outputs = model(*args, **kwargs)
return (output.cpu().numpy() for output in outputs)
class BertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification) if is_torch_available() else ()
class BertModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
device='cpu',
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.device = device
# 1. superset of bert input/output descs
# see BertPreTrainedModel doc
self.input_ids_desc = IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size)
self.attention_mask_desc = IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2)
self.token_type_ids_desc = IODescription('token_type_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2)
self.position_ids_desc = IODescription('position_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.max_position_embeddings)
self.head_mask_desc = IODescription('head_mask', [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2)
self.inputs_embeds_desc = IODescription('inputs_embeds', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
self.encoder_hidden_states_desc = IODescription('encoder_hidden_states', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
self.encoder_attention_mask_desc = IODescription('encoder_attention_mask', ['batch', 'max_seq_len_in_batch'], torch.float32)
# see BertForPreTraining doc
self.masked_lm_labels_desc = IODescription('masked_lm_labels', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size)
self.next_sentence_label_desc = IODescription('next_sentence_label', ['batch',], torch.int64, num_classes=2)
# outputs
self.loss_desc = IODescription('loss', [1,], torch.float32)
self.prediction_scores_desc = IODescription('prediction_scores', ['batch', 'max_seq_len_in_batch', self.vocab_size], torch.float32)
self.seq_relationship_scores_desc = IODescription('seq_relationship_scores', ['batch', 2], torch.float32) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32)
self.hidden_states_desc = IODescription('hidden_states', [self.num_hidden_layers, 'batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
self.attentions_desc = IODescription('attentions', [self.num_hidden_layers, 'batch', self.num_attention_heads, 'max_seq_len_in_batch', 'max_seq_len_in_batch'], torch.float32)
self.last_hidden_state_desc = IODescription('last_hidden_state', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
self.pooler_output_desc = IODescription('pooler_output', ['batch', self.hidden_size], torch.float32)
# BertForPreTraining forward:
# def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
# position_ids??=None, head_mask??=None, inputs_embeds??=None,
# masked_lm_labels=None, next_sentence_label=None):
#
# create_and_check_bert_for_pretraining calls BertForPreTraining:
# model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
# masked_lm_labels=token_labels, next_sentence_label=sequence_labels)
def BertForPreTraining_descs(self):
return ModelDescription(
[self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, self.masked_lm_labels_desc, self.next_sentence_label_desc],
# returns loss_desc if both masked_lm_labels_desc, next_sentence_label are provided
# hidden_states_desc, attentions_desc shall be included according to config.output_attentions, config.output_hidden_states
[self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc,
#hidden_states_desc, attentions_desc
])
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).to(self.device)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2).to(self.device)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size).to(self.device)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size).to(self.device)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels).to(self.device)
choice_labels = ids_tensor([self.batch_size], self.num_choices).to(self.device)
config = BertConfig(
vocab_size_or_config_json_file=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def prepare_config_and_inputs_for_decoder(self):
config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels = self.prepare_config_and_inputs()
config.is_decoder = True
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertModel(config=config)
model.to(input_ids.device)
model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
# failed because there is not loss output
model_desc = ModelDescription([self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc],
[self.last_hidden_state_desc, self.pooler_output_desc])
args_gradient_accumulation_steps = 8
args_local_rank = 0
args_world_size = 1
args_fp16 = True
args_allreduce_post_accumulation = True
model = ORTTrainer(model, None, model_desc, "LambOptimizer",
map_optimizer_attributes=map_optimizer_attributes,
learning_rate_description=IODescription('Learning_Rate', [1, ], torch.float32),
device=self.device, postprocess_model=postprocess_model,
gradient_accumulation_steps=args_gradient_accumulation_steps,
world_rank=args_local_rank, world_size=args_world_size,
use_mixed_precision=True if args_fp16 else False,
allreduce_post_accumulation=True if args_allreduce_post_accumulation else False)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_bert_model_as_decoder(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask):
model = BertModel(config)
model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask)
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states)
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMaskedLM(config=config)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels)
#####
model_desc = ModelDescription([self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, self.masked_lm_labels_desc],
[self.loss_desc, self.prediction_scores_desc])
args_gradient_accumulation_steps = 8
args_local_rank = 0
args_world_size = 1
args_fp16 = True
args_allreduce_post_accumulation = True
model = ORTTrainer(model, None, model_desc, "LambOptimizer",
map_optimizer_attributes=map_optimizer_attributes,
learning_rate_description=IODescription('Learning_Rate', [1, ], torch.float32),
device=self.device, postprocess_model=postprocess_model,
gradient_accumulation_steps=args_gradient_accumulation_steps,
world_rank=args_local_rank, world_size=args_world_size,
use_mixed_precision=True if args_fp16 else False,
allreduce_post_accumulation=True if args_allreduce_post_accumulation else False)
model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels)
def create_and_check_bert_model_for_masked_lm_as_decoder(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask):
model = BertForMaskedLM(config=config)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask)
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels, encoder_hidden_states=encoder_hidden_states)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_and_check_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForNextSentencePrediction(config=config)
model.eval()
loss, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels)
result = {
"loss": loss,
"seq_relationship_score": seq_relationship_score,
}
self.parent.assertListEqual(
list(result["seq_relationship_score"].size()),
[self.batch_size, 2])
self.check_loss_output(result)
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForPreTraining(config=config)
model.eval()
loss, prediction_scores, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
masked_lm_labels=token_labels, next_sentence_label=sequence_labels)
model_desc = ModelDescription([self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc,
self.masked_lm_labels_desc, self.next_sentence_label_desc],
[self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc])
import argparse
args_ = argparse.Namespace(fp16=True, amp_opt_level='O1')
from collections import namedtuple
MyArgs = namedtuple("MyArgs",
"local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len")
args = MyArgs(local_rank=0, world_size=1, max_steps=100, learning_rate=0.00001, warmup_proportion=0.01, batch_size=13, seq_len=7)
from train_with_ort_trainer import get_lr
def get_lr_this_step(global_step):
return get_lr(args, global_step)
loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000)
option_gradient_accumulation_steps = [8]
option_fp16 = [True, False]
option_allreduce_post_accumulation = True
option_use_internal_get_lr_this_step = False
option_use_internal_loss_scaler = False
# TODO: with with fetches
for gradient_accumulation_steps in option_gradient_accumulation_steps:
for fp16 in option_fp16:
for option_split_batch in BatchArgsOption:
loss_ort, prediction_scores_ort, seq_relationship_score_ort =\
run_test(model, model_desc, self.device, args, gradient_accumulation_steps, fp16,
option_allreduce_post_accumulation,
get_lr_this_step, option_use_internal_get_lr_this_step,
loss_scaler, option_use_internal_loss_scaler,
option_split_batch)
print(loss_ort)
print(prediction_scores_ort)
print(seq_relationship_score_ort)
def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForQuestionAnswering(config=config)
model.eval()
loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
start_positions=sequence_labels, end_positions=sequence_labels)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(
list(result["start_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_and_check_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_labels = self.num_labels
model = BertForSequenceClassification(config)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_labels])
self.check_loss_output(result)
def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_labels = self.num_labels
model = BertForTokenClassification(config=config)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
def create_and_check_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_choices = self.num_choices
model = BertForMultipleChoice(config=config)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_choices])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_mask,
sequence_labels, token_labels, choice_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
return config, inputs_dict
def setUp(self):
self.model_tester = BertModelTest.BertModelTester(self)
self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
# def test_config(self):
# self.config_tester.run_common_tests()
# def test_bert_model(self, use_cuda=False):
# # ^^ This could be a real fixture
# if use_cuda:
# self.model_tester.device = "cuda"
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
# self.model_tester.create_and_check_bert_model(*config_and_inputs)
# def test_bert_model_as_decoder(self):
# config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
# self.model_tester.create_and_check_bert_model_as_decoder(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
# def test_for_masked_lm_decoder(self):
# config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
# self.model_tester.create_and_check_bert_model_for_masked_lm_as_decoder(*config_and_inputs)
# def test_for_multiple_choice(self):
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
# self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
# def test_for_next_sequence_prediction(self):
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
# self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
# def test_for_question_answering(self):
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
# self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs)
# def test_for_sequence_classification(self):
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
# self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
# def test_for_token_classification(self):
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
# self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
# @pytest.mark.slow
# def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/"
# for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
# shutil.rmtree(cache_dir)
# self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()

View file

@ -309,6 +309,7 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op
# Other export options to use(this is for backward compatibility).
other_export_options = {}
# This option was added after 1.4 release.
if LooseVersion(torch.__version__) > LooseVersion('1.4.0'):
other_export_options['enable_onnx_checker'] = False
@ -630,6 +631,7 @@ class ORTTrainer():
if loss_fn is not None:
print("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.")
# TODO: accept loss_fn as an onnx model. build self.onnx_model_ with model and loss_fn
self.loss_fn_ = None
self.model_desc_ = model_desc
self.input_desc_with_lr = [*self.model_desc_.inputs_, learning_rate_description]
@ -664,9 +666,12 @@ class ORTTrainer():
self.enable_grad_norm_clip_ = enable_grad_norm_clip
self.frozen_weights_ = frozen_weights
self.opset_version_ = _opset_version
self.loss_scale_input_name = ''
self.state_dict_ = None
# use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs.
# see prepare_input_and_fetches for more details.
self.loss_scale_input_name = 'default_loss_scale_input_name'
self._init_session()
def _init_session(self):
@ -790,6 +795,15 @@ class ORTTrainer():
input = input + (internal_learning_rate,)
if internal_loss_scale is not None:
input = input + (internal_loss_scale,)
elif self.use_mixed_precision:
# loss_scale input name is needed to call train_step, for example:
# kwargs[model.loss_scale_input_name] = loss_scale
# outputs = model.train_step(*args, **kwargs)
# However, when first time train_step is called model.loss_scale_input_name is not set.
# To workaround this problem, we use the special name 'default_loss_scale_input_name' to indicate
# the loss_scale.
if 'default_loss_scale_input_name' in kwargs.keys():
input = input + (kwargs['default_loss_scale_input_name'],)
fetches = None
if 'fetches' in kwargs:

View file

@ -0,0 +1,12 @@
from orttraining_test_model_transform import add_name, fix_transpose, add_expand_shape
from orttraining_test_layer_norm_transform import layer_norm_transform
def postprocess_model(model):
add_name(model)
# remove transpose node if its input is a 2d weight which only feeds to the node
fix_transpose(model)
add_expand_shape(model)
layer_norm_transform(model)

View file

@ -0,0 +1,85 @@
from enum import Enum
import random
import torch
from torch.utils.data import Dataset, DataLoader
from onnxruntime.capi.ort_trainer import generate_sample
global_rng = random.Random()
def ids_tensor(shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = global_rng
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
def floats_tensor(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor of the shape within the vocab size."""
if rng is None:
rng = global_rng
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.random() * scale)
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
class OrtTestDataset(Dataset):
def __init__(self, input_desc, seq_len, device):
import copy
self.input_desc_ = copy.deepcopy(input_desc)
for input_desc in self.input_desc_:
shape_ = []
for i, axis in enumerate(input_desc.shape_):
if axis == 'max_seq_len_in_batch':
shape_ = shape_ + [seq_len, ]
elif axis != 'batch':
shape_ = input_desc.shape_[i]
input_desc.shape_ = shape_
self.device_ = device
def __len__(self):
return 100
def __getitem__(self, item):
input_batch = []
for input_desc in self.input_desc_:
input_sample = generate_sample(input_desc, self.device_)
input_batch.append(input_sample)
return input_batch
def create_ort_test_dataloader(input_desc, batch_size, seq_len, device):
dataset = OrtTestDataset(input_desc, seq_len, device)
return DataLoader(dataset, batch_size=batch_size)
class BatchArgsOption(Enum):
List = 1
Dict = 2
ListAndDict = 3
def split_batch(batch, input_desc, args_count):
total_argument_count = len(input_desc)
# batch=[input_ids[batch, seglen], attention_mask[batch, seglen], token_type_ids[batch,seglen], token_type_ids[batch, seglen]]
args = [] # (input_ids[batch, seglen], attention_mask[batch, seglen])
kwargs = {} # {'token_type_ids': token_type_ids[batch,seglen], 'position_ids': token_type_ids[batch, seglen]}
for i in range(args_count):
args = args + [batch[i]]
for i in range(args_count, total_argument_count):
kwargs[input_desc[i].name_] = batch[i]
return args, kwargs

View file

@ -0,0 +1,177 @@
import onnx
def find_node(graph_proto, op_type):
nodes = []
map_input_node = {}
for node in graph_proto.node:
if node.op_type == op_type:
map_input_node[node.input[0]] = node
if op_type == 'Div' or op_type == 'Mul':
map_input_node[node.input[1]] = node
nodes.append(node)
return nodes, map_input_node
def gen_attribute(key, value):
attr = AttributeProto()
attr.name = key
attr.ints.extend(int(v) for v in value)
attr.type = AttributeProto.INTS
return attr
def layer_norm_transform(model_proto):
# a layer norm subgraph
# input
# |
# ReduceMean
# __|____
# | |
# Sub Sub
# | |
# | Pow
# | |
# | ReduceMean
# | |
# | Add
# | |
# |__ __Sqrt
# | |
# Div
# |
# Mul
# |
# Add
# |
# output
graph_proto = model_proto.graph
_, map_input_Div = find_node(graph_proto, 'Div')
_, map_input_Sqrt = find_node(graph_proto, 'Sqrt')
_, map_input_Add = find_node(graph_proto, 'Add')
nodes_ReduceMean, map_input_ReduceMean = find_node(graph_proto, 'ReduceMean')
_, map_input_Pow = find_node(graph_proto, 'Pow')
_, map_input_Mul = find_node(graph_proto, 'Mul')
# find right side Sub (see the layer norm subgrapg)
nodes_Sub = []
map_input_Sub = {}
for node in graph_proto.node:
if node.op_type == 'Sub':
if node.output[0] in map_input_Pow:
nodes_Sub.append(node)
map_input_Sub[node.input[1]] = node
# find first ReduceMean
first_ReduceMean = []
first_ReduceMean_outputs = []
for node in nodes_ReduceMean:
if node.output[0] in map_input_Sub:
first_ReduceMean.append(node)
first_ReduceMean_outputs.append(node.output[0])
# find constant node
nodes_Constant = []
map_output_Constant = {}
for node in graph_proto.node:
if node.op_type == 'Constant':
nodes_Constant.append(node)
map_output_Constant[node.output[0]] = node
id = 0
removed_nodes = []
layer_norm_nodes = []
# Replace with layer norm
for node in first_ReduceMean:
layer_norm_input = []
layer_norm_output = []
layer_norm_input.append(node.input[0])
# collect nodes within a layer norm subgraph.
# skip building layer norm node if there is a pattern miss-match.
if node.output[0] not in map_input_Sub:
continue
node_sub = map_input_Sub[node.output[0]]
if node_sub.output[0] not in map_input_Pow:
continue
node_pow = map_input_Pow[node_sub.output[0]]
if node_pow.output[0] not in map_input_ReduceMean:
continue
node_reduce = map_input_ReduceMean[node_pow.output[0]]
if node_reduce.output[0] not in map_input_Add:
continue
node_Add = map_input_Add[node_reduce.output[0]]
if node_Add.output[0] not in map_input_Sqrt:
continue
node_Sqrt = map_input_Sqrt[node_Add.output[0]]
if node_Sqrt.output[0] not in map_input_Div:
continue
node_Div = map_input_Div[node_Sqrt.output[0]]
if node_Div.output[0] not in map_input_Mul:
continue
node_Mul = map_input_Mul[node_Div.output[0]]
if node_Mul.input[0] != node_Div.output[0]:
layer_norm_input.append(node_Mul.input[0])
else:
layer_norm_input.append(node_Mul.input[1])
if node_Mul.output[0] not in map_input_Add:
continue
node_Add1 = map_input_Add[node_Mul.output[0]]
layer_norm_input.append(node_Add1.input[1])
removed_nodes.append(node)
removed_nodes.append(node_sub)
removed_nodes.append(node_pow)
removed_nodes.append(node_reduce)
removed_nodes.append(node_Add)
removed_nodes.append(node_Sqrt)
removed_nodes.append(node_Div)
removed_nodes.append(node_Mul)
removed_nodes.append(node_Add1)
removed_nodes.append(map_output_Constant[node_pow.input[1]])
removed_nodes.append(map_output_Constant[node_Add.input[1]])
layer_norm_output.append(node_Add1.output[0])
id = id + 1
layer_norm_output.append('saved_mean_' + str(id))
id = id + 1
layer_norm_output.append('saved_inv_std_var_' + str(id))
layer_norm = onnx.helper.make_node("LayerNormalization",
layer_norm_input,
layer_norm_output,
"LayerNormalization_" + str(id),
None,
axis = node_reduce.attribute[0].ints[0],
epsilon = 9.999999960041972e-13)
layer_norm_nodes.append(layer_norm)
# remove left side Subs
for node in graph_proto.node:
if node.op_type == 'Sub':
if node.input[1] in first_ReduceMean_outputs:
removed_nodes.append(node)
all_nodes = []
for node in graph_proto.node:
if node not in removed_nodes:
all_nodes.append(node)
for node in layer_norm_nodes:
all_nodes.append(node)
graph_proto.ClearField("node")
graph_proto.node.extend(all_nodes)

View file

@ -0,0 +1,106 @@
from onnx import numpy_helper
def add_name(model):
i = 0
for node in model.graph.node:
node.name = '%s_%d' %(node.op_type, i)
i += 1
def find_single_output_node(model, arg):
result = []
for node in model.graph.node:
for input in node.input:
if input == arg:
result.append(node)
return result[0] if len(result) == 1 else None
def find_input_as_initializer(model, arg):
for initializer in model.graph.initializer:
if initializer.name == arg:
return initializer
return None
def get_node_index(model, node):
for i, n in enumerate(model.graph.node):
if n == node:
return i
return None
def replace_input_arg(model, arg, new_arg):
for node in model.graph.node:
for i in range(len(node.input)):
if node.input[i] == arg:
node.input[i] = new_arg
def find_weight_index(model, name):
for index, w in enumerate(model.graph.initializer):
if w.name == name:
return index
index += 1
return None
def fix_transpose(model):
"""
remove transpose node if its input is a 2d weight which only feeds to the node.
"""
# Find transpose nodes with initializer weight as input.
# The input weight needs to be only feeded into the transpose node.
# Collect these nodes and weights.
transpose = []
for node in model.graph.node:
if node.op_type == 'Transpose':
weight = find_input_as_initializer(model, node.input[0])
if weight is not None:
result = []
for n in model.graph.node:
for input in n.input:
if input == weight.name:
result.append(n)
if len(result) > 1:
continue
perm = node.attribute[0]
assert perm.name == 'perm'
perm = perm.ints
assert len(perm) == 2 and perm[0] == 1 and perm[1] == 0
transpose.append((get_node_index(model, node), weight))
# Transpose collected weights and add it to the model initializers.
# The transposed weight initializers become inputs to the transpose nodes' recipient nodes.
for t in transpose:
node = model.graph.node[t[0]]
weight = numpy_helper.to_array(t[1])
assert len(weight.shape) == 2
weight = weight.transpose(perm)
new_weight = numpy_helper.from_array(weight, "%s_transposed" % t[1].name)
model.graph.initializer.extend([new_weight])
replace_input_arg(model, node.output[0], new_weight.name)
# collected transpose nodes can be removed.
transpose.sort(reverse=True)
for t in transpose:
del model.graph.node[t[0]]
# the original weight initializer can be removed.
# (remember that a wight needs only to be feeded into the transpose node when collecting wights)
old_ws = []
for t in transpose:
if find_single_output_node(model, t[1].name) is None:
old_ws.append(find_weight_index(model, t[1].name))
old_ws.sort(reverse=True)
for w_i in old_ws:
del model.graph.initializer[w_i]
def add_expand_shape(model):
"""
this method is very specific to the Bert model where there is a solo Expand op.
training backend requires the op's output shape. it is the same as the shape of the model (single) input.
"""
expand_node = [n for n in model.graph.node if n.op_type == 'Expand']
if len(expand_node) != 1:
raise "cannot find the single expand node in the BERT model."
return
expand_out = model.graph.value_info.add()
expand_out.name = expand_node[0].output[0] # base: '421' # tiny: '85'
expand_out.type.CopyFrom(model.graph.input[0].type)

View file

@ -0,0 +1,198 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import shutil
import pytest
import os
from transformers import (BertConfig, BertForPreTraining, BertModel)
from orttraining_test_data_loader import ids_tensor, BatchArgsOption
from orttraining_test_utils import run_test, get_lr
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler
import torch
class BertModelTest(unittest.TestCase):
class BertModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
device='cpu',
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.device = device
# 1. superset of bert input/output descs
# see BertPreTrainedModel doc
self.input_ids_desc = IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size)
self.attention_mask_desc = IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2)
self.token_type_ids_desc = IODescription('token_type_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2)
self.position_ids_desc = IODescription('position_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.max_position_embeddings)
self.head_mask_desc = IODescription('head_mask', [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2)
self.inputs_embeds_desc = IODescription('inputs_embeds', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
self.encoder_hidden_states_desc = IODescription('encoder_hidden_states', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
self.encoder_attention_mask_desc = IODescription('encoder_attention_mask', ['batch', 'max_seq_len_in_batch'], torch.float32)
# see BertForPreTraining doc
self.masked_lm_labels_desc = IODescription('masked_lm_labels', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size)
self.next_sentence_label_desc = IODescription('next_sentence_label', ['batch',], torch.int64, num_classes=2)
# outputs
self.loss_desc = IODescription('loss', [1,], torch.float32)
self.prediction_scores_desc = IODescription('prediction_scores', ['batch', 'max_seq_len_in_batch', self.vocab_size], torch.float32)
self.seq_relationship_scores_desc = IODescription('seq_relationship_scores', ['batch', 2], torch.float32) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32)
self.hidden_states_desc = IODescription('hidden_states', [self.num_hidden_layers, 'batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
self.attentions_desc = IODescription('attentions', [self.num_hidden_layers, 'batch', self.num_attention_heads, 'max_seq_len_in_batch', 'max_seq_len_in_batch'], torch.float32)
self.last_hidden_state_desc = IODescription('last_hidden_state', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
self.pooler_output_desc = IODescription('pooler_output', ['batch', self.hidden_size], torch.float32)
def BertForPreTraining_descs(self):
return ModelDescription(
[self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, self.masked_lm_labels_desc, self.next_sentence_label_desc],
# returns loss_desc if both masked_lm_labels_desc, next_sentence_label are provided
# hidden_states_desc, attentions_desc shall be included according to config.output_attentions, config.output_hidden_states
[self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc,
#hidden_states_desc, attentions_desc
])
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).to(self.device)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2).to(self.device)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size).to(self.device)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size).to(self.device)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels).to(self.device)
choice_labels = ids_tensor([self.batch_size], self.num_choices).to(self.device)
config = BertConfig(
vocab_size=self.vocab_size,
vocab_size_or_config_json_file=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForPreTraining(config=config)
model.eval()
loss, prediction_scores, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
masked_lm_labels=token_labels, next_sentence_label=sequence_labels)
model_desc = ModelDescription([self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc,
self.masked_lm_labels_desc, self.next_sentence_label_desc],
[self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc])
from collections import namedtuple
MyArgs = namedtuple("MyArgs",
"local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len")
args = MyArgs(local_rank=0, world_size=1, max_steps=100, learning_rate=0.00001, warmup_proportion=0.01, batch_size=13, seq_len=7)
def get_lr_this_step(global_step):
return get_lr(args, global_step)
loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000)
# It would be better to test both with/without mixed precision and allreduce_post_accumulation.
# However, stress test of all the 4 cases is not stable at lease on the test machine.
# There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases.
option_fp16 = [True]
option_allreduce_post_accumulation = [True]
option_gradient_accumulation_steps = [1, 8]
option_use_internal_get_lr_this_step = [True, False]
option_use_internal_loss_scaler = [True, False]
option_split_batch = [BatchArgsOption.ListAndDict]
for fp16 in option_fp16:
for allreduce_post_accumulation in option_allreduce_post_accumulation:
for gradient_accumulation_steps in option_gradient_accumulation_steps:
for use_internal_get_lr_this_step in option_use_internal_get_lr_this_step:
for use_internal_loss_scaler in option_use_internal_loss_scaler:
for split_batch in option_split_batch:
print("gradient_accumulation_steps:", gradient_accumulation_steps)
print("use_internal_loss_scaler:", use_internal_loss_scaler)
loss_ort, prediction_scores_ort, seq_relationship_score_ort =\
run_test(model, model_desc, self.device, args, gradient_accumulation_steps, fp16,
allreduce_post_accumulation,
get_lr_this_step, use_internal_get_lr_this_step,
loss_scaler, use_internal_loss_scaler,
split_batch)
print(loss_ort)
print(prediction_scores_ort)
print(seq_relationship_score_ort)
def setUp(self):
self.model_tester = BertModelTest.BertModelTester(self)
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,118 @@
import torch
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription
from orttraining_test_data_loader import create_ort_test_dataloader, BatchArgsOption, split_batch
from orttraining_test_bert_postprocess import postprocess_model
def warmup_cosine(x, warmup=0.002):
if x < warmup:
return x/warmup
return 0.5 * (1.0 + torch.cos(math.pi * x))
def warmup_constant(x, warmup=0.002):
if x < warmup:
return x/warmup
return 1.0
def warmup_linear(x, warmup=0.002):
if x < warmup:
return x/warmup
return max((x - 1. )/ (warmup - 1.), 0.)
def warmup_poly(x, warmup=0.002, degree=0.5):
if x < warmup:
return x/warmup
return (1.0 - x)**degree
SCHEDULES = {
'warmup_cosine':warmup_cosine,
'warmup_constant':warmup_constant,
'warmup_linear':warmup_linear,
'warmup_poly':warmup_poly,
}
def get_lr(args, training_steps, schedule='warmup_poly'):
if args.max_steps == -1:
return args.learning_rate
schedule_fct = SCHEDULES[schedule]
return args.learning_rate * schedule_fct(training_steps / args.max_steps, args.warmup_proportion)
def map_optimizer_attributes(name):
no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys)
if no_decay:
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}
else:
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6}
def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16,
allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler,
batch_args_option):
dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, device)
model = ORTTrainer(model, None, model_desc, "LambOptimizer",
map_optimizer_attributes=map_optimizer_attributes,
learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32),
device=device, postprocess_model=postprocess_model,
gradient_accumulation_steps=gradient_accumulation_steps,
# BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
world_rank=args.local_rank, world_size=args.world_size,
use_mixed_precision=fp16,
allreduce_post_accumulation=allreduce_post_accumulation,
get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None,
loss_scaler=loss_scaler if use_internal_loss_scaler else None,
_opset_version=12)
# trainig loop
eval_batch = None
model.train()
for step, batch in enumerate(dataloader):
if eval_batch is None:
eval_batch = batch
if not use_internal_get_lr_this_step:
lr = get_lr_this_step(step)
learning_rate = torch.tensor([lr])
if not use_internal_loss_scaler and fp16:
loss_scale = torch.tensor([loss_scaler.loss_scale_])
if batch_args_option == BatchArgsOption.List:
if not use_internal_get_lr_this_step:
batch = batch + [learning_rate, ]
if not use_internal_loss_scaler and fp16:
batch = batch + [loss_scale, ]
outputs = model(*batch)
elif batch_args_option == BatchArgsOption.Dict:
args, kwargs = split_batch(batch, model_desc.inputs_, 0)
if not use_internal_get_lr_this_step:
kwargs['Learning_Rate'] = learning_rate
if not use_internal_loss_scaler and fp16:
kwargs[model.loss_scale_input_name] = loss_scale
outputs = model(*args, **kwargs)
else:
args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs
args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
if not use_internal_get_lr_this_step:
kwargs['Learning_Rate'] = learning_rate
if not use_internal_loss_scaler and fp16:
kwargs[model.loss_scale_input_name] = loss_scale
outputs = model(*args, **kwargs)
# eval
model.eval()
if batch_args_option == BatchArgsOption.List:
outputs = model(*batch)
elif batch_args_option == BatchArgsOption.Dict:
args, kwargs = split_batch(batch, model_desc.inputs_, 0)
outputs = model(*args, **kwargs)
else:
args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs
args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
outputs = model(*args, **kwargs)
return (output.cpu().numpy() for output in outputs)

View file

@ -19,14 +19,6 @@ from torchvision import datasets, transforms
import numpy as np
import os
# TODO: remove after ready for CV
# import sys
# sys.path.insert(0, '/bert_ort/liqun/onnxruntime/build/Linux/Debug/')
# import onnxruntime as ort
# sys.path.insert(0, '/bert_ort/liqun/onnxruntime/onnxruntime/python/')
# from ort_trainer import IODescription, ModelDescription, ORTTrainer, ORTModel
from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer, ORTModel
from mpi4py import MPI
from onnxruntime.capi._pybind_state import set_cuda_device_id

View file

@ -1006,18 +1006,27 @@ def adb_push(source_dir, src, dest, **kwargs):
def adb_shell(*args, **kwargs):
return run_subprocess(['adb', 'shell', *args], **kwargs)
def run_training_python_frontend_e2e_tests(args, cwd, dll_path):
def run_training_python_frontend_e2e_tests(args, cwd):
# frontend tests are to be added here:
log.info("Running python frontend e2e tests.")
run_subprocess(
[sys.executable, 'onnxruntime_test_ort_trainer_with_mixed_precision.py'],
cwd=cwd, dll_path=dll_path)
run_subprocess([sys.executable, 'orttraining_test_transformers.py'], cwd=cwd)
run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer.py'], cwd=cwd)
run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer_with_mixed_precision.py'], cwd=cwd)
def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs,
enable_tvm=False, enable_tensorrt=False):
for config in configs:
log.info("Running tests for %s configuration", config)
cwd = get_config_build_dir(build_dir, config)
if args.enable_training and args.use_cuda and args.enable_training_python_frontend_e2e_tests:
# run frontend tests for orttraining-linux-gpu-frontend_test-ci-pipeline.
# this is not a PR merge test so skip other tests.
run_training_python_frontend_e2e_tests(args, cwd=cwd)
continue
android_x86_64 = args.android_abi == 'x86_64'
if android_x86_64:
run_subprocess(os.path.join(
@ -1099,10 +1108,6 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs,
[sys.executable, 'onnxruntime_test_training_unit_tests.py'],
cwd=cwd, dll_path=dll_path)
# run additional frontend tests for orttraining-linux-gpu-frontend_test_ci-pipeline
if args.enable_training_python_frontend_e2e_tests:
run_training_python_frontend_e2e_tests(args, cwd=cwd, dll_path=dll_path)
try:
import onnx # noqa
# gen_test_models.py used by onnx_test requires scipy.

View file

@ -112,6 +112,9 @@ elif [ $DEVICE_TYPE = "gpu" ]; then
if [[ $BUILD_EXTR_PAR = *--enable_training* ]]; then
${PYTHON_EXE} -m pip install --upgrade --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html
fi
if [[ $BUILD_EXTR_PAR = *--enable_training_python_frontend_e2e_tests* ]]; then
${PYTHON_EXE} -m pip install transformers
fi
fi