mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
Refactor MNIST and BERT classifier to add time measures
This commit is contained in:
parent
395e082bc3
commit
07f5ae95e5
3 changed files with 86 additions and 21 deletions
|
|
@ -155,6 +155,7 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
self._onnx_training = ORTModule._get_forward_graph(self._original_module, *inputs, **kwargs)
|
||||
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
|
||||
# TODO: PyTorch exporter bug: changes the initializer order
|
||||
initializer_names = [p[0] for p in self._original_module.named_parameters()]
|
||||
onnx_gradient, self._onnx_forward, self._onnx_backward, self._onnx_graphs_info = \
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
|
|||
|
||||
# Measure how long the training epoch takes.
|
||||
t0 = time.time()
|
||||
start_time = t0
|
||||
|
||||
# Reset the total loss for this epoch.
|
||||
total_loss = 0
|
||||
|
|
@ -90,11 +91,6 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
|
|||
None,
|
||||
b_labels)
|
||||
|
||||
if args.view_graphs:
|
||||
import torchviz
|
||||
pytorch_backward_graph = torchviz.make_dot(outputs[0], params=dict(list(model.named_parameters())))
|
||||
pytorch_backward_graph.view()
|
||||
|
||||
# The call to `model` always returns a tuple, so we need to pull the
|
||||
# loss value out of the tuple.
|
||||
loss = outputs[0]
|
||||
|
|
@ -102,10 +98,17 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
|
|||
# Progress update every 40 batches.
|
||||
if step % args.log_interval == 0 and not step == 0:
|
||||
# Calculate elapsed time in minutes.
|
||||
elapsed = format_time(time.time() - t0)
|
||||
curr_time = time.time()
|
||||
elapsed_time = curr_time - start_time
|
||||
|
||||
# Report progress.
|
||||
print(f'Batch {step} of {len(train_dataloader)}. Elapsed: {elapsed}. Loss: {loss.item()}')
|
||||
print(f'Batch {step:4} of {len(train_dataloader):4}. Execution time: {elapsed_time:.4f}. Loss: {loss.item():.4f}')
|
||||
start_time = curr_time
|
||||
|
||||
if args.view_graphs:
|
||||
import torchviz
|
||||
pytorch_backward_graph = torchviz.make_dot(outputs[0], params=dict(list(model.named_parameters())))
|
||||
pytorch_backward_graph.view()
|
||||
|
||||
# Accumulate the training loss over all of the batches so that we can
|
||||
# calculate the average loss at the end. `loss` is a Tensor containing a
|
||||
|
|
@ -131,8 +134,10 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args):
|
|||
# Calculate the average loss over the training data.
|
||||
avg_train_loss = total_loss / len(train_dataloader)
|
||||
|
||||
epoch_time = time.time() - t0
|
||||
print("\n Average training loss: {0:.2f}".format(avg_train_loss))
|
||||
print(" Training epoch took: {:}".format(format_time(time.time() - t0)))
|
||||
print(" Training epoch took: {:.4f}s".format(epoch_time))
|
||||
return epoch_time
|
||||
|
||||
def test(model, validation_dataloader, device, args):
|
||||
# ========================================
|
||||
|
|
@ -142,12 +147,12 @@ def test(model, validation_dataloader, device, args):
|
|||
# our validation set.
|
||||
print("\nRunning Validation...")
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# Put the model in evaluation mode--the dropout layers behave differently
|
||||
# during evaluation.
|
||||
model.eval()
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# Tracking variables
|
||||
eval_loss, eval_accuracy = 0, 0
|
||||
nb_eval_steps, nb_eval_examples = 0, 0
|
||||
|
|
@ -207,8 +212,10 @@ def test(model, validation_dataloader, device, args):
|
|||
nb_eval_steps += 1
|
||||
|
||||
# Report the final accuracy for this validation run.
|
||||
epoch_time = time.time() - t0
|
||||
print(" Accuracy: {0:.2f}".format(eval_accuracy/nb_eval_steps))
|
||||
print(" Validation took: {:}".format(format_time(time.time() - t0)))
|
||||
print(" Validation took: {:.4f}s".format(epoch_time))
|
||||
return epoch_time
|
||||
|
||||
def load_dataset(args):
|
||||
# 2. Loading CoLA Dataset
|
||||
|
|
@ -417,9 +424,20 @@ def main():
|
|||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
# 4. Train loop (fine-tune)
|
||||
total_training_time, total_test_time, epoch_0_training = 0, 0, 0
|
||||
for epoch_i in range(0, args.epochs):
|
||||
train(model, optimizer, scheduler, train_dataloader, epoch_i, device, args)
|
||||
test(model, validation_dataloader, device, args)
|
||||
total_training_time += train(model, optimizer, scheduler, train_dataloader, epoch_i, device, args)
|
||||
if not args.pytorch_only and epoch_i == 0:
|
||||
epoch_0_training = total_training_time
|
||||
total_test_time += test(model, validation_dataloader, device, args)
|
||||
|
||||
print('\n======== Global stats ========')
|
||||
if not args.pytorch_only:
|
||||
estimated_export = epoch_0_training - (total_training_time - epoch_0_training)/(args.epochs-1)
|
||||
print(" Estimated ONNX export took: {:.4f}s".format(estimated_export))
|
||||
print(" Accumulated training without export took: {:.4f}s".format(total_training_time - estimated_export))
|
||||
print(" Accumulated training took: {:.4f}s".format(total_training_time))
|
||||
print(" Accumulated validation took: {:.4f}s".format(total_test_time))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import time
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
import onnxruntime
|
||||
|
|
@ -23,7 +24,15 @@ class NeuralNet(torch.nn.Module):
|
|||
|
||||
|
||||
def train(args, model, device, optimizer, loss_fn, train_loader, epoch):
|
||||
print('\n======== Epoch {:} / {:} ========'.format(epoch+1, args.epochs))
|
||||
model.train()
|
||||
# Measure how long the training epoch takes.
|
||||
t0 = time.time()
|
||||
start_time = t0
|
||||
|
||||
# Reset the total loss for this epoch.
|
||||
total_loss = 0
|
||||
|
||||
for iteration, (data, target) in enumerate(train_loader):
|
||||
if iteration == args.train_steps:
|
||||
break
|
||||
|
|
@ -42,18 +51,38 @@ def train(args, model, device, optimizer, loss_fn, train_loader, epoch):
|
|||
pytorch_backward_graph.view()
|
||||
|
||||
loss = loss_fn(probability, target)
|
||||
# Accumulate the training loss over all of the batches so that we can
|
||||
# calculate the average loss at the end. `loss` is a Tensor containing a
|
||||
# single value; the `.item()` function just returns the Python value
|
||||
# from the tensor.
|
||||
total_loss += loss.item()
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Stats
|
||||
if iteration % args.log_interval == 0:
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, iteration * len(data), len(train_loader.dataset),
|
||||
100. * iteration / len(train_loader), loss))
|
||||
curr_time = time.time()
|
||||
elapsed_time = curr_time - start_time
|
||||
print('[{:5}/{:5} ({:2.0f}%)]\tLoss: {:.6f}\tExecution time: {:.4f}'.format(
|
||||
iteration * len(data), len(train_loader.dataset),
|
||||
100. * iteration / len(train_loader), loss, elapsed_time))
|
||||
start_time = curr_time
|
||||
|
||||
# Calculate the average loss over the training data.
|
||||
avg_train_loss = total_loss / len(train_loader)
|
||||
|
||||
epoch_time = time.time() - t0
|
||||
print("\n Average training loss: {0:.2f}".format(avg_train_loss))
|
||||
print(" Training epoch took: {:.4f}s".format(epoch_time))
|
||||
return epoch_time
|
||||
|
||||
|
||||
def test(args, model, device, loss_fn, test_loader):
|
||||
model.eval()
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
|
|
@ -71,6 +100,12 @@ def test(args, model, device, loss_fn, test_loader):
|
|||
test_loss, correct, len(test_loader.dataset),
|
||||
100. * correct / len(test_loader.dataset)))
|
||||
|
||||
# Report the final accuracy for this validation run.
|
||||
epoch_time = time.time() - t0
|
||||
print(" Accuracy: {0:.2f}".format(float(correct)/len(test_loader.dataset)))
|
||||
print(" Validation took: {:.4f}s".format(epoch_time))
|
||||
return epoch_time
|
||||
|
||||
def my_loss(x, target, is_train=True):
|
||||
if is_train:
|
||||
return torch.nn.CrossEntropyLoss()(x, target)
|
||||
|
|
@ -94,8 +129,8 @@ def main():
|
|||
help='random seed (default: 42)')
|
||||
parser.add_argument('--pytorch-only', action='store_true', default=False,
|
||||
help='disables ONNX Runtime training')
|
||||
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
|
||||
help='how many batches to wait before logging training status (default: 100)')
|
||||
parser.add_argument('--log-interval', type=int, default=300, metavar='N',
|
||||
help='how many batches to wait before logging training status (default: 300)')
|
||||
parser.add_argument('--view-graphs', action='store_true', default=False,
|
||||
help='views forward and backward graphs')
|
||||
parser.add_argument('--epochs', type=int, default=10, metavar='N',
|
||||
|
|
@ -148,10 +183,21 @@ def main():
|
|||
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
|
||||
|
||||
# Train loop
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(args, model, device, optimizer, my_loss, train_loader, epoch)
|
||||
total_training_time, total_test_time, epoch_0_training = 0, 0, 0
|
||||
for epoch in range(0, args.epochs):
|
||||
total_training_time += train(args, model, device, optimizer, my_loss, train_loader, epoch)
|
||||
if not args.pytorch_only and epoch == 0:
|
||||
epoch_0_training = total_training_time
|
||||
if args.test_batch_size > 0:
|
||||
test(args, model, device, my_loss, test_loader)
|
||||
total_test_time += test(args, model, device, my_loss, test_loader)
|
||||
|
||||
print('\n======== Global stats ========')
|
||||
if not args.pytorch_only:
|
||||
estimated_export = epoch_0_training - (total_training_time - epoch_0_training)/(args.epochs-1)
|
||||
print(" Estimated ONNX export took: {:.4f}s".format(estimated_export))
|
||||
print(" Accumulated training without export took: {:.4f}s".format(total_training_time - estimated_export))
|
||||
print(" Accumulated training took: {:.4f}s".format(total_training_time))
|
||||
print(" Accumulated validation took: {:.4f}s".format(total_test_time))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
Loading…
Reference in a new issue