mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
### Introduce ZeROOffloadSubscriber for ORTModule As part of the work: integrate ORTModule with DeepSpeed stage3, this PR mainly focus on moving original PyTorch-based (leveraging hooks) param partition/offload implementation to ORTModule compatible implementation. Changes include: 1. Refactor `SubscriberBase`/`SubcriberManager` to support pre-forward/post_forward hooks. 2. Implement new `ZeROOffloadSubscriber` by re-using DeepSpeed hook function as much as possible. Since all hook functions are defined in `DeepSpeedZeRoOffload._register_hooks_recursively` and `DeepSpeedZeRoOffload.setup_zero_stage3_hooks`, and the good thing is, the closure is not complex, all hooks are referencing the owning `DeepSpeedZeRoOffload` instance, so we can create new hook function with `FunctionType` by binding the owning `DeepSpeedZeRoOffload` instance, then call the new created function in subscriber's `pre_forward_module_apply_impl` and `post_forward_module_apply_impl` interfaces. 3. Monkey patch `DeepSpeedZeRoOffload.setup_zero_stage3_hooks` to register the `ZeROOffloadSubscriber` for the model, then we don't need change any code on the DeepSpeed repo (at least so far). 4. Fix the ATen embedding custom symbolic exporter function by tolerating weights size be (0) (changed by DeepSpeed zero stage 3). UT will be added once stage3 is fully supported. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
140 lines
4.9 KiB
Markdown
140 lines
4.9 KiB
Markdown
# ORTModule Training Convergence Investigation
|
|
|
|
## 1. Discovering
|
|
|
|
Convergence issues can be identified by:
|
|
- Large discrepancies in core training metrics including training loss, evaluation loss, model specific AUC metrics.
|
|
- Runtime failures (for example when the loss scaler reaches the minimum, triggering an exception).
|
|
|
|
Before looking into this further, we should clarify a few things (if possible):
|
|
- If we change the seed for the baseline run, whether the metric diff is big?
|
|
(Make sure the discrepancy is not introduced by randomness)
|
|
- What are the very first steps we see obvious divergence?
|
|
- Still reproducible once randomness is removed?
|
|
- Set same seeds
|
|
- Set the dropout ratio to 0
|
|
- Set compute to be deterministic and torch-comparable (TODO(pengwa): need a flag for this).
|
|
|
|
|
|
## 2. Collect Activation Statistics
|
|
|
|
|
|
### 2.1 Use `GlobalSubscriberManager` to collect `nn.Module` forward() outputs
|
|
|
|
<table>
|
|
<tr>
|
|
<th>Baseline</th>
|
|
<th>ORTModule</th>
|
|
</tr>
|
|
<tr>
|
|
<td>
|
|
<sub>
|
|
|
|
```python
|
|
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
|
|
GlobalSubscriberManager.subscribe(
|
|
model, [StatisticsSubscriber(output_dir="pt_out", override_output_dir=True)]
|
|
)
|
|
```
|
|
|
|
</sub>
|
|
</td>
|
|
<td>
|
|
<sub>
|
|
|
|
```python
|
|
model = ORTModule(model)
|
|
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
|
|
GlobalSubscriberManager.subscribe(
|
|
model, [StatisticsSubscriber(output_dir="ort_out", override_output_dir=True)]
|
|
)
|
|
```
|
|
|
|
</sub>
|
|
</td>
|
|
</tr>
|
|
|
|
<tr>
|
|
<td>
|
|
|
|
- Run training script to the steps that trigger the divergence.
|
|
- A folder named `pt_out` is created in the current working directory.
|
|
- For each step, there is a folder containing summaries for every activation tensor.
|
|
|
|
</td>
|
|
<td>
|
|
|
|
|
|
- Run training script to the steps that trigger the divergence.
|
|
- Similarly, a folder named `ort_out` is created in the current working directory.
|
|
- `StatisticsSubscriber` can be subscribed before OR after wrapping ORTModule.
|
|
|
|
</td>
|
|
</tr>
|
|
</table>
|
|
|
|
|
|
Arguments:
|
|
- output_dir: the directory in all activation statistics files will be stored.
|
|
- `start_step` [optional]: the first step that runs subscriber actions.
|
|
- `end_step` [optional]: the end step (exclusively) that runs subscriber actions.
|
|
- `override_output_dir`: whether `output_dir` can be overridden if it already exists.
|
|
- `run_on_cpu`: whether to run the subscriber actions on CPU, this should be the last resort when inserted
|
|
inspector node affects memory peak causing the original recipe run to fail with OOM.
|
|
- `bucket_size`: the size of the bucket to split the statistic calculation.
|
|
|
|
### 2.2 Use `inspect_activation` to collect intermediate tensors in a `nn.Module` forward()
|
|
|
|
The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output tensors will be dumped, if you want to
|
|
dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example:
|
|
|
|
```diff
|
|
+ from onnxruntime.training.utils import inspect_activation
|
|
class BloomForCausalLM(BloomPreTrainedModel):
|
|
def __init__(self, config: BloomConfig):
|
|
...
|
|
|
|
def forward(self, input_ids, ...):
|
|
...
|
|
transformer_outputs = self.transformer(...)
|
|
hidden_states = transformer_outputs[0]
|
|
lm_logits = self.lm_head(hidden_states)
|
|
+ lm_logits = inspect_activation("lm_logits", lm_logits)
|
|
# Shift so that tokens < n predict n
|
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
|
+ shift_logits = inspect_activation("shift_logits", shift_logits)
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
batch_size, seq_length, vocab_size = shift_logits.shape
|
|
# Flatten the tokens
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(
|
|
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
|
|
)
|
|
|
|
return loss
|
|
```
|
|
|
|
Be noted, make sure the activation name (as the first argument of `inspect_activation`) is unique, otherwise
|
|
stat file using the activation name will be overwritten by the last write. The dumped data are stored in the `output_dir`.
|
|
|
|
|
|
### 2.3 Collect on multiple ranks
|
|
|
|
`GlobalSubscriberManager` does not explicitly handle the racing condition when multiple ranks write into the same file path,
|
|
here is the example if you want to collect statistics on multiple ranks:
|
|
|
|
```python
|
|
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
|
|
GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir="ort_out_" + str(torch.distributed.get_rank()),
|
|
override_output_dir=True)])
|
|
```
|
|
|
|
Check [StatisticsSubscriber implementation](../orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py) for more information.
|
|
|
|
### 2.4 Run command to generate per-step summary
|
|
|
|
```bash
|
|
python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output
|
|
```
|
|
|
|
### 2.5 Manually compare the generated per-step summary to find the first big diff.
|