diff --git a/docs/ORTModule_Convergence_Notes.md b/docs/ORTModule_Convergence_Notes.md index be28ef23d4..8f54fd6b5a 100644 --- a/docs/ORTModule_Convergence_Notes.md +++ b/docs/ORTModule_Convergence_Notes.md @@ -18,7 +18,8 @@ Before looking into this further, we should clarify a few things (if possible): ## 2. Collect Activation Statistics -### Add a few lines of code, run script to collect statistics: + +### 2.1 Use `GlobalSubscriberManager` to collect `nn.Module` forward() outputs
| -```diff -+ from onnxruntime.training.utils.hooks import SubscriberManager, -+ StatisticsSubscriber -+ sub_m = SubscriberManager() -+ sub_m.subscribe(model, [StatisticsSubscriber(output_dir="pt_out", -+ override_output_dir=True)]) +```python +from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber +GlobalSubscriberManager.subscribe( + model, [StatisticsSubscriber(output_dir="pt_out", override_output_dir=True)] +) ``` @@ -42,13 +42,12 @@ Before looking into this further, we should clarify a few things (if possible): | -```diff +```python model = ORTModule(model) -+ from onnxruntime.training.utils.hooks import SubscriberManager, -+ StatisticsSubscriber -+ sub_m = SubscriberManager() -+ sub_m.subscribe(model, [StatisticsSubscriber(output_dir="ort_out", -+ override_output_dir=True)]) +from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber +GlobalSubscriberManager.subscribe( + model, [StatisticsSubscriber(output_dir="ort_out", override_output_dir=True)] +) ``` @@ -77,16 +76,64 @@ model = ORTModule(model) Arguments: - output_dir: the directory in all activation statistics files will be stored. -- start_step [optional]: the first step that runs subscriber actions. -- end_step [optional]: the end step (exclusively) that runs subscriber actions. -- override_output_dir: whether `output_dir` can be overridden if it already exists. +- `start_step` [optional]: the first step that runs subscriber actions. +- `end_step` [optional]: the end step (exclusively) that runs subscriber actions. +- `override_output_dir`: whether `output_dir` can be overridden if it already exists. +- `run_on_cpu`: whether to run the subscriber actions on CPU, this should be the last resort when inserted + inspector node affects memory peak causing the original recipe run to fail with OOM. +- `bucket_size`: the size of the bucket to split the statistic calculation. -Check [StatisticsSubscriber implementation](../orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py) for more information. +### 2.2 Use `_InspectActivation` to collect intermediate tensors in a `nn.Module` forward() -### Run command to generate per-step summary +The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output tensors will be dumped, if you want to +dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example: + +```diff +class BloomForCausalLM(BloomPreTrainedModel): + def __init__(self, config: BloomConfig): + ... + + def forward(self, input_ids, ...): + ... + transformer_outputs = self.transformer(...) + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) ++ lm_logits = _InspectActivation.apply("lm_logits", None, GlobalSubscriberManager.get_run_context(), lm_logits) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() ++ shift_logits = _InspectActivation.apply("shift_logits", None, GlobalSubscriberManager.get_run_context(), shift_logits) + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + ) + + return loss +``` + +Be noted, make sure the activation name (as the first argument of `_InspectActivation.apply`) is unique, otherwise +stat file using the activation name will be overwritten by the last write. The dumped data are stored in the `output_dir`. + + +### 2.3 Collect on multiple ranks + +`GlobalSubscriberManager` does not explicitly handle the racing condition when multiple ranks write into the same file path, +here is the example if you want to collect statistics on multiple ranks: + +```python +from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber +GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir="ort_out_" + str(torch.distributed.get_rank()), + override_output_dir=True)]) +``` + +Check [StatisticsSubscriber implementation](../orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py) for more information. + +### 2.4 Run command to generate per-step summary ```bash python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output ``` -### Manually compare the generated per-step summary to find the first big diff. +### 2.5 Manually compare the generated per-step summary to find the first big diff. diff --git a/orttraining/orttraining/python/training/utils/hooks/__init__.py b/orttraining/orttraining/python/training/utils/hooks/__init__.py index 34e4a5bd04..91e919b1c5 100644 --- a/orttraining/orttraining/python/training/utils/hooks/__init__.py +++ b/orttraining/orttraining/python/training/utils/hooks/__init__.py @@ -5,8 +5,12 @@ __all__ = [ "StatisticsSubscriber", - "SubscriberManager", + "GlobalSubscriberManager", + "_InspectActivation", ] from ._statistics_subscriber import StatisticsSubscriber -from ._subscriber_manager import SubscriberManager +from ._subscriber_manager import SubscriberManager, _InspectActivation + +# Define a global uninitialized subscriber manager for usage where it is needed by different Python files. +GlobalSubscriberManager = SubscriberManager() diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py index 60f6586f1f..6cec29a308 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py @@ -39,6 +39,8 @@ class StatisticsSubscriber(SubscriberBase): start_step: Union[None, int] = None, end_step: Union[None, int] = None, override_output_dir: bool = False, + run_on_cpu: bool = False, + bucket_size: int = 1024 * 1024 * 1024 // 2, ): """ Steps in [start_step, end_step) will run subscriber actions. @@ -48,9 +50,14 @@ class StatisticsSubscriber(SubscriberBase): start_step: the first step that runs subscriber actions. end_step: the end step (exclusively) that runs subscriber actions. override_output_dir: whether `output_dir` can be overridden if it already exists. + run_on_cpu: whether to run the subscriber actions on CPU, this should be the last resort when inserted + inspector node affects memory peak causing the original recipe run to fail with OOM. + bucket_size: the size of the bucket to split the statistic calculation. """ super().__init__(start_step=start_step, end_step=end_step) self._output_dir = output_dir + self._run_on_cpu = run_on_cpu + self._bucket_size = bucket_size if os.path.exists(self._output_dir): if override_output_dir: warnings.warn(f"Output directory {self._output_dir} already exists, overriding it.") @@ -87,26 +94,82 @@ class StatisticsSubscriber(SubscriberBase): # though it does not always guarantee to do this way. torch.set_printoptions(precision=6, linewidth=128) - flatten_array = tensor.flatten() - zero_tensor = torch.tensor(0, dtype=flatten_array.dtype, device=flatten_array.device) - num_nan = torch.isnan(flatten_array).sum() - num_inf = torch.isinf(flatten_array).sum() + tensor_shape = tensor.shape + tensor_dtype = tensor.dtype + flatten_array = tensor.flatten().view(-1) + + if self._run_on_cpu: + flatten_array = flatten_array.to("cpu") + + if self._run_on_cpu: + num_nan = torch.isnan(flatten_array).sum() + num_inf = torch.isinf(flatten_array).sum() + num_neg = (flatten_array < 0).sum() + num_pos = (flatten_array > 0).sum() + num_zero = (flatten_array == 0).sum() + min_value = flatten_array.min() + max_value = flatten_array.max() + mean_value = flatten_array.mean() + std_value = flatten_array.std() + else: + # Split the calculation for each bucket, then do another round of calculation on the bucket results. + # This can at the best effort reduce the peak memory impact. + bucket_size = self._bucket_size + element_count = flatten_array.numel() + ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size) + nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + + # Summary for each bucket + element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + for i in range(ceil_bucket_count): + end = min((i + 1) * bucket_size, element_count) + bucket = flatten_array[i * bucket_size : end] + element_count_per_bucket[i] = bucket.numel() + + nan_buckets[i] = torch.isnan(bucket).sum() + inf_buckets[i] = torch.isinf(bucket).sum() + neg_buckets[i] = (bucket < 0).sum() + pos_buckets[i] = (bucket > 0).sum() + zero_buckets[i] = (bucket == 0).sum() + min_buckets[i] = bucket.min() + max_buckets[i] = bucket.max() + mean_buckets[i] = bucket.sum() + std_buckets[i] = bucket.std() + + # Reduction across all buckets + num_nan = nan_buckets.sum() + num_inf = inf_buckets.sum() + num_neg = neg_buckets.sum() + num_pos = pos_buckets.sum() + num_zero = zero_buckets.sum() + min_value = min_buckets.min() + max_value = max_buckets.max() + mean_value = float(mean_buckets.sum()) / float(element_count) + # Here we refer https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups + # to calculate the combined standard deviation of all buckets. + s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * ( + (mean_buckets - mean_value) ** 2 + ) + std_value = torch.sqrt(s.sum() / (element_count - 1)) with order_file_path.open(mode="a", encoding="utf-8") as f: f.write(f"{output_file_name}\n") with tensor_file_path.open(mode="w", encoding="utf-8") as f: f.write( - f"{'>'*depth + display_name} shape: {tensor.shape} dtype: {tensor.dtype} size: {flatten_array.size()} \n" - f"min: {flatten_array.min()} max: {flatten_array.max()}, mean: {flatten_array.mean()}, " - f"std: {flatten_array.std()} \n" + f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n" + f"min: {min_value} max: {max_value}, mean: {mean_value}, " + f"std: {std_value} \n" f"nan: {num_nan}, inf: {num_inf}\n" ) f.write(f"samples(top 128): {flatten_array[:128]}\n") - - f.write( - f"neg: {torch.less(flatten_array, zero_tensor).to(torch.int64).sum()}, " - f"pos: {torch.greater(flatten_array, zero_tensor).to(torch.int64).sum()}, " - f"zero: {torch.eq(flatten_array, zero_tensor).to(torch.int64).sum()},\n" - ) + f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n") f.write(f"{'='*16}\n") diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index 3312182334..d0d7fcce8d 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -5,7 +5,7 @@ from collections import abc -from typing import Callable, List, Union +from typing import Callable, List, Optional, Union import torch @@ -52,15 +52,34 @@ class _RuntimeStates: class _InspectActivation(torch.autograd.Function): """ This class is used to run the subscriber's forward and backward functions. + The function will be called by two kinds of callers: + 1. SubscriberManager calls it for each registered nn.Module. + 2. Users who want to inspect the activation tensor at any place of model definition code. """ @staticmethod - def forward(ctx, activation_name: str, module_idx: int, run_ctx: _RuntimeStates, input_tensor): + def forward( + ctx, activation_name: str, module_idx: Optional[int], run_ctx: _RuntimeStates, input_tensor: torch.Tensor + ): """ + Args: + ctx: context object to store intermediate information. + activation_name: the name of the activation tensor. + module_idx: + For call case 1 - the unique id of the module that the activation belongs to, it is detected by the + SubscriberManager automatically. + For call case 2 - e.g, _InspectActivation is called by users (NOT by SubscriberManager), module_idx can + be None. + run_ctx: runtime context. + For call case 2 - need retrieve the runtime state from GlobalSubscriberManager. + input_tensor: the activation tensor. + Make sure there is a same number of `tensor` type inputs and outputs. This is enforced by ORT's PythonOp's schema check. """ - depth = run_ctx.global_states.module_index_to_depth[module_idx] + depth = -1 + if module_idx is not None: + depth = run_ctx.global_states.module_index_to_depth[module_idx] input_tensor_copied = None if input_tensor is None or not isinstance(input_tensor, torch.Tensor): @@ -81,7 +100,7 @@ class _InspectActivation(torch.autograd.Function): return input_tensor.detach() if input_tensor is not None else None @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output: torch.Tensor): val = None if grad_output is None or not isinstance(grad_output, torch.Tensor): val = grad_output @@ -105,7 +124,7 @@ class _IncrementStep(torch.autograd.Function): """ @staticmethod - def forward(ctx, run_ctx, input_tensor): + def forward(ctx, run_ctx: _RuntimeStates, input_tensor: torch.Tensor): """ Make sure there is a same number of `tensor` inputs and outputs. This is enforced by ORT's PythonOp's schema check. @@ -135,7 +154,7 @@ class _IncrementStep(torch.autograd.Function): return input_tensor.detach() if isinstance(input_tensor, torch.Tensor) else input_tensor @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output: torch.Tensor): # In case there are multiple backward calls for multiple outputs of the outside-most module. if ctx.current_step == ctx.run_ctx.global_states.execution_step: if ctx.current_step >= 0: @@ -183,6 +202,9 @@ class SubscriberManager: self._initialize(module) + def get_run_context(self) -> _RuntimeStates: + return self._run_ctx + def _reset_all_states(self): self._run_ctx = _RuntimeStates() diff --git a/orttraining/orttraining/test/python/orttraining_test_hooks.py b/orttraining/orttraining/test/python/orttraining_test_hooks.py index 80f29ad881..4fb416e640 100644 --- a/orttraining/orttraining/test/python/orttraining_test_hooks.py +++ b/orttraining/orttraining/test/python/orttraining_test_hooks.py @@ -8,7 +8,7 @@ import pytest import torch from onnxruntime.training.ortmodule import ORTModule -from onnxruntime.training.utils.hooks import StatisticsSubscriber, SubscriberManager +from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber, _InspectActivation class NeuralNetSingleOutput(torch.nn.Module): @@ -55,8 +55,7 @@ def test_statistic_subscriber_single_output(device, backend): with tempfile.TemporaryDirectory() as temporary_dir: output_dir_path = os.path.join(temporary_dir, f"{backend}_out") - sub_manager = SubscriberManager() - sub_manager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)]) + GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)]) if backend == "ortmodule": model = ORTModule(model) @@ -100,9 +99,7 @@ def test_statistic_subscriber_multiple_outputs(device, backend): with tempfile.TemporaryDirectory() as temporary_dir: output_dir_path = os.path.join(temporary_dir, f"{backend}_out") - - sub_manager = SubscriberManager() - sub_manager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)]) + GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)]) if backend == "ortmodule": model = ORTModule(model) @@ -137,3 +134,69 @@ def test_statistic_subscriber_multiple_outputs(device, backend): assert len(os.listdir(step_dir)) == len(expected_files) for file in expected_files: assert os.path.exists(os.path.join(step_dir, file)) + + +class NeuralNetUserAnnotateIntermediateTensor(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super().__init__() + + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, input1, input2): + model_input = input1 + input2 + out = self.fc1(model_input) + out = _InspectActivation.apply("fc1_out", None, GlobalSubscriberManager.get_run_context(), out) + out = self.relu(out) + out = _InspectActivation.apply("relu_out", None, GlobalSubscriberManager.get_run_context(), out) + out = self.fc2(out) + return out + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("backend", ["torch", "ortmodule"]) +def test_statistic_subscriber_user_annotate_intermediate_tensors(device, backend): + input_size = 8 + hidden_size = 16 + num_classes = 32 + model = NeuralNetUserAnnotateIntermediateTensor(input_size, hidden_size, num_classes) + model.to(device) + model.train() + + with tempfile.TemporaryDirectory() as temporary_dir: + output_dir_path = os.path.join(temporary_dir, f"{backend}_out") + GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)]) + + if backend == "ortmodule": + model = ORTModule(model) + + batch_size = 4 + input1_tensor = torch.randn(batch_size, input_size, device=device) + input2_tensor = torch.randn(batch_size, input_size, device=device) + for _ in range(5): + y = model(input1_tensor, input2_tensor) + y.sum().backward() + + assert os.path.exists(output_dir_path) + + expected_files = [ + "order.txt", + "Linear_1_0th_output_forward", + "Linear_1_0th_output_backward", + "NeuralNetUserAnnotateIntermediateTensor_0_0th_output_forward", + "NeuralNetUserAnnotateIntermediateTensor_0_0th_output_backward", + "ReLU_2_0th_output_forward", + "ReLU_2_0th_output_backward", + "Linear_3_0th_output_forward", + "Linear_3_0th_output_backward", + "fc1_out_forward", + "fc1_out_backward", + "relu_out_forward", + "relu_out_backward", + ] + + for i in range(5): + step_dir = os.path.join(output_dir_path, f"step_{i}") + for file in expected_files: + assert os.path.exists(os.path.join(step_dir, file)) |