mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
This change adds a new pipeline for checking Python code. Currently this pipeline only runs flake8. flake8 is also run as part of the CMake project builds, but we can switch over completely to the new pipeline later. The .flake8 config file was also updated to make it easier to run standalone (flake8 --config ./.flake8) and some Python formatting issues were addressed in files that were not previously scanned. |
||
|---|---|---|
| .. | ||
| ort_train.py | ||
| ort_utils.py | ||
| pt_model.py | ||
| pt_train.py | ||
| README.md | ||
| utils.py | ||
TransformerModel example
This example was adapted from Pytorch's Sequence-to-Sequence Modeling with nn.Transformer and TorchText tutorial
Requirements
- PyTorch 1.6+
- TorchText 0.6+
- ONNX Runtime 1.5+
Running PyTorch version
python pt_train.py
Running ONNX Runtime version
python ort_train.py
Optional arguments
| Argument | Description | Default |
|---|---|---|
| --batch-size | input batch size for training | 20 |
| --test-batch-size | input batch size for testing | 20 |
| --epochs | number of epochs to train | 2 |
| --lr | learning rate | 0.001 |
| --no-cuda | disables CUDA training | False |
| --seed | random seed | 1 |
| --log-interval | how many batches to wait before logging training status | 200 |