mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
This PR introduces the functionalization of RNG ops. Key points are
* Introduces a new `philox_rand` prim operator that accepts seed, offset.
* Adds decompositions for random operators that use these philox_rand prims
* Adds a PhiloxStateTracker to track the offset for each occurence of rand ops
* Changes calling convention of AOT Autograd and adds <fwd_seed, fwd_base_offset> and <bwd_seed, bwd_base_offset>
* Monkeypatches set_rng_state and get_rng_state while AOT Autograd tracing to record the rng state behavior
* Raises assertion for CPU because CPU does not Philox RNG.
Not dealt in this PR
* dropout op - offset calculation is different
* other distributions like normal, poisson etc
* Inductor support
* Cudagraph support
* Dynamic shape support
An example
~~~
class Custom(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
@staticmethod
def backward(ctx, grad_out):
x, = ctx.saved_tensors
return grad_out * torch.rand_like(grad_out) * torch.cos(x)
====== Forward graph 0 ======
def forward(self, fwd_seed_1: i64[], fwd_base_offset_1: i64[], primals_1: f32[16, 16]):
# No stacktrace found for following nodes
add: i64[] = torch.ops.aten.add.Tensor(fwd_base_offset_1, 0)
philox_rand: f32[16, 16] = torch.ops.prims.philox_rand.default([16, 16], fwd_seed_1, add, [16, 1], device(type='cuda', index=0), torch.float32); add = None
mul: f32[16, 16] = torch.ops.aten.mul.Tensor(philox_rand, primals_1); philox_rand = None
add_1: i64[] = torch.ops.aten.add.Tensor(fwd_base_offset_1, 4); fwd_base_offset_1 = None
philox_rand_1: f32[16, 16] = torch.ops.prims.philox_rand.default([16, 16], fwd_seed_1, add_1, [16, 1], device(type='cuda', index=0), torch.float32); fwd_seed_1 = add_1 = None
mul_1: f32[16, 16] = torch.ops.aten.mul.Tensor(philox_rand_1, mul); philox_rand_1 = mul = None
return [mul_1, primals_1]
====== Backward graph 0 ======
def forward(self, bwd_seed_1: i64[], bwd_base_offset_1: i64[], primals_1: f32[16, 16], tangents_1: f32[16, 16]):
# No stacktrace found for following nodes
add_2: i64[] = torch.ops.aten.add.Tensor(bwd_base_offset_1, 0); bwd_base_offset_1 = None
philox_rand_2: f32[16, 16] = torch.ops.prims.philox_rand.default([16, 16], bwd_seed_1, add_2, [16, 1], device(type='cuda', index=0), torch.float32); bwd_seed_1 = add_2 = None
mul_2: f32[16, 16] = torch.ops.aten.mul.Tensor(tangents_1, philox_rand_2); tangents_1 = philox_rand_2 = None
cos: f32[16, 16] = torch.ops.aten.cos.default(primals_1); primals_1 = None
mul_3: f32[16, 16] = torch.ops.aten.mul.Tensor(mul_2, cos); mul_2 = cos = None
return [mul_3]
~~~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97377
Approved by: https://github.com/ezyang
282 lines
9.2 KiB
Python
282 lines
9.2 KiB
Python
# Owner(s): ["oncall: pt2"]
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase,
|
|
run_tests,
|
|
)
|
|
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
|
|
from functorch.compile import aot_function, nop, min_cut_rematerialization_partition
|
|
from unittest import skip
|
|
from unittest.mock import patch
|
|
import functools
|
|
import torch.utils.checkpoint
|
|
|
|
|
|
def count_philox_rand(gm, args, freq):
|
|
assert [node.target for node in gm.graph.nodes].count(torch.ops.rngprims.philox_rand.default) == freq
|
|
return gm
|
|
|
|
class TestFunctionalizationRngOps(TestCase):
|
|
@dtypes(torch.float32)
|
|
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
|
|
def test_rand_like(self, dtype, device):
|
|
def fn(x):
|
|
a = torch.rand_like(x) * x
|
|
a = torch.rand_like(x) * a
|
|
return a
|
|
|
|
x = torch.rand(10, device=device, dtype=dtype)
|
|
|
|
for seed in range(10):
|
|
torch.cuda.manual_seed(seed)
|
|
ref = fn(x)
|
|
|
|
torch.cuda.manual_seed(seed)
|
|
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
|
|
res = aot_fn(x)
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
@dtypes(torch.float32)
|
|
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
|
|
def test_rand(self, dtype, device):
|
|
shape = (10,)
|
|
|
|
def fn(x):
|
|
a = torch.rand(*shape, device=device, dtype=dtype) * x
|
|
a = torch.rand(*shape, device=device, dtype=dtype) * a
|
|
return a
|
|
|
|
x = torch.rand(*shape, device=device, dtype=dtype)
|
|
|
|
for seed in range(10):
|
|
torch.cuda.manual_seed(seed)
|
|
ref = fn(x)
|
|
|
|
torch.cuda.manual_seed(seed)
|
|
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
|
|
res = aot_fn(x)
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
@dtypes(torch.float32)
|
|
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
|
|
def test_autograd_function(self, dtype, device):
|
|
shape = (16, 16)
|
|
|
|
class Custom(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.save_for_backward(x)
|
|
a = torch.rand_like(x) * x
|
|
a = torch.rand_like(x) * a
|
|
return a
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
x, = ctx.saved_tensors
|
|
return grad_out * torch.rand_like(grad_out) * torch.cos(x)
|
|
|
|
custom = Custom.apply
|
|
|
|
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
|
|
|
|
x_clone = x.clone().detach().requires_grad_(True)
|
|
|
|
torch.cuda.manual_seed(123)
|
|
ref = custom(x)
|
|
ref.sum().backward()
|
|
|
|
torch.cuda.manual_seed(123)
|
|
fwd_compiler = functools.partial(count_philox_rand, freq=2)
|
|
bwd_compiler = functools.partial(count_philox_rand, freq=1)
|
|
aot_custom = aot_function(custom, fwd_compiler, bwd_compiler)
|
|
res = aot_custom(x_clone)
|
|
res.sum().backward()
|
|
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(x.grad, x_clone.grad)
|
|
|
|
@dtypes(torch.float32)
|
|
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
|
|
def test_multiple_subgraphs(self, dtype, device):
|
|
# Checks that rng state is maintained when there are multiple aot traced
|
|
# graphs.
|
|
shape = (16, 16)
|
|
|
|
class CustomOp1(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.save_for_backward(x)
|
|
a = torch.rand_like(x) * x
|
|
a = torch.rand_like(x) * a
|
|
return a
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
x, = ctx.saved_tensors
|
|
return grad_out * torch.rand_like(grad_out) * torch.cos(x)
|
|
|
|
class CustomOp2(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.save_for_backward(x)
|
|
a = torch.rand_like(x) * x
|
|
return a
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
x, = ctx.saved_tensors
|
|
return grad_out * torch.rand_like(grad_out) * torch.rand_like(x)
|
|
|
|
|
|
custom_op1 = CustomOp1.apply
|
|
custom_op2 = CustomOp2.apply
|
|
|
|
def fn(x):
|
|
a = custom_op1(x)
|
|
b = a.sin()
|
|
return custom_op2(b)
|
|
|
|
fwd_compiler = functools.partial(count_philox_rand, freq=2)
|
|
bwd_compiler = functools.partial(count_philox_rand, freq=1)
|
|
aot_custom_op1 = aot_function(custom_op1, fwd_compiler, bwd_compiler)
|
|
fwd_compiler = functools.partial(count_philox_rand, freq=1)
|
|
bwd_compiler = functools.partial(count_philox_rand, freq=2)
|
|
aot_custom_op2 = aot_function(custom_op2, fwd_compiler, bwd_compiler)
|
|
|
|
def aot_fn(x):
|
|
a = aot_custom_op1(x)
|
|
b = a.sin()
|
|
return aot_custom_op2(b)
|
|
|
|
|
|
for seed in range(10):
|
|
torch.cuda.manual_seed(seed)
|
|
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
|
|
x_clone = x.clone().detach().requires_grad_(True)
|
|
|
|
torch.cuda.manual_seed(seed)
|
|
ref = fn(x)
|
|
ref.sum().backward()
|
|
|
|
torch.cuda.manual_seed(seed)
|
|
res = aot_fn(x_clone)
|
|
res.sum().backward()
|
|
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(x.grad, x_clone.grad)
|
|
|
|
@dtypes(torch.float32)
|
|
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
|
|
def test_set_get_rng_state(self, dtype, device):
|
|
def fn(x):
|
|
a = torch.rand_like(x) * x
|
|
state = torch.cuda.get_rng_state()
|
|
a = torch.rand_like(x) * a
|
|
torch.cuda.set_rng_state(state)
|
|
a = torch.rand_like(x) * a
|
|
return a
|
|
|
|
x = torch.rand(10, device=device, dtype=dtype)
|
|
|
|
for seed in range(10):
|
|
torch.cuda.manual_seed(seed)
|
|
ref = fn(x)
|
|
|
|
torch.cuda.manual_seed(seed)
|
|
fwd_compiler = functools.partial(count_philox_rand, freq=3)
|
|
aot_fn = aot_function(fn, fwd_compiler)
|
|
res = aot_fn(x)
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
@dtypes(torch.float32)
|
|
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
|
|
def test_min_cut_partitioner(self, dtype, device):
|
|
# Checks that the calling convention is maintained
|
|
shape = (16, 16)
|
|
|
|
def fn(x):
|
|
a = torch.rand_like(x) * x
|
|
a = torch.rand_like(x) * a
|
|
a = torch.sin(a)
|
|
a = torch.sin(a)
|
|
a = torch.sin(a)
|
|
return a
|
|
|
|
|
|
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
|
|
|
|
x_clone = x.clone().detach().requires_grad_(True)
|
|
|
|
torch.cuda.manual_seed(123)
|
|
ref = fn(x)
|
|
ref.sum().backward()
|
|
|
|
torch.cuda.manual_seed(123)
|
|
fwd_compiler = functools.partial(count_philox_rand, freq=2)
|
|
bwd_compiler = functools.partial(count_philox_rand, freq=0)
|
|
aot_custom = aot_function(fn, fwd_compiler, bwd_compiler, partition_fn=min_cut_rematerialization_partition)
|
|
# aot_custom = aot_function(fn, fwd_compiler, bwd_compiler)
|
|
res = aot_custom(x_clone)
|
|
res.sum().backward()
|
|
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(x.grad, x_clone.grad)
|
|
|
|
# TODO - Dropout needs more work because of offset calculation
|
|
@skip("Dropout needs more work because of offset calculation")
|
|
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
|
|
@dtypes(torch.float32)
|
|
def test_checkpoint(self, dtype, device):
|
|
def g(x, y):
|
|
return torch.nn.functional.dropout(x, 0.6)
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(g, x, y, use_reentrant=False)
|
|
|
|
# x = torch.rand(2, 2, device="cuda", requires_grad=True)
|
|
x = torch.ones(2, 2, device="cuda", requires_grad=True)
|
|
y = torch.rand(2, 2, device="cuda", requires_grad=True)
|
|
torch.cuda.manual_seed(123)
|
|
ref = fn(x, y)
|
|
|
|
# With checkpointing we should recompute dropout in bwd, and should see philox_rand
|
|
fwd_compiler = functools.partial(count_philox_rand, freq=1)
|
|
bwd_compiler = functools.partial(count_philox_rand, freq=1)
|
|
aot_fn = aot_function(fn, fwd_compiler, bwd_compiler)
|
|
torch.cuda.manual_seed(123)
|
|
res = aot_fn(x, y)
|
|
# res.sum().backward()
|
|
# TODO - This is not same. Debug this further.
|
|
self.assertEqual(ref, res)
|
|
|
|
|
|
only_for = ("cuda",)
|
|
instantiate_device_type_tests(TestFunctionalizationRngOps, globals(), only_for=only_for)
|
|
|
|
|
|
class NegativeTest(TestCase):
|
|
@dtypes(torch.float32)
|
|
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
|
|
def test_on_cpu(self, dtype, device):
|
|
def fn(x):
|
|
a = torch.rand_like(x) * x
|
|
a = torch.rand_like(x) * a
|
|
return a
|
|
|
|
x = torch.rand(10, device=device, dtype=dtype)
|
|
|
|
aot_fn = aot_function(fn, nop)
|
|
with self.assertRaises(RuntimeError):
|
|
aot_fn(x)
|
|
|
|
|
|
only_for = ("cpu",)
|
|
instantiate_device_type_tests(NegativeTest, globals(), only_for=only_for)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|