onnxruntime/samples/python/training/orttrainer/pytorch_transformer
2021-10-01 22:47:52 +00:00
..
ort_train.py Merged PR 6524907: Fix merge conflicts from public ORT to WindowsAI ORT 2021-10-01 22:47:52 +00:00
ort_utils.py Merged PR 6524907: Fix merge conflicts from public ORT to WindowsAI ORT 2021-10-01 22:47:52 +00:00
pt_model.py Merged PR 6524907: Fix merge conflicts from public ORT to WindowsAI ORT 2021-10-01 22:47:52 +00:00
pt_train.py Merged PR 6524907: Fix merge conflicts from public ORT to WindowsAI ORT 2021-10-01 22:47:52 +00:00
README.md
utils.py Merged PR 6524907: Fix merge conflicts from public ORT to WindowsAI ORT 2021-10-01 22:47:52 +00:00

TransformerModel example

This example was adapted from Pytorch's Sequence-to-Sequence Modeling with nn.Transformer and TorchText tutorial

Requirements

  • PyTorch 1.6+
  • TorchText 0.6+
  • ONNX Runtime 1.5+

Running PyTorch version

python pt_train.py

Running ONNX Runtime version

python ort_train.py

Optional arguments

Argument Description Default
--batch-size input batch size for training 20
--test-batch-size input batch size for testing 20
--epochs number of epochs to train 2
--lr learning rate 0.001
--no-cuda disables CUDA training False
--seed random seed 1
--log-interval how many batches to wait before logging training status 200