mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Hard-code input types for DropoutGrad on BERT
This commit is contained in:
parent
3b267d1d60
commit
d4917f2d65
3 changed files with 54 additions and 22 deletions
|
|
@ -616,7 +616,12 @@ class ORTModule(torch.nn.Module):
|
|||
for input in backward_graph_inputs:
|
||||
if input in forward_graph_outputs:
|
||||
# inputs of backward graph that are also outputs from forward graph need to be added to backward graph input
|
||||
add_input(backward_model, input, tensor_elem_types[input] if input in tensor_elem_types else 1)
|
||||
# TODO: thiagofc: BERT: Remove this once graph splitter can handle unspecified optional input (without type)
|
||||
input_type = tensor_elem_types[input] if input in tensor_elem_types else 1
|
||||
if input in {'1835', '1813', '1781','1760', '1683','1651','1630','1553','1521','1500','1423','1391','1370','1293','1261','1240','1163','1131','1110','1033','1001','980','871',
|
||||
'267','330','351','383','460','481','513','590','611','643','720','741','773','850','903'}:
|
||||
input_type = 9
|
||||
add_input(backward_model, input, input_type)
|
||||
elif input in forward_graph_initializer_names:
|
||||
# inputs from forward graph initializers need to be added to backward graph input
|
||||
add_input_from_initializer(backward_model, initializers[input])
|
||||
|
|
|
|||
|
|
@ -136,6 +136,10 @@ def main():
|
|||
print('Training MNIST on ORTModule....')
|
||||
model = ORTModule(model)
|
||||
|
||||
# TODO: change it to False to stop saving ONNX models
|
||||
model._save_onnx = True
|
||||
model._save_onnx_prefix = 'MNIST'
|
||||
|
||||
# Set log level
|
||||
numeric_level = getattr(logging, args.log_level.upper(), None)
|
||||
if not isinstance(numeric_level, int):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
|
||||
import pdb
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import wget
|
||||
import os
|
||||
|
|
@ -18,17 +22,28 @@ import datetime
|
|||
import onnxruntime
|
||||
from onnxruntime.training import ORTModule
|
||||
|
||||
# 1. Device setup
|
||||
# TODO: Hard-coding for CPU for ORTModule
|
||||
|
||||
# if torch.cuda.is_available():
|
||||
# device = torch.device("cuda")
|
||||
# print('There are %d GPU(s) available.' % torch.cuda.device_count())
|
||||
# print('We will use the GPU:', torch.cuda.get_device_name(0))
|
||||
# else:
|
||||
# print('No GPU available, using the CPU instead.')
|
||||
# device = torch.device("cpu")
|
||||
device = torch.device("cpu")
|
||||
# 0. Common stuff
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--pytorch-only', action='store_true', default=False,
|
||||
help='disables ONNX Runtime training')
|
||||
parser.add_argument('--view-graphs', action='store_true', default=False,
|
||||
help='views forward and backward graphs')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--epochs', type=int, default=4, metavar='N',
|
||||
help='number of epochs to train (default: 4)')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# 1. Device setup
|
||||
if torch.cuda.is_available() and not args.no_cuda:
|
||||
device = torch.device("cuda")
|
||||
print('There are %d GPU(s) available.' % torch.cuda.device_count())
|
||||
print('We will use the GPU:', torch.cuda.get_device_name(0))
|
||||
else:
|
||||
print('No GPU available, using the CPU instead.')
|
||||
device = torch.device("cpu")
|
||||
|
||||
# 2. Loading CoLA Dataset
|
||||
print('Downloading dataset...')
|
||||
|
|
@ -141,15 +156,17 @@ model = BertForSequenceClassification.from_pretrained(
|
|||
output_attentions = False, # Whether the model returns attentions weights.
|
||||
output_hidden_states = False, # Whether the model returns all hidden-states.
|
||||
)
|
||||
model = ORTModule(model)
|
||||
|
||||
if not args.pytorch_only:
|
||||
model = ORTModule(model)
|
||||
|
||||
# TODO: change it to False to stop saving ONNX models
|
||||
model._save_onnx = True
|
||||
model._save_onnx_prefix = 'BertForSequenceClassification'
|
||||
|
||||
# Tell pytorch to run this model on the GPU.
|
||||
# TODO: Hard coding it to CPU for ORTModule
|
||||
# model.cuda()
|
||||
if torch.cuda.is_available() and not args.no_cuda:
|
||||
model.cuda()
|
||||
|
||||
# Note: AdamW is a class from the huggingface library (as opposed to pytorch)
|
||||
optimizer = AdamW(model.parameters(),
|
||||
|
|
@ -157,11 +174,9 @@ optimizer = AdamW(model.parameters(),
|
|||
eps = 1e-8 # args.adam_epsilon - default is 1e-8.
|
||||
)
|
||||
|
||||
# Number of training epochs (authors recommend between 2 and 4)
|
||||
epochs = 4
|
||||
|
||||
# Authors recommend between 2 and 4 epochs
|
||||
# Total number of training steps is number of batches * number of epochs.
|
||||
total_steps = len(train_dataloader) * epochs
|
||||
total_steps = len(train_dataloader) * args.epochs
|
||||
|
||||
# Create the learning rate scheduler.
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer,
|
||||
|
|
@ -193,8 +208,8 @@ random.seed(seed_val)
|
|||
np.random.seed(seed_val)
|
||||
torch.manual_seed(seed_val)
|
||||
onnxruntime.set_seed(seed_val)
|
||||
# TODO: We are not using CUDA for ORTModule just yet
|
||||
# torch.cuda.manual_seed_all(seed_val)
|
||||
if torch.cuda.is_available() and not args.no_cuda:
|
||||
torch.cuda.manual_seed_all(seed_val)
|
||||
|
||||
# Store the average loss after each epoch so we can plot them.
|
||||
loss_values = []
|
||||
|
|
@ -202,10 +217,10 @@ loss_values = []
|
|||
# ========================================
|
||||
# Training
|
||||
# ========================================
|
||||
for epoch_i in range(0, epochs):
|
||||
for epoch_i in range(0, args.epochs):
|
||||
# Perform one full pass over the training set.
|
||||
print("")
|
||||
print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
|
||||
print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, args.epochs))
|
||||
|
||||
# Measure how long the training epoch takes.
|
||||
t0 = time.time()
|
||||
|
|
@ -259,14 +274,22 @@ for epoch_i in range(0, epochs):
|
|||
attention_mask=b_input_mask,
|
||||
labels=b_labels)
|
||||
|
||||
if args.view_graphs:
|
||||
import torchviz
|
||||
pytorch_backward_graph = torchviz.make_dot(outputs[0], params=dict(list(model.named_parameters())))
|
||||
pytorch_backward_graph.view()
|
||||
|
||||
# The call to `model` always returns a tuple, so we need to pull the
|
||||
# loss value out of the tuple.
|
||||
# pdb.set_trace()
|
||||
loss = outputs[0]
|
||||
|
||||
# Accumulate the training loss over all of the batches so that we can
|
||||
# calculate the average loss at the end. `loss` is a Tensor containing a
|
||||
# single value; the `.item()` function just returns the Python value
|
||||
# from the tensor.
|
||||
print(loss.shape)
|
||||
print(loss)
|
||||
total_loss += loss.item()
|
||||
# total_loss += loss
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue