diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 2dd7f21e6c..d6eca4c1e6 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -150,7 +150,7 @@ based performance optimizations. #### ORTMODULE_PRINT_INPUT_DENSITY - **Feature Area**: *ORTMODULE/RuntimeInspector* -- **Description**: By default, this is disabled. This env var can be used for print the input data sparsity +- **Description**: By default, this is disabled. This env var can be used for printing the input data sparsity inspection results to standard outputs. ```bash @@ -158,6 +158,17 @@ inspection results to standard outputs. export ORTMODULE_PRINT_INPUT_DENSITY=0 # Disable ``` +#### ORTMODULE_PRINT_MEMORY_STATS + +- **Feature Area**: *ORTMODULE/RuntimeInspector* +- **Description**: By default, this is disabled. This env var can be used for printing the memory inspection results +to standard outputs. + + ```bash + export ORTMODULE_PRINT_MEMORY_STATS=1 # Enable + export ORTMODULE_PRINT_MEMORY_STATS=0 # Disable + ``` + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 1b164d6104..b2ad5567bb 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -123,7 +123,6 @@ class GraphExecutionManager(GraphExecutionInterface): ) self._first_skip_check_warning = True - # Inspect embedding input index sparsity. self._rt_inspector = _runtime_inspector.RuntimeInspector(self._logger) # Graph transformer config @@ -205,7 +204,7 @@ class GraphExecutionManager(GraphExecutionInterface): ) self._print_input_density = ortmodule._defined_from_envvar("ORTMODULE_PRINT_INPUT_DENSITY", 0, warn=True) == 1 - + self._print_memory_stat = ortmodule._defined_from_envvar("ORTMODULE_PRINT_MEMORY_STATS", 0, warn=True) == 1 self._enable_memory_optimizer = ortmodule._defined_from_envvar("ORTMODULE_MEMORY_OPT_CONFIG", "", warn=True) # Flag to re-export the model due to attribute change on the original module. @@ -582,7 +581,7 @@ class GraphExecutionManager(GraphExecutionInterface): [ " -FLOPReduction", "ON" if self._enable_compute_optimizer else "OFF", - "Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0", + "Reduce FLOPs by upstreaming shrinking-sized ops", ], ] ) @@ -638,6 +637,9 @@ class GraphExecutionManager(GraphExecutionInterface): if not self._print_input_density: self._rt_inspector.disable_input_inspector() + if self._print_memory_stat: + self._rt_inspector.enable_memory_inspector(self._original_module) + def _log_feature_stats(self): rank = 0 if torch.distributed.is_initialized(): diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index d847d490c7..a756527f09 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from enum import IntEnum from logging import Logger from typing import List, Tuple, Union @@ -12,21 +13,40 @@ from onnx import ModelProto, helper from onnx import onnx_pb as onnx_proto +class Phase(IntEnum): + INVALID = -1 + PRE_FORWARD = 0 + POST_FORWARD = 1 + PRE_BACKWARD = 2 # not applicable for inference + POST_BACKWARD = 3 # not applicable for inference + + +def _convert_phase_to_string(phase: Phase) -> str: + if phase == Phase.PRE_FORWARD: + return "pre_forward" + elif phase == Phase.POST_FORWARD: + return "post_forward" + elif phase == Phase.PRE_BACKWARD: + return "pre_backward" + elif phase == Phase.POST_BACKWARD: + return "post_backward" + else: + return "invalid" + + class RuntimeInspector: """ Runtime inspector for ORTModule. - - Currently, it only wraps input density inspector. """ def __init__(self, logger: Logger): self._logger = logger self.input_density_ob: Union[InputDensityObserver, None] = None + self.memory_ob: Union[MemoryObserver, None] = None def enable_input_inspector(self, model: ModelProto, user_input_names: List[str]) -> None: - """ - Initialize input inspector from the given ONNX model and user input names. + """Initialize input inspector from the given ONNX model and user input names. Args: model: ONNX model. @@ -41,8 +61,7 @@ class RuntimeInspector: return self.input_density_ob.initialize(model, user_input_names) def inspect_input(self, input_name, input_data) -> Tuple[bool, float, float]: - """ - Inspect input data and print statistics. + """Inspect input data and print statistics. Args: input_name: User input name. @@ -63,10 +82,29 @@ class RuntimeInspector: """Disable input density inspector.""" self.input_density_ob = None + def enable_memory_inspector(self, module: torch.nn.Module): + """Enable memory inspector for ORTModule. + + Args: + module: ORTModule. + """ + if self.memory_ob is None: + self.memory_ob = MemoryObserver(module, self._logger) + else: + raise RuntimeError("Memory observer is already enabled.") + + def inspect_memory(self, phase: Phase) -> None: + """Inspect memory usage and print statistics. + + Args: + phase: Phase to inspect. + """ + if self.memory_ob is not None: + self.memory_ob.inspect_memory(phase) + class InputDensityObserver: - """ - Training input data observer for ORTModule. + """Training input data observer for ORTModule. Data observer is used to collect data/compute sparsity information for embedding and label inputs. It needs to be firstly initialized with the ONNX model and user input names. Then, it can be used to inspect the input data @@ -89,8 +127,7 @@ class InputDensityObserver: self._tensor_to_node_map = {} def initialize(self, model: ModelProto, user_input_names: List[str]) -> None: - """ - Initialize data observer from the given ONNX model and user input names. + """Initialize data observer from the given ONNX model and user input names. For embedding input (e.g. ATen embedding), try to parse the padding_idx from the ONNX model, if padding_idx is valid, register it in _embedding_graph_input_to_padding_idx_map. @@ -283,8 +320,7 @@ class InputDensityObserver: ) def inspect_from_input_data(self, name: str, inp) -> Tuple[bool, float, float]: - """ - Inspect input data and print statistics. + """Inspect input data and print statistics. Args: name: User input name. @@ -379,7 +415,7 @@ class InputDensityObserver: stat += "\t| {:<10} | {:<10} | {:<15} | {:<10} | {:<10} | {:<15} | {:<15} | {:<15} |\n".format( "STEP", "INPUT TYPE", - " INPUT NAME", + "INPUT NAME", "PAD IDX", "DENSITY", "VALID TOKENS", @@ -422,3 +458,87 @@ class InputDensityObserver: return None value = onnx.numpy_helper.to_array(tensor) return value + + +class MemoryObserver: + """Memory inspector across the training lifetime. + + On different training/inference phases, `inspect_memory` is called to print out the memory usage, including + current/peak memory usage, current/peak inactive and non-releasable memory. + """ + + NORMALIZER_FACTOR = float(1024 * 1024) + NORMALIZER_UNIT = "MiB" + + def __init__(self, m: torch.nn.Module, logger: Logger): + self._logger = logger + self._current_step = 0 + self._rank = 0 + self._world_size = 1 + if torch.distributed.is_initialized(): + self._rank = torch.distributed.get_rank() + self._world_size = torch.distributed.get_world_size() + + self._rank_info = f"[{self._rank}/{self._world_size}]" + self._pre_phase = Phase.INVALID + self._last_phase = Phase.POST_BACKWARD if m.training else Phase.POST_FORWARD + + self._is_first_inspect = True + + def inspect_memory(self, cur_phase: Phase): + if not torch.cuda.is_available(): + return + + if self._is_first_inspect: + # Clean the memory cache and memory stats before the first time run forward pass, FOR EVERY RANK. + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + self._is_first_inspect = False + + if self._rank != 0: + return + + if cur_phase < Phase.PRE_FORWARD or cur_phase > self._last_phase: + raise RuntimeError(f"Invalid phase detected: {cur_phase}") + + if (cur_phase - self._pre_phase) != 1: + raise RuntimeError(f"Invalid phase transition detected: {self._pre_phase} -> {cur_phase}") + + cur_mem_allocated = self._normalize(torch.cuda.memory_allocated()) + max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated()) + cur_mem_cached = self._normalize(torch.cuda.memory_reserved()) + max_mem_cached = self._normalize(torch.cuda.max_memory_reserved()) + torch_mem_stat = torch.cuda.memory_stats() + cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0)) + max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0)) + + mem_stats = [ + ["phase", _convert_phase_to_string(cur_phase)], + ["allocated", cur_mem_allocated], # current memory alloeated for tensors + ["max allocated", max_mem_allocated], # peak memory allocated for tensors + ["cached", cur_mem_cached], # current memory cached for caching allocator + ["max cached", max_mem_cached], # peak memory cached for caching allocator. + ["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory + ["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory + ] + + summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})" + for stat in mem_stats: + summ += f" | {stat[0]}: {stat[1]}" + + # For the 10+ steps, only print when it is power of 2. + if self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0): + self._logger.info(summ) + + if cur_phase == self._last_phase: + self._increase_step() + self._pre_phase = Phase.INVALID + return + + self._pre_phase = cur_phase + + def _increase_step(self): + self._current_step += 1 + + def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str: + return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}" diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index a6755551c6..10edadf8f9 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -18,6 +18,7 @@ from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPo from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo, _SkipCheck from ._io import _FlattenedModule, _InputInfo +from ._runtime_inspector import Phase from .debug_options import DebugOptions @@ -103,6 +104,7 @@ class TrainingManager(GraphExecutionManager): Module outputs are returned to the user """ + self._rt_inspector.inspect_memory(Phase.PRE_FORWARD) if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: # Assert that the input and model device match @@ -137,12 +139,16 @@ class TrainingManager(GraphExecutionManager): for idx in self._graph_info.output_grad_indices_non_differentiable: ctx.mark_non_differentiable(user_outputs[idx]) + self._rt_inspector.inspect_memory(Phase.POST_FORWARD) + return user_outputs @staticmethod def backward(ctx, *grad_outputs): """Performs backward pass based on grad wrt module output""" + self._rt_inspector.inspect_memory(Phase.PRE_BACKWARD) + assert ctx.run_info is not None, "forward() or __call__() methods must be called before backward()" if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: _utils._check_same_device(self._device, "Input argument to backward", *grad_outputs) @@ -190,8 +196,11 @@ class TrainingManager(GraphExecutionManager): # Fast version: all backward_outputs are converted first. # This version only works if backward_outputs is an OrtValueVector. - transfered_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) - return tuple(transfered_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) + transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) + + self._rt_inspector.inspect_memory(Phase.POST_BACKWARD) + + return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) return _ORTModuleFunction diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py index 6cec29a308..0dd06eee13 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py @@ -153,7 +153,8 @@ class StatisticsSubscriber(SubscriberBase): min_value = min_buckets.min() max_value = max_buckets.max() mean_value = float(mean_buckets.sum()) / float(element_count) - # Here we refer https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups + # Here we refer to + # https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups # to calculate the combined standard deviation of all buckets. s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * ( (mean_buckets - mean_value) ** 2