onnxruntime/docs/ORTModule_Convergence_Notes.md

140 lines
4.9 KiB
Markdown
Raw Normal View History

Statistics tool for ORTModule convergence parity (#15020) ### Statistics tool for ORTModule convergence parity As ORTModule get more and more validated, it is pretty fast to intergrade PyTorch based model with ORT. The same time, we need make sure once there is convergence issue, we don't spend months of time to investigate. As part of this efforts, this PR is introducing a tool to dump activation statistics without much involvement from users. The dumping results contains only some statistic numbers plus sampled data, which is not big, compared with dumping all the tensors, it is much faster and space efficient. For us to use it, two single lines are needed before wrapping ORTModule. For baseline run, need also apply the same trick. ``` + from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber + SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)]) ``` Once you run the steps, following command can be used to merge result into per-step-summary respectively for ORT and baseline runs. ```bash python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output ``` Docs is added here as part of this PR [convergence investigation notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md) Based on the generated merged files, we can compare them with tools. ![image](https://user-images.githubusercontent.com/10530022/224653929-4e4480bd-bb02-4bbe-bd44-2672bdf91a87.png) ### Design and Implementation This PR introduced a common mechanism registering custom logic for nn.Module's post forward hooks. And statistics for activation (StatisticsSubscriber) is one of the implementations. If there is other needs, we can define another XXSubscriber to do the customized things.
2023-03-23 12:34:24 +00:00
# ORTModule Training Convergence Investigation
## 1. Discovering
Convergence issues can be identified by:
- Large discrepancies in core training metrics including training loss, evaluation loss, model specific AUC metrics.
- Runtime failures (for example when the loss scaler reaches the minimum, triggering an exception).
Statistics tool for ORTModule convergence parity (#15020) ### Statistics tool for ORTModule convergence parity As ORTModule get more and more validated, it is pretty fast to intergrade PyTorch based model with ORT. The same time, we need make sure once there is convergence issue, we don't spend months of time to investigate. As part of this efforts, this PR is introducing a tool to dump activation statistics without much involvement from users. The dumping results contains only some statistic numbers plus sampled data, which is not big, compared with dumping all the tensors, it is much faster and space efficient. For us to use it, two single lines are needed before wrapping ORTModule. For baseline run, need also apply the same trick. ``` + from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber + SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)]) ``` Once you run the steps, following command can be used to merge result into per-step-summary respectively for ORT and baseline runs. ```bash python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output ``` Docs is added here as part of this PR [convergence investigation notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md) Based on the generated merged files, we can compare them with tools. ![image](https://user-images.githubusercontent.com/10530022/224653929-4e4480bd-bb02-4bbe-bd44-2672bdf91a87.png) ### Design and Implementation This PR introduced a common mechanism registering custom logic for nn.Module's post forward hooks. And statistics for activation (StatisticsSubscriber) is one of the implementations. If there is other needs, we can define another XXSubscriber to do the customized things.
2023-03-23 12:34:24 +00:00
Before looking into this further, we should clarify a few things (if possible):
- If we change the seed for the baseline run, whether the metric diff is big?
(Make sure the discrepancy is not introduced by randomness)
- What are the very first steps we see obvious divergence?
- Still reproducible once randomness is removed?
- Set same seeds
- Set the dropout ratio to 0
- Set compute to be deterministic and torch-comparable (TODO(pengwa): need a flag for this).
Statistics tool for ORTModule convergence parity (#15020) ### Statistics tool for ORTModule convergence parity As ORTModule get more and more validated, it is pretty fast to intergrade PyTorch based model with ORT. The same time, we need make sure once there is convergence issue, we don't spend months of time to investigate. As part of this efforts, this PR is introducing a tool to dump activation statistics without much involvement from users. The dumping results contains only some statistic numbers plus sampled data, which is not big, compared with dumping all the tensors, it is much faster and space efficient. For us to use it, two single lines are needed before wrapping ORTModule. For baseline run, need also apply the same trick. ``` + from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber + SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)]) ``` Once you run the steps, following command can be used to merge result into per-step-summary respectively for ORT and baseline runs. ```bash python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output ``` Docs is added here as part of this PR [convergence investigation notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md) Based on the generated merged files, we can compare them with tools. ![image](https://user-images.githubusercontent.com/10530022/224653929-4e4480bd-bb02-4bbe-bd44-2672bdf91a87.png) ### Design and Implementation This PR introduced a common mechanism registering custom logic for nn.Module's post forward hooks. And statistics for activation (StatisticsSubscriber) is one of the implementations. If there is other needs, we can define another XXSubscriber to do the customized things.
2023-03-23 12:34:24 +00:00
## 2. Collect Activation Statistics
### 2.1 Use `GlobalSubscriberManager` to collect `nn.Module` forward() outputs
Statistics tool for ORTModule convergence parity (#15020) ### Statistics tool for ORTModule convergence parity As ORTModule get more and more validated, it is pretty fast to intergrade PyTorch based model with ORT. The same time, we need make sure once there is convergence issue, we don't spend months of time to investigate. As part of this efforts, this PR is introducing a tool to dump activation statistics without much involvement from users. The dumping results contains only some statistic numbers plus sampled data, which is not big, compared with dumping all the tensors, it is much faster and space efficient. For us to use it, two single lines are needed before wrapping ORTModule. For baseline run, need also apply the same trick. ``` + from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber + SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)]) ``` Once you run the steps, following command can be used to merge result into per-step-summary respectively for ORT and baseline runs. ```bash python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output ``` Docs is added here as part of this PR [convergence investigation notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md) Based on the generated merged files, we can compare them with tools. ![image](https://user-images.githubusercontent.com/10530022/224653929-4e4480bd-bb02-4bbe-bd44-2672bdf91a87.png) ### Design and Implementation This PR introduced a common mechanism registering custom logic for nn.Module's post forward hooks. And statistics for activation (StatisticsSubscriber) is one of the implementations. If there is other needs, we can define another XXSubscriber to do the customized things.
2023-03-23 12:34:24 +00:00
<table>
<tr>
<th>Baseline</th>
<th>ORTModule</th>
</tr>
<tr>
<td>
<sub>
Statistics tool for ORTModule convergence parity (#15020) ### Statistics tool for ORTModule convergence parity As ORTModule get more and more validated, it is pretty fast to intergrade PyTorch based model with ORT. The same time, we need make sure once there is convergence issue, we don't spend months of time to investigate. As part of this efforts, this PR is introducing a tool to dump activation statistics without much involvement from users. The dumping results contains only some statistic numbers plus sampled data, which is not big, compared with dumping all the tensors, it is much faster and space efficient. For us to use it, two single lines are needed before wrapping ORTModule. For baseline run, need also apply the same trick. ``` + from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber + SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)]) ``` Once you run the steps, following command can be used to merge result into per-step-summary respectively for ORT and baseline runs. ```bash python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output ``` Docs is added here as part of this PR [convergence investigation notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md) Based on the generated merged files, we can compare them with tools. ![image](https://user-images.githubusercontent.com/10530022/224653929-4e4480bd-bb02-4bbe-bd44-2672bdf91a87.png) ### Design and Implementation This PR introduced a common mechanism registering custom logic for nn.Module's post forward hooks. And statistics for activation (StatisticsSubscriber) is one of the implementations. If there is other needs, we can define another XXSubscriber to do the customized things.
2023-03-23 12:34:24 +00:00
```python
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(
model, [StatisticsSubscriber(output_dir="pt_out", override_output_dir=True)]
)
Statistics tool for ORTModule convergence parity (#15020) ### Statistics tool for ORTModule convergence parity As ORTModule get more and more validated, it is pretty fast to intergrade PyTorch based model with ORT. The same time, we need make sure once there is convergence issue, we don't spend months of time to investigate. As part of this efforts, this PR is introducing a tool to dump activation statistics without much involvement from users. The dumping results contains only some statistic numbers plus sampled data, which is not big, compared with dumping all the tensors, it is much faster and space efficient. For us to use it, two single lines are needed before wrapping ORTModule. For baseline run, need also apply the same trick. ``` + from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber + SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)]) ``` Once you run the steps, following command can be used to merge result into per-step-summary respectively for ORT and baseline runs. ```bash python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output ``` Docs is added here as part of this PR [convergence investigation notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md) Based on the generated merged files, we can compare them with tools. ![image](https://user-images.githubusercontent.com/10530022/224653929-4e4480bd-bb02-4bbe-bd44-2672bdf91a87.png) ### Design and Implementation This PR introduced a common mechanism registering custom logic for nn.Module's post forward hooks. And statistics for activation (StatisticsSubscriber) is one of the implementations. If there is other needs, we can define another XXSubscriber to do the customized things.
2023-03-23 12:34:24 +00:00
```
</sub>
</td>
<td>
<sub>
Statistics tool for ORTModule convergence parity (#15020) ### Statistics tool for ORTModule convergence parity As ORTModule get more and more validated, it is pretty fast to intergrade PyTorch based model with ORT. The same time, we need make sure once there is convergence issue, we don't spend months of time to investigate. As part of this efforts, this PR is introducing a tool to dump activation statistics without much involvement from users. The dumping results contains only some statistic numbers plus sampled data, which is not big, compared with dumping all the tensors, it is much faster and space efficient. For us to use it, two single lines are needed before wrapping ORTModule. For baseline run, need also apply the same trick. ``` + from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber + SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)]) ``` Once you run the steps, following command can be used to merge result into per-step-summary respectively for ORT and baseline runs. ```bash python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output ``` Docs is added here as part of this PR [convergence investigation notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md) Based on the generated merged files, we can compare them with tools. ![image](https://user-images.githubusercontent.com/10530022/224653929-4e4480bd-bb02-4bbe-bd44-2672bdf91a87.png) ### Design and Implementation This PR introduced a common mechanism registering custom logic for nn.Module's post forward hooks. And statistics for activation (StatisticsSubscriber) is one of the implementations. If there is other needs, we can define another XXSubscriber to do the customized things.
2023-03-23 12:34:24 +00:00
```python
model = ORTModule(model)
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(
model, [StatisticsSubscriber(output_dir="ort_out", override_output_dir=True)]
)
Statistics tool for ORTModule convergence parity (#15020) ### Statistics tool for ORTModule convergence parity As ORTModule get more and more validated, it is pretty fast to intergrade PyTorch based model with ORT. The same time, we need make sure once there is convergence issue, we don't spend months of time to investigate. As part of this efforts, this PR is introducing a tool to dump activation statistics without much involvement from users. The dumping results contains only some statistic numbers plus sampled data, which is not big, compared with dumping all the tensors, it is much faster and space efficient. For us to use it, two single lines are needed before wrapping ORTModule. For baseline run, need also apply the same trick. ``` + from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber + SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)]) ``` Once you run the steps, following command can be used to merge result into per-step-summary respectively for ORT and baseline runs. ```bash python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output ``` Docs is added here as part of this PR [convergence investigation notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md) Based on the generated merged files, we can compare them with tools. ![image](https://user-images.githubusercontent.com/10530022/224653929-4e4480bd-bb02-4bbe-bd44-2672bdf91a87.png) ### Design and Implementation This PR introduced a common mechanism registering custom logic for nn.Module's post forward hooks. And statistics for activation (StatisticsSubscriber) is one of the implementations. If there is other needs, we can define another XXSubscriber to do the customized things.
2023-03-23 12:34:24 +00:00
```
</sub>
</td>
</tr>
<tr>
<td>
- Run training script to the steps that trigger the divergence.
- A folder named `pt_out` is created in the current working directory.
- For each step, there is a folder containing summaries for every activation tensor.
</td>
<td>
- Run training script to the steps that trigger the divergence.
- Similarly, a folder named `ort_out` is created in the current working directory.
- `StatisticsSubscriber` can be subscribed before OR after wrapping ORTModule.
</td>
</tr>
</table>
Arguments:
- output_dir: the directory in all activation statistics files will be stored.
- `start_step` [optional]: the first step that runs subscriber actions.
- `end_step` [optional]: the end step (exclusively) that runs subscriber actions.
- `override_output_dir`: whether `output_dir` can be overridden if it already exists.
- `run_on_cpu`: whether to run the subscriber actions on CPU, this should be the last resort when inserted
inspector node affects memory peak causing the original recipe run to fail with OOM.
- `bucket_size`: the size of the bucket to split the statistic calculation.
### 2.2 Use `_InspectActivation` to collect intermediate tensors in a `nn.Module` forward()
The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output tensors will be dumped, if you want to
dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example:
```diff
class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config: BloomConfig):
...
def forward(self, input_ids, ...):
...
transformer_outputs = self.transformer(...)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
+ lm_logits = _InspectActivation.apply("lm_logits", None, GlobalSubscriberManager.get_run_context(), lm_logits)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_logits = _InspectActivation.apply("shift_logits", None, GlobalSubscriberManager.get_run_context(), shift_logits)
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
return loss
```
Be noted, make sure the activation name (as the first argument of `_InspectActivation.apply`) is unique, otherwise
stat file using the activation name will be overwritten by the last write. The dumped data are stored in the `output_dir`.
### 2.3 Collect on multiple ranks
`GlobalSubscriberManager` does not explicitly handle the racing condition when multiple ranks write into the same file path,
here is the example if you want to collect statistics on multiple ranks:
```python
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir="ort_out_" + str(torch.distributed.get_rank()),
override_output_dir=True)])
```
Statistics tool for ORTModule convergence parity (#15020) ### Statistics tool for ORTModule convergence parity As ORTModule get more and more validated, it is pretty fast to intergrade PyTorch based model with ORT. The same time, we need make sure once there is convergence issue, we don't spend months of time to investigate. As part of this efforts, this PR is introducing a tool to dump activation statistics without much involvement from users. The dumping results contains only some statistic numbers plus sampled data, which is not big, compared with dumping all the tensors, it is much faster and space efficient. For us to use it, two single lines are needed before wrapping ORTModule. For baseline run, need also apply the same trick. ``` + from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber + SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)]) ``` Once you run the steps, following command can be used to merge result into per-step-summary respectively for ORT and baseline runs. ```bash python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output ``` Docs is added here as part of this PR [convergence investigation notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md) Based on the generated merged files, we can compare them with tools. ![image](https://user-images.githubusercontent.com/10530022/224653929-4e4480bd-bb02-4bbe-bd44-2672bdf91a87.png) ### Design and Implementation This PR introduced a common mechanism registering custom logic for nn.Module's post forward hooks. And statistics for activation (StatisticsSubscriber) is one of the implementations. If there is other needs, we can define another XXSubscriber to do the customized things.
2023-03-23 12:34:24 +00:00
Check [StatisticsSubscriber implementation](../orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py) for more information.
Statistics tool for ORTModule convergence parity (#15020) ### Statistics tool for ORTModule convergence parity As ORTModule get more and more validated, it is pretty fast to intergrade PyTorch based model with ORT. The same time, we need make sure once there is convergence issue, we don't spend months of time to investigate. As part of this efforts, this PR is introducing a tool to dump activation statistics without much involvement from users. The dumping results contains only some statistic numbers plus sampled data, which is not big, compared with dumping all the tensors, it is much faster and space efficient. For us to use it, two single lines are needed before wrapping ORTModule. For baseline run, need also apply the same trick. ``` + from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber + SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)]) ``` Once you run the steps, following command can be used to merge result into per-step-summary respectively for ORT and baseline runs. ```bash python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output ``` Docs is added here as part of this PR [convergence investigation notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md) Based on the generated merged files, we can compare them with tools. ![image](https://user-images.githubusercontent.com/10530022/224653929-4e4480bd-bb02-4bbe-bd44-2672bdf91a87.png) ### Design and Implementation This PR introduced a common mechanism registering custom logic for nn.Module's post forward hooks. And statistics for activation (StatisticsSubscriber) is one of the implementations. If there is other needs, we can define another XXSubscriber to do the customized things.
2023-03-23 12:34:24 +00:00
### 2.4 Run command to generate per-step summary
Statistics tool for ORTModule convergence parity (#15020) ### Statistics tool for ORTModule convergence parity As ORTModule get more and more validated, it is pretty fast to intergrade PyTorch based model with ORT. The same time, we need make sure once there is convergence issue, we don't spend months of time to investigate. As part of this efforts, this PR is introducing a tool to dump activation statistics without much involvement from users. The dumping results contains only some statistic numbers plus sampled data, which is not big, compared with dumping all the tensors, it is much faster and space efficient. For us to use it, two single lines are needed before wrapping ORTModule. For baseline run, need also apply the same trick. ``` + from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber + SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)]) ``` Once you run the steps, following command can be used to merge result into per-step-summary respectively for ORT and baseline runs. ```bash python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output ``` Docs is added here as part of this PR [convergence investigation notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md) Based on the generated merged files, we can compare them with tools. ![image](https://user-images.githubusercontent.com/10530022/224653929-4e4480bd-bb02-4bbe-bd44-2672bdf91a87.png) ### Design and Implementation This PR introduced a common mechanism registering custom logic for nn.Module's post forward hooks. And statistics for activation (StatisticsSubscriber) is one of the implementations. If there is other needs, we can define another XXSubscriber to do the customized things.
2023-03-23 12:34:24 +00:00
```bash
python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output
```
### 2.5 Manually compare the generated per-step summary to find the first big diff.