mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
This PR:
- adds the ability to run functorch tests via run_test.py
- changes the functorch shards in PyTorch CI to invoke functorch tests
via run_test.py
The main motivation for this is so that functorch tests hook into the
standard PyTorch test infrastructure.
Questions for reviewers:
- the functorch tests are located outside of the pytorch/test folder
(they're in the pytorch/functorch/test folder). Is this OK? (run_test.py
works locally for me).
Test Plan:
- checked that `python run_test.py --functorch` ran functorch tests
locally
- Local mock test: added `{"test_compilation_for_dynamic_shape
(__main__.TestCompileCache)":
["https://github.com/pytorch/pytorch/issues/82016", ["linux"]]}` to .pytorch-disabled-tests.json, ran functorch tests, verified that the test was skipped.
- Wait for CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82012
Approved by: https://github.com/janeyx99
686 lines
22 KiB
Python
686 lines
22 KiB
Python
# Owner(s): ["module: functorch"]
|
|
|
|
import torch
|
|
|
|
import functorch
|
|
from torch.testing._internal.common_utils import run_tests, TestCase, IS_WINDOWS
|
|
import unittest
|
|
|
|
from functorch.compile import aot_function, nop
|
|
|
|
|
|
class TestCompileCache(TestCase):
|
|
def check(self, a, b, aot_fn, fn):
|
|
a_clone = a.clone().detach().requires_grad_(True)
|
|
b_clone = b.clone().detach().requires_grad_(True)
|
|
ref = fn(a, b)
|
|
ref.sum().backward()
|
|
|
|
res = aot_fn(a_clone, b_clone)
|
|
res.sum().backward()
|
|
assert torch.allclose(res, ref)
|
|
assert torch.allclose(a.grad, a_clone.grad)
|
|
assert torch.allclose(b.grad, b_clone.grad)
|
|
|
|
def test_recompilation_on_broadcast(self):
|
|
def fn(x, bias):
|
|
return x + bias
|
|
|
|
for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
|
|
functorch.compile.clear_compile_cache()
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
aot_autograd_fn = aot_function(fn, nop, nop, hasher_type=hasher_type)
|
|
|
|
a = torch.randn(10, 20, requires_grad=True)
|
|
b = torch.randn(20, requires_grad=True)
|
|
self.check(a, b, aot_autograd_fn, fn)
|
|
|
|
a = torch.randn(10, 20, requires_grad=True)
|
|
b = torch.randn(10, 20, requires_grad=True)
|
|
self.check(a, b, aot_autograd_fn, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 2
|
|
|
|
def test_compilation_for_dynamic_shape(self):
|
|
def fn(x, bias):
|
|
return x + bias
|
|
|
|
for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
|
|
functorch.compile.clear_compile_cache()
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
aot_autograd_fn = aot_function(fn, nop, nop, hasher_type=hasher_type)
|
|
|
|
for s in range(10, 20):
|
|
a = torch.randn(s, requires_grad=True)
|
|
b = torch.randn(s, requires_grad=True)
|
|
self.check(a, b, aot_autograd_fn, fn)
|
|
|
|
for s in range(10, 20):
|
|
a = torch.randn(s, requires_grad=True)
|
|
b = torch.randn(s, requires_grad=True)
|
|
self.check(a, b, aot_autograd_fn, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
if hasher_type == "DynamicShapeHasher":
|
|
assert total_recomps == 1
|
|
elif hasher_type == "StaticShapeHasher":
|
|
assert total_recomps == 10
|
|
|
|
for s in range(10, 20):
|
|
a = torch.randn(s, s, requires_grad=True)
|
|
b = torch.randn(s, s, requires_grad=True)
|
|
self.check(a, b, aot_autograd_fn, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
if hasher_type == "DynamicShapeHasher":
|
|
assert total_recomps == 2
|
|
elif hasher_type == "StaticShapeHasher":
|
|
assert total_recomps == 20
|
|
|
|
def test_global_cache_no_recompilations(self):
|
|
def f(x, bias):
|
|
return x + bias
|
|
|
|
def g(x, bias):
|
|
return aot_function(f, nop, nop, hasher_type="DynamicShapeHasher")(x, bias)
|
|
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
for _ in range(10):
|
|
a = torch.randn(10, 20, requires_grad=True)
|
|
b = torch.randn(10, 20, requires_grad=True)
|
|
self.check(a, b, g, f)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 1
|
|
|
|
def test_multiple_functions(self):
|
|
def f(x, bias):
|
|
return x + bias
|
|
|
|
def g(x, y):
|
|
return x * y
|
|
|
|
for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
|
|
functorch.compile.clear_compile_cache()
|
|
aot_autograd_f = aot_function(f, nop, nop, hasher_type=hasher_type)
|
|
aot_autograd_g = aot_function(g, nop, nop, hasher_type=hasher_type)
|
|
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
a = torch.randn(10, requires_grad=True)
|
|
b = torch.randn(10, requires_grad=True)
|
|
self.check(a, b, aot_autograd_f, f)
|
|
|
|
a = torch.randn(10, requires_grad=True)
|
|
b = torch.randn(10, requires_grad=True)
|
|
self.check(a, b, aot_autograd_g, g)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 2
|
|
|
|
# Force recompilation for function f and check num of recompilations again
|
|
a = torch.randn(10, 20, requires_grad=True)
|
|
b = torch.randn(10, 20, requires_grad=True)
|
|
self.check(a, b, aot_autograd_f, f)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 3
|
|
|
|
def test_high_number_of_args(self):
|
|
def f(*args):
|
|
res = args[0]
|
|
for arg in args:
|
|
res = res * arg
|
|
return res
|
|
|
|
def check(args, aot_autograd_fn, fn):
|
|
args_clone = [arg.clone().detach().requires_grad_(True) for arg in args]
|
|
ref = fn(*args)
|
|
ref.sum().backward()
|
|
|
|
res = aot_autograd_fn(*args_clone)
|
|
res.sum().backward()
|
|
assert torch.allclose(res, ref)
|
|
for (arg, arg_clone) in zip(args, args_clone):
|
|
assert torch.allclose(arg.grad, arg_clone.grad)
|
|
|
|
for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
|
|
functorch.compile.clear_compile_cache()
|
|
|
|
aot_autograd_f = aot_function(f, nop, nop, hasher_type=hasher_type)
|
|
|
|
args = [torch.randn(10, requires_grad=True) for _ in range(100)]
|
|
check(args, aot_autograd_f, f)
|
|
|
|
def test_multiple_compiler(self):
|
|
def fn(x, bias):
|
|
return x + bias
|
|
|
|
def nop_duplicate(fx_g, _):
|
|
return fx_g
|
|
|
|
for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
|
|
functorch.compile.clear_compile_cache()
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
nop_fn = aot_function(fn, nop, nop, hasher_type=hasher_type)
|
|
nop_duplicate_fn = aot_function(
|
|
fn, nop_duplicate, nop_duplicate, hasher_type=hasher_type
|
|
)
|
|
|
|
a = torch.randn(10, 20, requires_grad=True)
|
|
b = torch.randn(20, requires_grad=True)
|
|
nop_fn(a, b)
|
|
nop_duplicate_fn(a, b)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 2
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, 'test broken on windows')
|
|
class TestCompileCacheStaticArgs(TestCase):
|
|
def check(self, a, b, aot_autograd_fn, fn):
|
|
a_clone = a.clone().detach().requires_grad_(True)
|
|
ref = fn(a, b)
|
|
ref.sum().backward()
|
|
|
|
res = aot_autograd_fn(a_clone, b)
|
|
res.sum().backward()
|
|
assert torch.allclose(res, ref)
|
|
assert torch.allclose(a.grad, a_clone.grad)
|
|
|
|
def test_failure(self):
|
|
# Test that not setting up static_argnums should raise exception
|
|
def fn(x, p):
|
|
return x * p
|
|
|
|
aot_autograd_f = aot_function(fn, nop, nop)
|
|
|
|
a = torch.randn(2, 2, requires_grad=True)
|
|
b = 2
|
|
try:
|
|
# Since b is not marked as static, it should raise exception
|
|
aot_autograd_f(a, b)
|
|
raise AssertionError()
|
|
except RuntimeError:
|
|
pass
|
|
|
|
def test_simple(self):
|
|
def fn(x, static_arg):
|
|
return x * static_arg
|
|
|
|
functorch.compile.clear_compile_cache()
|
|
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=1)
|
|
|
|
a = torch.randn(2, 2, requires_grad=True)
|
|
b = 2
|
|
self.check(a, b, aot_autograd_f, fn)
|
|
|
|
# Same type of args, so no recompilation
|
|
a = torch.randn(2, 2, requires_grad=True)
|
|
b = 2
|
|
self.check(a, b, aot_autograd_f, fn)
|
|
|
|
# Trigger recompilation
|
|
a = torch.randn(2, 2, requires_grad=True)
|
|
b = 3
|
|
self.check(a, b, aot_autograd_f, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 2
|
|
|
|
def test_static_arg_before_tensor_arg(self):
|
|
def fn(static_arg, x):
|
|
return static_arg - x
|
|
|
|
def check(a, b, aot_autograd_fn, fn):
|
|
b_clone = b.clone().detach().requires_grad_(True)
|
|
ref = fn(a, b)
|
|
ref.sum().backward()
|
|
|
|
res = aot_autograd_fn(a, b_clone)
|
|
res.sum().backward()
|
|
assert torch.allclose(res, ref)
|
|
assert torch.allclose(b.grad, b_clone.grad)
|
|
|
|
functorch.compile.clear_compile_cache()
|
|
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=0)
|
|
|
|
a = 2
|
|
b = torch.randn(2, 2, requires_grad=True)
|
|
check(a, b, aot_autograd_f, fn)
|
|
|
|
a = 3
|
|
b = torch.randn(2, 2, requires_grad=True)
|
|
check(a, b, aot_autograd_f, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 2
|
|
|
|
def test_interleaved_static_args(self):
|
|
def fn(static_arg1, x, static_arg2):
|
|
return static_arg1 - x - static_arg2
|
|
|
|
def check(a, b, c, aot_autograd_fn, fn):
|
|
b_clone = b.clone().detach().requires_grad_(True)
|
|
ref = fn(a, b, c)
|
|
ref.sum().backward()
|
|
|
|
res = aot_autograd_fn(a, b_clone, c)
|
|
res.sum().backward()
|
|
assert torch.allclose(res, ref)
|
|
assert torch.allclose(b.grad, b_clone.grad)
|
|
|
|
functorch.compile.clear_compile_cache()
|
|
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0, 2))
|
|
|
|
a = 2
|
|
b = torch.randn(2, 2, requires_grad=True)
|
|
c = 0.1
|
|
check(a, b, c, aot_autograd_f, fn)
|
|
|
|
a = 3
|
|
b = torch.randn(2, 2, requires_grad=True)
|
|
c = 0.1
|
|
check(a, b, c, aot_autograd_f, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 2
|
|
|
|
def test_dropout(self):
|
|
def fn(x, prob):
|
|
return torch.nn.functional.dropout(x, p=prob)
|
|
|
|
functorch.compile.clear_compile_cache()
|
|
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1])
|
|
|
|
a = torch.randn(2, 2, requires_grad=True)
|
|
b = 0.3
|
|
aot_autograd_f(a, b)
|
|
|
|
# Setting the prob to 0. This should cause recompilation.
|
|
a = torch.randn(2, 2, requires_grad=True)
|
|
b = 0
|
|
self.check(a, b, aot_autograd_f, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 2
|
|
|
|
def test_if_condition(self):
|
|
def fn(x, state: bool):
|
|
if state:
|
|
return torch.sin(x)
|
|
else:
|
|
return torch.cos(x)
|
|
|
|
functorch.compile.clear_compile_cache()
|
|
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1])
|
|
|
|
a = torch.randn(2, 2, requires_grad=True)
|
|
b = True
|
|
self.check(a, b, aot_autograd_f, fn)
|
|
|
|
a = torch.randn(2, 2, requires_grad=True)
|
|
b = True
|
|
self.check(a, b, aot_autograd_f, fn)
|
|
|
|
a = torch.randn(2, 2, requires_grad=True)
|
|
b = False
|
|
self.check(a, b, aot_autograd_f, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 2
|
|
|
|
def test_custom(self):
|
|
class Record:
|
|
def __init__(self, name, multiplier):
|
|
self.name = name
|
|
self.multiplier = multiplier
|
|
|
|
def __eq__(self, other):
|
|
return self.name == other.name and self.multiplier == other.multiplier
|
|
|
|
def __hash__(self):
|
|
return hash((self.name, self.multiplier))
|
|
|
|
def fn(x, record):
|
|
return x * record.multiplier
|
|
|
|
functorch.compile.clear_compile_cache()
|
|
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1])
|
|
|
|
a = torch.randn(2, 2, requires_grad=True)
|
|
b = Record("Foo", 0.5)
|
|
self.check(a, b, aot_autograd_f, fn)
|
|
|
|
a = torch.randn(2, 2, requires_grad=True)
|
|
b = Record("Bar", 10.2)
|
|
self.check(a, b, aot_autograd_f, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 2
|
|
|
|
def test_tuple(self):
|
|
def fn(a_tuple, static_arg):
|
|
return torch.sin(a_tuple[0]) - a_tuple[1] - static_arg
|
|
|
|
def check(a_tuple, b, aot_autograd_fn, fn):
|
|
a0 = a_tuple[0]
|
|
a1 = a_tuple[1]
|
|
|
|
a0_clone = a0.clone().detach().requires_grad_(True)
|
|
a1_clone = a1.clone().detach().requires_grad_(True)
|
|
ref = fn(a, b)
|
|
ref.sum().backward()
|
|
|
|
res = aot_autograd_fn((a0_clone, a1_clone), b)
|
|
res.sum().backward()
|
|
assert torch.allclose(res, ref)
|
|
assert torch.allclose(a0.grad, a0_clone.grad)
|
|
assert torch.allclose(a1.grad, a1_clone.grad)
|
|
|
|
functorch.compile.clear_compile_cache()
|
|
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(1,))
|
|
|
|
a = (
|
|
torch.randn(2, 2, requires_grad=True),
|
|
torch.randn(2, 2, requires_grad=True),
|
|
)
|
|
b = 0.1
|
|
check(a, b, aot_autograd_f, fn)
|
|
|
|
a = (
|
|
torch.randn(2, 2, requires_grad=True),
|
|
torch.randn(2, 2, requires_grad=True),
|
|
)
|
|
b = 1
|
|
check(a, b, aot_autograd_f, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 2
|
|
|
|
def test_tuple_with_first_arg_as_static(self):
|
|
def fn(static_arg, a_tuple):
|
|
return torch.sin(a_tuple[0]) - a_tuple[1] - static_arg
|
|
|
|
def check(a, b_tuple, aot_autograd_fn, fn):
|
|
b0 = b_tuple[0]
|
|
b1 = b_tuple[1]
|
|
|
|
b0_clone = b0.clone().detach().requires_grad_(True)
|
|
b1_clone = b1.clone().detach().requires_grad_(True)
|
|
ref = fn(a, b_tuple)
|
|
ref.sum().backward()
|
|
|
|
res = aot_autograd_fn(a, (b0_clone, b1_clone))
|
|
res.sum().backward()
|
|
assert torch.allclose(res, ref)
|
|
assert torch.allclose(b0.grad, b0_clone.grad)
|
|
assert torch.allclose(b1.grad, b1_clone.grad)
|
|
|
|
functorch.compile.clear_compile_cache()
|
|
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0,))
|
|
|
|
a = 0.1
|
|
b = (
|
|
torch.randn(2, 2, requires_grad=True),
|
|
torch.randn(2, 2, requires_grad=True),
|
|
)
|
|
check(a, b, aot_autograd_f, fn)
|
|
|
|
a = 1
|
|
b = (
|
|
torch.randn(2, 2, requires_grad=True),
|
|
torch.randn(2, 2, requires_grad=True),
|
|
)
|
|
check(a, b, aot_autograd_f, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 2
|
|
|
|
def test_dict(self):
|
|
def fn(a_dict, static_arg):
|
|
return torch.sin(a_dict["foo"]) - a_dict["bar"] - static_arg
|
|
|
|
def check(a_dict, b, aot_autograd_fn, fn):
|
|
|
|
a0 = a_dict["foo"]
|
|
a1 = a_dict["bar"]
|
|
|
|
a0_clone = a0.clone().detach().requires_grad_(True)
|
|
a1_clone = a1.clone().detach().requires_grad_(True)
|
|
ref = fn(a_dict, b)
|
|
ref.sum().backward()
|
|
|
|
a_clone = {}
|
|
a_clone["foo"] = a0_clone
|
|
a_clone["bar"] = a1_clone
|
|
res = aot_autograd_fn(a_clone, b)
|
|
res.sum().backward()
|
|
assert torch.allclose(res, ref)
|
|
assert torch.allclose(a0.grad, a0_clone.grad)
|
|
assert torch.allclose(a1.grad, a1_clone.grad)
|
|
|
|
functorch.compile.clear_compile_cache()
|
|
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(1,))
|
|
|
|
a = {}
|
|
a["foo"] = torch.zeros(2, 2, requires_grad=True)
|
|
a["bar"] = torch.ones(2, 2, requires_grad=True)
|
|
b = 0
|
|
check(a, b, aot_autograd_f, fn)
|
|
|
|
a = {}
|
|
a["foo"] = torch.randn(2, 2, requires_grad=True)
|
|
a["bar"] = torch.randn(2, 2, requires_grad=True)
|
|
b = 0.2
|
|
check(a, b, aot_autograd_f, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 2
|
|
|
|
def test_dict_with_static_arg_before_dict(self):
|
|
def fn(static_arg, a_dict):
|
|
return torch.sin(a_dict["foo"]) - a_dict["bar"] - static_arg
|
|
|
|
def check(a, b_dict, aot_autograd_fn, fn):
|
|
|
|
ref = fn(a, b_dict)
|
|
res = aot_autograd_fn(a, b_dict)
|
|
assert torch.allclose(res, ref)
|
|
|
|
b0 = b_dict["foo"]
|
|
b1 = b_dict["bar"]
|
|
|
|
b0_clone = b0.clone().detach().requires_grad_(True)
|
|
b1_clone = b1.clone().detach().requires_grad_(True)
|
|
ref.sum().backward()
|
|
|
|
b_clone = {}
|
|
b_clone["foo"] = b0_clone
|
|
b_clone["bar"] = b1_clone
|
|
res = aot_autograd_fn(a, b_clone)
|
|
res.sum().backward()
|
|
assert torch.allclose(res, ref)
|
|
assert torch.allclose(b0.grad, b0_clone.grad)
|
|
assert torch.allclose(b1.grad, b1_clone.grad)
|
|
|
|
functorch.compile.clear_compile_cache()
|
|
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0,))
|
|
|
|
a = 0.1
|
|
b = {}
|
|
b["foo"] = torch.randn(2, 2, requires_grad=True)
|
|
b["bar"] = torch.randn(2, 2, requires_grad=True)
|
|
check(a, b, aot_autograd_f, fn)
|
|
|
|
a = 0.2
|
|
b = {}
|
|
b["foo"] = torch.randn(2, 2, requires_grad=True)
|
|
b["bar"] = torch.randn(2, 2, requires_grad=True)
|
|
check(a, b, aot_autograd_f, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 2
|
|
|
|
def test_tuple_static_args(self):
|
|
def fn(x, tuple_static_arg):
|
|
return x * tuple_static_arg[0] * tuple_static_arg[1]
|
|
|
|
functorch.compile.clear_compile_cache()
|
|
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=1)
|
|
|
|
a = torch.randn(2, 2, requires_grad=True)
|
|
b = (2, 3)
|
|
self.check(a, b, aot_autograd_f, fn)
|
|
|
|
# Same type of args, so no recompilation
|
|
a = torch.randn(2, 2, requires_grad=True)
|
|
b = (2, 3)
|
|
self.check(a, b, aot_autograd_f, fn)
|
|
|
|
# Trigger recompilation
|
|
a = torch.randn(2, 2, requires_grad=True)
|
|
b = (3, 4)
|
|
self.check(a, b, aot_autograd_f, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 2
|
|
|
|
def test_arg_none(self):
|
|
def check(a, b, c, aot_autograd_fn, fn):
|
|
def cloner(x):
|
|
if x is not None:
|
|
return x.clone().detach().requires_grad_(True)
|
|
return None
|
|
|
|
def check_grad(x, x_clone):
|
|
if x is not None:
|
|
return torch.allclose(x.grad, x_clone.grad)
|
|
return True
|
|
|
|
ref = fn(a, b, c)
|
|
res = aot_autograd_fn(a, b, c)
|
|
assert torch.allclose(res, ref)
|
|
|
|
a_clone = cloner(a)
|
|
b_clone = cloner(b)
|
|
c_clone = cloner(c)
|
|
ref.sum().backward()
|
|
res = aot_autograd_fn(a_clone, b_clone, c_clone)
|
|
res.sum().backward()
|
|
|
|
check_grad(a, a_clone)
|
|
check_grad(b, b_clone)
|
|
check_grad(c, c_clone)
|
|
|
|
def fn(a, b, c):
|
|
if a is None and b is None:
|
|
return c
|
|
elif a is None and c is None:
|
|
return b
|
|
elif b is None and c is None:
|
|
return a
|
|
elif a is None:
|
|
return b + c
|
|
elif b is None:
|
|
return a + c
|
|
elif c is None:
|
|
return a + b
|
|
return a + b + c
|
|
|
|
functorch.compile.clear_compile_cache()
|
|
|
|
start_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
aot_autograd_f = aot_function(fn, nop, nop)
|
|
|
|
t1 = torch.randn(2, 2, requires_grad=True)
|
|
check(t1, None, None, aot_autograd_f, fn)
|
|
check(None, t1, None, aot_autograd_f, fn)
|
|
check(None, None, t1, aot_autograd_f, fn)
|
|
|
|
t2 = torch.randn(2, 2, requires_grad=True)
|
|
check(t1, t2, None, aot_autograd_f, fn)
|
|
check(t1, None, t2, aot_autograd_f, fn)
|
|
check(None, t1, t2, aot_autograd_f, fn)
|
|
|
|
t3 = torch.randn(2, 2, requires_grad=True)
|
|
check(t1, t2, t3, aot_autograd_f, fn)
|
|
|
|
# Same type of args, so no recompilation
|
|
check(t1, t2, None, aot_autograd_f, fn)
|
|
|
|
end_num_recomps = functorch.compile.num_of_recompilations()
|
|
|
|
total_recomps = end_num_recomps - start_num_recomps
|
|
assert total_recomps == 7
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|