pytorch/test/dynamo/test_ctx_manager.py
Chen, Zejun 8e60d646b9 [dynamo][stream]support device-agnostic stream in dynamo and capture stream/event method in fx graph (#108312)
This PR implements 2 things:
1. support the device agnostic stream and runtime APIs captured by the dynamo.
2. support the stream methods(include the event) captured by the dynamo.

Here are details for 1st.
Previously the stream captured in dynamo was tightly bind to CUDA. Here we implement a global singleton container named `StreamMethodContainer` for different backends to register their associated stream methods to dynamo. When import the backend’s product, the stream operations can be registered directly by calling

```
device_stream_method = {'current_stream': method_1,
                         'create_stream_context': method_2,
                         'set_stream': method_3,
                         'set_stream_by_id': method_4}
torch._dynamo.stream.register_stream_method(device_name, device_stream_method)
```

Stream methods need to be passed in this API according to the precise semantics represented by the dict key in `device_stream_method`. After register, these methods can be used by dynamo to capture the stream operations in users’ script, for example, get the current stream or set the specific stream. Additionally, the wrapped stream variable and the stream context variable are changed to be the device-agnostic, the proxy functions of these variables are assigned by the associated methods in the container. All of this are illustrated in the below. Below is a illustration.

![image](https://github.com/pytorch/pytorch/assets/74231238/37ac7350-c539-4167-9886-c3744ecab65d)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108312
Approved by: https://github.com/jansel, https://github.com/jgong5
2023-10-22 13:22:58 +00:00

982 lines
34 KiB
Python

# Owner(s): ["module: dynamo"]
import unittest
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.onnx.operators
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same
from torch.nn import functional as F
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
class CutomizedCtxManager:
def __init__(self, mode):
self.prev = torch.is_grad_enabled()
self.mode = mode
def __enter__(self):
torch._C._set_grad_enabled(self.mode)
def __exit__(self, exc_type, exc_value, traceback):
torch._C._set_grad_enabled(self.prev)
class CtxManagerTests(torch._dynamo.test_case.TestCase):
def test_no_grad(self):
def fn1(a, b):
x = a + 1
# redundant no_grad should get ignored
with torch.no_grad():
x = x + b
x = x + 2
return x
def fn2(a, b):
x = a + 1
with torch.set_grad_enabled(False):
x = x + b
x = x + 2
return x
def fn3(a, b):
x = a + 1
with torch.enable_grad():
x = x + b
x = x + 2
return x
def fn4(a, b):
x = a + 1
with torch.set_grad_enabled(True):
if torch.is_grad_enabled():
x = x + b
x = x + 2
return x
with torch.no_grad():
torch._dynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5)
with torch.enable_grad():
torch._dynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5)
def test_grad_mode_guard(self):
def fn(a, b):
prev_grad = torch.is_grad_enabled()
torch.set_grad_enabled(False)
a = a + 1
a.tolist() # graph break
ret = a + b
torch.set_grad_enabled(prev_grad)
return ret
a = torch.randn([3, 4])
b = torch.randn([3, 4])
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
for _ in range(10):
opt_fn(a, b)
self.assertEqual(cnts.frame_count, 2)
def test_nested_grad_mode_graph_break(self):
def fn(x):
before = torch.is_grad_enabled()
with torch.set_grad_enabled(False):
torch._dynamo.graph_break()
with torch.set_grad_enabled(True):
x = torch.mul(x, 5)
torch._dynamo.graph_break()
x = torch.sqrt(x)
assert torch.is_grad_enabled()
assert not torch.is_grad_enabled()
assert torch.is_grad_enabled() == before
return x
a = torch.randn([3, 4])
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
for _ in range(10):
opt_fn(a)
self.assertEqual(cnts.frame_count, 2)
def test_torch_profiler(self):
# wrap torch.profiler.* as NullContextVariable and do nothing
def fn(x):
y = x**2
with torch.profiler.profile():
y = y + 2
with torch.profiler.record_function("my_function"):
z = y**3
z.tolist() # graph break
z = z + 1
return z
x = torch.randn((2, 2), requires_grad=True)
ref = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 2)
def test_autograd_profiler(self):
# wrap torch.autograd.profiler.* as NullContextVariable and do nothing
def fn(x):
y = x**2
with torch.autograd.profiler.profile():
y = y + 2
with torch.autograd.profiler.record_function("my_function"):
z = y**3
z.tolist() # graph break
z = z + 1
return z
x = torch.randn((2, 2), requires_grad=True)
ref = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 2)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_stream_context_manager1(self):
def fn(x):
s = torch.cuda.Stream()
x = torch.mul(x, 5)
x = torch.add(x, 2)
with torch.cuda.stream(s):
x = torch.relu(x)
x = torch.add(x, 1)
x = torch.cos(x)
return x
x = torch.randn((2, 2), device="cuda")
ref = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 9)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_stream_context_manager2(self):
def fn(x, s):
x = torch.mul(x, 5)
x = torch.add(x, 2)
with torch.cuda.stream(s):
x = torch.relu(x)
s1 = torch.cuda.current_stream()
with torch.cuda.stream(s1):
x = torch.relu(x)
s2 = torch.cuda.Stream()
with torch.cuda.stream(s2):
x = torch.relu(x)
x = torch.add(x, 1)
x = torch.cos(x)
return x
x = torch.randn((2, 2), device="cuda")
s = torch.cuda.Stream()
ref = fn(x, s)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
res = opt_fn(x, s)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 18)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_stream_method(self):
def fn(x):
x = torch.mul(x, 1)
x = torch.add(x, 2)
new_stream = torch.cuda.Stream()
with torch.cuda.stream(new_stream):
x = torch.sin(x)
x = torch.add(x, 3)
cur_stream = torch.cuda.current_stream()
cur_stream.wait_stream(new_stream)
x = torch.add(x, 4)
is_idle = cur_stream.query()
cur_stream.synchronize()
with torch.cuda.stream(new_stream):
x = torch.add(x, 5)
new_stream.synchronize()
is_equal = cur_stream == new_stream
x = torch.relu(x)
x = torch.cos(x)
return x
x = torch.randn((2, 2), device="cuda")
ref = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 20)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_event_method(self):
def fn(x):
x = torch.mul(x, 1)
x = torch.add(x, 2)
cur_stream = torch.cuda.current_stream()
new_stream = torch.cuda.Stream()
x = torch.add(x, 3)
event = cur_stream.record_event()
is_idle = event.query()
new_stream.wait_event(event)
with torch.cuda.stream(new_stream):
x = torch.add(x, 4)
new_event = torch.cuda.Event()
new_event.record(new_stream)
x = torch.add(x, 5)
new_event.wait(cur_stream)
# use new event to sync
new_event.synchronize()
x = torch.relu(x)
x = torch.cos(x)
return x
x = torch.randn((2, 2), device="cuda")
ref = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 19)
def test_autograd_profiler_enabled(self):
def fn(x):
if torch.autograd._profiler_enabled():
return x + 1
else:
return x - 1
x = torch.randn((2, 2), requires_grad=True)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
if torch.autograd._profiler_enabled():
torch.autograd._disable_profiler()
assert not torch.autograd._profiler_enabled()
ref = fn(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))
with torch.autograd.profiler.profile():
assert torch.autograd._profiler_enabled()
ref = fn(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_autocast(self):
if not torch.cuda.is_bf16_supported():
raise unittest.SkipTest("requires bf16")
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
e_float16 = torch.mm(a_float32, b_float32)
f_float16 = torch.mm(d_float32, e_float16)
return f_float16
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.type, "cuda")
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.bfloat16)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_amp_autocast(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
with torch.cuda.amp.autocast(dtype=torch.torch.float64):
c_float64 = torch.mm(a_float32, b_float32)
return c_float64
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.type, "cuda")
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.float64)
def test_is_autocast_cpu_enabled(self):
def fn(a_float32, b_float32):
with torch.cpu.amp.autocast(dtype=torch.bfloat16):
c_float16 = torch.mm(a_float32, b_float32)
if torch.is_autocast_cpu_enabled():
c_float16 = c_float16 + 1
return c_float16
a = torch.rand((8, 8))
b = torch.rand((8, 8))
ref = fn(a, b)
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
res = opt_fn(a, b)
self.assertTrue(same(ref, res))
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
"Can't run fused SDPA on this platform",
)
def test_autocast_sdpa(self):
class MyModule(torch.nn.Module):
def forward(self, query, key, value):
with torch.autocast("cpu"):
with torch.autocast("cuda", dtype=torch.float32):
out = F.scaled_dot_product_attention(
query, key, value, None, 0.0, True
)
return out
dtype = torch.float32
seq_len_q = 1
seq_len_k = 1
head_dim = 8
query = torch.ones(
1, 8, seq_len_q, head_dim, device="cuda", dtype=dtype, requires_grad=True
)
key = torch.ones(
1, 8, seq_len_k, head_dim, device="cuda", dtype=dtype, requires_grad=True
)
value = torch.ones(
1, 8, seq_len_k, head_dim, device="cuda", dtype=dtype, requires_grad=True
)
module = MyModule()
real = module(query, key, value)
real_device = real.device
real_dtype = real.dtype
opt_mod = torch._dynamo.optimize("inductor")(module)
compiled = opt_mod(query, key, value)
self.assertEqual(compiled.device, real_device)
self.assertEqual(compiled.dtype, real_dtype)
self.assertEqual(compiled.device.type, "cuda")
self.assertEqual(compiled.device.index, 0)
self.assertEqual(compiled.dtype, torch.float32)
def test_autocast_cpu(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cpu")
b_float32 = torch.rand((8, 8), device="cpu")
d_float32 = torch.rand((8, 8), device="cpu")
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
e_float16 = torch.mm(a_float32, b_float32)
f_float16 = torch.mm(d_float32, e_float16)
return f_float16
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.type, "cpu")
self.assertEqual(exported.dtype, torch.bfloat16)
def test_autocast_cpu_graph_break(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cpu")
b_float32 = torch.rand((8, 8), device="cpu")
torch._dynamo.graph_break()
d_float32 = torch.rand((8, 8), device="cpu")
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
e_float16 = torch.mm(a_float32, b_float32)
torch._dynamo.graph_break()
f_float16 = torch.mm(d_float32, e_float16)
return f_float16
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
opt = torch._dynamo.optimize("eager")(module)
res = opt(torch.tensor([0.5]))
self.assertEqual(res.device, real_device)
self.assertEqual(res.dtype, real_dtype)
self.assertEqual(res.device.type, "cpu")
self.assertEqual(res.dtype, torch.bfloat16)
def test_autocast_cpu_graph_break_2(self):
# Regression for: https://github.com/pytorch/pytorch/issues/93890
def fn(x):
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
x = torch.mm(x, x)
torch._dynamo.graph_break()
x = torch.relu(x)
return x
x = torch.rand([4, 4])
self.assertEqual(x.dtype, torch.float32)
res = fn(x)
opt_fn = torch._dynamo.optimize("eager")(fn)
opt_res = opt_fn(x)
self.assertTrue(torch.allclose(res, opt_res))
self.assertEqual(res.dtype, torch.bfloat16)
self.assertEqual(opt_res.dtype, torch.bfloat16)
def test_autocast_cpu_graph_break_inner_fn(self):
class MyModule(torch.nn.Module):
@staticmethod
def mm_breaks(x, y):
torch._dynamo.graph_break()
return torch.mm(x, y)
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cpu")
b_float32 = torch.rand((8, 8), device="cpu")
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
torch._dynamo.graph_break()
with torch.autocast(
device_type="cpu", dtype=torch.bfloat16, enabled=False
):
torch._dynamo.graph_break()
g_float32 = torch.mm(a_float32, b_float32)
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
# Check that nested with non-inlineable function with graph break
torch._dynamo.graph_break()
f_float16_1 = self.mm_breaks(a_float32, b_float32)
# We remember to exit the inner autocast correctly to outer
# even after graph breaks
f_float16 = self.mm_breaks(a_float32, b_float32)
assert f_float16.dtype == f_float16_1.dtype
return f_float16, g_float32
module = MyModule()
real_16, real_32 = module(torch.tensor([0.5]))
real_device_16 = real_16.device
real_dtype_16 = real_16.dtype
real_device_32 = real_32.device
real_dtype_32 = real_32.dtype
graph = torch._dynamo.optimize("eager")(module)
out_16, out_32 = graph(torch.tensor([0.5]))
self.assertEqual(out_16.device, real_device_16)
self.assertEqual(out_16.dtype, real_dtype_16)
self.assertEqual(out_32.device, real_device_32)
self.assertEqual(out_32.dtype, real_dtype_32)
self.assertEqual(out_16.device.type, "cpu")
self.assertEqual(out_16.dtype, torch.bfloat16)
self.assertEqual(out_32.device.type, "cpu")
self.assertEqual(out_32.dtype, torch.float32)
def test_autocast_graph_break_method(self):
class MyModule(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.bias = bias
def mm_not_break(self, x, y):
return torch.mm(x, y) + self.bias
def mm_breaks(self, x, y):
torch._dynamo.graph_break()
return torch.mm(x, y) + self.bias
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cpu")
b_float32 = torch.rand((8, 8), device="cpu")
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
with torch.autocast(
device_type="cpu", dtype=torch.bfloat16, enabled=False
):
g_float32 = torch.mm(a_float32, b_float32)
f_float16 = self.mm_breaks(a_float32, b_float32)
assert (
f_float16[0][0] == self.mm_not_break(a_float32, b_float32)[0][0]
)
return f_float16, g_float32
module = MyModule(bias=torch.rand((8, 8), device="cpu", dtype=torch.bfloat16))
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
# Autocast doesn't work on addition, so we need the bias to be `bfloat16`
res = torch.rand((8, 8), device="cpu", dtype=torch.float32) + torch.rand(
(8, 8), device="cpu", dtype=torch.bfloat16
)
self.assertEqual(res.dtype, torch.float32)
real_16, real_32 = module(torch.tensor([0.5]))
real_device_16 = real_16.device
real_dtype_16 = real_16.dtype
real_device_32 = real_32.device
real_dtype_32 = real_32.dtype
graph = torch._dynamo.optimize("eager")(module)
out_16, out_32 = graph(torch.tensor([0.5]))
self.assertEqual(out_16.device, real_device_16)
self.assertEqual(out_16.dtype, real_dtype_16)
self.assertEqual(out_32.device, real_device_32)
self.assertEqual(out_32.dtype, real_dtype_32)
self.assertEqual(out_16.device.type, "cpu")
self.assertEqual(out_16.dtype, torch.bfloat16)
self.assertEqual(out_32.device.type, "cpu")
self.assertEqual(out_32.dtype, torch.float32)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_autocast_float64(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with torch.autocast(device_type="cuda", dtype=torch.float64):
e_float64 = torch.mm(a_float32, b_float32)
f_float64 = torch.mm(d_float32, e_float64)
return f_float64
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.float64)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_autocast_device(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with torch.autocast("cuda"):
e_float64 = torch.mm(a_float32, b_float32)
f_float64 = torch.mm(d_float32, e_float64)
return f_float64
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.torch.float16)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_autocast_arguments_binding(self):
def f1(x):
with torch.cuda.amp.autocast(False):
x = torch.sin(x + 1)
return x
def f2(x):
with torch.cpu.amp.autocast(False):
x = torch.cos(x + 1)
return x
x = torch.rand([2, 3])
ref1 = f1(x)
ref2 = f2(x)
opt_f1 = torch.compile(backend="eager")(f1)
opt_f2 = torch.compile(backend="eager")(f2)
res1 = opt_f1(x)
res2 = opt_f2(x)
self.assertTrue(same(ref1, res1))
self.assertTrue(same(ref2, res2))
def test_generic_context_manager(self):
def fn(x):
with CutomizedCtxManager(True):
x = x + 1
if torch.is_grad_enabled():
x = x * 2
x = torch.relu(x)
return x - 1
with torch.no_grad():
torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=6)
with torch.enable_grad():
torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=6)
def test_nested_generic_context_manager(self):
def fn(x):
with CutomizedCtxManager(True):
x = x + 1
if torch.is_grad_enabled():
x = x * 2
with CutomizedCtxManager(False):
if torch.is_grad_enabled():
x = x - 3
x = x * 1.5
x = torch.relu(x)
return x - 1
with torch.no_grad():
torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=9)
with torch.enable_grad():
torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=9)
def test_generic_context_manager_with_graph_break(self):
def fn(x):
with CutomizedCtxManager(True):
x = x + 1
if torch.is_grad_enabled():
x = x * 2
torch._dynamo.graph_break()
x = torch.relu(x)
return x - 1
x = torch.rand(2, 3)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn)
with torch.no_grad():
ref = fn(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 2)
with torch.enable_grad():
ref = fn(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 4)
self.assertEqual(cnts.op_count, 4)
def test_nested_generic_context_manager_with_graph_break(self):
def fn(x):
with CutomizedCtxManager(True):
x = x + 1
if torch.is_grad_enabled():
x = x * 2
with CutomizedCtxManager(False):
if torch.is_grad_enabled():
x = x - 3
torch._dynamo.graph_break()
x = x * 1.5
x = torch.relu(x)
return x - 1
x = torch.rand(2, 3)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn)
with torch.no_grad():
ref = fn(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 4)
self.assertEqual(cnts.op_count, 4)
torch._dynamo.reset()
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
with torch.enable_grad():
ref = fn(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 4)
self.assertEqual(cnts.op_count, 4)
def test_graph_break_inlining_grad(self):
def gn(z):
with torch.no_grad():
torch._dynamo.graph_break()
return torch.sin(z)
def fn(x, y, z):
a = torch.mm(x, y)
z = gn(z)
return a
torch._dynamo.reset()
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
z = torch.randn(4)
opt_fn(x, y, z).sum().backward()
self.assertEqual(cnts.frame_count, 2)
def _graph_break_inlining_autocast_test_helper(self, device):
def gn(x, y):
with torch.autocast(device_type=device, dtype=torch.bfloat16):
z = torch.mm(x, y)
torch._dynamo.graph_break()
return torch.sin(z)
def fn(x, y):
z = torch.mm(x, y)
z = z + gn(x, y)
return z
x = torch.rand(3, 3).to(device)
y = torch.rand(3, 3).to(device)
opt_fn = torch.compile(backend="eager")(fn)
ref = fn(x, y)
res = opt_fn(x, y)
self.assertEqual(ref, res)
def test_graph_break_inlining_autocast(self):
for device in ["cuda", "cpu"]:
if device == "cuda" and not (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()
):
continue
self._graph_break_inlining_autocast_test_helper(device)
def test_disable_saved_tensors_hooks(self):
def fn(z):
@torch.autograd.graph.disable_saved_tensors_hooks("This is not supported")
def f(x, y):
return x + y
x, y = torch.ones(
1,
), torch.zeros(
1,
)
return f(x, y)
eager = EagerAndRecordGraphs()
torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(()))
graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))
expected = """\
class GraphModule(torch.nn.Module):
def forward(self):
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
x = torch.ones(1)
y = torch.zeros(1)
add = x + y; x = y = None
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (add,)
"""
self.assertExpectedInline(actual, expected)
def test_disable_saved_tensors_hooks_prev_disabled(self):
def fn(z):
@torch.autograd.graph.disable_saved_tensors_hooks("This is not supported")
def f(x, y):
return x + y
x, y = torch.ones(
1,
), torch.zeros(
1,
)
return f(x, y)
eager = EagerAndRecordGraphs()
with torch.autograd.graph.disable_saved_tensors_hooks(
"Previously disabled message"
):
torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(()))
graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))
expected = """\
class GraphModule(torch.nn.Module):
def forward(self):
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
x = torch.ones(1)
y = torch.zeros(1)
add = x + y; x = y = None
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message')
return (add,)
"""
self.assertExpectedInline(actual, expected)
def test_disable_saved_tensors_hooks_prev_disabled_nested(self):
def fn(z):
@torch.autograd.graph.disable_saved_tensors_hooks("This is not supported")
def f(x, y):
@torch.autograd.graph.disable_saved_tensors_hooks(
"This is not supported inner"
)
def inner_fn(x, y):
return x + y
return inner_fn(x, y) + x
x, y = torch.ones(
1,
), torch.zeros(
1,
)
return f(x, y)
eager = EagerAndRecordGraphs()
with torch.autograd.graph.disable_saved_tensors_hooks(
"Previously disabled message"
):
torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(()))
graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))
expected = """\
class GraphModule(torch.nn.Module):
def forward(self):
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
x = torch.ones(1)
y = torch.zeros(1)
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('This is not supported inner')
add = x + y; y = None
_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
add_1 = add + x; add = x = None
_saved_tensors_hooks_disable_3 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message')
return (add_1,)
"""
self.assertExpectedInline(actual, expected)
def test_disable_saved_tensors_hooks_graph_break(self):
def fn(x):
with torch.autograd.graph.disable_saved_tensors_hooks(
"This is not supported"
):
y = x + 1
torch._dynamo.graph_break()
return y * 2
eager = EagerAndRecordGraphs()
torch.compile(fn, backend=eager, fullgraph=False)(torch.randn(()))
def check_graph(actual, expected):
self.assertExpectedInline(actual, expected)
expected = """\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
y = l_x_ + 1; l_x_ = None
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (y,)
"""
graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))
check_graph(actual, expected)
expected = """\
class GraphModule(torch.nn.Module):
def forward(self, L_y_ : torch.Tensor):
l_y_ = L_y_
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
mul = l_y_ * 2; l_y_ = None
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (mul,)
"""
graph = eager.graphs[1]
actual = normalize_gm(graph.print_readable(False))
check_graph(actual, expected)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()