mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
89 lines
2.3 KiB
ReStructuredText
89 lines
2.3 KiB
ReStructuredText
Train the Model on the Device
|
|
==============================
|
|
|
|
Once the training artifacts are generated, the model can be trained on the device using the onnxruntime training python API.
|
|
|
|
The expected training artifacts are:
|
|
|
|
1. The training onnx model
|
|
2. The checkpoint state
|
|
3. The optimizer onnx model
|
|
4. The eval onnx model (optional)
|
|
|
|
Sample usage:
|
|
|
|
.. code-block:: python
|
|
|
|
from onnxruntime.training.api import CheckpointState, Module, Optimizer
|
|
|
|
# Load the checkpoint state
|
|
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)
|
|
|
|
# Create the module
|
|
module = Module(path_to_the_training_model,
|
|
state,
|
|
path_to_the_eval_model,
|
|
device="cpu")
|
|
|
|
optimizer = Optimizer(path_to_the_optimizer_model, module)
|
|
|
|
# Training loop
|
|
for ...:
|
|
module.train()
|
|
training_loss = module(...)
|
|
optimizer.step()
|
|
module.lazy_reset_grad()
|
|
|
|
# Eval
|
|
module.eval()
|
|
eval_loss = module(...)
|
|
|
|
# Save the checkpoint
|
|
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
|
|
|
|
|
|
.. autoclass:: onnxruntime.training.api.checkpoint_state.Parameter
|
|
:members:
|
|
:show-inheritance:
|
|
:member-order: bysource
|
|
:inherited-members:
|
|
:special-members: __repr__
|
|
|
|
.. autoclass:: onnxruntime.training.api.checkpoint_state.Parameters
|
|
:members:
|
|
:show-inheritance:
|
|
:member-order: bysource
|
|
:inherited-members:
|
|
:special-members: __getitem__, __setitem__, __contains__, __iter__, __repr__, __len__
|
|
|
|
.. autoclass:: onnxruntime.training.api.checkpoint_state.Properties
|
|
:members:
|
|
:show-inheritance:
|
|
:member-order: bysource
|
|
:inherited-members:
|
|
:special-members: __getitem__, __setitem__, __contains__, __iter__, __repr__, __len__
|
|
|
|
.. autoclass:: onnxruntime.training.api.CheckpointState
|
|
:members:
|
|
:show-inheritance:
|
|
:member-order: bysource
|
|
:inherited-members:
|
|
|
|
.. autoclass:: onnxruntime.training.api.Module
|
|
:members:
|
|
:show-inheritance:
|
|
:member-order: bysource
|
|
:inherited-members:
|
|
:special-members: __call__
|
|
|
|
.. autoclass:: onnxruntime.training.api.Optimizer
|
|
:members:
|
|
:show-inheritance:
|
|
:member-order: bysource
|
|
:inherited-members:
|
|
|
|
.. autoclass:: onnxruntime.training.api.LinearLRScheduler
|
|
:members:
|
|
:show-inheritance:
|
|
:member-order: bysource
|
|
:inherited-members:
|