mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
This PR presents a mixed integer linear programming (MILP) formulation that can be utilized to determine, under a memory budget, which modules to apply activation checkpointing (AC) and the amount of activation memory that should be discarded for each module. The MILP uses information collected from MemTracker, Runtime Estimator, and SAC Estimator, introduced in these PRs: * https://github.com/pytorch/pytorch/pull/124688 * https://github.com/pytorch/pytorch/pull/134243 * https://github.com/pytorch/pytorch/pull/135208 End-to-end example and its sample output: ``` import copy from typing import Tuple import torch from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed._tools.ilp_utils import ( aggregate_stats, get_peak_memory_runtime_baseline, parse_module_info, ) from torch.distributed._tools.mem_tracker import _ModState, MemTracker from torch.distributed._tools.runtime_estimator import RuntimeEstimator from torch.distributed._tools.sac_estimator import SACEstimator from torch.distributed._tools.sac_ilp import sac_milp from torch.testing._internal.distributed._tensor.common_dtensor import ( ModelArgs, Transformer, ) def _init_model_input_optimizer() -> Tuple[ torch.nn.Module, torch.optim.Optimizer, torch.Tensor ]: bsz = 8 model_args = ModelArgs( n_layers=4, n_heads=12, vocab_size=8192, max_seq_len=1024, dim=768, dropout_p=0.1, ) with torch.device(torch.cuda.current_device()): model = Transformer(model_args) optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) inp = torch.randint( 0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=torch.cuda.current_device(), ) return (model, optimizer, inp) def _run_and_get_mem_tracker( model: torch.nn.Module, optimizer: torch.optim.Optimizer, inp: torch.Tensor, ) -> MemTracker: mem_tracker = MemTracker() mem_tracker.track_external(model, optimizer) with mem_tracker as mt: for iter_idx in range(2): # running twice to initialize optimizer output = model(inp) output.sum().backward() if iter_idx == 1: last_snapshot = mt.get_tracker_snapshot("current") optimizer.step() optimizer.zero_grad() if iter_idx == 0: mt.reset_mod_stats() assert last_snapshot is not None for mod_stats in mem_tracker.memory_tracking.values(): if _ModState.POST_BW not in mod_stats.snapshots.keys(): mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append( copy.deepcopy(last_snapshot) ) return mem_tracker def _run_and_get_runtime_estimator( model: torch.nn.Module, optimizer: torch.optim.Optimizer, inp: torch.Tensor, ) -> RuntimeEstimator: def _run_one_step() -> None: output = model(inp) output.sum().backward() optimizer.step() optimizer.zero_grad() # Initializing optimizer states and warm-up _run_one_step() runtime_estimator = RuntimeEstimator() with runtime_estimator(estimate_mode_type="operator-level-cost-model"): _run_one_step() # We use only one iteration for estimation return runtime_estimator def _run_and_get_sac_estimator( model: torch.nn.Module, inp: torch.Tensor, ) -> SACEstimator: sac_estimator = SACEstimator() with sac_estimator(estimate_mode_type="operator-level-cost-model"): loss = model(inp).sum() loss.backward() return sac_estimator def main(): with FakeTensorMode(): model, optimizer, inp = _init_model_input_optimizer() mem_tracker = _run_and_get_mem_tracker(model, optimizer, inp) runtime_estimator = _run_and_get_runtime_estimator(model, optimizer, inp) sac_estimator = _run_and_get_sac_estimator(model, inp) mod_info = aggregate_stats( model, mem_tracker, runtime_estimator, sac_estimator, torch.device(torch.cuda.current_device()), ) g = parse_module_info(mod_info) peak_mem, compute_time = get_peak_memory_runtime_baseline(g) print("=== WITHOUT AC ===") print(f"peak_mem: {round(peak_mem / 2**30, 2)} GiB") print(f"compute_time: {round(compute_time, 2)} ms") ac_decisions, recomputation_time, peak_mem = sac_milp(g, memory_budget=1.75) print("=== WITH AC ===") print(f"ac_decisions: {ac_decisions}") print(f"peak_mem: {round(peak_mem / 2**30, 2)} GiB") print(f"recomputation_time: {recomputation_time} ms") if __name__ == "__main__": main() ``` ``` === WITHOUT AC === peak_mem: 2.41 GiB compute_time: 97.97 ms === WITH AC === ac_decisions: {'Transformer.layers.0': 0.5232, 'Transformer.layers.1': 0.5232, 'Transformer.layers.2': 0.6849, 'Transformer.layers.3': 0.5232} peak_mem: 1.75 GiB recomputation_time: 5.92 ms ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/137908 Approved by: https://github.com/weifengpy |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| fsdp2_mem_tracker.py | ||
| ilp_utils.py | ||
| mem_tracker.py | ||
| memory_tracker.py | ||
| mod_tracker.py | ||
| runtime_estimator.py | ||
| sac_estimator.py | ||
| sac_ilp.py | ||