[functorch] update compile example imports (pytorch/functorch#834)

This commit is contained in:
Andre 2022-05-26 12:09:18 -07:00 committed by Jon Janzen
parent 36fa9d8295
commit 1feff6bb69
3 changed files with 8 additions and 8 deletions

View file

@ -1,5 +1,5 @@
import timeit
from functorch import compiled_module, tvm_compile
from functorch.compile import compiled_module, tvm_compile
import torch.nn as nn
import torch
from functools import partial
@ -9,8 +9,8 @@ def nop(f, _):
return f
fw_compiler = partial(tvm_compile, name='fw_keops')
bw_compiler = partial(tvm_compile, name='bw_keops')
fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops')
bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops')
fw_compiler = nop
bw_compiler = nop

View file

@ -4,7 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functorch import nnc_jit, make_functional
from functorch import make_functional
from functorch.compile import nnc_jit
import torch
import torch.nn as nn
import time

View file

@ -4,7 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functorch import grad, nnc_jit, make_fx, make_nnc
from functorch import grad, make_fx
from functorch.compile import nnc_jit
import torch
import time
@ -16,9 +17,7 @@ def f(x):
inp = torch.randn(100)
grad_pt = grad(f)
grad_fx = make_fx(grad_pt)(inp)
grad_nnc = nnc_jit(grad_pt, skip_specialization=True)
loopnest = make_nnc(grad_pt)(inp)
print(loopnest)
grad_nnc = nnc_jit(grad_pt)
def bench(name, f, iters=10000, warmup=3):