mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
| .. | ||
| __init__.py | ||
| checkpoint_state.py | ||
| lr_scheduler.py | ||
| module.py | ||
| optimizer.py | ||
| README.md | ||
Getting Started
This is a simple guide on how to use onnxruntime training APIs.
What's needed for training?
The ort training APIs need the following files for performing training
- The training onnx model.
- The eval onnx model (optional).
- The optimizer onnx model (optional).
- The checkpoint file.
To generate these files, refer to this onnxblock's README
Once the onnx models are generated, you can use the training APIs to run your training.
Training Loop
from onnxruntime.training.api import Module, Optimizer, CheckpointState
# Create Checkpoint State.
state = CheckpointState.load_checkpoint("checkpoint.ckpt")
# Create Module and Optimizer.
model = Module("training_model.onnx", state, "eval_model.onnx")
optimizer = Optimizer("optimizer.onnx", model)
# Set model in training mode and run a Train step.
model.train()
training_model_outputs = model(<inputs to your training model>)
# Optimizer step
optimizer.step()
# Set Model in eval mode and run an Eval step.
model.eval()
eval_model_outputs = model(<inputs to your eval model>)
# Assuming that the loss is the first element of the output in the training model.
print("Loss : ", training_model_outputs[0])
# Saving checkpoint.
CheckpointState.save_checkpoint(state, "checkpoint_export.ckpt")
For more detailed information refer to Module and Optimizer.