onnxruntime/orttraining
pengwa 735a32fee1
Introduce memory observer for ORTModule (#16213)
### Introduce memory observer for ORTModule

To analyze memory usage for ORTModule training, we need collect
per-iteration memory footprint in different stages (pre-forward,
post-forward, pre-backward, and post-backward).

Currently we only collect the data using torch.cuda APIs. The next step
is, we could collect the detailed stashed activation list and its
percentage within ORT backend, which is beyond this PR.

Sample as below: 
```
0/8] step 0 memory (MiB) | phase: pre_forward | allocated: 1866 | max allocated: 1866 | cached: 1874 | max cached: 1874 | inactive: 8 | max inactive: 8
[0/8] step 0 memory (MiB) | phase: post_forward | allocated: 23277 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 193 | max inactive: 405
[0/8] step 0 memory (MiB) | phase: pre_backward | allocated: 23277 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 193 | max inactive: 405
[0/8] step 0 memory (MiB) | phase: post_backward | allocated: 2932 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 6158 | max inactive: 6158
  0%|█                                                                                                                                                                                                            | 1/200 [00:26<1:26:18, 26.02s/it]
[0/8] step 1 memory (MiB) | phase: pre_forward | allocated: 2356 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 2454 | max inactive: 6165
[0/8] step 1 memory (MiB) | phase: post_forward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165
[0/8] step 1 memory (MiB) | phase: pre_backward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165
[0/8] step 1 memory (MiB) | phase: post_backward | allocated: 3422 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 5284 | max inactive: 6165
  1%|██                                                                                                                                                                                                             | 2/200 [00:26<36:47, 11.15s/it]
[0/8] step 2 memory (MiB) | phase: pre_forward | allocated: 2356 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2454 | max inactive: 6165
[0/8] step 2 memory (MiB) | phase: post_forward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165
[0/8] step 2 memory (MiB) | phase: pre_backward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165
[0/8] step 2 memory (MiB) | phase: post_backward | allocated: 3422 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 5284 | max inactive: 6165
```
2023-06-15 15:45:36 +08:00
..
orttraining Introduce memory observer for ORTModule (#16213) 2023-06-15 15:45:36 +08:00
pytorch_frontend_examples Enable pylint and numpy rules (#15218) 2023-03-27 20:37:53 -07:00
tools [ROCm] reduce batch size to fix CI error (#15714) 2023-05-16 13:10:02 +08:00