pytorch/test/dynamo/test_compile.py
clr 171fb7f358 easy: Fix missing tab in test/dynamo/test_compile.py (#145013)
It turns out that if you request a merge on a pytorch PR, and then push a fix for a bad rebase, and the test is
relativley new, the merge will go through with the previous commit and not notice the test break.

Explicitly running the test now passes vs failing, and this is just the last missing commit from https://github.com/pytorch/pytorch/pull/144817

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145013
Approved by: https://github.com/masnesral, https://github.com/jansel
2025-01-16 22:51:51 +00:00

222 lines
6.6 KiB
Python

# Owner(s): ["module: dynamo"]
import inspect
import io
import os
import tempfile
from unittest.mock import patch
import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter
class ToyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.linear(x))
class InPlaceCompilationTests(TestCase):
def test_compilation(self):
torch._dynamo.reset()
model = ToyModel()
cnt = CompileCounter()
model.compile(backend=cnt)
x = torch.randn(10, 10)
model(x)
self.assertEqual(cnt.frame_count, 1)
def test_overwrite_call_impl(self):
torch._dynamo.reset()
model = ToyModel()
self.assertTrue(model._compiled_call_impl is None)
model.compile()
self.assertTrue(model._compiled_call_impl is not None)
def test_save(self):
torch._dynamo.reset()
model = ToyModel()
model.compile()
model(torch.randn(1, 10))
with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(model, os.path.join(tmpdirname, "model.pt"))
# weights_only=False as this is a legacy use case that loads a module
loaded_model = torch.load(
os.path.join(tmpdirname, "model.pt"), weights_only=False
)
loaded_model(torch.randn(1, 10))
def test_state_dict_save(self):
torch._dynamo.reset()
model = ToyModel()
model.compile()
model(torch.randn(1, 10))
with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(model.state_dict(), os.path.join(tmpdirname, "model.pt"))
loaded_model = ToyModel()
loaded_model.load_state_dict(
# weights_only=False as this is a legacy use case that loads a module
torch.load(os.path.join(tmpdirname, "model.pt"), weights_only=False)
)
loaded_model(torch.randn(1, 10))
def test_jit_save(self):
torch._dynamo.reset()
model = ToyModel()
model.compile()
model(torch.randn(1, 10))
scripted_model = torch.jit.script(model)
with tempfile.TemporaryDirectory() as tmpdirname:
torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt"))
loaded_model(torch.randn(1, 10))
def test_compilation_callback(self):
torch._dynamo.reset()
@torch._dynamo.on_compile_start
def start_callback():
print("Compilation started.")
@torch._dynamo.on_compile_end
def end_callback():
print("Compilation ended.")
mod = ToyModel()
x = torch.randn(10, 10)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_mod = torch.compile(backend="eager", fullgraph=True)(mod)
opt_mod(x)
printed_output = mock_stdout.getvalue().strip()
self.assertEqual(printed_output, "Compilation started.\nCompilation ended.")
def test_compile_eager_options(self):
@torch.compile(backend="eager", options={"foo": 2})
def f(x):
return x + x
f(torch.randn(3))
@torch.compile(backend="aot_eager", options={"foo": 2})
def g(x):
return x + x
g(torch.randn(3))
def test_compilation_callback_with_graph_break(self):
torch._dynamo.reset()
counter = 0
@torch._dynamo.on_compile_start
def start_callback():
nonlocal counter
counter += 1
print(f"Counter = {counter}")
@torch._dynamo.on_compile_end
def end_callback():
nonlocal counter
counter += 1
print(f"Counter = {counter}")
@torch.compile(backend="eager")
def fn(x):
x = x + 1
torch._dynamo.graph_break()
return torch.sin(x)
x = torch.randn(10, 10)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
fn(x)
printed_output = mock_stdout.getvalue().strip()
self.assertEqual(
printed_output, "Counter = 1\nCounter = 2\nCounter = 3\nCounter = 4"
)
def test_compilation_constant_hasattr_fail(self):
@torch.compile(backend="eager")
def fn(x):
return x.max()
# We should fallback to normal mode, and throw a AttributeError, not a internal dynamo exception
with self.assertRaises(AttributeError):
fn(None)
def test_compilation_evnum_hasattr_fail(self):
from enum import Enum
class TestEnum(Enum):
VALID = 1
@torch.compile(backend="eager")
def fn(x):
return x.max()
# We should fallback to normal mode, and throw a AttributeError, not a internal dynamo exception
with self.assertRaises(AttributeError):
fn(TestEnum.VALID)
def test_compilation_name_error(self):
@torch.compile(backend="eager")
def fn(x):
x = x + 1
does_not_exist() # noqa: F821
return x
x = torch.randn(10, 10)
with self.assertRaises(NameError):
fn(x)
def test_compilation_tensor_invalid_method(self):
@torch.compile(backend="eager")
def fn(x):
y = torch.tensor(x)
return y.doesnotexist()
x = torch.randn(10, 10)
with self.assertRaises(AttributeError):
fn(x)
# The private variants of the below functions are extensively tested
# So as long as the signatures match we're good
class PublicTorchCompilerTests(TestCase):
def check_signature(self, public_fn_name, private_fn_name, private_namespace):
public_fn = getattr(torch.compiler, public_fn_name)
private_fn = getattr(private_namespace, private_fn_name)
public_sig = inspect.signature(public_fn)
private_sig = inspect.signature(private_fn)
self.assertEqual(
public_sig,
private_sig,
f"Signatures do not match for function {public_fn_name}() \n Public: {public_sig} \n Private: {private_sig}",
)
def test_dynamo_signatures(self):
function_names = [
"reset",
"allow_in_graph",
"list_backends",
"assume_constant_result",
"disable",
]
for fn_name in function_names:
self.check_signature(fn_name, fn_name, torch._dynamo)
if __name__ == "__main__":
run_tests()