Enhance StatisticsSubscriber (#16098)

### Enhance StatisticsSubscriber

There are few improvements for `StatisticsSubscriber`:

- Reduce peak memory impact for tensors (having many many many elements,
consuming too much GPU memory, causing original recipe run failed with
OOM), by split the statistics into two phases (split into buckets, and
merge result across buckets).
- Allow dump intermediate tensors. Originally only nn.Module forward()'s
return value are dumped, there are requirements we want to inspect some
specific intermediate tensor in the forward() function, now we support
it.
- Add documents for collecting dumps on multiple ranks

Docs link on this branch for better view:
https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool_v2/docs/ORTModule_Convergence_Notes.md

---------

Co-authored-by: mindest <30493312+mindest@users.noreply.github.com>
This commit is contained in:
pengwa 2023-06-12 18:32:08 +08:00 committed by GitHub
parent eed02a3f78
commit 40bcc0441b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 245 additions and 46 deletions

View file

@ -18,7 +18,8 @@ Before looking into this further, we should clarify a few things (if possible):
## 2. Collect Activation Statistics
### Add a few lines of code, run script to collect statistics:
### 2.1 Use `GlobalSubscriberManager` to collect `nn.Module` forward() outputs
<table>
<tr>
@ -29,12 +30,11 @@ Before looking into this further, we should clarify a few things (if possible):
<td>
<sub>
```diff
+ from onnxruntime.training.utils.hooks import SubscriberManager,
+ StatisticsSubscriber
+ sub_m = SubscriberManager()
+ sub_m.subscribe(model, [StatisticsSubscriber(output_dir="pt_out",
+ override_output_dir=True)])
```python
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(
model, [StatisticsSubscriber(output_dir="pt_out", override_output_dir=True)]
)
```
</sub>
@ -42,13 +42,12 @@ Before looking into this further, we should clarify a few things (if possible):
<td>
<sub>
```diff
```python
model = ORTModule(model)
+ from onnxruntime.training.utils.hooks import SubscriberManager,
+ StatisticsSubscriber
+ sub_m = SubscriberManager()
+ sub_m.subscribe(model, [StatisticsSubscriber(output_dir="ort_out",
+ override_output_dir=True)])
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(
model, [StatisticsSubscriber(output_dir="ort_out", override_output_dir=True)]
)
```
</sub>
@ -77,16 +76,64 @@ model = ORTModule(model)
Arguments:
- output_dir: the directory in all activation statistics files will be stored.
- start_step [optional]: the first step that runs subscriber actions.
- end_step [optional]: the end step (exclusively) that runs subscriber actions.
- override_output_dir: whether `output_dir` can be overridden if it already exists.
- `start_step` [optional]: the first step that runs subscriber actions.
- `end_step` [optional]: the end step (exclusively) that runs subscriber actions.
- `override_output_dir`: whether `output_dir` can be overridden if it already exists.
- `run_on_cpu`: whether to run the subscriber actions on CPU, this should be the last resort when inserted
inspector node affects memory peak causing the original recipe run to fail with OOM.
- `bucket_size`: the size of the bucket to split the statistic calculation.
Check [StatisticsSubscriber implementation](../orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py) for more information.
### 2.2 Use `_InspectActivation` to collect intermediate tensors in a `nn.Module` forward()
### Run command to generate per-step summary
The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output tensors will be dumped, if you want to
dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example:
```diff
class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config: BloomConfig):
...
def forward(self, input_ids, ...):
...
transformer_outputs = self.transformer(...)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
+ lm_logits = _InspectActivation.apply("lm_logits", None, GlobalSubscriberManager.get_run_context(), lm_logits)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_logits = _InspectActivation.apply("shift_logits", None, GlobalSubscriberManager.get_run_context(), shift_logits)
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
return loss
```
Be noted, make sure the activation name (as the first argument of `_InspectActivation.apply`) is unique, otherwise
stat file using the activation name will be overwritten by the last write. The dumped data are stored in the `output_dir`.
### 2.3 Collect on multiple ranks
`GlobalSubscriberManager` does not explicitly handle the racing condition when multiple ranks write into the same file path,
here is the example if you want to collect statistics on multiple ranks:
```python
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir="ort_out_" + str(torch.distributed.get_rank()),
override_output_dir=True)])
```
Check [StatisticsSubscriber implementation](../orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py) for more information.
### 2.4 Run command to generate per-step summary
```bash
python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output
```
### Manually compare the generated per-step summary to find the first big diff.
### 2.5 Manually compare the generated per-step summary to find the first big diff.

View file

@ -5,8 +5,12 @@
__all__ = [
"StatisticsSubscriber",
"SubscriberManager",
"GlobalSubscriberManager",
"_InspectActivation",
]
from ._statistics_subscriber import StatisticsSubscriber
from ._subscriber_manager import SubscriberManager
from ._subscriber_manager import SubscriberManager, _InspectActivation
# Define a global uninitialized subscriber manager for usage where it is needed by different Python files.
GlobalSubscriberManager = SubscriberManager()

View file

@ -39,6 +39,8 @@ class StatisticsSubscriber(SubscriberBase):
start_step: Union[None, int] = None,
end_step: Union[None, int] = None,
override_output_dir: bool = False,
run_on_cpu: bool = False,
bucket_size: int = 1024 * 1024 * 1024 // 2,
):
"""
Steps in [start_step, end_step) will run subscriber actions.
@ -48,9 +50,14 @@ class StatisticsSubscriber(SubscriberBase):
start_step: the first step that runs subscriber actions.
end_step: the end step (exclusively) that runs subscriber actions.
override_output_dir: whether `output_dir` can be overridden if it already exists.
run_on_cpu: whether to run the subscriber actions on CPU, this should be the last resort when inserted
inspector node affects memory peak causing the original recipe run to fail with OOM.
bucket_size: the size of the bucket to split the statistic calculation.
"""
super().__init__(start_step=start_step, end_step=end_step)
self._output_dir = output_dir
self._run_on_cpu = run_on_cpu
self._bucket_size = bucket_size
if os.path.exists(self._output_dir):
if override_output_dir:
warnings.warn(f"Output directory {self._output_dir} already exists, overriding it.")
@ -87,26 +94,82 @@ class StatisticsSubscriber(SubscriberBase):
# though it does not always guarantee to do this way.
torch.set_printoptions(precision=6, linewidth=128)
flatten_array = tensor.flatten()
zero_tensor = torch.tensor(0, dtype=flatten_array.dtype, device=flatten_array.device)
num_nan = torch.isnan(flatten_array).sum()
num_inf = torch.isinf(flatten_array).sum()
tensor_shape = tensor.shape
tensor_dtype = tensor.dtype
flatten_array = tensor.flatten().view(-1)
if self._run_on_cpu:
flatten_array = flatten_array.to("cpu")
if self._run_on_cpu:
num_nan = torch.isnan(flatten_array).sum()
num_inf = torch.isinf(flatten_array).sum()
num_neg = (flatten_array < 0).sum()
num_pos = (flatten_array > 0).sum()
num_zero = (flatten_array == 0).sum()
min_value = flatten_array.min()
max_value = flatten_array.max()
mean_value = flatten_array.mean()
std_value = flatten_array.std()
else:
# Split the calculation for each bucket, then do another round of calculation on the bucket results.
# This can at the best effort reduce the peak memory impact.
bucket_size = self._bucket_size
element_count = flatten_array.numel()
ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size)
nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
# Summary for each bucket
element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
for i in range(ceil_bucket_count):
end = min((i + 1) * bucket_size, element_count)
bucket = flatten_array[i * bucket_size : end]
element_count_per_bucket[i] = bucket.numel()
nan_buckets[i] = torch.isnan(bucket).sum()
inf_buckets[i] = torch.isinf(bucket).sum()
neg_buckets[i] = (bucket < 0).sum()
pos_buckets[i] = (bucket > 0).sum()
zero_buckets[i] = (bucket == 0).sum()
min_buckets[i] = bucket.min()
max_buckets[i] = bucket.max()
mean_buckets[i] = bucket.sum()
std_buckets[i] = bucket.std()
# Reduction across all buckets
num_nan = nan_buckets.sum()
num_inf = inf_buckets.sum()
num_neg = neg_buckets.sum()
num_pos = pos_buckets.sum()
num_zero = zero_buckets.sum()
min_value = min_buckets.min()
max_value = max_buckets.max()
mean_value = float(mean_buckets.sum()) / float(element_count)
# Here we refer https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups
# to calculate the combined standard deviation of all buckets.
s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * (
(mean_buckets - mean_value) ** 2
)
std_value = torch.sqrt(s.sum() / (element_count - 1))
with order_file_path.open(mode="a", encoding="utf-8") as f:
f.write(f"{output_file_name}\n")
with tensor_file_path.open(mode="w", encoding="utf-8") as f:
f.write(
f"{'>'*depth + display_name} shape: {tensor.shape} dtype: {tensor.dtype} size: {flatten_array.size()} \n"
f"min: {flatten_array.min()} max: {flatten_array.max()}, mean: {flatten_array.mean()}, "
f"std: {flatten_array.std()} \n"
f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n"
f"min: {min_value} max: {max_value}, mean: {mean_value}, "
f"std: {std_value} \n"
f"nan: {num_nan}, inf: {num_inf}\n"
)
f.write(f"samples(top 128): {flatten_array[:128]}\n")
f.write(
f"neg: {torch.less(flatten_array, zero_tensor).to(torch.int64).sum()}, "
f"pos: {torch.greater(flatten_array, zero_tensor).to(torch.int64).sum()}, "
f"zero: {torch.eq(flatten_array, zero_tensor).to(torch.int64).sum()},\n"
)
f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n")
f.write(f"{'='*16}\n")

View file

@ -5,7 +5,7 @@
from collections import abc
from typing import Callable, List, Union
from typing import Callable, List, Optional, Union
import torch
@ -52,15 +52,34 @@ class _RuntimeStates:
class _InspectActivation(torch.autograd.Function):
"""
This class is used to run the subscriber's forward and backward functions.
The function will be called by two kinds of callers:
1. SubscriberManager calls it for each registered nn.Module.
2. Users who want to inspect the activation tensor at any place of model definition code.
"""
@staticmethod
def forward(ctx, activation_name: str, module_idx: int, run_ctx: _RuntimeStates, input_tensor):
def forward(
ctx, activation_name: str, module_idx: Optional[int], run_ctx: _RuntimeStates, input_tensor: torch.Tensor
):
"""
Args:
ctx: context object to store intermediate information.
activation_name: the name of the activation tensor.
module_idx:
For call case 1 - the unique id of the module that the activation belongs to, it is detected by the
SubscriberManager automatically.
For call case 2 - e.g, _InspectActivation is called by users (NOT by SubscriberManager), module_idx can
be None.
run_ctx: runtime context.
For call case 2 - need retrieve the runtime state from GlobalSubscriberManager.
input_tensor: the activation tensor.
Make sure there is a same number of `tensor` type inputs and outputs.
This is enforced by ORT's PythonOp's schema check.
"""
depth = run_ctx.global_states.module_index_to_depth[module_idx]
depth = -1
if module_idx is not None:
depth = run_ctx.global_states.module_index_to_depth[module_idx]
input_tensor_copied = None
if input_tensor is None or not isinstance(input_tensor, torch.Tensor):
@ -81,7 +100,7 @@ class _InspectActivation(torch.autograd.Function):
return input_tensor.detach() if input_tensor is not None else None
@staticmethod
def backward(ctx, grad_output):
def backward(ctx, grad_output: torch.Tensor):
val = None
if grad_output is None or not isinstance(grad_output, torch.Tensor):
val = grad_output
@ -105,7 +124,7 @@ class _IncrementStep(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, run_ctx, input_tensor):
def forward(ctx, run_ctx: _RuntimeStates, input_tensor: torch.Tensor):
"""
Make sure there is a same number of `tensor` inputs and outputs.
This is enforced by ORT's PythonOp's schema check.
@ -135,7 +154,7 @@ class _IncrementStep(torch.autograd.Function):
return input_tensor.detach() if isinstance(input_tensor, torch.Tensor) else input_tensor
@staticmethod
def backward(ctx, grad_output):
def backward(ctx, grad_output: torch.Tensor):
# In case there are multiple backward calls for multiple outputs of the outside-most module.
if ctx.current_step == ctx.run_ctx.global_states.execution_step:
if ctx.current_step >= 0:
@ -183,6 +202,9 @@ class SubscriberManager:
self._initialize(module)
def get_run_context(self) -> _RuntimeStates:
return self._run_ctx
def _reset_all_states(self):
self._run_ctx = _RuntimeStates()

View file

@ -8,7 +8,7 @@ import pytest
import torch
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.utils.hooks import StatisticsSubscriber, SubscriberManager
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber, _InspectActivation
class NeuralNetSingleOutput(torch.nn.Module):
@ -55,8 +55,7 @@ def test_statistic_subscriber_single_output(device, backend):
with tempfile.TemporaryDirectory() as temporary_dir:
output_dir_path = os.path.join(temporary_dir, f"{backend}_out")
sub_manager = SubscriberManager()
sub_manager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)])
GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)])
if backend == "ortmodule":
model = ORTModule(model)
@ -100,9 +99,7 @@ def test_statistic_subscriber_multiple_outputs(device, backend):
with tempfile.TemporaryDirectory() as temporary_dir:
output_dir_path = os.path.join(temporary_dir, f"{backend}_out")
sub_manager = SubscriberManager()
sub_manager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)])
GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)])
if backend == "ortmodule":
model = ORTModule(model)
@ -137,3 +134,69 @@ def test_statistic_subscriber_multiple_outputs(device, backend):
assert len(os.listdir(step_dir)) == len(expected_files)
for file in expected_files:
assert os.path.exists(os.path.join(step_dir, file))
class NeuralNetUserAnnotateIntermediateTensor(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super().__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
def forward(self, input1, input2):
model_input = input1 + input2
out = self.fc1(model_input)
out = _InspectActivation.apply("fc1_out", None, GlobalSubscriberManager.get_run_context(), out)
out = self.relu(out)
out = _InspectActivation.apply("relu_out", None, GlobalSubscriberManager.get_run_context(), out)
out = self.fc2(out)
return out
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("backend", ["torch", "ortmodule"])
def test_statistic_subscriber_user_annotate_intermediate_tensors(device, backend):
input_size = 8
hidden_size = 16
num_classes = 32
model = NeuralNetUserAnnotateIntermediateTensor(input_size, hidden_size, num_classes)
model.to(device)
model.train()
with tempfile.TemporaryDirectory() as temporary_dir:
output_dir_path = os.path.join(temporary_dir, f"{backend}_out")
GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)])
if backend == "ortmodule":
model = ORTModule(model)
batch_size = 4
input1_tensor = torch.randn(batch_size, input_size, device=device)
input2_tensor = torch.randn(batch_size, input_size, device=device)
for _ in range(5):
y = model(input1_tensor, input2_tensor)
y.sum().backward()
assert os.path.exists(output_dir_path)
expected_files = [
"order.txt",
"Linear_1_0th_output_forward",
"Linear_1_0th_output_backward",
"NeuralNetUserAnnotateIntermediateTensor_0_0th_output_forward",
"NeuralNetUserAnnotateIntermediateTensor_0_0th_output_backward",
"ReLU_2_0th_output_forward",
"ReLU_2_0th_output_backward",
"Linear_3_0th_output_forward",
"Linear_3_0th_output_backward",
"fc1_out_forward",
"fc1_out_backward",
"relu_out_forward",
"relu_out_backward",
]
for i in range(5):
step_dir = os.path.join(output_dir_path, f"step_{i}")
for file in expected_files:
assert os.path.exists(os.path.join(step_dir, file))