diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index 078ce4d27c..772b9bd9e3 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -14,7 +14,7 @@ from onnx import onnx_pb as onnx_proto from sympy import Symbol, simplify from sympy.parsing.sympy_parser import parse_expr -from onnxruntime.training.utils import PTable +from onnxruntime.training.utils import PTable, log_memory_usage from ._execution_agent import TrainingAgent from .options import _MemoryOptimizationLevel, _RuntimeOptions @@ -509,6 +509,8 @@ class MemoryObserver: self._is_first_inspect = True + self._m = m + def is_enabled(self) -> bool: """Check if memory inspector is enabled.""" return self._is_enabled @@ -621,29 +623,13 @@ class MemoryObserver: need_print = self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0) if need_print: - cur_mem_allocated = self._normalize(torch.cuda.memory_allocated()) - max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated()) - cur_mem_cached = self._normalize(torch.cuda.memory_reserved()) - max_mem_cached = self._normalize(torch.cuda.max_memory_reserved()) - torch_mem_stat = torch.cuda.memory_stats() - cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0)) - max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0)) - - mem_stats = [ - ["phase", _convert_phase_to_string(cur_phase)], - ["allocated", cur_mem_allocated], # current memory allocated for tensors - ["max allocated", max_mem_allocated], # peak memory allocated for tensors - ["cached", cur_mem_cached], # current memory cached for the caching allocator - ["max cached", max_mem_cached], # peak memory cached for caching allocator. - ["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory - ["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory - ] - - summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})" - for stat in mem_stats: - summ += f" | {stat[0]}: {stat[1]}" - - self._logger.info(summ) + log_memory_usage( + _convert_phase_to_string(cur_phase), + rank_0_only=True, + step_info=f"step {self._current_step}", + logger=self._logger, + module=self._m, + ) if cur_phase == self._last_phase: self._increase_step() @@ -655,9 +641,6 @@ class MemoryObserver: def _increase_step(self): self._current_step += 1 - def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str: - return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}" - def display_memory_optimization_plans(self, memory_optimizer_config, details=False) -> Tuple[List[str], PTable]: mem_plan_count = len(self.cluster_id_combination_to_saving_symbolics_map) diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index b4a518d573..ecfb7d7907 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -12,6 +12,7 @@ from onnxruntime.training.utils.torch_io_helper import ( unflatten_data_using_schema, ) from onnxruntime.training.utils.torch_profile_utils import ( + log_memory_usage, nvtx_function_decorator, torch_nvtx_range_pop, torch_nvtx_range_push, @@ -31,6 +32,7 @@ __all__ = [ "torch_nvtx_range_push", "torch_nvtx_range_pop", "nvtx_function_decorator", + "log_memory_usage", "pytorch_type_to_onnx_dtype", "onnx_dtype_to_pytorch_dtype", "pytorch_scalar_type_to_pytorch_dtype", diff --git a/orttraining/orttraining/python/training/utils/torch_profile_utils.py b/orttraining/orttraining/python/training/utils/torch_profile_utils.py index 382d7dac14..9e8a41e0dc 100644 --- a/orttraining/orttraining/python/training/utils/torch_profile_utils.py +++ b/orttraining/orttraining/python/training/utils/torch_profile_utils.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + import torch @@ -26,3 +28,77 @@ def nvtx_function_decorator(func): return ret_val return wrapped_fn + + +def log_memory_usage(cur_phase: str, rank_0_only=True, step_info="", logger=None, module=None): + """Log memory usage for the current phase. + Args: + cur_phase (str): The current phase. + rank_0_only (bool, optional): Only log the memory usage for rank 0. Defaults to True. + step_info (str, optional): The step information. Defaults to "". + logger (logging.Logger, optional): The logger to log the memory usage. Defaults to None, which means print to stdout. + module (torch.nn.Module, optional): The module to get parameter, buffer and grad sizes. Defaults to None. + """ + rank = 0 + if rank_0_only is True: + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + if rank != 0: + return + + _normalizer_factor = float(1024 * 1024) + _normalizer_unit = "MiB" + + def _normalize(mem_size_in_bytes: float | int) -> str: + return f"{float(mem_size_in_bytes) / _normalizer_factor:.0f}" + + cur_mem_allocated = _normalize(torch.cuda.memory_allocated()) + max_mem_allocated = _normalize(torch.cuda.max_memory_allocated()) + cur_mem_cached = _normalize(torch.cuda.memory_reserved()) + max_mem_cached = _normalize(torch.cuda.max_memory_reserved()) + torch_mem_stat = torch.cuda.memory_stats() + cur_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0)) + max_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0)) + + mem_stats = [ + ["phase", cur_phase], + ["allocated", cur_mem_allocated], # current memory allocated for tensors + ["max allocated", max_mem_allocated], # peak memory allocated for tensors + ["cached", cur_mem_cached], # current memory cached for the caching allocator + ["max cached", max_mem_cached], # peak memory cached for caching allocator. + ["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory + ["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory + ] + + # Calculate the total size of parameters and gradients in the model + if module: + param_total_size = 0 + grad_total_size = 0 + for p in module.parameters(): + if p.is_cuda: + param_total_size += p.numel() * p.element_size() + if p.grad is not None and p.grad.is_cuda: + grad_total_size += p.grad.numel() * p.grad.element_size() + + # Calculate the total size of buffers in the model + buffer_total_size = 0 + for b in module.buffers(): + if b.is_cuda: + buffer_total_size += b.numel() * b.element_size() + + mem_stats.extend( + [ + ["param", _normalize(param_total_size)], + ["grad", _normalize(grad_total_size)], + ["buffer", _normalize(buffer_total_size)], + ] + ) + + summ = f"rank-{rank} {step_info} memory ({_normalizer_unit})" + for stat in mem_stats: + summ += f" | {stat[0]}: {stat[1]}" + + if logger is None: + print(summ) + else: + logger.info(summ)