onnxruntime/samples/python/pytorch_transformer
Tianlei Wu c00e13a291
Cherry pick (batch 2) to rel-1.5.1 (#5290)
* remove implicit linking of tensorrt and dnnl ep shared libs (#5262)
* Update DirectML Nuget to 1.3.0 (#5274)
* Update PyTorch TransformerModel sample (#5275)
* Insert telemetry template into GPU build, add telemry build switches. (#5278)
* Synchronize training dependency versions between Docker image and Python wheel (#5261)
* Downgrade GCC (#5269)
* Remove --enable_symbolic_shape_infer_tests to fix linux ci pipeline build error.

Co-authored-by: Edward Chen
Co-authored-by: George Wu <jywu@microsoft.com>
Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>
Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
Co-authored-by: Dmitri Smirnov <yuslepukhin@users.noreply.github.com>
Co-authored-by: edgchen1 <18449977+edgchen1@users.noreply.github.com>
Co-authored-by: Changming Sun <chasun@microsoft.com>
2020-09-25 09:26:40 -07:00
..
ort_train.py Cherry pick (batch 2) to rel-1.5.1 (#5290) 2020-09-25 09:26:40 -07:00
ort_utils.py Add new PytTrch front-end (#4815) 2020-08-17 09:45:25 -07:00
pt_model.py Enable example transformer test with dynamic size inputs (#4888) 2020-08-24 14:31:08 -07:00
pt_train.py Cherry pick (batch 2) to rel-1.5.1 (#5290) 2020-09-25 09:26:40 -07:00
README.md Cherry pick (batch 2) to rel-1.5.1 (#5290) 2020-09-25 09:26:40 -07:00
utils.py Add new PytTrch front-end (#4815) 2020-08-17 09:45:25 -07:00

TransformerModel example

This example was adapted from Pytorch's Sequence-to-Sequence Modeling with nn.Transformer and TorchText tutorial

Requirements

  • PyTorch 1.6+
  • TorchText 0.6+
  • ONNX Runtime 1.5+

Running PyTorch version

python pt_model.py

Running ONNX Runtime version

python ort_model.py

Optional arguments

Argument Description Default
--batch-size input batch size for training 20
--test-batch-size input batch size for testing 20
--epochs number of epochs to train 2
--lr learning rate 0.001
--no-cuda disables CUDA training False
--seed random seed 1
--log-interval how many batches to wait before logging training status 200