mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
2.8 KiB
2.8 KiB
ORTModule Training Convergence Investigation
1. Discovering
Convergence issues can be identified by:
- Large discrepancies in core training metrics including training loss, evaluation loss, model specific AUC metrics.
- Runtime failures (for example when the loss scaler reaches the minimum, triggering an exception).
Before looking into this further, we should clarify a few things (if possible):
- If we change the seed for the baseline run, whether the metric diff is big? (Make sure the discrepancy is not introduced by randomness)
- What are the very first steps we see obvious divergence?
- Still reproducible once randomness is removed?
- Set same seeds
- Set the dropout ratio to 0
- Set compute to be deterministic and torch-comparable (TODO(pengwa): need a flag for this).
2. Collect Activation Statistics
Add a few lines of code, run script to collect statistics:
| Baseline | ORTModule |
|---|---|
|
|
|
|
Arguments:
- output_dir: the directory in all activation statistics files will be stored.
- start_step [optional]: the first step that runs subscriber actions.
- end_step [optional]: the end step (exclusively) that runs subscriber actions.
- override_output_dir: whether
output_dircan be overridden if it already exists.
Check StatisticsSubscriber implementation for more information.
Run command to generate per-step summary
python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output