Refactor MNIST and BERT classifier to add time measures

This commit is contained in:
Thiago Crepaldi 2020-11-19 15:04:54 -08:00
parent 395e082bc3
commit 07f5ae95e5
3 changed files with 86 additions and 21 deletions

View file

@ -155,6 +155,7 @@ class ORTModule(torch.nn.Module):
self._onnx_training = ORTModule._get_forward_graph(self._original_module, *inputs, **kwargs)
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
# TODO: PyTorch exporter bug: changes the initializer order
initializer_names = [p[0] for p in self._original_module.named_parameters()]
onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = \

View file

@ -32,6 +32,7 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
# Measure how long the training epoch takes.
t0 = time.time()
start_time = t0
# Reset the total loss for this epoch.
total_loss = 0
@ -90,11 +91,6 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
None,
b_labels)
if args.view_graphs:
import torchviz
pytorch_backward_graph = torchviz.make_dot(outputs[0], params=dict(list(model.named_parameters())))
pytorch_backward_graph.view()
# The call to `model` always returns a tuple, so we need to pull the
# loss value out of the tuple.
loss = outputs[0]
@ -102,10 +98,17 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
# Progress update every 40 batches.
if step % args.log_interval == 0 and not step == 0:
# Calculate elapsed time in minutes.
elapsed = format_time(time.time() - t0)
curr_time = time.time()
elapsed_time = curr_time - start_time
# Report progress.
print(f'Batch {step} of {len(train_dataloader)}. Elapsed: {elapsed}. Loss: {loss.item()}')
print(f'Batch {step:4} of {len(train_dataloader):4}. Execution time: {elapsed_time:.4f}. Loss: {loss.item():.4f}')
start_time = curr_time
if args.view_graphs:
import torchviz
pytorch_backward_graph = torchviz.make_dot(outputs[0], params=dict(list(model.named_parameters())))
pytorch_backward_graph.view()
# Accumulate the training loss over all of the batches so that we can
# calculate the average loss at the end. `loss` is a Tensor containing a
@ -131,8 +134,10 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
# Calculate the average loss over the training data.
avg_train_loss = total_loss / len(train_dataloader)
epoch_time = time.time() - t0
print("\n Average training loss: {0:.2f}".format(avg_train_loss))
print(" Training epoch took: {:}".format(format_time(time.time() - t0)))
print(" Training epoch took: {:.4f}s".format(epoch_time))
return epoch_time
def test(model, validation_dataloader, device, args):
# ========================================
@ -142,12 +147,12 @@ def test(model, validation_dataloader, device, args):
# our validation set.
print("\nRunning Validation...")
t0 = time.time()
# Put the model in evaluation mode--the dropout layers behave differently
# during evaluation.
model.eval()
t0 = time.time()
# Tracking variables
eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0
@ -207,8 +212,10 @@ def test(model, validation_dataloader, device, args):
nb_eval_steps += 1
# Report the final accuracy for this validation run.
epoch_time = time.time() - t0
print(" Accuracy: {0:.2f}".format(eval_accuracy/nb_eval_steps))
print(" Validation took: {:}".format(format_time(time.time() - t0)))
print(" Validation took: {:.4f}s".format(epoch_time))
return epoch_time
def load_dataset(args):
# 2. Loading CoLA Dataset
@ -417,9 +424,20 @@ def main():
torch.cuda.manual_seed_all(args.seed)
# 4. Train loop (fine-tune)
total_training_time, total_test_time, epoch_0_training = 0, 0, 0
for epoch_i in range(0, args.epochs):
train(model, optimizer, scheduler, train_dataloader, epoch_i, device, args)
test(model, validation_dataloader, device, args)
total_training_time += train(model, optimizer, scheduler, train_dataloader, epoch_i, device, args)
if not args.pytorch_only and epoch_i == 0:
epoch_0_training = total_training_time
total_test_time += test(model, validation_dataloader, device, args)
print('\n======== Global stats ========')
if not args.pytorch_only:
estimated_export = epoch_0_training - (total_training_time - epoch_0_training)/(args.epochs-1)
print(" Estimated ONNX export took: {:.4f}s".format(estimated_export))
print(" Accumulated training without export took: {:.4f}s".format(total_training_time - estimated_export))
print(" Accumulated training took: {:.4f}s".format(total_training_time))
print(" Accumulated validation took: {:.4f}s".format(total_test_time))
if __name__ == '__main__':
main()

View file

@ -1,6 +1,7 @@
import argparse
import logging
import torch
import time
from torchvision import datasets, transforms
import onnxruntime
@ -23,7 +24,15 @@ class NeuralNet(torch.nn.Module):
def train(args, model, device, optimizer, loss_fn, train_loader, epoch):
print('\n======== Epoch {:} / {:} ========'.format(epoch+1, args.epochs))
model.train()
# Measure how long the training epoch takes.
t0 = time.time()
start_time = t0
# Reset the total loss for this epoch.
total_loss = 0
for iteration, (data, target) in enumerate(train_loader):
if iteration == args.train_steps:
break
@ -42,18 +51,38 @@ def train(args, model, device, optimizer, loss_fn, train_loader, epoch):
pytorch_backward_graph.view()
loss = loss_fn(probability, target)
# Accumulate the training loss over all of the batches so that we can
# calculate the average loss at the end. `loss` is a Tensor containing a
# single value; the `.item()` function just returns the Python value
# from the tensor.
total_loss += loss.item()
loss.backward()
optimizer.step()
# Stats
if iteration % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, iteration * len(data), len(train_loader.dataset),
100. * iteration / len(train_loader), loss))
curr_time = time.time()
elapsed_time = curr_time - start_time
print('[{:5}/{:5} ({:2.0f}%)]\tLoss: {:.6f}\tExecution time: {:.4f}'.format(
iteration * len(data), len(train_loader.dataset),
100. * iteration / len(train_loader), loss, elapsed_time))
start_time = curr_time
# Calculate the average loss over the training data.
avg_train_loss = total_loss / len(train_loader)
epoch_time = time.time() - t0
print("\n Average training loss: {0:.2f}".format(avg_train_loss))
print(" Training epoch took: {:.4f}s".format(epoch_time))
return epoch_time
def test(args, model, device, loss_fn, test_loader):
model.eval()
t0 = time.time()
test_loss = 0
correct = 0
with torch.no_grad():
@ -71,6 +100,12 @@ def test(args, model, device, loss_fn, test_loader):
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
# Report the final accuracy for this validation run.
epoch_time = time.time() - t0
print(" Accuracy: {0:.2f}".format(float(correct)/len(test_loader.dataset)))
print(" Validation took: {:.4f}s".format(epoch_time))
return epoch_time
def my_loss(x, target, is_train=True):
if is_train:
return torch.nn.CrossEntropyLoss()(x, target)
@ -94,8 +129,8 @@ def main():
help='random seed (default: 42)')
parser.add_argument('--pytorch-only', action='store_true', default=False,
help='disables ONNX Runtime training')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
help='how many batches to wait before logging training status (default: 100)')
parser.add_argument('--log-interval', type=int, default=300, metavar='N',
help='how many batches to wait before logging training status (default: 300)')
parser.add_argument('--view-graphs', action='store_true', default=False,
help='views forward and backward graphs')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
@ -148,10 +183,21 @@ def main():
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
# Train loop
for epoch in range(1, args.epochs + 1):
train(args, model, device, optimizer, my_loss, train_loader, epoch)
total_training_time, total_test_time, epoch_0_training = 0, 0, 0
for epoch in range(0, args.epochs):
total_training_time += train(args, model, device, optimizer, my_loss, train_loader, epoch)
if not args.pytorch_only and epoch == 0:
epoch_0_training = total_training_time
if args.test_batch_size > 0:
test(args, model, device, my_loss, test_loader)
total_test_time += test(args, model, device, my_loss, test_loader)
print('\n======== Global stats ========')
if not args.pytorch_only:
estimated_export = epoch_0_training - (total_training_time - epoch_0_training)/(args.epochs-1)
print(" Estimated ONNX export took: {:.4f}s".format(estimated_export))
print(" Accumulated training without export took: {:.4f}s".format(total_training_time - estimated_export))
print(" Accumulated training took: {:.4f}s".format(total_training_time))
print(" Accumulated validation took: {:.4f}s".format(total_test_time))
if __name__ == '__main__':