diff --git a/functorch/examples/compilation/fuse_module.py b/functorch/examples/compilation/fuse_module.py index 627484b43ab..9e635e8fc33 100644 --- a/functorch/examples/compilation/fuse_module.py +++ b/functorch/examples/compilation/fuse_module.py @@ -1,5 +1,5 @@ import timeit -from functorch import compiled_module, tvm_compile +from functorch.compile import compiled_module, tvm_compile import torch.nn as nn import torch from functools import partial @@ -9,8 +9,8 @@ def nop(f, _): return f -fw_compiler = partial(tvm_compile, name='fw_keops') -bw_compiler = partial(tvm_compile, name='bw_keops') +fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops') +bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops') fw_compiler = nop bw_compiler = nop diff --git a/functorch/examples/compilation/linear_train.py b/functorch/examples/compilation/linear_train.py index ac9b9e60b97..2d5f9d7dd37 100644 --- a/functorch/examples/compilation/linear_train.py +++ b/functorch/examples/compilation/linear_train.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from functorch import nnc_jit, make_functional +from functorch import make_functional +from functorch.compile import nnc_jit import torch import torch.nn as nn import time diff --git a/functorch/examples/compilation/simple_function.py b/functorch/examples/compilation/simple_function.py index 791990b1ad9..14731c7c666 100644 --- a/functorch/examples/compilation/simple_function.py +++ b/functorch/examples/compilation/simple_function.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from functorch import grad, nnc_jit, make_fx, make_nnc +from functorch import grad, make_fx +from functorch.compile import nnc_jit import torch import time @@ -16,9 +17,7 @@ def f(x): inp = torch.randn(100) grad_pt = grad(f) grad_fx = make_fx(grad_pt)(inp) -grad_nnc = nnc_jit(grad_pt, skip_specialization=True) -loopnest = make_nnc(grad_pt)(inp) -print(loopnest) +grad_nnc = nnc_jit(grad_pt) def bench(name, f, iters=10000, warmup=3):