onnxruntime/orttraining/orttraining/python/training/api
Ashwini Khade 68b5b2d7d3
Refactor training build options (#13964)
### Description
1. Renames all references of on device training to training apis. This
is to keep the naming general. Nothing really prevents us from using the
same apis on servers\non-edge devices.
2. Update ENABLE_TRAINING option: With this PR when this option is
enabled, training apis and torch interop is also enabled.
3. Refactoring for onnxruntime_ENABLE_TRAINING_TORCH_INTEROP option: 
   -  Removed user facing option
- Setting onnxruntime_ENABLE_TRAINING_TORCH_INTEROP to ON when
onnxruntime_ENABLE_TRAINING is ON as we always build with torch interop.

Once this PR is merged when --enable_training is selected we will do a
"FULL Build" for training (with all the training entry points and
features).
Training entry points include:
1. ORTModule
2. Training APIs

Features include:
1. ATen Fallback
2. All Training OPs includes communication and collectives
3. Strided Tensor Support
4. Python Op (torch interop)
5. ONNXBlock (Front end tools for training artifacts prep when using
trianing apis)

### Motivation and Context
Intention is to simply the options for building training enabled builds.
This is part of the larger work item to create dedicated build for
learning on the edge scenarios with just training apis enabled.
2023-01-03 13:28:16 -08:00
..
__init__.py expose lr scheduler python bindings for on device training. (#13882) 2022-12-22 18:44:04 -08:00
checkpoint_state.py
lr_scheduler.py expose lr scheduler python bindings for on device training. (#13882) 2022-12-22 18:44:04 -08:00
module.py Miscellaneous updates to training apis (#13929) 2022-12-14 13:33:07 -08:00
optimizer.py add cuda support to python bindings (#13700) 2022-12-08 16:03:53 -08:00
README.md Refactor training build options (#13964) 2023-01-03 13:28:16 -08:00

Getting Started

This is a simple guide on how to use onnxruntime training APIs.

What's needed for training?

The ort training APIs need the following files for performing training

  1. The training onnx model.
  2. The eval onnx model (optional).
  3. The optimizer onnx model.
  4. The checkpoint file.

To generate these files, refer to this onnxblock's README

Once the onnx models are generated, you can use the training APIs to run your training.

Training Loop

from onnxruntime.training.api import Module, Optimizer, CheckpointState
# Create Checkpoint State.
state = CheckpointState("checkpoint.ckpt")
# Create Module and Optimizer.
model = Module("training_model.onnx", state, "eval_model.onnx")
optimizer = Optimizer("optimizer.onnx", model)

# Data should be a list of numpy arrays.
forward_inputs = ...

# Set model in training mode and run a Train step.
model.train()
model(forward_inputs)

# Optimizer step
optimizer.step()

# Set Model in eval mode and run an Eval step.
model.eval()

loss = model(forward_inputs)

# Assuming that the loss is the first element of the output in our case.
print("Loss : ", loss[0])

# Saving checkpoint.
model.save_checkpoint("checkpoint_export.ckpt")

For more detailed information refer to Module and Optimizer.