Introduce memory observer for ORTModule (#16213)

### Introduce memory observer for ORTModule

To analyze memory usage for ORTModule training, we need collect
per-iteration memory footprint in different stages (pre-forward,
post-forward, pre-backward, and post-backward).

Currently we only collect the data using torch.cuda APIs. The next step
is, we could collect the detailed stashed activation list and its
percentage within ORT backend, which is beyond this PR.

Sample as below: 
```
0/8] step 0 memory (MiB) | phase: pre_forward | allocated: 1866 | max allocated: 1866 | cached: 1874 | max cached: 1874 | inactive: 8 | max inactive: 8
[0/8] step 0 memory (MiB) | phase: post_forward | allocated: 23277 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 193 | max inactive: 405
[0/8] step 0 memory (MiB) | phase: pre_backward | allocated: 23277 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 193 | max inactive: 405
[0/8] step 0 memory (MiB) | phase: post_backward | allocated: 2932 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 6158 | max inactive: 6158
  0%|█                                                                                                                                                                                                            | 1/200 [00:26<1:26:18, 26.02s/it]
[0/8] step 1 memory (MiB) | phase: pre_forward | allocated: 2356 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 2454 | max inactive: 6165
[0/8] step 1 memory (MiB) | phase: post_forward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165
[0/8] step 1 memory (MiB) | phase: pre_backward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165
[0/8] step 1 memory (MiB) | phase: post_backward | allocated: 3422 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 5284 | max inactive: 6165
  1%|██                                                                                                                                                                                                             | 2/200 [00:26<36:47, 11.15s/it]
[0/8] step 2 memory (MiB) | phase: pre_forward | allocated: 2356 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2454 | max inactive: 6165
[0/8] step 2 memory (MiB) | phase: post_forward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165
[0/8] step 2 memory (MiB) | phase: pre_backward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165
[0/8] step 2 memory (MiB) | phase: post_backward | allocated: 3422 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 5284 | max inactive: 6165
```
This commit is contained in:
pengwa 2023-06-15 15:45:36 +08:00 committed by GitHub
parent 574e17ade4
commit 735a32fee1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 163 additions and 20 deletions

View file

@ -150,7 +150,7 @@ based performance optimizations.
#### ORTMODULE_PRINT_INPUT_DENSITY
- **Feature Area**: *ORTMODULE/RuntimeInspector*
- **Description**: By default, this is disabled. This env var can be used for print the input data sparsity
- **Description**: By default, this is disabled. This env var can be used for printing the input data sparsity
inspection results to standard outputs.
```bash
@ -158,6 +158,17 @@ inspection results to standard outputs.
export ORTMODULE_PRINT_INPUT_DENSITY=0 # Disable
```
#### ORTMODULE_PRINT_MEMORY_STATS
- **Feature Area**: *ORTMODULE/RuntimeInspector*
- **Description**: By default, this is disabled. This env var can be used for printing the memory inspection results
to standard outputs.
```bash
export ORTMODULE_PRINT_MEMORY_STATS=1 # Enable
export ORTMODULE_PRINT_MEMORY_STATS=0 # Disable
```
### 2.2 Memory Optimization
Q: *Want to run a bigger batch size?*

View file

@ -123,7 +123,6 @@ class GraphExecutionManager(GraphExecutionInterface):
)
self._first_skip_check_warning = True
# Inspect embedding input index sparsity.
self._rt_inspector = _runtime_inspector.RuntimeInspector(self._logger)
# Graph transformer config
@ -205,7 +204,7 @@ class GraphExecutionManager(GraphExecutionInterface):
)
self._print_input_density = ortmodule._defined_from_envvar("ORTMODULE_PRINT_INPUT_DENSITY", 0, warn=True) == 1
self._print_memory_stat = ortmodule._defined_from_envvar("ORTMODULE_PRINT_MEMORY_STATS", 0, warn=True) == 1
self._enable_memory_optimizer = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_CONFIG", "", warn=True)
# Flag to re-export the model due to attribute change on the original module.
@ -582,7 +581,7 @@ class GraphExecutionManager(GraphExecutionInterface):
[
" -FLOPReduction",
"ON" if self._enable_compute_optimizer else "OFF",
"Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0",
"Reduce FLOPs by upstreaming shrinking-sized ops",
],
]
)
@ -638,6 +637,9 @@ class GraphExecutionManager(GraphExecutionInterface):
if not self._print_input_density:
self._rt_inspector.disable_input_inspector()
if self._print_memory_stat:
self._rt_inspector.enable_memory_inspector(self._original_module)
def _log_feature_stats(self):
rank = 0
if torch.distributed.is_initialized():

View file

@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from enum import IntEnum
from logging import Logger
from typing import List, Tuple, Union
@ -12,21 +13,40 @@ from onnx import ModelProto, helper
from onnx import onnx_pb as onnx_proto
class Phase(IntEnum):
INVALID = -1
PRE_FORWARD = 0
POST_FORWARD = 1
PRE_BACKWARD = 2 # not applicable for inference
POST_BACKWARD = 3 # not applicable for inference
def _convert_phase_to_string(phase: Phase) -> str:
if phase == Phase.PRE_FORWARD:
return "pre_forward"
elif phase == Phase.POST_FORWARD:
return "post_forward"
elif phase == Phase.PRE_BACKWARD:
return "pre_backward"
elif phase == Phase.POST_BACKWARD:
return "post_backward"
else:
return "invalid"
class RuntimeInspector:
"""
Runtime inspector for ORTModule.
Currently, it only wraps input density inspector.
"""
def __init__(self, logger: Logger):
self._logger = logger
self.input_density_ob: Union[InputDensityObserver, None] = None
self.memory_ob: Union[MemoryObserver, None] = None
def enable_input_inspector(self, model: ModelProto, user_input_names: List[str]) -> None:
"""
Initialize input inspector from the given ONNX model and user input names.
"""Initialize input inspector from the given ONNX model and user input names.
Args:
model: ONNX model.
@ -41,8 +61,7 @@ class RuntimeInspector:
return self.input_density_ob.initialize(model, user_input_names)
def inspect_input(self, input_name, input_data) -> Tuple[bool, float, float]:
"""
Inspect input data and print statistics.
"""Inspect input data and print statistics.
Args:
input_name: User input name.
@ -63,10 +82,29 @@ class RuntimeInspector:
"""Disable input density inspector."""
self.input_density_ob = None
def enable_memory_inspector(self, module: torch.nn.Module):
"""Enable memory inspector for ORTModule.
Args:
module: ORTModule.
"""
if self.memory_ob is None:
self.memory_ob = MemoryObserver(module, self._logger)
else:
raise RuntimeError("Memory observer is already enabled.")
def inspect_memory(self, phase: Phase) -> None:
"""Inspect memory usage and print statistics.
Args:
phase: Phase to inspect.
"""
if self.memory_ob is not None:
self.memory_ob.inspect_memory(phase)
class InputDensityObserver:
"""
Training input data observer for ORTModule.
"""Training input data observer for ORTModule.
Data observer is used to collect data/compute sparsity information for embedding and label inputs. It needs to be
firstly initialized with the ONNX model and user input names. Then, it can be used to inspect the input data
@ -89,8 +127,7 @@ class InputDensityObserver:
self._tensor_to_node_map = {}
def initialize(self, model: ModelProto, user_input_names: List[str]) -> None:
"""
Initialize data observer from the given ONNX model and user input names.
"""Initialize data observer from the given ONNX model and user input names.
For embedding input (e.g. ATen embedding), try to parse the padding_idx from the ONNX model, if padding_idx is
valid, register it in _embedding_graph_input_to_padding_idx_map.
@ -283,8 +320,7 @@ class InputDensityObserver:
)
def inspect_from_input_data(self, name: str, inp) -> Tuple[bool, float, float]:
"""
Inspect input data and print statistics.
"""Inspect input data and print statistics.
Args:
name: User input name.
@ -379,7 +415,7 @@ class InputDensityObserver:
stat += "\t| {:<10} | {:<10} | {:<15} | {:<10} | {:<10} | {:<15} | {:<15} | {:<15} |\n".format(
"STEP",
"INPUT TYPE",
" INPUT NAME",
"INPUT NAME",
"PAD IDX",
"DENSITY",
"VALID TOKENS",
@ -422,3 +458,87 @@ class InputDensityObserver:
return None
value = onnx.numpy_helper.to_array(tensor)
return value
class MemoryObserver:
"""Memory inspector across the training lifetime.
On different training/inference phases, `inspect_memory` is called to print out the memory usage, including
current/peak memory usage, current/peak inactive and non-releasable memory.
"""
NORMALIZER_FACTOR = float(1024 * 1024)
NORMALIZER_UNIT = "MiB"
def __init__(self, m: torch.nn.Module, logger: Logger):
self._logger = logger
self._current_step = 0
self._rank = 0
self._world_size = 1
if torch.distributed.is_initialized():
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._rank_info = f"[{self._rank}/{self._world_size}]"
self._pre_phase = Phase.INVALID
self._last_phase = Phase.POST_BACKWARD if m.training else Phase.POST_FORWARD
self._is_first_inspect = True
def inspect_memory(self, cur_phase: Phase):
if not torch.cuda.is_available():
return
if self._is_first_inspect:
# Clean the memory cache and memory stats before the first time run forward pass, FOR EVERY RANK.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
self._is_first_inspect = False
if self._rank != 0:
return
if cur_phase < Phase.PRE_FORWARD or cur_phase > self._last_phase:
raise RuntimeError(f"Invalid phase detected: {cur_phase}")
if (cur_phase - self._pre_phase) != 1:
raise RuntimeError(f"Invalid phase transition detected: {self._pre_phase} -> {cur_phase}")
cur_mem_allocated = self._normalize(torch.cuda.memory_allocated())
max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated())
cur_mem_cached = self._normalize(torch.cuda.memory_reserved())
max_mem_cached = self._normalize(torch.cuda.max_memory_reserved())
torch_mem_stat = torch.cuda.memory_stats()
cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0))
max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0))
mem_stats = [
["phase", _convert_phase_to_string(cur_phase)],
["allocated", cur_mem_allocated], # current memory alloeated for tensors
["max allocated", max_mem_allocated], # peak memory allocated for tensors
["cached", cur_mem_cached], # current memory cached for caching allocator
["max cached", max_mem_cached], # peak memory cached for caching allocator.
["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory
["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory
]
summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})"
for stat in mem_stats:
summ += f" | {stat[0]}: {stat[1]}"
# For the 10+ steps, only print when it is power of 2.
if self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0):
self._logger.info(summ)
if cur_phase == self._last_phase:
self._increase_step()
self._pre_phase = Phase.INVALID
return
self._pre_phase = cur_phase
def _increase_step(self):
self._current_step += 1
def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str:
return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}"

View file

@ -18,6 +18,7 @@ from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPo
from ._gradient_accumulation_manager import GradientAccumulationManager
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo, _SkipCheck
from ._io import _FlattenedModule, _InputInfo
from ._runtime_inspector import Phase
from .debug_options import DebugOptions
@ -103,6 +104,7 @@ class TrainingManager(GraphExecutionManager):
Module outputs are returned to the user
"""
self._rt_inspector.inspect_memory(Phase.PRE_FORWARD)
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
# Assert that the input and model device match
@ -137,12 +139,16 @@ class TrainingManager(GraphExecutionManager):
for idx in self._graph_info.output_grad_indices_non_differentiable:
ctx.mark_non_differentiable(user_outputs[idx])
self._rt_inspector.inspect_memory(Phase.POST_FORWARD)
return user_outputs
@staticmethod
def backward(ctx, *grad_outputs):
"""Performs backward pass based on grad wrt module output"""
self._rt_inspector.inspect_memory(Phase.PRE_BACKWARD)
assert ctx.run_info is not None, "forward() or __call__() methods must be called before backward()"
if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False:
_utils._check_same_device(self._device, "Input argument to backward", *grad_outputs)
@ -190,8 +196,11 @@ class TrainingManager(GraphExecutionManager):
# Fast version: all backward_outputs are converted first.
# This version only works if backward_outputs is an OrtValueVector.
transfered_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)
return tuple(transfered_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)
self._rt_inspector.inspect_memory(Phase.POST_BACKWARD)
return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
return _ORTModuleFunction

View file

@ -153,7 +153,8 @@ class StatisticsSubscriber(SubscriberBase):
min_value = min_buckets.min()
max_value = max_buckets.max()
mean_value = float(mean_buckets.sum()) / float(element_count)
# Here we refer https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups
# Here we refer to
# https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups
# to calculate the combined standard deviation of all buckets.
s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * (
(mean_buckets - mean_value) ** 2