mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Introduce memory observer for ORTModule (#16213)
### Introduce memory observer for ORTModule To analyze memory usage for ORTModule training, we need collect per-iteration memory footprint in different stages (pre-forward, post-forward, pre-backward, and post-backward). Currently we only collect the data using torch.cuda APIs. The next step is, we could collect the detailed stashed activation list and its percentage within ORT backend, which is beyond this PR. Sample as below: ``` 0/8] step 0 memory (MiB) | phase: pre_forward | allocated: 1866 | max allocated: 1866 | cached: 1874 | max cached: 1874 | inactive: 8 | max inactive: 8 [0/8] step 0 memory (MiB) | phase: post_forward | allocated: 23277 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 193 | max inactive: 405 [0/8] step 0 memory (MiB) | phase: pre_backward | allocated: 23277 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 193 | max inactive: 405 [0/8] step 0 memory (MiB) | phase: post_backward | allocated: 2932 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 6158 | max inactive: 6158 0%|█ | 1/200 [00:26<1:26:18, 26.02s/it] [0/8] step 1 memory (MiB) | phase: pre_forward | allocated: 2356 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 2454 | max inactive: 6165 [0/8] step 1 memory (MiB) | phase: post_forward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 1 memory (MiB) | phase: pre_backward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 1 memory (MiB) | phase: post_backward | allocated: 3422 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 5284 | max inactive: 6165 1%|██ | 2/200 [00:26<36:47, 11.15s/it] [0/8] step 2 memory (MiB) | phase: pre_forward | allocated: 2356 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2454 | max inactive: 6165 [0/8] step 2 memory (MiB) | phase: post_forward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 2 memory (MiB) | phase: pre_backward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 2 memory (MiB) | phase: post_backward | allocated: 3422 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 5284 | max inactive: 6165 ```
This commit is contained in:
parent
574e17ade4
commit
735a32fee1
5 changed files with 163 additions and 20 deletions
|
|
@ -150,7 +150,7 @@ based performance optimizations.
|
|||
#### ORTMODULE_PRINT_INPUT_DENSITY
|
||||
|
||||
- **Feature Area**: *ORTMODULE/RuntimeInspector*
|
||||
- **Description**: By default, this is disabled. This env var can be used for print the input data sparsity
|
||||
- **Description**: By default, this is disabled. This env var can be used for printing the input data sparsity
|
||||
inspection results to standard outputs.
|
||||
|
||||
```bash
|
||||
|
|
@ -158,6 +158,17 @@ inspection results to standard outputs.
|
|||
export ORTMODULE_PRINT_INPUT_DENSITY=0 # Disable
|
||||
```
|
||||
|
||||
#### ORTMODULE_PRINT_MEMORY_STATS
|
||||
|
||||
- **Feature Area**: *ORTMODULE/RuntimeInspector*
|
||||
- **Description**: By default, this is disabled. This env var can be used for printing the memory inspection results
|
||||
to standard outputs.
|
||||
|
||||
```bash
|
||||
export ORTMODULE_PRINT_MEMORY_STATS=1 # Enable
|
||||
export ORTMODULE_PRINT_MEMORY_STATS=0 # Disable
|
||||
```
|
||||
|
||||
### 2.2 Memory Optimization
|
||||
|
||||
Q: *Want to run a bigger batch size?*
|
||||
|
|
|
|||
|
|
@ -123,7 +123,6 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
)
|
||||
self._first_skip_check_warning = True
|
||||
|
||||
# Inspect embedding input index sparsity.
|
||||
self._rt_inspector = _runtime_inspector.RuntimeInspector(self._logger)
|
||||
|
||||
# Graph transformer config
|
||||
|
|
@ -205,7 +204,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
)
|
||||
|
||||
self._print_input_density = ortmodule._defined_from_envvar("ORTMODULE_PRINT_INPUT_DENSITY", 0, warn=True) == 1
|
||||
|
||||
self._print_memory_stat = ortmodule._defined_from_envvar("ORTMODULE_PRINT_MEMORY_STATS", 0, warn=True) == 1
|
||||
self._enable_memory_optimizer = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_CONFIG", "", warn=True)
|
||||
|
||||
# Flag to re-export the model due to attribute change on the original module.
|
||||
|
|
@ -582,7 +581,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
[
|
||||
" -FLOPReduction",
|
||||
"ON" if self._enable_compute_optimizer else "OFF",
|
||||
"Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0",
|
||||
"Reduce FLOPs by upstreaming shrinking-sized ops",
|
||||
],
|
||||
]
|
||||
)
|
||||
|
|
@ -638,6 +637,9 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
if not self._print_input_density:
|
||||
self._rt_inspector.disable_input_inspector()
|
||||
|
||||
if self._print_memory_stat:
|
||||
self._rt_inspector.enable_memory_inspector(self._original_module)
|
||||
|
||||
def _log_feature_stats(self):
|
||||
rank = 0
|
||||
if torch.distributed.is_initialized():
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from enum import IntEnum
|
||||
from logging import Logger
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
|
|
@ -12,21 +13,40 @@ from onnx import ModelProto, helper
|
|||
from onnx import onnx_pb as onnx_proto
|
||||
|
||||
|
||||
class Phase(IntEnum):
|
||||
INVALID = -1
|
||||
PRE_FORWARD = 0
|
||||
POST_FORWARD = 1
|
||||
PRE_BACKWARD = 2 # not applicable for inference
|
||||
POST_BACKWARD = 3 # not applicable for inference
|
||||
|
||||
|
||||
def _convert_phase_to_string(phase: Phase) -> str:
|
||||
if phase == Phase.PRE_FORWARD:
|
||||
return "pre_forward"
|
||||
elif phase == Phase.POST_FORWARD:
|
||||
return "post_forward"
|
||||
elif phase == Phase.PRE_BACKWARD:
|
||||
return "pre_backward"
|
||||
elif phase == Phase.POST_BACKWARD:
|
||||
return "post_backward"
|
||||
else:
|
||||
return "invalid"
|
||||
|
||||
|
||||
class RuntimeInspector:
|
||||
"""
|
||||
Runtime inspector for ORTModule.
|
||||
|
||||
Currently, it only wraps input density inspector.
|
||||
"""
|
||||
|
||||
def __init__(self, logger: Logger):
|
||||
self._logger = logger
|
||||
|
||||
self.input_density_ob: Union[InputDensityObserver, None] = None
|
||||
self.memory_ob: Union[MemoryObserver, None] = None
|
||||
|
||||
def enable_input_inspector(self, model: ModelProto, user_input_names: List[str]) -> None:
|
||||
"""
|
||||
Initialize input inspector from the given ONNX model and user input names.
|
||||
"""Initialize input inspector from the given ONNX model and user input names.
|
||||
|
||||
Args:
|
||||
model: ONNX model.
|
||||
|
|
@ -41,8 +61,7 @@ class RuntimeInspector:
|
|||
return self.input_density_ob.initialize(model, user_input_names)
|
||||
|
||||
def inspect_input(self, input_name, input_data) -> Tuple[bool, float, float]:
|
||||
"""
|
||||
Inspect input data and print statistics.
|
||||
"""Inspect input data and print statistics.
|
||||
|
||||
Args:
|
||||
input_name: User input name.
|
||||
|
|
@ -63,10 +82,29 @@ class RuntimeInspector:
|
|||
"""Disable input density inspector."""
|
||||
self.input_density_ob = None
|
||||
|
||||
def enable_memory_inspector(self, module: torch.nn.Module):
|
||||
"""Enable memory inspector for ORTModule.
|
||||
|
||||
Args:
|
||||
module: ORTModule.
|
||||
"""
|
||||
if self.memory_ob is None:
|
||||
self.memory_ob = MemoryObserver(module, self._logger)
|
||||
else:
|
||||
raise RuntimeError("Memory observer is already enabled.")
|
||||
|
||||
def inspect_memory(self, phase: Phase) -> None:
|
||||
"""Inspect memory usage and print statistics.
|
||||
|
||||
Args:
|
||||
phase: Phase to inspect.
|
||||
"""
|
||||
if self.memory_ob is not None:
|
||||
self.memory_ob.inspect_memory(phase)
|
||||
|
||||
|
||||
class InputDensityObserver:
|
||||
"""
|
||||
Training input data observer for ORTModule.
|
||||
"""Training input data observer for ORTModule.
|
||||
|
||||
Data observer is used to collect data/compute sparsity information for embedding and label inputs. It needs to be
|
||||
firstly initialized with the ONNX model and user input names. Then, it can be used to inspect the input data
|
||||
|
|
@ -89,8 +127,7 @@ class InputDensityObserver:
|
|||
self._tensor_to_node_map = {}
|
||||
|
||||
def initialize(self, model: ModelProto, user_input_names: List[str]) -> None:
|
||||
"""
|
||||
Initialize data observer from the given ONNX model and user input names.
|
||||
"""Initialize data observer from the given ONNX model and user input names.
|
||||
|
||||
For embedding input (e.g. ATen embedding), try to parse the padding_idx from the ONNX model, if padding_idx is
|
||||
valid, register it in _embedding_graph_input_to_padding_idx_map.
|
||||
|
|
@ -283,8 +320,7 @@ class InputDensityObserver:
|
|||
)
|
||||
|
||||
def inspect_from_input_data(self, name: str, inp) -> Tuple[bool, float, float]:
|
||||
"""
|
||||
Inspect input data and print statistics.
|
||||
"""Inspect input data and print statistics.
|
||||
|
||||
Args:
|
||||
name: User input name.
|
||||
|
|
@ -379,7 +415,7 @@ class InputDensityObserver:
|
|||
stat += "\t| {:<10} | {:<10} | {:<15} | {:<10} | {:<10} | {:<15} | {:<15} | {:<15} |\n".format(
|
||||
"STEP",
|
||||
"INPUT TYPE",
|
||||
" INPUT NAME",
|
||||
"INPUT NAME",
|
||||
"PAD IDX",
|
||||
"DENSITY",
|
||||
"VALID TOKENS",
|
||||
|
|
@ -422,3 +458,87 @@ class InputDensityObserver:
|
|||
return None
|
||||
value = onnx.numpy_helper.to_array(tensor)
|
||||
return value
|
||||
|
||||
|
||||
class MemoryObserver:
|
||||
"""Memory inspector across the training lifetime.
|
||||
|
||||
On different training/inference phases, `inspect_memory` is called to print out the memory usage, including
|
||||
current/peak memory usage, current/peak inactive and non-releasable memory.
|
||||
"""
|
||||
|
||||
NORMALIZER_FACTOR = float(1024 * 1024)
|
||||
NORMALIZER_UNIT = "MiB"
|
||||
|
||||
def __init__(self, m: torch.nn.Module, logger: Logger):
|
||||
self._logger = logger
|
||||
self._current_step = 0
|
||||
self._rank = 0
|
||||
self._world_size = 1
|
||||
if torch.distributed.is_initialized():
|
||||
self._rank = torch.distributed.get_rank()
|
||||
self._world_size = torch.distributed.get_world_size()
|
||||
|
||||
self._rank_info = f"[{self._rank}/{self._world_size}]"
|
||||
self._pre_phase = Phase.INVALID
|
||||
self._last_phase = Phase.POST_BACKWARD if m.training else Phase.POST_FORWARD
|
||||
|
||||
self._is_first_inspect = True
|
||||
|
||||
def inspect_memory(self, cur_phase: Phase):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
if self._is_first_inspect:
|
||||
# Clean the memory cache and memory stats before the first time run forward pass, FOR EVERY RANK.
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self._is_first_inspect = False
|
||||
|
||||
if self._rank != 0:
|
||||
return
|
||||
|
||||
if cur_phase < Phase.PRE_FORWARD or cur_phase > self._last_phase:
|
||||
raise RuntimeError(f"Invalid phase detected: {cur_phase}")
|
||||
|
||||
if (cur_phase - self._pre_phase) != 1:
|
||||
raise RuntimeError(f"Invalid phase transition detected: {self._pre_phase} -> {cur_phase}")
|
||||
|
||||
cur_mem_allocated = self._normalize(torch.cuda.memory_allocated())
|
||||
max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated())
|
||||
cur_mem_cached = self._normalize(torch.cuda.memory_reserved())
|
||||
max_mem_cached = self._normalize(torch.cuda.max_memory_reserved())
|
||||
torch_mem_stat = torch.cuda.memory_stats()
|
||||
cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0))
|
||||
max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0))
|
||||
|
||||
mem_stats = [
|
||||
["phase", _convert_phase_to_string(cur_phase)],
|
||||
["allocated", cur_mem_allocated], # current memory alloeated for tensors
|
||||
["max allocated", max_mem_allocated], # peak memory allocated for tensors
|
||||
["cached", cur_mem_cached], # current memory cached for caching allocator
|
||||
["max cached", max_mem_cached], # peak memory cached for caching allocator.
|
||||
["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory
|
||||
["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory
|
||||
]
|
||||
|
||||
summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})"
|
||||
for stat in mem_stats:
|
||||
summ += f" | {stat[0]}: {stat[1]}"
|
||||
|
||||
# For the 10+ steps, only print when it is power of 2.
|
||||
if self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0):
|
||||
self._logger.info(summ)
|
||||
|
||||
if cur_phase == self._last_phase:
|
||||
self._increase_step()
|
||||
self._pre_phase = Phase.INVALID
|
||||
return
|
||||
|
||||
self._pre_phase = cur_phase
|
||||
|
||||
def _increase_step(self):
|
||||
self._current_step += 1
|
||||
|
||||
def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str:
|
||||
return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}"
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPo
|
|||
from ._gradient_accumulation_manager import GradientAccumulationManager
|
||||
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo, _SkipCheck
|
||||
from ._io import _FlattenedModule, _InputInfo
|
||||
from ._runtime_inspector import Phase
|
||||
from .debug_options import DebugOptions
|
||||
|
||||
|
||||
|
|
@ -103,6 +104,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
|
||||
Module outputs are returned to the user
|
||||
"""
|
||||
self._rt_inspector.inspect_memory(Phase.PRE_FORWARD)
|
||||
|
||||
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
|
||||
# Assert that the input and model device match
|
||||
|
|
@ -137,12 +139,16 @@ class TrainingManager(GraphExecutionManager):
|
|||
for idx in self._graph_info.output_grad_indices_non_differentiable:
|
||||
ctx.mark_non_differentiable(user_outputs[idx])
|
||||
|
||||
self._rt_inspector.inspect_memory(Phase.POST_FORWARD)
|
||||
|
||||
return user_outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
"""Performs backward pass based on grad wrt module output"""
|
||||
|
||||
self._rt_inspector.inspect_memory(Phase.PRE_BACKWARD)
|
||||
|
||||
assert ctx.run_info is not None, "forward() or __call__() methods must be called before backward()"
|
||||
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
|
||||
_utils._check_same_device(self._device, "Input argument to backward", *grad_outputs)
|
||||
|
|
@ -190,8 +196,11 @@ class TrainingManager(GraphExecutionManager):
|
|||
|
||||
# Fast version: all backward_outputs are converted first.
|
||||
# This version only works if backward_outputs is an OrtValueVector.
|
||||
transfered_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)
|
||||
return tuple(transfered_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
|
||||
transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)
|
||||
|
||||
self._rt_inspector.inspect_memory(Phase.POST_BACKWARD)
|
||||
|
||||
return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
|
||||
|
||||
return _ORTModuleFunction
|
||||
|
||||
|
|
|
|||
|
|
@ -153,7 +153,8 @@ class StatisticsSubscriber(SubscriberBase):
|
|||
min_value = min_buckets.min()
|
||||
max_value = max_buckets.max()
|
||||
mean_value = float(mean_buckets.sum()) / float(element_count)
|
||||
# Here we refer https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups
|
||||
# Here we refer to
|
||||
# https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups
|
||||
# to calculate the combined standard deviation of all buckets.
|
||||
s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * (
|
||||
(mean_buckets - mean_value) ** 2
|
||||
|
|
|
|||
Loading…
Reference in a new issue