mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
_Redo of #86586 with all BC breaking changes granularly placed into separate commits._
---
Per title. Deprecation happened on Feb 25, 2022 in c6f1bbc0ac, which made it into the 1.12 release. Since it is now 245 days later and the next release will be 1.14, the removals later in the stack comply with the [BC policy](https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#minimizing-the-disruption-of-bc-breaking-changes).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87969
Approved by: https://github.com/mruberry
55 lines
1.4 KiB
Python
55 lines
1.4 KiB
Python
import timeit
|
|
from functorch.compile import compiled_module, tvm_compile
|
|
import torch.nn as nn
|
|
import torch
|
|
|
|
|
|
def nop(f, _):
|
|
return f
|
|
|
|
|
|
fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops')
|
|
bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops')
|
|
fw_compiler = nop
|
|
bw_compiler = nop
|
|
|
|
|
|
def run(mod, input):
|
|
out = mod(input)
|
|
out.sum().backward()
|
|
grads = [p.grad for p in mod.parameters()]
|
|
return (out, *grads)
|
|
|
|
|
|
class Foo(nn.Module):
|
|
def __init__(self):
|
|
super(Foo, self).__init__()
|
|
self.param = nn.Parameter(torch.randn(1))
|
|
self.register_buffer("buf", torch.randn(1))
|
|
|
|
def forward(self, x):
|
|
return (self.param * x + self.buf).sum(dim=0)
|
|
|
|
|
|
input = torch.randn(1)
|
|
mod = Foo()
|
|
compiled_mod = compiled_module(mod, fw_compiler, bw_compiler)
|
|
|
|
for a, b in zip(run(mod, input), run(compiled_mod, input)):
|
|
torch.testing.assert_close(a, b)
|
|
|
|
out = mod(input)
|
|
out.sum().backward()
|
|
mod.param.data -= mod.param.grad
|
|
compiled_mod.orig_module.param.data -= compiled_mod.orig_module.param.grad
|
|
compiled_mod.orig_module.param.grad = None
|
|
|
|
for a, b in zip(run(mod, input), run(compiled_mod, input)):
|
|
torch.testing.assert_close(a, b)
|
|
|
|
for _ in range(5):
|
|
i = 10000
|
|
t = timeit.Timer("mod(input)", globals=globals()).timeit(10000)
|
|
print(f"eager {t/i*1e6}")
|
|
t = timeit.Timer("compiled_mod(input)", globals=globals()).timeit(10000)
|
|
print(f"compiled {t/i*1e6}")
|