Commit graph

3 commits

Author SHA1 Message Date
pengwa
40bcc0441b
Enhance StatisticsSubscriber (#16098)
### Enhance StatisticsSubscriber

There are few improvements for `StatisticsSubscriber`:

- Reduce peak memory impact for tensors (having many many many elements,
consuming too much GPU memory, causing original recipe run failed with
OOM), by split the statistics into two phases (split into buckets, and
merge result across buckets).
- Allow dump intermediate tensors. Originally only nn.Module forward()'s
return value are dumped, there are requirements we want to inspect some
specific intermediate tensor in the forward() function, now we support
it.
- Add documents for collecting dumps on multiple ranks

Docs link on this branch for better view:
https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool_v2/docs/ORTModule_Convergence_Notes.md

---------

Co-authored-by: mindest <30493312+mindest@users.noreply.github.com>
2023-06-12 18:32:08 +08:00
pengwa
5baf5f506b
log level control + fix typos (#15302)
### log level control + fix typos
2023-04-04 20:19:13 +08:00
pengwa
1d32285536
Statistics tool for ORTModule convergence parity (#15020)
### Statistics tool for ORTModule convergence parity

As ORTModule get more and more validated, it is pretty fast to
intergrade PyTorch based model with ORT.

The same time, we need make sure once there is convergence issue, we
don't spend months of time to investigate. As part of this efforts, this
PR is introducing a tool to dump activation statistics without much
involvement from users. The dumping results contains only some statistic
numbers plus sampled data, which is not big, compared with dumping all
the tensors, it is much faster and space efficient.

For us to use it, two single lines are needed before wrapping ORTModule.
For baseline run, need also apply the same trick.

```
+	from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber
+	SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)])
```

Once you run the steps, following command can be used to merge result
into per-step-summary respectively for ORT and baseline runs.
 
```bash
python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output
```

Docs is added here as part of this PR [convergence investigation
notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md)

Based on the generated merged files, we can compare them with tools. 


![image](https://user-images.githubusercontent.com/10530022/224653929-4e4480bd-bb02-4bbe-bd44-2672bdf91a87.png)

### Design and Implementation

This PR introduced a common mechanism registering custom logic for
nn.Module's post forward hooks. And statistics for activation
(StatisticsSubscriber) is one of the implementations. If there is other
needs, we can define another XXSubscriber to do the customized things.
2023-03-23 20:34:24 +08:00