mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Improve memory matrix for ORTModule (#19620)
### Memory matrix for ORTModule Collect parameter/gradient/buffers sizes also. Exposed as a function, can be used externally for debugging purpose. ``` 2024-02-27 07:18:55,283 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,322 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,358 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,438 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▏ | 2/3200 [01:27<32:05:11, 36.12s/it]2024-02-27 07:18:55,498 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,537 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,576 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,657 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▏ | 3/3200 [01:27<17:30:57, 19.72s/it]2024-02-27 07:18:55,711 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,750 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,786 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,867 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 [2024-02-27 07:18:55,886] [INFO] [loss_scaler.py:190:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536, but hysteresis is 2. Reducing hysteresis to 1 0%|▎ | 4/3200 [01:28<10:39:52, 12.01s/it]2024-02-27 07:18:55,902 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,944 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,979 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,060 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▍ | 5/3200 [01:28<6:53:04, 7.76s/it]2024-02-27 07:18:56,115 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,154 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,190 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,270 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▍ | 6/3200 [01:28<4:36:19, 5.19s/it]2024-02-27 07:18:56,323 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,365 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,398 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,478 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▌ | 7/3200 [01:28<3:09:33, 3.56s/it]2024-02-27 07:18:56,533 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,572 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,608 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,727 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▌ | 8/3200 [01:28<2:13:48, 2.52s/it]2024-02-27 07:18:56,806 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,846 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,882 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,962 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▋ | 9/3200 [01:29<1:36:03, 1.81s/it]2024-02-27 07:18:57,053 orttraining.rank-0 [INFO] - rank-0 step 9 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:57,094 orttraining.rank-0 [INFO] - rank-0 step 9 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 ```
This commit is contained in:
parent
f95c0773a1
commit
026e3178ae
3 changed files with 88 additions and 27 deletions
|
|
@ -14,7 +14,7 @@ from onnx import onnx_pb as onnx_proto
|
|||
from sympy import Symbol, simplify
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
||||
from onnxruntime.training.utils import PTable
|
||||
from onnxruntime.training.utils import PTable, log_memory_usage
|
||||
|
||||
from ._execution_agent import TrainingAgent
|
||||
from .options import _MemoryOptimizationLevel, _RuntimeOptions
|
||||
|
|
@ -509,6 +509,8 @@ class MemoryObserver:
|
|||
|
||||
self._is_first_inspect = True
|
||||
|
||||
self._m = m
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if memory inspector is enabled."""
|
||||
return self._is_enabled
|
||||
|
|
@ -621,29 +623,13 @@ class MemoryObserver:
|
|||
need_print = self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0)
|
||||
|
||||
if need_print:
|
||||
cur_mem_allocated = self._normalize(torch.cuda.memory_allocated())
|
||||
max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated())
|
||||
cur_mem_cached = self._normalize(torch.cuda.memory_reserved())
|
||||
max_mem_cached = self._normalize(torch.cuda.max_memory_reserved())
|
||||
torch_mem_stat = torch.cuda.memory_stats()
|
||||
cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0))
|
||||
max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0))
|
||||
|
||||
mem_stats = [
|
||||
["phase", _convert_phase_to_string(cur_phase)],
|
||||
["allocated", cur_mem_allocated], # current memory allocated for tensors
|
||||
["max allocated", max_mem_allocated], # peak memory allocated for tensors
|
||||
["cached", cur_mem_cached], # current memory cached for the caching allocator
|
||||
["max cached", max_mem_cached], # peak memory cached for caching allocator.
|
||||
["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory
|
||||
["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory
|
||||
]
|
||||
|
||||
summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})"
|
||||
for stat in mem_stats:
|
||||
summ += f" | {stat[0]}: {stat[1]}"
|
||||
|
||||
self._logger.info(summ)
|
||||
log_memory_usage(
|
||||
_convert_phase_to_string(cur_phase),
|
||||
rank_0_only=True,
|
||||
step_info=f"step {self._current_step}",
|
||||
logger=self._logger,
|
||||
module=self._m,
|
||||
)
|
||||
|
||||
if cur_phase == self._last_phase:
|
||||
self._increase_step()
|
||||
|
|
@ -655,9 +641,6 @@ class MemoryObserver:
|
|||
def _increase_step(self):
|
||||
self._current_step += 1
|
||||
|
||||
def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str:
|
||||
return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}"
|
||||
|
||||
def display_memory_optimization_plans(self, memory_optimizer_config, details=False) -> Tuple[List[str], PTable]:
|
||||
mem_plan_count = len(self.cluster_id_combination_to_saving_symbolics_map)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from onnxruntime.training.utils.torch_io_helper import (
|
|||
unflatten_data_using_schema,
|
||||
)
|
||||
from onnxruntime.training.utils.torch_profile_utils import (
|
||||
log_memory_usage,
|
||||
nvtx_function_decorator,
|
||||
torch_nvtx_range_pop,
|
||||
torch_nvtx_range_push,
|
||||
|
|
@ -31,6 +32,7 @@ __all__ = [
|
|||
"torch_nvtx_range_push",
|
||||
"torch_nvtx_range_pop",
|
||||
"nvtx_function_decorator",
|
||||
"log_memory_usage",
|
||||
"pytorch_type_to_onnx_dtype",
|
||||
"onnx_dtype_to_pytorch_dtype",
|
||||
"pytorch_scalar_type_to_pytorch_dtype",
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
|
@ -26,3 +28,77 @@ def nvtx_function_decorator(func):
|
|||
return ret_val
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
def log_memory_usage(cur_phase: str, rank_0_only=True, step_info="", logger=None, module=None):
|
||||
"""Log memory usage for the current phase.
|
||||
Args:
|
||||
cur_phase (str): The current phase.
|
||||
rank_0_only (bool, optional): Only log the memory usage for rank 0. Defaults to True.
|
||||
step_info (str, optional): The step information. Defaults to "".
|
||||
logger (logging.Logger, optional): The logger to log the memory usage. Defaults to None, which means print to stdout.
|
||||
module (torch.nn.Module, optional): The module to get parameter, buffer and grad sizes. Defaults to None.
|
||||
"""
|
||||
rank = 0
|
||||
if rank_0_only is True:
|
||||
if torch.distributed.is_initialized():
|
||||
rank = torch.distributed.get_rank()
|
||||
if rank != 0:
|
||||
return
|
||||
|
||||
_normalizer_factor = float(1024 * 1024)
|
||||
_normalizer_unit = "MiB"
|
||||
|
||||
def _normalize(mem_size_in_bytes: float | int) -> str:
|
||||
return f"{float(mem_size_in_bytes) / _normalizer_factor:.0f}"
|
||||
|
||||
cur_mem_allocated = _normalize(torch.cuda.memory_allocated())
|
||||
max_mem_allocated = _normalize(torch.cuda.max_memory_allocated())
|
||||
cur_mem_cached = _normalize(torch.cuda.memory_reserved())
|
||||
max_mem_cached = _normalize(torch.cuda.max_memory_reserved())
|
||||
torch_mem_stat = torch.cuda.memory_stats()
|
||||
cur_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0))
|
||||
max_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0))
|
||||
|
||||
mem_stats = [
|
||||
["phase", cur_phase],
|
||||
["allocated", cur_mem_allocated], # current memory allocated for tensors
|
||||
["max allocated", max_mem_allocated], # peak memory allocated for tensors
|
||||
["cached", cur_mem_cached], # current memory cached for the caching allocator
|
||||
["max cached", max_mem_cached], # peak memory cached for caching allocator.
|
||||
["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory
|
||||
["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory
|
||||
]
|
||||
|
||||
# Calculate the total size of parameters and gradients in the model
|
||||
if module:
|
||||
param_total_size = 0
|
||||
grad_total_size = 0
|
||||
for p in module.parameters():
|
||||
if p.is_cuda:
|
||||
param_total_size += p.numel() * p.element_size()
|
||||
if p.grad is not None and p.grad.is_cuda:
|
||||
grad_total_size += p.grad.numel() * p.grad.element_size()
|
||||
|
||||
# Calculate the total size of buffers in the model
|
||||
buffer_total_size = 0
|
||||
for b in module.buffers():
|
||||
if b.is_cuda:
|
||||
buffer_total_size += b.numel() * b.element_size()
|
||||
|
||||
mem_stats.extend(
|
||||
[
|
||||
["param", _normalize(param_total_size)],
|
||||
["grad", _normalize(grad_total_size)],
|
||||
["buffer", _normalize(buffer_total_size)],
|
||||
]
|
||||
)
|
||||
|
||||
summ = f"rank-{rank} {step_info} memory ({_normalizer_unit})"
|
||||
for stat in mem_stats:
|
||||
summ += f" | {stat[0]}: {stat[1]}"
|
||||
|
||||
if logger is None:
|
||||
print(summ)
|
||||
else:
|
||||
logger.info(summ)
|
||||
|
|
|
|||
Loading…
Reference in a new issue