pytorch/test/distributed/_composable/test_checkpoint.py
Sanket Jayant Purandare d67923b955 Adding kwargs to composable AC API to enable full capabilities (#128516)
Summary:
Firstly, this does not change any existing behaviour, since all the
default values for kwargs were hardcoded into the ``_checkpoint_without_reentrant_generator`` call.

Secondly, this is needed for unlocking the full potential of composable
checkpointing making it equivalent to ``torch.utils.checkpoint.checkpoint(use_reentrant=False)``.

Finally, an added benefit is now composable checkpointing can be used under ``FakeTensorMode`` by
passing ``preserve_rng_state=False``.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128516
Approved by: https://github.com/awgu
2024-06-15 00:23:48 +00:00

336 lines
11 KiB
Python

# Owner(s): ["oncall: distributed"]
import unittest
from collections import deque, OrderedDict
from contextlib import ContextDecorator, contextmanager, nullcontext
from copy import deepcopy
from functools import partial
from typing import Tuple
import torch
import torch.nn as nn
from torch.distributed._composable import checkpoint
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils.checkpoint import CheckpointError
class MemoryDelta(ContextDecorator):
def __init__(self, device: torch.device):
self.device: torch.device = device
self.active_memory_enter: int = 0
self.active_memory_exit: int = 0
def __enter__(self):
self.active_memory_enter = (
torch.cuda.memory_stats()["active_bytes.all.current"]
if self.device.type == "cuda"
else 0
)
return self
def __exit__(self, *exc):
self.active_memory_exit = (
torch.cuda.memory_stats()["active_bytes.all.current"]
if self.device.type == "cuda"
else 0
)
def delta(self) -> int:
return self.active_memory_exit - self.active_memory_enter
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(100, 100)
self.seq = nn.Sequential(
nn.ReLU(),
nn.Linear(100, 100),
nn.ReLU(),
)
def forward(self, x):
return self.seq(self.l1(x))
class RandomModel(nn.Module):
def __init__(self):
super().__init__()
self.p = nn.Parameter(torch.randn(100, 100))
def forward(self, x):
y = torch.matmul(self.p, torch.randn(100, 100, device=self.p.device))
return torch.matmul(x, y)
class MultiOutputModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.w1 = nn.Parameter(torch.randn((100, 100), device=device))
self.w2 = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
z = x @ self.w1
z = nn.functional.relu(z)
z = z @ self.w2
return z.sin(), z.cos()
class MultiInputModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.w = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, xs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
assert len(xs) == 2, f"Expects 2 args but got {len(xs)}"
x, y = xs
z = x + y
z = z @ self.w
return nn.functional.relu(z)
class TestCheckpoint(TestCase):
def _get_graph_size(self, out: torch.Tensor) -> int:
q = deque([out.grad_fn])
num_functions = 0
while len(q):
fn = q.pop()
num_functions += 1
for next_fn, _ in fn.next_functions:
if next_fn:
q.append(next_fn)
return num_functions
def _test_tensor_only(
self,
net: nn.Module,
x: torch.Tensor,
) -> None:
x1 = x.clone()
x2 = x.clone()
x1.requires_grad = True
x2.requires_grad = True
net1 = net
net2 = deepcopy(net)
# no checkpoint
with MemoryDelta(x.device) as mem1:
loss1 = net1(x1).sum()
graph_size1 = self._get_graph_size(loss1)
loss1.backward()
# with checkpoint
checkpoint(net2.seq)
with MemoryDelta(x.device) as mem2:
loss2 = net2(x2).sum()
loss2.backward()
if x.is_cuda:
self.assertTrue(mem2.delta() < mem1.delta())
for p1, p2 in zip(net1.parameters(), net2.parameters()):
self.assertEqual(p1.grad, p2.grad)
def test_tensor_only_cpu(self):
x = torch.randn(20, 100)
net = ToyModel()
self._test_tensor_only(net, x)
@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_tensor_only_gpu(self):
x = torch.randn(20, 100, device="cuda:0")
net = ToyModel().to("cuda:0")
self._test_tensor_only(net, x)
def test_random_cpu(self):
x1 = torch.randn(20, 100, requires_grad=True)
x2 = x1.clone()
net1 = RandomModel()
net2 = deepcopy(net1)
cpu_rng_state = torch.get_rng_state()
net1(x1).sum().backward()
torch.set_rng_state(cpu_rng_state)
checkpoint(net2)(x2).sum().backward()
for p1, p2 in zip(net1.parameters(), net2.parameters()):
self.assertEqual(p1.grad, p2.grad)
def test_multi_args(self):
"""
Tests checkpoint for modules with multiple output args and hence
multiple backward function input args.
"""
device = torch.device("cpu")
net1 = nn.Sequential(
MultiOutputModel(device),
MultiInputModel(device),
MultiOutputModel(device),
MultiInputModel(device),
)
net2 = deepcopy(net1)
checkpoint(net2[0])
checkpoint(net2[2])
x1 = torch.randn(20, 100, requires_grad=True)
x2 = x1.clone()
net1(x1).sum().backward()
net2(x2).sum().backward()
for p1, p2 in zip(net1.parameters(), net2.parameters()):
self.assertEqual(p1.grad, p2.grad)
def test_clears_state_on_error_in_forward(self):
class MyModel(torch.nn.Module):
def __init__(self, raise_in_recomp):
super().__init__()
self.fwd_count = 0
self.raise_in_recomp = raise_in_recomp
self.a = torch.nn.Linear(2, 2)
def forward(self, x):
if self.raise_in_recomp and self.fwd_count == 1:
raise RuntimeError("foo")
else:
if not self.raise_in_recomp:
# raise in the first forward
raise RuntimeError("foo")
self.fwd_count += 1
return self.a(x)
m = MyModel(raise_in_recomp=True)
m_seq = torch.nn.Sequential(OrderedDict({"m": m}))
checkpoint(m_seq.m)
inp = torch.randn(1, 2)
out = m_seq(inp).sum()
# Should raise in forward recomputation
with self.assertRaisesRegex(RuntimeError, "foo"):
out.backward()
# Check that _ac_generator is cleared out
self.assertEqual(None, checkpoint.state(m)._ac_generator)
m = MyModel(raise_in_recomp=False)
checkpoint(m)
inp = torch.randn(1, 2)
# Should raise in first forward
with self.assertRaises(RuntimeError):
m(inp)
self.assertEqual(None, checkpoint.state(m)._ac_generator)
def test_checkpoint_kwargs(self):
class MyModel(torch.nn.Module):
def __init__(self, raise_exp: bool, change_shape_in_recomp: bool):
super().__init__()
self.fwd_count = 0
self.raise_exp = raise_exp
self.change_shape_in_recomp = change_shape_in_recomp
self.a = torch.nn.Linear(2, 2)
def forward(self, x):
if self.raise_exp and self.fwd_count == 0:
raise RuntimeError("foo")
if self.raise_exp and self.fwd_count == 1:
raise RuntimeError("bar")
if self.change_shape_in_recomp and self.fwd_count == 1:
x.relu_()
random_tensor = torch.randn(1, 2)
x = self.a(x + random_tensor)
self.fwd_count += 1
return x
m = MyModel(True, False)
m0, m1, m2, m3 = (deepcopy(m) for _ in range(4))
# composable checkpoint does not support use_reentrant=True
with self.assertRaisesRegex(
NotImplementedError,
"use_reentrant=True is not supported in composable checkpoint. "
"Please use torch.utils.checkpoint.checkpoint instead.",
):
checkpoint(m, use_reentrant=True)
# check giving an unsupported kwarg
with self.assertRaisesRegex(ValueError, "Unexpected keyword arguments: foo"):
checkpoint(m0, foo="bar")
handled_fwd_exp = False
handled_recomp_exp = False
@contextmanager
def fwd_ctx(mod: MyModel):
try:
mod.raise_exp = False
yield
finally:
nonlocal handled_fwd_exp
handled_fwd_exp = True
mod.raise_exp = True
@contextmanager
def recomp_ctx(mod: MyModel):
try:
mod.raise_exp = False
yield
finally:
nonlocal handled_recomp_exp
handled_recomp_exp = True
mod.raise_exp = True
# Test different context functions
x = torch.randn(1, 2, requires_grad=True)
checkpoint(
m1, context_fn=lambda: (partial(fwd_ctx, m1)(), partial(recomp_ctx, m1)())
)
m1(x.clone()).sum().backward()
self.assertEqual((handled_fwd_exp, handled_recomp_exp), (True, True))
checkpoint(m2, context_fn=lambda: (nullcontext(), partial(recomp_ctx, m2)()))
with self.assertRaisesRegex(RuntimeError, "foo"):
m2(x.clone())
handled_fwd_exp = False # Reset flag
checkpoint(m3, context_fn=lambda: (partial(fwd_ctx, m3)(), nullcontext()))
with self.assertRaisesRegex(RuntimeError, "bar"):
m3(x.clone()).sum().backward()
self.assertEqual(handled_fwd_exp, True)
# Test determinism check failure
m4 = MyModel(False, True)
m5 = deepcopy(m4)
# Determinism check should not throw an error,
# but autograd should throw a RuntimeError
checkpoint(m4, determinism_check="none")
with self.assertRaises(RuntimeError):
m4(x.clone()).sum().backward()
# Determinism check should throw a CheckpointError
checkpoint(m5, determinism_check="default")
with self.assertRaises(CheckpointError):
m5(x.clone()).sum().backward()
# Test preserving random state
m6 = MyModel(False, False)
m7, m8 = (deepcopy(m6) for _ in range(2))
checkpoint(m7, preserve_rng_state=False)
checkpoint(m8, preserve_rng_state=True)
for mi in (m6, m7, m8):
torch.manual_seed(42)
loss = mi(x.clone()).sum()
torch.manual_seed(41)
loss.backward()
# check that m6 and m7 have at least one different grad
self.assertNotEqual(
(p1.grad for p1 in m6.parameters()), (p2.grad for p2 in m7.parameters())
)
# check that m6 and m8 have identical grads
for p1, p2 in zip(m6.parameters(), m8.parameters()):
self.assertEqual(p1.grad, p2.grad)
if __name__ == "__main__":
run_tests()