Improve memory matrix for ORTModule (#19620)

### Memory matrix for ORTModule

Collect  parameter/gradient/buffers sizes also. 
Exposed as a function, can be used externally for debugging purpose. 


```
2024-02-27 07:18:55,283 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,322 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,358 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,438 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
  0%|▏                                                                                                                                                                                                                                              | 2/3200 [01:27<32:05:11, 36.12s/it]2024-02-27 07:18:55,498 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,537 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,576 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,657 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
  0%|▏                                                                                                                                                                                                                                              | 3/3200 [01:27<17:30:57, 19.72s/it]2024-02-27 07:18:55,711 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,750 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,786 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,867 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
[2024-02-27 07:18:55,886] [INFO] [loss_scaler.py:190:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536, but hysteresis is 2. Reducing hysteresis to 1
  0%|▎                                                                                                                                                                                                                                              | 4/3200 [01:28<10:39:52, 12.01s/it]2024-02-27 07:18:55,902 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,944 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:55,979 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,060 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
  0%|▍                                                                                                                                                                                                                                               | 5/3200 [01:28<6:53:04,  7.76s/it]2024-02-27 07:18:56,115 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,154 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,190 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,270 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
  0%|▍                                                                                                                                                                                                                                               | 6/3200 [01:28<4:36:19,  5.19s/it]2024-02-27 07:18:56,323 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,365 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,398 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,478 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
  0%|▌                                                                                                                                                                                                                                               | 7/3200 [01:28<3:09:33,  3.56s/it]2024-02-27 07:18:56,533 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,572 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,608 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,727 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
  0%|▌                                                                                                                                                                                                                                               | 8/3200 [01:28<2:13:48,  2.52s/it]2024-02-27 07:18:56,806 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,846 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,882 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:56,962 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8
  0%|▋                                                                                                                                                                                                                                               | 9/3200 [01:29<1:36:03,  1.81s/it]2024-02-27 07:18:57,053 orttraining.rank-0 [INFO] - rank-0 step 9 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8
2024-02-27 07:18:57,094 orttraining.rank-0 [INFO] - rank-0 step 9 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8

```
This commit is contained in:
pengwa 2024-02-28 15:57:05 +08:00 committed by GitHub
parent f95c0773a1
commit 026e3178ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 88 additions and 27 deletions

View file

@ -14,7 +14,7 @@ from onnx import onnx_pb as onnx_proto
from sympy import Symbol, simplify
from sympy.parsing.sympy_parser import parse_expr
from onnxruntime.training.utils import PTable
from onnxruntime.training.utils import PTable, log_memory_usage
from ._execution_agent import TrainingAgent
from .options import _MemoryOptimizationLevel, _RuntimeOptions
@ -509,6 +509,8 @@ class MemoryObserver:
self._is_first_inspect = True
self._m = m
def is_enabled(self) -> bool:
"""Check if memory inspector is enabled."""
return self._is_enabled
@ -621,29 +623,13 @@ class MemoryObserver:
need_print = self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0)
if need_print:
cur_mem_allocated = self._normalize(torch.cuda.memory_allocated())
max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated())
cur_mem_cached = self._normalize(torch.cuda.memory_reserved())
max_mem_cached = self._normalize(torch.cuda.max_memory_reserved())
torch_mem_stat = torch.cuda.memory_stats()
cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0))
max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0))
mem_stats = [
["phase", _convert_phase_to_string(cur_phase)],
["allocated", cur_mem_allocated], # current memory allocated for tensors
["max allocated", max_mem_allocated], # peak memory allocated for tensors
["cached", cur_mem_cached], # current memory cached for the caching allocator
["max cached", max_mem_cached], # peak memory cached for caching allocator.
["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory
["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory
]
summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})"
for stat in mem_stats:
summ += f" | {stat[0]}: {stat[1]}"
self._logger.info(summ)
log_memory_usage(
_convert_phase_to_string(cur_phase),
rank_0_only=True,
step_info=f"step {self._current_step}",
logger=self._logger,
module=self._m,
)
if cur_phase == self._last_phase:
self._increase_step()
@ -655,9 +641,6 @@ class MemoryObserver:
def _increase_step(self):
self._current_step += 1
def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str:
return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}"
def display_memory_optimization_plans(self, memory_optimizer_config, details=False) -> Tuple[List[str], PTable]:
mem_plan_count = len(self.cluster_id_combination_to_saving_symbolics_map)

View file

@ -12,6 +12,7 @@ from onnxruntime.training.utils.torch_io_helper import (
unflatten_data_using_schema,
)
from onnxruntime.training.utils.torch_profile_utils import (
log_memory_usage,
nvtx_function_decorator,
torch_nvtx_range_pop,
torch_nvtx_range_push,
@ -31,6 +32,7 @@ __all__ = [
"torch_nvtx_range_push",
"torch_nvtx_range_pop",
"nvtx_function_decorator",
"log_memory_usage",
"pytorch_type_to_onnx_dtype",
"onnx_dtype_to_pytorch_dtype",
"pytorch_scalar_type_to_pytorch_dtype",

View file

@ -3,6 +3,8 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations
import torch
@ -26,3 +28,77 @@ def nvtx_function_decorator(func):
return ret_val
return wrapped_fn
def log_memory_usage(cur_phase: str, rank_0_only=True, step_info="", logger=None, module=None):
"""Log memory usage for the current phase.
Args:
cur_phase (str): The current phase.
rank_0_only (bool, optional): Only log the memory usage for rank 0. Defaults to True.
step_info (str, optional): The step information. Defaults to "".
logger (logging.Logger, optional): The logger to log the memory usage. Defaults to None, which means print to stdout.
module (torch.nn.Module, optional): The module to get parameter, buffer and grad sizes. Defaults to None.
"""
rank = 0
if rank_0_only is True:
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
if rank != 0:
return
_normalizer_factor = float(1024 * 1024)
_normalizer_unit = "MiB"
def _normalize(mem_size_in_bytes: float | int) -> str:
return f"{float(mem_size_in_bytes) / _normalizer_factor:.0f}"
cur_mem_allocated = _normalize(torch.cuda.memory_allocated())
max_mem_allocated = _normalize(torch.cuda.max_memory_allocated())
cur_mem_cached = _normalize(torch.cuda.memory_reserved())
max_mem_cached = _normalize(torch.cuda.max_memory_reserved())
torch_mem_stat = torch.cuda.memory_stats()
cur_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0))
max_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0))
mem_stats = [
["phase", cur_phase],
["allocated", cur_mem_allocated], # current memory allocated for tensors
["max allocated", max_mem_allocated], # peak memory allocated for tensors
["cached", cur_mem_cached], # current memory cached for the caching allocator
["max cached", max_mem_cached], # peak memory cached for caching allocator.
["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory
["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory
]
# Calculate the total size of parameters and gradients in the model
if module:
param_total_size = 0
grad_total_size = 0
for p in module.parameters():
if p.is_cuda:
param_total_size += p.numel() * p.element_size()
if p.grad is not None and p.grad.is_cuda:
grad_total_size += p.grad.numel() * p.grad.element_size()
# Calculate the total size of buffers in the model
buffer_total_size = 0
for b in module.buffers():
if b.is_cuda:
buffer_total_size += b.numel() * b.element_size()
mem_stats.extend(
[
["param", _normalize(param_total_size)],
["grad", _normalize(grad_total_size)],
["buffer", _normalize(buffer_total_size)],
]
)
summ = f"rank-{rank} {step_info} memory ({_normalizer_unit})"
for stat in mem_stats:
summ += f" | {stat[0]}: {stat[1]}"
if logger is None:
print(summ)
else:
logger.info(summ)