diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index e6b68b3d37..eeb581d16a 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -616,7 +616,12 @@ class ORTModule(torch.nn.Module): for input in backward_graph_inputs: if input in forward_graph_outputs: # inputs of backward graph that are also outputs from forward graph need to be added to backward graph input - add_input(backward_model, input, tensor_elem_types[input] if input in tensor_elem_types else 1) + # TODO: thiagofc: BERT: Remove this once graph splitter can handle unspecified optional input (without type) + input_type = tensor_elem_types[input] if input in tensor_elem_types else 1 + if input in {'1835', '1813', '1781','1760', '1683','1651','1630','1553','1521','1500','1423','1391','1370','1293','1261','1240','1163','1131','1110','1033','1001','980','871', + '267','330','351','383','460','481','513','590','611','643','720','741','773','850','903'}: + input_type = 9 + add_input(backward_model, input, input_type) elif input in forward_graph_initializer_names: # inputs from forward graph initializers need to be added to backward graph input add_input_from_initializer(backward_model, initializers[input]) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py index 8a93aa862a..fd6fea783c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_basic.py @@ -136,6 +136,10 @@ def main(): print('Training MNIST on ORTModule....') model = ORTModule(model) + # TODO: change it to False to stop saving ONNX models + model._save_onnx = True + model._save_onnx_prefix = 'MNIST' + # Set log level numeric_level = getattr(logging, args.log_level.upper(), None) if not isinstance(numeric_level, int): diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py index 2bb30604bc..724f7e17ce 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py @@ -1,3 +1,7 @@ + +import pdb + +import argparse import torch import wget import os @@ -18,17 +22,28 @@ import datetime import onnxruntime from onnxruntime.training import ORTModule -# 1. Device setup -# TODO: Hard-coding for CPU for ORTModule -# if torch.cuda.is_available(): -# device = torch.device("cuda") -# print('There are %d GPU(s) available.' % torch.cuda.device_count()) -# print('We will use the GPU:', torch.cuda.get_device_name(0)) -# else: -# print('No GPU available, using the CPU instead.') -# device = torch.device("cpu") -device = torch.device("cpu") +# 0. Common stuff +parser = argparse.ArgumentParser(description='PyTorch MNIST Example') +parser.add_argument('--pytorch-only', action='store_true', default=False, + help='disables ONNX Runtime training') +parser.add_argument('--view-graphs', action='store_true', default=False, + help='views forward and backward graphs') +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') +parser.add_argument('--epochs', type=int, default=4, metavar='N', + help='number of epochs to train (default: 4)') +args = parser.parse_args() + + +# 1. Device setup +if torch.cuda.is_available() and not args.no_cuda: + device = torch.device("cuda") + print('There are %d GPU(s) available.' % torch.cuda.device_count()) + print('We will use the GPU:', torch.cuda.get_device_name(0)) +else: + print('No GPU available, using the CPU instead.') + device = torch.device("cpu") # 2. Loading CoLA Dataset print('Downloading dataset...') @@ -141,15 +156,17 @@ model = BertForSequenceClassification.from_pretrained( output_attentions = False, # Whether the model returns attentions weights. output_hidden_states = False, # Whether the model returns all hidden-states. ) -model = ORTModule(model) + +if not args.pytorch_only: + model = ORTModule(model) # TODO: change it to False to stop saving ONNX models model._save_onnx = True model._save_onnx_prefix = 'BertForSequenceClassification' # Tell pytorch to run this model on the GPU. -# TODO: Hard coding it to CPU for ORTModule -# model.cuda() +if torch.cuda.is_available() and not args.no_cuda: + model.cuda() # Note: AdamW is a class from the huggingface library (as opposed to pytorch) optimizer = AdamW(model.parameters(), @@ -157,11 +174,9 @@ optimizer = AdamW(model.parameters(), eps = 1e-8 # args.adam_epsilon - default is 1e-8. ) -# Number of training epochs (authors recommend between 2 and 4) -epochs = 4 - +# Authors recommend between 2 and 4 epochs # Total number of training steps is number of batches * number of epochs. -total_steps = len(train_dataloader) * epochs +total_steps = len(train_dataloader) * args.epochs # Create the learning rate scheduler. scheduler = get_linear_schedule_with_warmup(optimizer, @@ -193,8 +208,8 @@ random.seed(seed_val) np.random.seed(seed_val) torch.manual_seed(seed_val) onnxruntime.set_seed(seed_val) -# TODO: We are not using CUDA for ORTModule just yet -# torch.cuda.manual_seed_all(seed_val) +if torch.cuda.is_available() and not args.no_cuda: + torch.cuda.manual_seed_all(seed_val) # Store the average loss after each epoch so we can plot them. loss_values = [] @@ -202,10 +217,10 @@ loss_values = [] # ======================================== # Training # ======================================== -for epoch_i in range(0, epochs): +for epoch_i in range(0, args.epochs): # Perform one full pass over the training set. print("") - print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs)) + print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, args.epochs)) # Measure how long the training epoch takes. t0 = time.time() @@ -259,14 +274,22 @@ for epoch_i in range(0, epochs): attention_mask=b_input_mask, labels=b_labels) + if args.view_graphs: + import torchviz + pytorch_backward_graph = torchviz.make_dot(outputs[0], params=dict(list(model.named_parameters()))) + pytorch_backward_graph.view() + # The call to `model` always returns a tuple, so we need to pull the # loss value out of the tuple. + # pdb.set_trace() loss = outputs[0] # Accumulate the training loss over all of the batches so that we can # calculate the average loss at the end. `loss` is a Tensor containing a # single value; the `.item()` function just returns the Python value # from the tensor. + print(loss.shape) + print(loss) total_loss += loss.item() # total_loss += loss