mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
### Statistics tool for ORTModule convergence parity
As ORTModule get more and more validated, it is pretty fast to
intergrade PyTorch based model with ORT.
The same time, we need make sure once there is convergence issue, we
don't spend months of time to investigate. As part of this efforts, this
PR is introducing a tool to dump activation statistics without much
involvement from users. The dumping results contains only some statistic
numbers plus sampled data, which is not big, compared with dumping all
the tensors, it is much faster and space efficient.
For us to use it, two single lines are needed before wrapping ORTModule.
For baseline run, need also apply the same trick.
```
+ from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber
+ SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)])
```
Once you run the steps, following command can be used to merge result
into per-step-summary respectively for ORT and baseline runs.
```bash
python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output
```
Docs is added here as part of this PR [convergence investigation
notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md)
Based on the generated merged files, we can compare them with tools.

### Design and Implementation
This PR introduced a common mechanism registering custom logic for
nn.Module's post forward hooks. And statistics for activation
(StatisticsSubscriber) is one of the implementations. If there is other
needs, we can define another XXSubscriber to do the customized things.
2 KiB
2 KiB
ORTModule Training Convergence Investigation
1. Discovering
Convergence issues can be identified by:
- Large discrepancy on core training metrics including training loss, evaluation loss, model specific AUC metrics.
- Runtime failures (for example loss scaler reach the minimum triggering an exception).
Before looking into further, we should clarify few things (if possible):
- If we change seed for baseline run, whether the metric diff is big? (Make sure the discrepancy is not introduced by random)
- What's the very first steps we see obvious diverges?
- Still repro once remove randomness?
- Set same seeds
- Set dropout ratio to 0
- Set compute to be deterministic and torch-comparable (TODO(pengwa): need a flag for this).
2. Collect Activation Statistics
Add codes:
+ from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber
+ SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)])
Run training script to the steps that triggered the divergence. A folder named pt_out is created in current working directory. For each step, there is a folder containing summaries for every activation tensor.
Add few lines of code:
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber
model = ORTModule(model)
+ SubscriberManager.subscribe(model, [StatisticsSubscriber("ort_out", override_output_dir=True)])
StatisticsSubscribercan be initialized before OR after wrapping ORTModule.
Run training script to the steps that triggered the divergence. Similarly, a folder named ort_out is created in current working directory.
Run command to generate per step summary
python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output
Manual diff the generate per-step summary to find the where is the first big diff happens.