From ebeeff22dd7774e6a1a72fa328f8698a74cf96ba Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Thu, 24 Sep 2020 16:28:07 -0700 Subject: [PATCH] Update PyTorch TransformerModel sample (#5275) --- samples/python/pytorch_transformer/README.md | 16 +++- .../python/pytorch_transformer/ort_train.py | 85 +++++++++++++++++ .../python/pytorch_transformer/pt_train.py | 92 +++++++++++++++++++ 3 files changed, 191 insertions(+), 2 deletions(-) create mode 100644 samples/python/pytorch_transformer/ort_train.py create mode 100644 samples/python/pytorch_transformer/pt_train.py diff --git a/samples/python/pytorch_transformer/README.md b/samples/python/pytorch_transformer/README.md index 569b94aa70..7514d6cf70 100644 --- a/samples/python/pytorch_transformer/README.md +++ b/samples/python/pytorch_transformer/README.md @@ -10,12 +10,24 @@ This example was adapted from Pytorch's [Sequence-to-Sequence Modeling with nn.T ## Running PyTorch version -```python +```bash python pt_model.py ``` ## Running ONNX Runtime version -```python +```bash python ort_model.py ``` + +## Optional arguments + +| Argument | Description | Default | +| :---------------- | :-----------------------------------------------------: | --------: | +| --batch-size | input batch size for training | 20 | +| --test-batch-size | input batch size for testing | 20 | +| --epochs | number of epochs to train | 2 | +| --lr | learning rate | 0.001 | +| --no-cuda | disables CUDA training | False | +| --seed | random seed | 1 | +| --log-interval | how many batches to wait before logging training status | 200 | diff --git a/samples/python/pytorch_transformer/ort_train.py b/samples/python/pytorch_transformer/ort_train.py new file mode 100644 index 0000000000..9830106799 --- /dev/null +++ b/samples/python/pytorch_transformer/ort_train.py @@ -0,0 +1,85 @@ +import argparse +import math +import torch +import onnxruntime + +from utils import prepare_data, get_batch +from ort_utils import my_loss, transformer_model_description_dynamic_axes +from pt_model import TransformerModel + + +def train(trainer, data_source, device, epoch, args, bptt=35): + total_loss = 0. + for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)): + data, targets = get_batch(data_source, i) + + loss, pred = trainer.train_step(data, targets) + total_loss += loss.item() + if batch % args.log_interval == 0 and batch > 0: + cur_loss = total_loss / args.log_interval + print('epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}'.format(epoch, + batch, + len(data_source) // bptt, + cur_loss)) + total_loss = 0 + + +def evaluate(trainer, data_source, bptt=35): + total_loss = 0. + with torch.no_grad(): + for i in range(0, data_source.size(0) - 1, bptt): + data, targets = get_batch(data_source, i) + loss, pred = trainer.eval_step(data, targets) + total_loss += len(data) * loss.item() + return total_loss / (len(data_source) - 1) + + +if __name__ == "__main__": + # Training settings + parser = argparse.ArgumentParser(description='PyTorch TransformerModel example') + parser.add_argument('--batch-size', type=int, default=20, metavar='N', + help='input batch size for training (default: 20)') + parser.add_argument('--test-batch-size', type=int, default=20, metavar='N', + help='input batch size for testing (default: 20)') + parser.add_argument('--epochs', type=int, default=2, metavar='N', + help='number of epochs to train (default: 2)') + parser.add_argument('--lr', type=float, default=0.001, metavar='LR', + help='learning rate (default: 0.001)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=200, metavar='N', + help='how many batches to wait before logging training status (default: 200)') + + # Basic setup + args = parser.parse_args() + if not args.no_cuda and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + torch.manual_seed(args.seed) + onnxruntime.set_seed(args.seed) + + # Model + optim_config = onnxruntime.training.optim.SGDConfig(lr=args.lr) + model_desc = transformer_model_description_dynamic_axes() + model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) + + # Preparing data + train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size) + trainer = onnxruntime.training.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss) + + # Train + for epoch in range(1, args.epochs + 1): + train(trainer, train_data, device, epoch, args) + val_loss = evaluate(trainer, val_data) + print('-' * 89) + print('| end of epoch {:3d} | valid loss {:5.2f} | '.format(epoch, val_loss)) + print('-' * 89) + + # Evaluate + test_loss = evaluate(trainer, test_data) + print('=' * 89) + print('| End of training | test loss {:5.2f}'.format(test_loss)) + print('=' * 89) diff --git a/samples/python/pytorch_transformer/pt_train.py b/samples/python/pytorch_transformer/pt_train.py new file mode 100644 index 0000000000..7d3e8851c9 --- /dev/null +++ b/samples/python/pytorch_transformer/pt_train.py @@ -0,0 +1,92 @@ +import argparse +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from utils import prepare_data, get_batch +from pt_model import TransformerModel + + +def train(model, data_source, device, epoch, args, bptt=35): + total_loss = 0. + model.train() + for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)): + data, targets = get_batch(data_source, i) + + optimizer.zero_grad() + output = model(data) + loss = criterion(output.view(-1, 28785), targets) + loss.backward() + optimizer.step() + + total_loss += loss.item() + if batch % args.log_interval == 0 and batch > 0: + cur_loss = total_loss / args.log_interval + print('epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}'.format(epoch, + batch, + len(data_source) // bptt, + cur_loss)) + total_loss = 0 + + +def evaluate(model, data_source, criterion, bptt=35): + total_loss = 0. + model.eval() + with torch.no_grad(): + for i in range(0, data_source.size(0) - 1, bptt): + data, targets = get_batch(data_source, i) + output = model(data) + output_flat = output.view(-1, 28785) + total_loss += len(data) * criterion(output_flat, targets).item() + return total_loss / (len(data_source) - 1) + + +if __name__ == "__main__": + # Training settings + parser = argparse.ArgumentParser(description='PyTorch TransformerModel example') + parser.add_argument('--batch-size', type=int, default=20, metavar='N', + help='input batch size for training (default: 20)') + parser.add_argument('--test-batch-size', type=int, default=20, metavar='N', + help='input batch size for testing (default: 20)') + parser.add_argument('--epochs', type=int, default=2, metavar='N', + help='number of epochs to train (default: 2)') + parser.add_argument('--lr', type=float, default=0.001, metavar='LR', + help='learning rate (default: 0.001)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=200, metavar='N', + help='how many batches to wait before logging training status (default: 200)') + + # Basic setup + args = parser.parse_args() + if not args.no_cuda and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + torch.manual_seed(args.seed) + + # Model + criterion = nn.CrossEntropyLoss() + lr = 0.001 + model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + + # Preparing data + train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size) + + # Train + for epoch in range(1, args.epochs + 1): + train(model, train_data, device, epoch, args) + val_loss = evaluate(model, val_data, criterion) + print('-' * 89) + print('| end of epoch {:3d} | valid loss {:5.2f} | '.format(epoch, val_loss)) + print('-' * 89) + + # Evaluate + test_loss = evaluate(model, test_data, criterion) + print('=' * 89) + print('| End of training | test loss {:5.2f}'.format(test_loss)) + print('=' * 89)