mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Liqun/e2e transformer test (#3540)
* initial change to transformer.py * prepare e2e transformer tests * refactor transformer tests * put test python files in a flat folder * fix typo pip install transform(s) * python 3.6 * python version to 3.6 in install_ubuntu.sh * remove argparser * to use opset ver 12 * workaround loss_scale naming patch in case of loss_fn_ * assign self.loss_fn_ so it can be checked * skip a few un-needed post-process steps * fix loss_scale_input_name, clean up post process steps * skip non-frontend tests * move cpu/cuda related files to coresponding cpu/cuda folder (#3668) Co-authored-by: Weixing Zhang <wezhan@microsoft.com> * type cast for ratio is not necessary for dropout (#3682) Co-authored-by: Weixing Zhang <wezhan@microsoft.com> * thrustallocator is not needed since cub is used directly for gather now. (#3683) Co-authored-by: Weixing Zhang <wezhan@microsoft.com> * GatherND-12 Implementation (#3645) * Renamed, UT passing * Move GatherND CUDA Kerenl into onnxruntime * Merge GatherNDOpTest * Refactor Test code * Merge CPU Kernel Impl * Handle Negative Indice, Fix UT * Improve CUDA kernel to handle negative index * Minor Fixes * Preserve GatherND-1 Cuda kernel * Fix Mac build * fix UT * Fix Build * fix GatherNDOpTest.double > CUDA error cudaErrorInvalidDeviceFunction:invalid device function Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> Co-authored-by: Peng Wang (pengwa) <pengwa@microsoft.com> * update with reviewers' comments * testBertTrainingGradientAccumulation was not using rtol and may fail occasionally with small (e-06) difference * fix merge mistakes Co-authored-by: liqun <liqun@OrtTrainingDev4.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> Co-authored-by: Weixing Zhang <weixingzhang@users.noreply.github.com> Co-authored-by: Weixing Zhang <wezhan@microsoft.com> Co-authored-by: Sherlock <baihan.huang@gmail.com> Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> Co-authored-by: Peng Wang (pengwa) <pengwa@microsoft.com>
This commit is contained in:
parent
177c1357f4
commit
af3988198c
12 changed files with 728 additions and 610 deletions
|
|
@ -173,6 +173,7 @@ endif()
|
|||
|
||||
file(GLOB onnxruntime_python_test_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/test/python/*.py"
|
||||
"${ORTTRAINING_SOURCE_DIR}/test/python/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/*.py"
|
||||
|
|
|
|||
|
|
@ -1,593 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
# set transformer_repo to the huggingface repo location
|
||||
transformer_repo = ''
|
||||
nvidia_deep_learning_examples_repo = ''
|
||||
|
||||
sys.path.append(transformer_repo)
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
sys.path.append(os.path.join(transformer_repo, "transformers/tests"))
|
||||
from modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
|
||||
from configuration_common_test import ConfigTester
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (BertConfig, BertModel, BertForMaskedLM,
|
||||
BertForNextSentencePrediction, BertForPreTraining,
|
||||
BertForQuestionAnswering, BertForSequenceClassification,
|
||||
BertForTokenClassification, BertForMultipleChoice)
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler, generate_sample
|
||||
import torch
|
||||
|
||||
sys.path.append(os.path.join(nvidia_deep_learning_examples_repo, 'PyTorch/LanguageModeling/BERT'))
|
||||
from run_pretraining import postprocess_model
|
||||
|
||||
def map_optimizer_attributes(name):
|
||||
no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
|
||||
no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys)
|
||||
if no_decay:
|
||||
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}
|
||||
else:
|
||||
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6}
|
||||
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
class OrtTestDataset(Dataset):
|
||||
def __init__(self, input_desc, seq_len, device):
|
||||
import copy
|
||||
self.input_desc_ = copy.deepcopy(input_desc)
|
||||
for input_desc in self.input_desc_:
|
||||
shape_ = []
|
||||
for i, axis in enumerate(input_desc.shape_):
|
||||
if axis == 'max_seq_len_in_batch':
|
||||
shape_ = shape_ + [seq_len, ]
|
||||
elif axis != 'batch':
|
||||
shape_ = input_desc.shape_[i]
|
||||
input_desc.shape_ = shape_
|
||||
self.device_ = device
|
||||
|
||||
def __len__(self):
|
||||
return 100
|
||||
|
||||
def __getitem__(self, item):
|
||||
input_batch = []
|
||||
for input_desc in self.input_desc_:
|
||||
input_sample = generate_sample(input_desc, self.device_)
|
||||
input_batch.append(input_sample)
|
||||
return input_batch
|
||||
|
||||
def create_ort_test_dataloader(input_desc, batch_size, seq_len, device):
|
||||
dataset = OrtTestDataset(input_desc, seq_len, device)
|
||||
return DataLoader(dataset, batch_size=batch_size)
|
||||
|
||||
class BatchArgsOption(Enum):
|
||||
List = 1
|
||||
Dict = 2
|
||||
ListAndDict = 3
|
||||
|
||||
def split_batch(batch, input_desc, args_count):
|
||||
total_argument_count = len(input_desc)
|
||||
# batch=[input_ids[batch, seglen], attention_mask[batch, seglen], token_type_ids[batch,seglen], token_type_ids[batch, seglen]]
|
||||
args = [] # (input_ids[batch, seglen], attention_mask[batch, seglen])
|
||||
kwargs = {} # {'token_type_ids': token_type_ids[batch,seglen], 'position_ids': token_type_ids[batch, seglen]}
|
||||
for i in range(args_count):
|
||||
args = args + [batch[i]]
|
||||
|
||||
for i in range(args_count, total_argument_count):
|
||||
kwargs[input_desc[i].name_] = batch[i]
|
||||
|
||||
return args, kwargs
|
||||
|
||||
def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16,
|
||||
allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler,
|
||||
batch_args_option):
|
||||
dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, device)
|
||||
|
||||
model = ORTTrainer(model, None, model_desc, "LambOptimizer",
|
||||
map_optimizer_attributes=map_optimizer_attributes,
|
||||
learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32),
|
||||
device=device, postprocess_model=postprocess_model,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
# BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
|
||||
world_rank=args.local_rank, world_size=args.world_size,
|
||||
use_mixed_precision=fp16,
|
||||
allreduce_post_accumulation=allreduce_post_accumulation,
|
||||
get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None,
|
||||
loss_scaler=loss_scaler if use_internal_loss_scaler else None)
|
||||
|
||||
# trainig loop
|
||||
eval_batch = None
|
||||
model.train()
|
||||
for step, batch in enumerate(dataloader):
|
||||
if eval_batch is None:
|
||||
eval_batch = batch
|
||||
|
||||
if not use_internal_get_lr_this_step:
|
||||
lr = get_lr_this_step(step)
|
||||
learning_rate = torch.tensor([lr])
|
||||
|
||||
if not use_internal_loss_scaler and fp16:
|
||||
loss_scale = torch.tensor(loss_scaler.loss_scale_)
|
||||
|
||||
if batch_args_option == BatchArgsOption.List:
|
||||
if not use_internal_get_lr_this_step:
|
||||
batch = batch + [learning_rate, ]
|
||||
if not use_internal_loss_scaler and fp16:
|
||||
batch = batch + [loss_scale, ]
|
||||
outputs = model(*batch)
|
||||
elif batch_args_option == BatchArgsOption.Dict:
|
||||
args, kwargs = split_batch(batch, model_desc.inputs_, 0)
|
||||
if not use_internal_get_lr_this_step:
|
||||
kwargs['Learning_Rate'] = learning_rate
|
||||
if not use_internal_loss_scaler and fp16:
|
||||
kwargs[model.loss_scale_input_name] = loss_scale
|
||||
outputs = model(*args, **kwargs)
|
||||
else:
|
||||
args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs
|
||||
args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
|
||||
if not use_internal_get_lr_this_step:
|
||||
kwargs['Learning_Rate'] = learning_rate
|
||||
if not use_internal_loss_scaler and fp16:
|
||||
kwargs[model.loss_scale_input_name] = loss_scale
|
||||
outputs = model(*args, **kwargs)
|
||||
|
||||
# eval
|
||||
model.eval()
|
||||
if batch_args_option == BatchArgsOption.List:
|
||||
outputs = model(*batch)
|
||||
elif batch_args_option == BatchArgsOption.Dict:
|
||||
args, kwargs = split_batch(batch, model_desc.inputs_, 0)
|
||||
outputs = model(*args, **kwargs)
|
||||
else:
|
||||
args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs
|
||||
args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
|
||||
outputs = model(*args, **kwargs)
|
||||
|
||||
return (output.cpu().numpy() for output in outputs)
|
||||
|
||||
|
||||
class BertModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
|
||||
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
|
||||
BertForTokenClassification) if is_torch_available() else ()
|
||||
|
||||
class BertModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
device='cpu',
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.device = device
|
||||
|
||||
# 1. superset of bert input/output descs
|
||||
# see BertPreTrainedModel doc
|
||||
self.input_ids_desc = IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size)
|
||||
self.attention_mask_desc = IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2)
|
||||
self.token_type_ids_desc = IODescription('token_type_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2)
|
||||
self.position_ids_desc = IODescription('position_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.max_position_embeddings)
|
||||
self.head_mask_desc = IODescription('head_mask', [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2)
|
||||
self.inputs_embeds_desc = IODescription('inputs_embeds', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
|
||||
|
||||
self.encoder_hidden_states_desc = IODescription('encoder_hidden_states', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
|
||||
self.encoder_attention_mask_desc = IODescription('encoder_attention_mask', ['batch', 'max_seq_len_in_batch'], torch.float32)
|
||||
|
||||
# see BertForPreTraining doc
|
||||
self.masked_lm_labels_desc = IODescription('masked_lm_labels', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size)
|
||||
self.next_sentence_label_desc = IODescription('next_sentence_label', ['batch',], torch.int64, num_classes=2)
|
||||
|
||||
# outputs
|
||||
self.loss_desc = IODescription('loss', [1,], torch.float32)
|
||||
self.prediction_scores_desc = IODescription('prediction_scores', ['batch', 'max_seq_len_in_batch', self.vocab_size], torch.float32)
|
||||
|
||||
self.seq_relationship_scores_desc = IODescription('seq_relationship_scores', ['batch', 2], torch.float32) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32)
|
||||
self.hidden_states_desc = IODescription('hidden_states', [self.num_hidden_layers, 'batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
|
||||
self.attentions_desc = IODescription('attentions', [self.num_hidden_layers, 'batch', self.num_attention_heads, 'max_seq_len_in_batch', 'max_seq_len_in_batch'], torch.float32)
|
||||
self.last_hidden_state_desc = IODescription('last_hidden_state', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
|
||||
self.pooler_output_desc = IODescription('pooler_output', ['batch', self.hidden_size], torch.float32)
|
||||
|
||||
# BertForPreTraining forward:
|
||||
# def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||
# position_ids??=None, head_mask??=None, inputs_embeds??=None,
|
||||
# masked_lm_labels=None, next_sentence_label=None):
|
||||
#
|
||||
# create_and_check_bert_for_pretraining calls BertForPreTraining:
|
||||
# model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
|
||||
# masked_lm_labels=token_labels, next_sentence_label=sequence_labels)
|
||||
|
||||
def BertForPreTraining_descs(self):
|
||||
return ModelDescription(
|
||||
[self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, self.masked_lm_labels_desc, self.next_sentence_label_desc],
|
||||
# returns loss_desc if both masked_lm_labels_desc, next_sentence_label are provided
|
||||
# hidden_states_desc, attentions_desc shall be included according to config.output_attentions, config.output_hidden_states
|
||||
[self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc,
|
||||
#hidden_states_desc, attentions_desc
|
||||
])
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).to(self.device)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2).to(self.device)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size).to(self.device)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size).to(self.device)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels).to(self.device)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices).to(self.device)
|
||||
|
||||
config = BertConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels = self.prepare_config_and_inputs()
|
||||
|
||||
config.is_decoder = True
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
|
||||
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertModel(config=config)
|
||||
model.to(input_ids.device)
|
||||
model.eval()
|
||||
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
|
||||
# failed because there is not loss output
|
||||
model_desc = ModelDescription([self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc],
|
||||
[self.last_hidden_state_desc, self.pooler_output_desc])
|
||||
args_gradient_accumulation_steps = 8
|
||||
args_local_rank = 0
|
||||
args_world_size = 1
|
||||
args_fp16 = True
|
||||
args_allreduce_post_accumulation = True
|
||||
|
||||
model = ORTTrainer(model, None, model_desc, "LambOptimizer",
|
||||
map_optimizer_attributes=map_optimizer_attributes,
|
||||
learning_rate_description=IODescription('Learning_Rate', [1, ], torch.float32),
|
||||
device=self.device, postprocess_model=postprocess_model,
|
||||
gradient_accumulation_steps=args_gradient_accumulation_steps,
|
||||
world_rank=args_local_rank, world_size=args_world_size,
|
||||
use_mixed_precision=True if args_fp16 else False,
|
||||
allreduce_post_accumulation=True if args_allreduce_post_accumulation else False)
|
||||
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_bert_model_as_decoder(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask):
|
||||
model = BertModel(config)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask)
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states)
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertForMaskedLM(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels)
|
||||
|
||||
#####
|
||||
model_desc = ModelDescription([self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, self.masked_lm_labels_desc],
|
||||
[self.loss_desc, self.prediction_scores_desc])
|
||||
args_gradient_accumulation_steps = 8
|
||||
args_local_rank = 0
|
||||
args_world_size = 1
|
||||
args_fp16 = True
|
||||
args_allreduce_post_accumulation = True
|
||||
|
||||
model = ORTTrainer(model, None, model_desc, "LambOptimizer",
|
||||
map_optimizer_attributes=map_optimizer_attributes,
|
||||
learning_rate_description=IODescription('Learning_Rate', [1, ], torch.float32),
|
||||
device=self.device, postprocess_model=postprocess_model,
|
||||
gradient_accumulation_steps=args_gradient_accumulation_steps,
|
||||
world_rank=args_local_rank, world_size=args_world_size,
|
||||
use_mixed_precision=True if args_fp16 else False,
|
||||
allreduce_post_accumulation=True if args_allreduce_post_accumulation else False)
|
||||
model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels)
|
||||
|
||||
def create_and_check_bert_model_for_masked_lm_as_decoder(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask):
|
||||
model = BertForMaskedLM(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask)
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels, encoder_hidden_states=encoder_hidden_states)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertForNextSentencePrediction(config=config)
|
||||
model.eval()
|
||||
loss, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"seq_relationship_score": seq_relationship_score,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["seq_relationship_score"].size()),
|
||||
[self.batch_size, 2])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertForPreTraining(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
|
||||
masked_lm_labels=token_labels, next_sentence_label=sequence_labels)
|
||||
model_desc = ModelDescription([self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc,
|
||||
self.masked_lm_labels_desc, self.next_sentence_label_desc],
|
||||
[self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc])
|
||||
|
||||
import argparse
|
||||
args_ = argparse.Namespace(fp16=True, amp_opt_level='O1')
|
||||
|
||||
from collections import namedtuple
|
||||
MyArgs = namedtuple("MyArgs",
|
||||
"local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len")
|
||||
args = MyArgs(local_rank=0, world_size=1, max_steps=100, learning_rate=0.00001, warmup_proportion=0.01, batch_size=13, seq_len=7)
|
||||
|
||||
from train_with_ort_trainer import get_lr
|
||||
def get_lr_this_step(global_step):
|
||||
return get_lr(args, global_step)
|
||||
loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000)
|
||||
|
||||
option_gradient_accumulation_steps = [8]
|
||||
option_fp16 = [True, False]
|
||||
option_allreduce_post_accumulation = True
|
||||
option_use_internal_get_lr_this_step = False
|
||||
option_use_internal_loss_scaler = False
|
||||
# TODO: with with fetches
|
||||
|
||||
for gradient_accumulation_steps in option_gradient_accumulation_steps:
|
||||
for fp16 in option_fp16:
|
||||
for option_split_batch in BatchArgsOption:
|
||||
loss_ort, prediction_scores_ort, seq_relationship_score_ort =\
|
||||
run_test(model, model_desc, self.device, args, gradient_accumulation_steps, fp16,
|
||||
option_allreduce_post_accumulation,
|
||||
get_lr_this_step, option_use_internal_get_lr_this_step,
|
||||
loss_scaler, option_use_internal_loss_scaler,
|
||||
option_split_batch)
|
||||
|
||||
print(loss_ort)
|
||||
print(prediction_scores_ort)
|
||||
print(seq_relationship_score_ort)
|
||||
|
||||
def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertForQuestionAnswering(config=config)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = BertForSequenceClassification(config)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = BertForTokenClassification(config=config)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
config.num_choices = self.num_choices
|
||||
model = BertForMultipleChoice(config=config)
|
||||
model.eval()
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.num_choices])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, token_type_ids, input_mask,
|
||||
sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BertModelTest.BertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
|
||||
|
||||
# def test_config(self):
|
||||
# self.config_tester.run_common_tests()
|
||||
|
||||
# def test_bert_model(self, use_cuda=False):
|
||||
# # ^^ This could be a real fixture
|
||||
# if use_cuda:
|
||||
# self.model_tester.device = "cuda"
|
||||
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
# self.model_tester.create_and_check_bert_model(*config_and_inputs)
|
||||
|
||||
# def test_bert_model_as_decoder(self):
|
||||
# config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
# self.model_tester.create_and_check_bert_model_as_decoder(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
|
||||
|
||||
# def test_for_masked_lm_decoder(self):
|
||||
# config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
# self.model_tester.create_and_check_bert_model_for_masked_lm_as_decoder(*config_and_inputs)
|
||||
|
||||
# def test_for_multiple_choice(self):
|
||||
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
# self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
# def test_for_next_sequence_prediction(self):
|
||||
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
# self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
|
||||
|
||||
def test_for_pretraining(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
|
||||
|
||||
# def test_for_question_answering(self):
|
||||
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
# self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs)
|
||||
|
||||
# def test_for_sequence_classification(self):
|
||||
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
# self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
# def test_for_token_classification(self):
|
||||
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
# self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
|
||||
|
||||
# @pytest.mark.slow
|
||||
# def test_model_from_pretrained(self):
|
||||
# cache_dir = "/tmp/transformers_test/"
|
||||
# for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
# model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
# shutil.rmtree(cache_dir)
|
||||
# self.assertIsNotNone(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -309,6 +309,7 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op
|
|||
|
||||
# Other export options to use(this is for backward compatibility).
|
||||
other_export_options = {}
|
||||
|
||||
# This option was added after 1.4 release.
|
||||
if LooseVersion(torch.__version__) > LooseVersion('1.4.0'):
|
||||
other_export_options['enable_onnx_checker'] = False
|
||||
|
|
@ -630,6 +631,7 @@ class ORTTrainer():
|
|||
if loss_fn is not None:
|
||||
print("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.")
|
||||
# TODO: accept loss_fn as an onnx model. build self.onnx_model_ with model and loss_fn
|
||||
self.loss_fn_ = None
|
||||
|
||||
self.model_desc_ = model_desc
|
||||
self.input_desc_with_lr = [*self.model_desc_.inputs_, learning_rate_description]
|
||||
|
|
@ -664,9 +666,12 @@ class ORTTrainer():
|
|||
self.enable_grad_norm_clip_ = enable_grad_norm_clip
|
||||
self.frozen_weights_ = frozen_weights
|
||||
self.opset_version_ = _opset_version
|
||||
self.loss_scale_input_name = ''
|
||||
self.state_dict_ = None
|
||||
|
||||
# use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs.
|
||||
# see prepare_input_and_fetches for more details.
|
||||
self.loss_scale_input_name = 'default_loss_scale_input_name'
|
||||
|
||||
self._init_session()
|
||||
|
||||
def _init_session(self):
|
||||
|
|
@ -790,6 +795,15 @@ class ORTTrainer():
|
|||
input = input + (internal_learning_rate,)
|
||||
if internal_loss_scale is not None:
|
||||
input = input + (internal_loss_scale,)
|
||||
elif self.use_mixed_precision:
|
||||
# loss_scale input name is needed to call train_step, for example:
|
||||
# kwargs[model.loss_scale_input_name] = loss_scale
|
||||
# outputs = model.train_step(*args, **kwargs)
|
||||
# However, when first time train_step is called model.loss_scale_input_name is not set.
|
||||
# To workaround this problem, we use the special name 'default_loss_scale_input_name' to indicate
|
||||
# the loss_scale.
|
||||
if 'default_loss_scale_input_name' in kwargs.keys():
|
||||
input = input + (kwargs['default_loss_scale_input_name'],)
|
||||
|
||||
fetches = None
|
||||
if 'fetches' in kwargs:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
from orttraining_test_model_transform import add_name, fix_transpose, add_expand_shape
|
||||
from orttraining_test_layer_norm_transform import layer_norm_transform
|
||||
|
||||
def postprocess_model(model):
|
||||
add_name(model)
|
||||
|
||||
# remove transpose node if its input is a 2d weight which only feeds to the node
|
||||
fix_transpose(model)
|
||||
|
||||
add_expand_shape(model)
|
||||
|
||||
layer_norm_transform(model)
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
from enum import Enum
|
||||
import random
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from onnxruntime.capi.ort_trainer import generate_sample
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
def ids_tensor(shape, vocab_size, rng=None, name=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
if rng is None:
|
||||
rng = global_rng
|
||||
|
||||
total_dims = 1
|
||||
for dim in shape:
|
||||
total_dims *= dim
|
||||
|
||||
values = []
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.randint(0, vocab_size - 1))
|
||||
|
||||
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
|
||||
|
||||
|
||||
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
"""Creates a random float32 tensor of the shape within the vocab size."""
|
||||
if rng is None:
|
||||
rng = global_rng
|
||||
|
||||
total_dims = 1
|
||||
for dim in shape:
|
||||
total_dims *= dim
|
||||
|
||||
values = []
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.random() * scale)
|
||||
|
||||
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
|
||||
|
||||
|
||||
class OrtTestDataset(Dataset):
|
||||
def __init__(self, input_desc, seq_len, device):
|
||||
import copy
|
||||
self.input_desc_ = copy.deepcopy(input_desc)
|
||||
for input_desc in self.input_desc_:
|
||||
shape_ = []
|
||||
for i, axis in enumerate(input_desc.shape_):
|
||||
if axis == 'max_seq_len_in_batch':
|
||||
shape_ = shape_ + [seq_len, ]
|
||||
elif axis != 'batch':
|
||||
shape_ = input_desc.shape_[i]
|
||||
input_desc.shape_ = shape_
|
||||
self.device_ = device
|
||||
|
||||
def __len__(self):
|
||||
return 100
|
||||
|
||||
def __getitem__(self, item):
|
||||
input_batch = []
|
||||
for input_desc in self.input_desc_:
|
||||
input_sample = generate_sample(input_desc, self.device_)
|
||||
input_batch.append(input_sample)
|
||||
return input_batch
|
||||
|
||||
def create_ort_test_dataloader(input_desc, batch_size, seq_len, device):
|
||||
dataset = OrtTestDataset(input_desc, seq_len, device)
|
||||
return DataLoader(dataset, batch_size=batch_size)
|
||||
|
||||
class BatchArgsOption(Enum):
|
||||
List = 1
|
||||
Dict = 2
|
||||
ListAndDict = 3
|
||||
|
||||
def split_batch(batch, input_desc, args_count):
|
||||
total_argument_count = len(input_desc)
|
||||
# batch=[input_ids[batch, seglen], attention_mask[batch, seglen], token_type_ids[batch,seglen], token_type_ids[batch, seglen]]
|
||||
args = [] # (input_ids[batch, seglen], attention_mask[batch, seglen])
|
||||
kwargs = {} # {'token_type_ids': token_type_ids[batch,seglen], 'position_ids': token_type_ids[batch, seglen]}
|
||||
for i in range(args_count):
|
||||
args = args + [batch[i]]
|
||||
|
||||
for i in range(args_count, total_argument_count):
|
||||
kwargs[input_desc[i].name_] = batch[i]
|
||||
|
||||
return args, kwargs
|
||||
|
|
@ -0,0 +1,177 @@
|
|||
import onnx
|
||||
|
||||
def find_node(graph_proto, op_type):
|
||||
nodes = []
|
||||
map_input_node = {}
|
||||
for node in graph_proto.node:
|
||||
if node.op_type == op_type:
|
||||
map_input_node[node.input[0]] = node
|
||||
if op_type == 'Div' or op_type == 'Mul':
|
||||
map_input_node[node.input[1]] = node
|
||||
nodes.append(node)
|
||||
return nodes, map_input_node
|
||||
|
||||
def gen_attribute(key, value):
|
||||
attr = AttributeProto()
|
||||
attr.name = key
|
||||
attr.ints.extend(int(v) for v in value)
|
||||
attr.type = AttributeProto.INTS
|
||||
return attr
|
||||
|
||||
def layer_norm_transform(model_proto):
|
||||
# a layer norm subgraph
|
||||
# input
|
||||
# |
|
||||
# ReduceMean
|
||||
# __|____
|
||||
# | |
|
||||
# Sub Sub
|
||||
# | |
|
||||
# | Pow
|
||||
# | |
|
||||
# | ReduceMean
|
||||
# | |
|
||||
# | Add
|
||||
# | |
|
||||
# |__ __Sqrt
|
||||
# | |
|
||||
# Div
|
||||
# |
|
||||
# Mul
|
||||
# |
|
||||
# Add
|
||||
# |
|
||||
# output
|
||||
|
||||
graph_proto = model_proto.graph
|
||||
|
||||
_, map_input_Div = find_node(graph_proto, 'Div')
|
||||
|
||||
_, map_input_Sqrt = find_node(graph_proto, 'Sqrt')
|
||||
|
||||
_, map_input_Add = find_node(graph_proto, 'Add')
|
||||
|
||||
nodes_ReduceMean, map_input_ReduceMean = find_node(graph_proto, 'ReduceMean')
|
||||
|
||||
_, map_input_Pow = find_node(graph_proto, 'Pow')
|
||||
|
||||
_, map_input_Mul = find_node(graph_proto, 'Mul')
|
||||
|
||||
# find right side Sub (see the layer norm subgrapg)
|
||||
nodes_Sub = []
|
||||
map_input_Sub = {}
|
||||
for node in graph_proto.node:
|
||||
if node.op_type == 'Sub':
|
||||
if node.output[0] in map_input_Pow:
|
||||
nodes_Sub.append(node)
|
||||
map_input_Sub[node.input[1]] = node
|
||||
|
||||
# find first ReduceMean
|
||||
first_ReduceMean = []
|
||||
first_ReduceMean_outputs = []
|
||||
for node in nodes_ReduceMean:
|
||||
if node.output[0] in map_input_Sub:
|
||||
first_ReduceMean.append(node)
|
||||
first_ReduceMean_outputs.append(node.output[0])
|
||||
|
||||
# find constant node
|
||||
nodes_Constant = []
|
||||
map_output_Constant = {}
|
||||
for node in graph_proto.node:
|
||||
if node.op_type == 'Constant':
|
||||
nodes_Constant.append(node)
|
||||
map_output_Constant[node.output[0]] = node
|
||||
|
||||
id = 0
|
||||
removed_nodes = []
|
||||
layer_norm_nodes = []
|
||||
# Replace with layer norm
|
||||
for node in first_ReduceMean:
|
||||
layer_norm_input = []
|
||||
layer_norm_output = []
|
||||
layer_norm_input.append(node.input[0])
|
||||
|
||||
# collect nodes within a layer norm subgraph.
|
||||
# skip building layer norm node if there is a pattern miss-match.
|
||||
if node.output[0] not in map_input_Sub:
|
||||
continue
|
||||
|
||||
node_sub = map_input_Sub[node.output[0]]
|
||||
if node_sub.output[0] not in map_input_Pow:
|
||||
continue
|
||||
|
||||
node_pow = map_input_Pow[node_sub.output[0]]
|
||||
if node_pow.output[0] not in map_input_ReduceMean:
|
||||
continue
|
||||
|
||||
node_reduce = map_input_ReduceMean[node_pow.output[0]]
|
||||
if node_reduce.output[0] not in map_input_Add:
|
||||
continue
|
||||
|
||||
node_Add = map_input_Add[node_reduce.output[0]]
|
||||
if node_Add.output[0] not in map_input_Sqrt:
|
||||
continue
|
||||
|
||||
node_Sqrt = map_input_Sqrt[node_Add.output[0]]
|
||||
if node_Sqrt.output[0] not in map_input_Div:
|
||||
continue
|
||||
|
||||
node_Div = map_input_Div[node_Sqrt.output[0]]
|
||||
if node_Div.output[0] not in map_input_Mul:
|
||||
continue
|
||||
|
||||
node_Mul = map_input_Mul[node_Div.output[0]]
|
||||
|
||||
if node_Mul.input[0] != node_Div.output[0]:
|
||||
layer_norm_input.append(node_Mul.input[0])
|
||||
else:
|
||||
layer_norm_input.append(node_Mul.input[1])
|
||||
|
||||
if node_Mul.output[0] not in map_input_Add:
|
||||
continue
|
||||
|
||||
node_Add1 = map_input_Add[node_Mul.output[0]]
|
||||
layer_norm_input.append(node_Add1.input[1])
|
||||
|
||||
removed_nodes.append(node)
|
||||
removed_nodes.append(node_sub)
|
||||
removed_nodes.append(node_pow)
|
||||
removed_nodes.append(node_reduce)
|
||||
removed_nodes.append(node_Add)
|
||||
removed_nodes.append(node_Sqrt)
|
||||
removed_nodes.append(node_Div)
|
||||
removed_nodes.append(node_Mul)
|
||||
removed_nodes.append(node_Add1)
|
||||
removed_nodes.append(map_output_Constant[node_pow.input[1]])
|
||||
|
||||
removed_nodes.append(map_output_Constant[node_Add.input[1]])
|
||||
layer_norm_output.append(node_Add1.output[0])
|
||||
id = id + 1
|
||||
layer_norm_output.append('saved_mean_' + str(id))
|
||||
id = id + 1
|
||||
layer_norm_output.append('saved_inv_std_var_' + str(id))
|
||||
layer_norm = onnx.helper.make_node("LayerNormalization",
|
||||
layer_norm_input,
|
||||
layer_norm_output,
|
||||
"LayerNormalization_" + str(id),
|
||||
None,
|
||||
axis = node_reduce.attribute[0].ints[0],
|
||||
epsilon = 9.999999960041972e-13)
|
||||
layer_norm_nodes.append(layer_norm)
|
||||
|
||||
# remove left side Subs
|
||||
for node in graph_proto.node:
|
||||
if node.op_type == 'Sub':
|
||||
if node.input[1] in first_ReduceMean_outputs:
|
||||
removed_nodes.append(node)
|
||||
|
||||
all_nodes = []
|
||||
for node in graph_proto.node:
|
||||
if node not in removed_nodes:
|
||||
all_nodes.append(node)
|
||||
|
||||
for node in layer_norm_nodes:
|
||||
all_nodes.append(node)
|
||||
|
||||
graph_proto.ClearField("node")
|
||||
graph_proto.node.extend(all_nodes)
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
from onnx import numpy_helper
|
||||
|
||||
def add_name(model):
|
||||
i = 0
|
||||
for node in model.graph.node:
|
||||
node.name = '%s_%d' %(node.op_type, i)
|
||||
i += 1
|
||||
|
||||
def find_single_output_node(model, arg):
|
||||
result = []
|
||||
for node in model.graph.node:
|
||||
for input in node.input:
|
||||
if input == arg:
|
||||
result.append(node)
|
||||
return result[0] if len(result) == 1 else None
|
||||
|
||||
def find_input_as_initializer(model, arg):
|
||||
for initializer in model.graph.initializer:
|
||||
if initializer.name == arg:
|
||||
return initializer
|
||||
return None
|
||||
|
||||
def get_node_index(model, node):
|
||||
for i, n in enumerate(model.graph.node):
|
||||
if n == node:
|
||||
return i
|
||||
return None
|
||||
|
||||
def replace_input_arg(model, arg, new_arg):
|
||||
for node in model.graph.node:
|
||||
for i in range(len(node.input)):
|
||||
if node.input[i] == arg:
|
||||
node.input[i] = new_arg
|
||||
|
||||
def find_weight_index(model, name):
|
||||
for index, w in enumerate(model.graph.initializer):
|
||||
if w.name == name:
|
||||
return index
|
||||
index += 1
|
||||
return None
|
||||
|
||||
def fix_transpose(model):
|
||||
"""
|
||||
remove transpose node if its input is a 2d weight which only feeds to the node.
|
||||
"""
|
||||
|
||||
# Find transpose nodes with initializer weight as input.
|
||||
# The input weight needs to be only feeded into the transpose node.
|
||||
# Collect these nodes and weights.
|
||||
transpose = []
|
||||
for node in model.graph.node:
|
||||
if node.op_type == 'Transpose':
|
||||
weight = find_input_as_initializer(model, node.input[0])
|
||||
if weight is not None:
|
||||
result = []
|
||||
for n in model.graph.node:
|
||||
for input in n.input:
|
||||
if input == weight.name:
|
||||
result.append(n)
|
||||
if len(result) > 1:
|
||||
continue
|
||||
perm = node.attribute[0]
|
||||
assert perm.name == 'perm'
|
||||
perm = perm.ints
|
||||
assert len(perm) == 2 and perm[0] == 1 and perm[1] == 0
|
||||
transpose.append((get_node_index(model, node), weight))
|
||||
|
||||
# Transpose collected weights and add it to the model initializers.
|
||||
# The transposed weight initializers become inputs to the transpose nodes' recipient nodes.
|
||||
for t in transpose:
|
||||
node = model.graph.node[t[0]]
|
||||
weight = numpy_helper.to_array(t[1])
|
||||
assert len(weight.shape) == 2
|
||||
weight = weight.transpose(perm)
|
||||
new_weight = numpy_helper.from_array(weight, "%s_transposed" % t[1].name)
|
||||
model.graph.initializer.extend([new_weight])
|
||||
replace_input_arg(model, node.output[0], new_weight.name)
|
||||
|
||||
# collected transpose nodes can be removed.
|
||||
transpose.sort(reverse=True)
|
||||
for t in transpose:
|
||||
del model.graph.node[t[0]]
|
||||
|
||||
# the original weight initializer can be removed.
|
||||
# (remember that a wight needs only to be feeded into the transpose node when collecting wights)
|
||||
old_ws = []
|
||||
for t in transpose:
|
||||
if find_single_output_node(model, t[1].name) is None:
|
||||
old_ws.append(find_weight_index(model, t[1].name))
|
||||
old_ws.sort(reverse=True)
|
||||
for w_i in old_ws:
|
||||
del model.graph.initializer[w_i]
|
||||
|
||||
def add_expand_shape(model):
|
||||
"""
|
||||
this method is very specific to the Bert model where there is a solo Expand op.
|
||||
training backend requires the op's output shape. it is the same as the shape of the model (single) input.
|
||||
"""
|
||||
|
||||
expand_node = [n for n in model.graph.node if n.op_type == 'Expand']
|
||||
if len(expand_node) != 1:
|
||||
raise "cannot find the single expand node in the BERT model."
|
||||
return
|
||||
expand_out = model.graph.value_info.add()
|
||||
expand_out.name = expand_node[0].output[0] # base: '421' # tiny: '85'
|
||||
expand_out.type.CopyFrom(model.graph.input[0].type)
|
||||
|
|
@ -0,0 +1,198 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
import os
|
||||
|
||||
from transformers import (BertConfig, BertForPreTraining, BertModel)
|
||||
|
||||
from orttraining_test_data_loader import ids_tensor, BatchArgsOption
|
||||
from orttraining_test_utils import run_test, get_lr
|
||||
|
||||
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler
|
||||
|
||||
import torch
|
||||
|
||||
class BertModelTest(unittest.TestCase):
|
||||
|
||||
class BertModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
device='cpu',
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.device = device
|
||||
|
||||
# 1. superset of bert input/output descs
|
||||
# see BertPreTrainedModel doc
|
||||
self.input_ids_desc = IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size)
|
||||
self.attention_mask_desc = IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2)
|
||||
self.token_type_ids_desc = IODescription('token_type_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2)
|
||||
self.position_ids_desc = IODescription('position_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.max_position_embeddings)
|
||||
self.head_mask_desc = IODescription('head_mask', [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2)
|
||||
self.inputs_embeds_desc = IODescription('inputs_embeds', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
|
||||
|
||||
self.encoder_hidden_states_desc = IODescription('encoder_hidden_states', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
|
||||
self.encoder_attention_mask_desc = IODescription('encoder_attention_mask', ['batch', 'max_seq_len_in_batch'], torch.float32)
|
||||
|
||||
# see BertForPreTraining doc
|
||||
self.masked_lm_labels_desc = IODescription('masked_lm_labels', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size)
|
||||
self.next_sentence_label_desc = IODescription('next_sentence_label', ['batch',], torch.int64, num_classes=2)
|
||||
|
||||
# outputs
|
||||
self.loss_desc = IODescription('loss', [1,], torch.float32)
|
||||
self.prediction_scores_desc = IODescription('prediction_scores', ['batch', 'max_seq_len_in_batch', self.vocab_size], torch.float32)
|
||||
|
||||
self.seq_relationship_scores_desc = IODescription('seq_relationship_scores', ['batch', 2], torch.float32) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32)
|
||||
self.hidden_states_desc = IODescription('hidden_states', [self.num_hidden_layers, 'batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
|
||||
self.attentions_desc = IODescription('attentions', [self.num_hidden_layers, 'batch', self.num_attention_heads, 'max_seq_len_in_batch', 'max_seq_len_in_batch'], torch.float32)
|
||||
self.last_hidden_state_desc = IODescription('last_hidden_state', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32)
|
||||
self.pooler_output_desc = IODescription('pooler_output', ['batch', self.hidden_size], torch.float32)
|
||||
|
||||
def BertForPreTraining_descs(self):
|
||||
return ModelDescription(
|
||||
[self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, self.masked_lm_labels_desc, self.next_sentence_label_desc],
|
||||
# returns loss_desc if both masked_lm_labels_desc, next_sentence_label are provided
|
||||
# hidden_states_desc, attentions_desc shall be included according to config.output_attentions, config.output_hidden_states
|
||||
[self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc,
|
||||
#hidden_states_desc, attentions_desc
|
||||
])
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).to(self.device)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2).to(self.device)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size).to(self.device)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size).to(self.device)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels).to(self.device)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices).to(self.device)
|
||||
|
||||
config = BertConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertForPreTraining(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
|
||||
masked_lm_labels=token_labels, next_sentence_label=sequence_labels)
|
||||
model_desc = ModelDescription([self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc,
|
||||
self.masked_lm_labels_desc, self.next_sentence_label_desc],
|
||||
[self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc])
|
||||
|
||||
from collections import namedtuple
|
||||
MyArgs = namedtuple("MyArgs",
|
||||
"local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len")
|
||||
args = MyArgs(local_rank=0, world_size=1, max_steps=100, learning_rate=0.00001, warmup_proportion=0.01, batch_size=13, seq_len=7)
|
||||
|
||||
def get_lr_this_step(global_step):
|
||||
return get_lr(args, global_step)
|
||||
loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000)
|
||||
|
||||
# It would be better to test both with/without mixed precision and allreduce_post_accumulation.
|
||||
# However, stress test of all the 4 cases is not stable at lease on the test machine.
|
||||
# There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases.
|
||||
option_fp16 = [True]
|
||||
option_allreduce_post_accumulation = [True]
|
||||
option_gradient_accumulation_steps = [1, 8]
|
||||
option_use_internal_get_lr_this_step = [True, False]
|
||||
option_use_internal_loss_scaler = [True, False]
|
||||
option_split_batch = [BatchArgsOption.ListAndDict]
|
||||
|
||||
for fp16 in option_fp16:
|
||||
for allreduce_post_accumulation in option_allreduce_post_accumulation:
|
||||
for gradient_accumulation_steps in option_gradient_accumulation_steps:
|
||||
for use_internal_get_lr_this_step in option_use_internal_get_lr_this_step:
|
||||
for use_internal_loss_scaler in option_use_internal_loss_scaler:
|
||||
for split_batch in option_split_batch:
|
||||
print("gradient_accumulation_steps:", gradient_accumulation_steps)
|
||||
print("use_internal_loss_scaler:", use_internal_loss_scaler)
|
||||
loss_ort, prediction_scores_ort, seq_relationship_score_ort =\
|
||||
run_test(model, model_desc, self.device, args, gradient_accumulation_steps, fp16,
|
||||
allreduce_post_accumulation,
|
||||
get_lr_this_step, use_internal_get_lr_this_step,
|
||||
loss_scaler, use_internal_loss_scaler,
|
||||
split_batch)
|
||||
|
||||
print(loss_ort)
|
||||
print(prediction_scores_ort)
|
||||
print(seq_relationship_score_ort)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BertModelTest.BertModelTester(self)
|
||||
|
||||
def test_for_pretraining(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
118
orttraining/orttraining/test/python/orttraining_test_utils.py
Normal file
118
orttraining/orttraining/test/python/orttraining_test_utils.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
import torch
|
||||
|
||||
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription
|
||||
|
||||
from orttraining_test_data_loader import create_ort_test_dataloader, BatchArgsOption, split_batch
|
||||
from orttraining_test_bert_postprocess import postprocess_model
|
||||
|
||||
def warmup_cosine(x, warmup=0.002):
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 0.5 * (1.0 + torch.cos(math.pi * x))
|
||||
|
||||
def warmup_constant(x, warmup=0.002):
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 1.0
|
||||
|
||||
def warmup_linear(x, warmup=0.002):
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return max((x - 1. )/ (warmup - 1.), 0.)
|
||||
|
||||
def warmup_poly(x, warmup=0.002, degree=0.5):
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return (1.0 - x)**degree
|
||||
|
||||
|
||||
SCHEDULES = {
|
||||
'warmup_cosine':warmup_cosine,
|
||||
'warmup_constant':warmup_constant,
|
||||
'warmup_linear':warmup_linear,
|
||||
'warmup_poly':warmup_poly,
|
||||
}
|
||||
|
||||
def get_lr(args, training_steps, schedule='warmup_poly'):
|
||||
if args.max_steps == -1:
|
||||
return args.learning_rate
|
||||
|
||||
schedule_fct = SCHEDULES[schedule]
|
||||
return args.learning_rate * schedule_fct(training_steps / args.max_steps, args.warmup_proportion)
|
||||
|
||||
def map_optimizer_attributes(name):
|
||||
no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
|
||||
no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys)
|
||||
if no_decay:
|
||||
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}
|
||||
else:
|
||||
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6}
|
||||
|
||||
def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16,
|
||||
allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler,
|
||||
batch_args_option):
|
||||
dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, device)
|
||||
|
||||
model = ORTTrainer(model, None, model_desc, "LambOptimizer",
|
||||
map_optimizer_attributes=map_optimizer_attributes,
|
||||
learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32),
|
||||
device=device, postprocess_model=postprocess_model,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
# BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6
|
||||
world_rank=args.local_rank, world_size=args.world_size,
|
||||
use_mixed_precision=fp16,
|
||||
allreduce_post_accumulation=allreduce_post_accumulation,
|
||||
get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None,
|
||||
loss_scaler=loss_scaler if use_internal_loss_scaler else None,
|
||||
_opset_version=12)
|
||||
|
||||
# trainig loop
|
||||
eval_batch = None
|
||||
model.train()
|
||||
for step, batch in enumerate(dataloader):
|
||||
if eval_batch is None:
|
||||
eval_batch = batch
|
||||
|
||||
if not use_internal_get_lr_this_step:
|
||||
lr = get_lr_this_step(step)
|
||||
learning_rate = torch.tensor([lr])
|
||||
|
||||
if not use_internal_loss_scaler and fp16:
|
||||
loss_scale = torch.tensor([loss_scaler.loss_scale_])
|
||||
|
||||
if batch_args_option == BatchArgsOption.List:
|
||||
if not use_internal_get_lr_this_step:
|
||||
batch = batch + [learning_rate, ]
|
||||
if not use_internal_loss_scaler and fp16:
|
||||
batch = batch + [loss_scale, ]
|
||||
outputs = model(*batch)
|
||||
elif batch_args_option == BatchArgsOption.Dict:
|
||||
args, kwargs = split_batch(batch, model_desc.inputs_, 0)
|
||||
if not use_internal_get_lr_this_step:
|
||||
kwargs['Learning_Rate'] = learning_rate
|
||||
if not use_internal_loss_scaler and fp16:
|
||||
kwargs[model.loss_scale_input_name] = loss_scale
|
||||
outputs = model(*args, **kwargs)
|
||||
else:
|
||||
args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs
|
||||
args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
|
||||
if not use_internal_get_lr_this_step:
|
||||
kwargs['Learning_Rate'] = learning_rate
|
||||
if not use_internal_loss_scaler and fp16:
|
||||
kwargs[model.loss_scale_input_name] = loss_scale
|
||||
outputs = model(*args, **kwargs)
|
||||
|
||||
# eval
|
||||
model.eval()
|
||||
if batch_args_option == BatchArgsOption.List:
|
||||
outputs = model(*batch)
|
||||
elif batch_args_option == BatchArgsOption.Dict:
|
||||
args, kwargs = split_batch(batch, model_desc.inputs_, 0)
|
||||
outputs = model(*args, **kwargs)
|
||||
else:
|
||||
args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs
|
||||
args, kwargs = split_batch(batch, model_desc.inputs_, args_count)
|
||||
outputs = model(*args, **kwargs)
|
||||
|
||||
return (output.cpu().numpy() for output in outputs)
|
||||
|
||||
|
|
@ -19,14 +19,6 @@ from torchvision import datasets, transforms
|
|||
import numpy as np
|
||||
import os
|
||||
|
||||
# TODO: remove after ready for CV
|
||||
# import sys
|
||||
# sys.path.insert(0, '/bert_ort/liqun/onnxruntime/build/Linux/Debug/')
|
||||
# import onnxruntime as ort
|
||||
|
||||
# sys.path.insert(0, '/bert_ort/liqun/onnxruntime/onnxruntime/python/')
|
||||
# from ort_trainer import IODescription, ModelDescription, ORTTrainer, ORTModel
|
||||
|
||||
from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer, ORTModel
|
||||
from mpi4py import MPI
|
||||
from onnxruntime.capi._pybind_state import set_cuda_device_id
|
||||
|
|
|
|||
|
|
@ -1006,18 +1006,27 @@ def adb_push(source_dir, src, dest, **kwargs):
|
|||
def adb_shell(*args, **kwargs):
|
||||
return run_subprocess(['adb', 'shell', *args], **kwargs)
|
||||
|
||||
def run_training_python_frontend_e2e_tests(args, cwd, dll_path):
|
||||
def run_training_python_frontend_e2e_tests(args, cwd):
|
||||
# frontend tests are to be added here:
|
||||
log.info("Running python frontend e2e tests.")
|
||||
run_subprocess(
|
||||
[sys.executable, 'onnxruntime_test_ort_trainer_with_mixed_precision.py'],
|
||||
cwd=cwd, dll_path=dll_path)
|
||||
run_subprocess([sys.executable, 'orttraining_test_transformers.py'], cwd=cwd)
|
||||
|
||||
run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer.py'], cwd=cwd)
|
||||
|
||||
run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer_with_mixed_precision.py'], cwd=cwd)
|
||||
|
||||
def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs,
|
||||
enable_tvm=False, enable_tensorrt=False):
|
||||
for config in configs:
|
||||
log.info("Running tests for %s configuration", config)
|
||||
cwd = get_config_build_dir(build_dir, config)
|
||||
|
||||
if args.enable_training and args.use_cuda and args.enable_training_python_frontend_e2e_tests:
|
||||
# run frontend tests for orttraining-linux-gpu-frontend_test-ci-pipeline.
|
||||
# this is not a PR merge test so skip other tests.
|
||||
run_training_python_frontend_e2e_tests(args, cwd=cwd)
|
||||
continue
|
||||
|
||||
android_x86_64 = args.android_abi == 'x86_64'
|
||||
if android_x86_64:
|
||||
run_subprocess(os.path.join(
|
||||
|
|
@ -1099,10 +1108,6 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs,
|
|||
[sys.executable, 'onnxruntime_test_training_unit_tests.py'],
|
||||
cwd=cwd, dll_path=dll_path)
|
||||
|
||||
# run additional frontend tests for orttraining-linux-gpu-frontend_test_ci-pipeline
|
||||
if args.enable_training_python_frontend_e2e_tests:
|
||||
run_training_python_frontend_e2e_tests(args, cwd=cwd, dll_path=dll_path)
|
||||
|
||||
try:
|
||||
import onnx # noqa
|
||||
# gen_test_models.py used by onnx_test requires scipy.
|
||||
|
|
|
|||
|
|
@ -112,6 +112,9 @@ elif [ $DEVICE_TYPE = "gpu" ]; then
|
|||
if [[ $BUILD_EXTR_PAR = *--enable_training* ]]; then
|
||||
${PYTHON_EXE} -m pip install --upgrade --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html
|
||||
fi
|
||||
if [[ $BUILD_EXTR_PAR = *--enable_training_python_frontend_e2e_tests* ]]; then
|
||||
${PYTHON_EXE} -m pip install transformers
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue