mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Enhance StatisticsSubscriber (#16098)
### Enhance StatisticsSubscriber There are few improvements for `StatisticsSubscriber`: - Reduce peak memory impact for tensors (having many many many elements, consuming too much GPU memory, causing original recipe run failed with OOM), by split the statistics into two phases (split into buckets, and merge result across buckets). - Allow dump intermediate tensors. Originally only nn.Module forward()'s return value are dumped, there are requirements we want to inspect some specific intermediate tensor in the forward() function, now we support it. - Add documents for collecting dumps on multiple ranks Docs link on this branch for better view: https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool_v2/docs/ORTModule_Convergence_Notes.md --------- Co-authored-by: mindest <30493312+mindest@users.noreply.github.com>
This commit is contained in:
parent
eed02a3f78
commit
40bcc0441b
5 changed files with 245 additions and 46 deletions
|
|
@ -18,7 +18,8 @@ Before looking into this further, we should clarify a few things (if possible):
|
|||
|
||||
## 2. Collect Activation Statistics
|
||||
|
||||
### Add a few lines of code, run script to collect statistics:
|
||||
|
||||
### 2.1 Use `GlobalSubscriberManager` to collect `nn.Module` forward() outputs
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
|
|
@ -29,12 +30,11 @@ Before looking into this further, we should clarify a few things (if possible):
|
|||
<td>
|
||||
<sub>
|
||||
|
||||
```diff
|
||||
+ from onnxruntime.training.utils.hooks import SubscriberManager,
|
||||
+ StatisticsSubscriber
|
||||
+ sub_m = SubscriberManager()
|
||||
+ sub_m.subscribe(model, [StatisticsSubscriber(output_dir="pt_out",
|
||||
+ override_output_dir=True)])
|
||||
```python
|
||||
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
|
||||
GlobalSubscriberManager.subscribe(
|
||||
model, [StatisticsSubscriber(output_dir="pt_out", override_output_dir=True)]
|
||||
)
|
||||
```
|
||||
|
||||
</sub>
|
||||
|
|
@ -42,13 +42,12 @@ Before looking into this further, we should clarify a few things (if possible):
|
|||
<td>
|
||||
<sub>
|
||||
|
||||
```diff
|
||||
```python
|
||||
model = ORTModule(model)
|
||||
+ from onnxruntime.training.utils.hooks import SubscriberManager,
|
||||
+ StatisticsSubscriber
|
||||
+ sub_m = SubscriberManager()
|
||||
+ sub_m.subscribe(model, [StatisticsSubscriber(output_dir="ort_out",
|
||||
+ override_output_dir=True)])
|
||||
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
|
||||
GlobalSubscriberManager.subscribe(
|
||||
model, [StatisticsSubscriber(output_dir="ort_out", override_output_dir=True)]
|
||||
)
|
||||
```
|
||||
|
||||
</sub>
|
||||
|
|
@ -77,16 +76,64 @@ model = ORTModule(model)
|
|||
|
||||
Arguments:
|
||||
- output_dir: the directory in all activation statistics files will be stored.
|
||||
- start_step [optional]: the first step that runs subscriber actions.
|
||||
- end_step [optional]: the end step (exclusively) that runs subscriber actions.
|
||||
- override_output_dir: whether `output_dir` can be overridden if it already exists.
|
||||
- `start_step` [optional]: the first step that runs subscriber actions.
|
||||
- `end_step` [optional]: the end step (exclusively) that runs subscriber actions.
|
||||
- `override_output_dir`: whether `output_dir` can be overridden if it already exists.
|
||||
- `run_on_cpu`: whether to run the subscriber actions on CPU, this should be the last resort when inserted
|
||||
inspector node affects memory peak causing the original recipe run to fail with OOM.
|
||||
- `bucket_size`: the size of the bucket to split the statistic calculation.
|
||||
|
||||
Check [StatisticsSubscriber implementation](../orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py) for more information.
|
||||
### 2.2 Use `_InspectActivation` to collect intermediate tensors in a `nn.Module` forward()
|
||||
|
||||
### Run command to generate per-step summary
|
||||
The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output tensors will be dumped, if you want to
|
||||
dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example:
|
||||
|
||||
```diff
|
||||
class BloomForCausalLM(BloomPreTrainedModel):
|
||||
def __init__(self, config: BloomConfig):
|
||||
...
|
||||
|
||||
def forward(self, input_ids, ...):
|
||||
...
|
||||
transformer_outputs = self.transformer(...)
|
||||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
+ lm_logits = _InspectActivation.apply("lm_logits", None, GlobalSubscriberManager.get_run_context(), lm_logits)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
+ shift_logits = _InspectActivation.apply("shift_logits", None, GlobalSubscriberManager.get_run_context(), shift_logits)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
batch_size, seq_length, vocab_size = shift_logits.shape
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
|
||||
)
|
||||
|
||||
return loss
|
||||
```
|
||||
|
||||
Be noted, make sure the activation name (as the first argument of `_InspectActivation.apply`) is unique, otherwise
|
||||
stat file using the activation name will be overwritten by the last write. The dumped data are stored in the `output_dir`.
|
||||
|
||||
|
||||
### 2.3 Collect on multiple ranks
|
||||
|
||||
`GlobalSubscriberManager` does not explicitly handle the racing condition when multiple ranks write into the same file path,
|
||||
here is the example if you want to collect statistics on multiple ranks:
|
||||
|
||||
```python
|
||||
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
|
||||
GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir="ort_out_" + str(torch.distributed.get_rank()),
|
||||
override_output_dir=True)])
|
||||
```
|
||||
|
||||
Check [StatisticsSubscriber implementation](../orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py) for more information.
|
||||
|
||||
### 2.4 Run command to generate per-step summary
|
||||
|
||||
```bash
|
||||
python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output
|
||||
```
|
||||
|
||||
### Manually compare the generated per-step summary to find the first big diff.
|
||||
### 2.5 Manually compare the generated per-step summary to find the first big diff.
|
||||
|
|
|
|||
|
|
@ -5,8 +5,12 @@
|
|||
|
||||
__all__ = [
|
||||
"StatisticsSubscriber",
|
||||
"SubscriberManager",
|
||||
"GlobalSubscriberManager",
|
||||
"_InspectActivation",
|
||||
]
|
||||
|
||||
from ._statistics_subscriber import StatisticsSubscriber
|
||||
from ._subscriber_manager import SubscriberManager
|
||||
from ._subscriber_manager import SubscriberManager, _InspectActivation
|
||||
|
||||
# Define a global uninitialized subscriber manager for usage where it is needed by different Python files.
|
||||
GlobalSubscriberManager = SubscriberManager()
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ class StatisticsSubscriber(SubscriberBase):
|
|||
start_step: Union[None, int] = None,
|
||||
end_step: Union[None, int] = None,
|
||||
override_output_dir: bool = False,
|
||||
run_on_cpu: bool = False,
|
||||
bucket_size: int = 1024 * 1024 * 1024 // 2,
|
||||
):
|
||||
"""
|
||||
Steps in [start_step, end_step) will run subscriber actions.
|
||||
|
|
@ -48,9 +50,14 @@ class StatisticsSubscriber(SubscriberBase):
|
|||
start_step: the first step that runs subscriber actions.
|
||||
end_step: the end step (exclusively) that runs subscriber actions.
|
||||
override_output_dir: whether `output_dir` can be overridden if it already exists.
|
||||
run_on_cpu: whether to run the subscriber actions on CPU, this should be the last resort when inserted
|
||||
inspector node affects memory peak causing the original recipe run to fail with OOM.
|
||||
bucket_size: the size of the bucket to split the statistic calculation.
|
||||
"""
|
||||
super().__init__(start_step=start_step, end_step=end_step)
|
||||
self._output_dir = output_dir
|
||||
self._run_on_cpu = run_on_cpu
|
||||
self._bucket_size = bucket_size
|
||||
if os.path.exists(self._output_dir):
|
||||
if override_output_dir:
|
||||
warnings.warn(f"Output directory {self._output_dir} already exists, overriding it.")
|
||||
|
|
@ -87,26 +94,82 @@ class StatisticsSubscriber(SubscriberBase):
|
|||
# though it does not always guarantee to do this way.
|
||||
torch.set_printoptions(precision=6, linewidth=128)
|
||||
|
||||
flatten_array = tensor.flatten()
|
||||
zero_tensor = torch.tensor(0, dtype=flatten_array.dtype, device=flatten_array.device)
|
||||
num_nan = torch.isnan(flatten_array).sum()
|
||||
num_inf = torch.isinf(flatten_array).sum()
|
||||
tensor_shape = tensor.shape
|
||||
tensor_dtype = tensor.dtype
|
||||
flatten_array = tensor.flatten().view(-1)
|
||||
|
||||
if self._run_on_cpu:
|
||||
flatten_array = flatten_array.to("cpu")
|
||||
|
||||
if self._run_on_cpu:
|
||||
num_nan = torch.isnan(flatten_array).sum()
|
||||
num_inf = torch.isinf(flatten_array).sum()
|
||||
num_neg = (flatten_array < 0).sum()
|
||||
num_pos = (flatten_array > 0).sum()
|
||||
num_zero = (flatten_array == 0).sum()
|
||||
min_value = flatten_array.min()
|
||||
max_value = flatten_array.max()
|
||||
mean_value = flatten_array.mean()
|
||||
std_value = flatten_array.std()
|
||||
else:
|
||||
# Split the calculation for each bucket, then do another round of calculation on the bucket results.
|
||||
# This can at the best effort reduce the peak memory impact.
|
||||
bucket_size = self._bucket_size
|
||||
element_count = flatten_array.numel()
|
||||
ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size)
|
||||
nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
|
||||
max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
|
||||
mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
|
||||
std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device)
|
||||
|
||||
# Summary for each bucket
|
||||
element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device)
|
||||
for i in range(ceil_bucket_count):
|
||||
end = min((i + 1) * bucket_size, element_count)
|
||||
bucket = flatten_array[i * bucket_size : end]
|
||||
element_count_per_bucket[i] = bucket.numel()
|
||||
|
||||
nan_buckets[i] = torch.isnan(bucket).sum()
|
||||
inf_buckets[i] = torch.isinf(bucket).sum()
|
||||
neg_buckets[i] = (bucket < 0).sum()
|
||||
pos_buckets[i] = (bucket > 0).sum()
|
||||
zero_buckets[i] = (bucket == 0).sum()
|
||||
min_buckets[i] = bucket.min()
|
||||
max_buckets[i] = bucket.max()
|
||||
mean_buckets[i] = bucket.sum()
|
||||
std_buckets[i] = bucket.std()
|
||||
|
||||
# Reduction across all buckets
|
||||
num_nan = nan_buckets.sum()
|
||||
num_inf = inf_buckets.sum()
|
||||
num_neg = neg_buckets.sum()
|
||||
num_pos = pos_buckets.sum()
|
||||
num_zero = zero_buckets.sum()
|
||||
min_value = min_buckets.min()
|
||||
max_value = max_buckets.max()
|
||||
mean_value = float(mean_buckets.sum()) / float(element_count)
|
||||
# Here we refer https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups
|
||||
# to calculate the combined standard deviation of all buckets.
|
||||
s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * (
|
||||
(mean_buckets - mean_value) ** 2
|
||||
)
|
||||
std_value = torch.sqrt(s.sum() / (element_count - 1))
|
||||
|
||||
with order_file_path.open(mode="a", encoding="utf-8") as f:
|
||||
f.write(f"{output_file_name}\n")
|
||||
|
||||
with tensor_file_path.open(mode="w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
f"{'>'*depth + display_name} shape: {tensor.shape} dtype: {tensor.dtype} size: {flatten_array.size()} \n"
|
||||
f"min: {flatten_array.min()} max: {flatten_array.max()}, mean: {flatten_array.mean()}, "
|
||||
f"std: {flatten_array.std()} \n"
|
||||
f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n"
|
||||
f"min: {min_value} max: {max_value}, mean: {mean_value}, "
|
||||
f"std: {std_value} \n"
|
||||
f"nan: {num_nan}, inf: {num_inf}\n"
|
||||
)
|
||||
f.write(f"samples(top 128): {flatten_array[:128]}\n")
|
||||
|
||||
f.write(
|
||||
f"neg: {torch.less(flatten_array, zero_tensor).to(torch.int64).sum()}, "
|
||||
f"pos: {torch.greater(flatten_array, zero_tensor).to(torch.int64).sum()}, "
|
||||
f"zero: {torch.eq(flatten_array, zero_tensor).to(torch.int64).sum()},\n"
|
||||
)
|
||||
f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n")
|
||||
f.write(f"{'='*16}\n")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
|
||||
from collections import abc
|
||||
from typing import Callable, List, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -52,15 +52,34 @@ class _RuntimeStates:
|
|||
class _InspectActivation(torch.autograd.Function):
|
||||
"""
|
||||
This class is used to run the subscriber's forward and backward functions.
|
||||
The function will be called by two kinds of callers:
|
||||
1. SubscriberManager calls it for each registered nn.Module.
|
||||
2. Users who want to inspect the activation tensor at any place of model definition code.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, activation_name: str, module_idx: int, run_ctx: _RuntimeStates, input_tensor):
|
||||
def forward(
|
||||
ctx, activation_name: str, module_idx: Optional[int], run_ctx: _RuntimeStates, input_tensor: torch.Tensor
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
ctx: context object to store intermediate information.
|
||||
activation_name: the name of the activation tensor.
|
||||
module_idx:
|
||||
For call case 1 - the unique id of the module that the activation belongs to, it is detected by the
|
||||
SubscriberManager automatically.
|
||||
For call case 2 - e.g, _InspectActivation is called by users (NOT by SubscriberManager), module_idx can
|
||||
be None.
|
||||
run_ctx: runtime context.
|
||||
For call case 2 - need retrieve the runtime state from GlobalSubscriberManager.
|
||||
input_tensor: the activation tensor.
|
||||
|
||||
Make sure there is a same number of `tensor` type inputs and outputs.
|
||||
This is enforced by ORT's PythonOp's schema check.
|
||||
"""
|
||||
depth = run_ctx.global_states.module_index_to_depth[module_idx]
|
||||
depth = -1
|
||||
if module_idx is not None:
|
||||
depth = run_ctx.global_states.module_index_to_depth[module_idx]
|
||||
|
||||
input_tensor_copied = None
|
||||
if input_tensor is None or not isinstance(input_tensor, torch.Tensor):
|
||||
|
|
@ -81,7 +100,7 @@ class _InspectActivation(torch.autograd.Function):
|
|||
return input_tensor.detach() if input_tensor is not None else None
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
def backward(ctx, grad_output: torch.Tensor):
|
||||
val = None
|
||||
if grad_output is None or not isinstance(grad_output, torch.Tensor):
|
||||
val = grad_output
|
||||
|
|
@ -105,7 +124,7 @@ class _IncrementStep(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_ctx, input_tensor):
|
||||
def forward(ctx, run_ctx: _RuntimeStates, input_tensor: torch.Tensor):
|
||||
"""
|
||||
Make sure there is a same number of `tensor` inputs and outputs.
|
||||
This is enforced by ORT's PythonOp's schema check.
|
||||
|
|
@ -135,7 +154,7 @@ class _IncrementStep(torch.autograd.Function):
|
|||
return input_tensor.detach() if isinstance(input_tensor, torch.Tensor) else input_tensor
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
def backward(ctx, grad_output: torch.Tensor):
|
||||
# In case there are multiple backward calls for multiple outputs of the outside-most module.
|
||||
if ctx.current_step == ctx.run_ctx.global_states.execution_step:
|
||||
if ctx.current_step >= 0:
|
||||
|
|
@ -183,6 +202,9 @@ class SubscriberManager:
|
|||
|
||||
self._initialize(module)
|
||||
|
||||
def get_run_context(self) -> _RuntimeStates:
|
||||
return self._run_ctx
|
||||
|
||||
def _reset_all_states(self):
|
||||
self._run_ctx = _RuntimeStates()
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
from onnxruntime.training.utils.hooks import StatisticsSubscriber, SubscriberManager
|
||||
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber, _InspectActivation
|
||||
|
||||
|
||||
class NeuralNetSingleOutput(torch.nn.Module):
|
||||
|
|
@ -55,8 +55,7 @@ def test_statistic_subscriber_single_output(device, backend):
|
|||
|
||||
with tempfile.TemporaryDirectory() as temporary_dir:
|
||||
output_dir_path = os.path.join(temporary_dir, f"{backend}_out")
|
||||
sub_manager = SubscriberManager()
|
||||
sub_manager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)])
|
||||
GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)])
|
||||
|
||||
if backend == "ortmodule":
|
||||
model = ORTModule(model)
|
||||
|
|
@ -100,9 +99,7 @@ def test_statistic_subscriber_multiple_outputs(device, backend):
|
|||
|
||||
with tempfile.TemporaryDirectory() as temporary_dir:
|
||||
output_dir_path = os.path.join(temporary_dir, f"{backend}_out")
|
||||
|
||||
sub_manager = SubscriberManager()
|
||||
sub_manager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)])
|
||||
GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)])
|
||||
|
||||
if backend == "ortmodule":
|
||||
model = ORTModule(model)
|
||||
|
|
@ -137,3 +134,69 @@ def test_statistic_subscriber_multiple_outputs(device, backend):
|
|||
assert len(os.listdir(step_dir)) == len(expected_files)
|
||||
for file in expected_files:
|
||||
assert os.path.exists(os.path.join(step_dir, file))
|
||||
|
||||
|
||||
class NeuralNetUserAnnotateIntermediateTensor(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
super().__init__()
|
||||
|
||||
self.fc1 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, input1, input2):
|
||||
model_input = input1 + input2
|
||||
out = self.fc1(model_input)
|
||||
out = _InspectActivation.apply("fc1_out", None, GlobalSubscriberManager.get_run_context(), out)
|
||||
out = self.relu(out)
|
||||
out = _InspectActivation.apply("relu_out", None, GlobalSubscriberManager.get_run_context(), out)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||
@pytest.mark.parametrize("backend", ["torch", "ortmodule"])
|
||||
def test_statistic_subscriber_user_annotate_intermediate_tensors(device, backend):
|
||||
input_size = 8
|
||||
hidden_size = 16
|
||||
num_classes = 32
|
||||
model = NeuralNetUserAnnotateIntermediateTensor(input_size, hidden_size, num_classes)
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temporary_dir:
|
||||
output_dir_path = os.path.join(temporary_dir, f"{backend}_out")
|
||||
GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir_path, override_output_dir=True)])
|
||||
|
||||
if backend == "ortmodule":
|
||||
model = ORTModule(model)
|
||||
|
||||
batch_size = 4
|
||||
input1_tensor = torch.randn(batch_size, input_size, device=device)
|
||||
input2_tensor = torch.randn(batch_size, input_size, device=device)
|
||||
for _ in range(5):
|
||||
y = model(input1_tensor, input2_tensor)
|
||||
y.sum().backward()
|
||||
|
||||
assert os.path.exists(output_dir_path)
|
||||
|
||||
expected_files = [
|
||||
"order.txt",
|
||||
"Linear_1_0th_output_forward",
|
||||
"Linear_1_0th_output_backward",
|
||||
"NeuralNetUserAnnotateIntermediateTensor_0_0th_output_forward",
|
||||
"NeuralNetUserAnnotateIntermediateTensor_0_0th_output_backward",
|
||||
"ReLU_2_0th_output_forward",
|
||||
"ReLU_2_0th_output_backward",
|
||||
"Linear_3_0th_output_forward",
|
||||
"Linear_3_0th_output_backward",
|
||||
"fc1_out_forward",
|
||||
"fc1_out_backward",
|
||||
"relu_out_forward",
|
||||
"relu_out_backward",
|
||||
]
|
||||
|
||||
for i in range(5):
|
||||
step_dir = os.path.join(output_dir_path, f"step_{i}")
|
||||
for file in expected_files:
|
||||
assert os.path.exists(os.path.join(step_dir, file))
|
||||
|
|
|
|||
Loading…
Reference in a new issue