onnxruntime/samples/python/training/orttrainer/pytorch_transformer
Edward Chen baf8c39a8d
Add Python checks pipeline (#7032)
This change adds a new pipeline for checking Python code. Currently this pipeline only runs flake8.
flake8 is also run as part of the CMake project builds, but we can switch over completely to the new pipeline later.
The .flake8 config file was also updated to make it easier to run standalone (flake8 --config ./.flake8) and some Python formatting issues were addressed in files that were not previously scanned.
2021-08-09 10:37:05 -07:00
..
ort_train.py Add Python checks pipeline (#7032) 2021-08-09 10:37:05 -07:00
ort_utils.py Add Python checks pipeline (#7032) 2021-08-09 10:37:05 -07:00
pt_model.py Add Python checks pipeline (#7032) 2021-08-09 10:37:05 -07:00
pt_train.py Add Python checks pipeline (#7032) 2021-08-09 10:37:05 -07:00
README.md Update ORTModule frontend code owner file (#8335) 2021-07-14 09:26:04 -07:00
utils.py Add Python checks pipeline (#7032) 2021-08-09 10:37:05 -07:00

TransformerModel example

This example was adapted from Pytorch's Sequence-to-Sequence Modeling with nn.Transformer and TorchText tutorial

Requirements

  • PyTorch 1.6+
  • TorchText 0.6+
  • ONNX Runtime 1.5+

Running PyTorch version

python pt_train.py

Running ONNX Runtime version

python ort_train.py

Optional arguments

Argument Description Default
--batch-size input batch size for training 20
--test-batch-size input batch size for testing 20
--epochs number of epochs to train 2
--lr learning rate 0.001
--no-cuda disables CUDA training False
--seed random seed 1
--log-interval how many batches to wait before logging training status 200