mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[reland2][inductor] Add an AOT compilation mode for Inductor CPP backend (#96520)
Summary: This is a reland of https://github.com/pytorch/pytorch/pull/94822. Solved the long compilation issue for inductor cpp tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96520 Approved by: https://github.com/huydhn, https://github.com/malfet
This commit is contained in:
parent
178d2a38e0
commit
f03db8d6cb
13 changed files with 317 additions and 56 deletions
|
|
@ -484,7 +484,7 @@ Inside this directory, each run will have a separate folder named with the times
|
|||
$ ls
|
||||
run_2023_03_01_08_20_52_143510-pid_180167
|
||||
|
||||
In the run folder there will be a torchdynamo directory which contains debug logs, and an aot_torchinductor
|
||||
In the run folder there will be a torchdynamo directory which contains debug logs, and an torchinductor
|
||||
folder which contains a subfolder for each compiled kernel with inductor debug artifacts.
|
||||
|
||||
::
|
||||
|
|
@ -492,13 +492,13 @@ folder which contains a subfolder for each compiled kernel with inductor debug a
|
|||
$ cd
|
||||
run_2023_03_01_08_20_52_143510-pid_180167
|
||||
$ ls
|
||||
aot_torchinductor torchdynamo
|
||||
torchinductor torchdynamo
|
||||
|
||||
Moving further into the aot_torchinductor directory, the \*.log files are logs from the aot autograd phase of compilation, model__0_forward_1.0 contains the inductor debug artifacts.
|
||||
Moving further into the torchinductor directory, the \*.log files are logs from the aot autograd phase of compilation, model__0_forward_1.0 contains the inductor debug artifacts.
|
||||
|
||||
::
|
||||
|
||||
$ cd aot_torchinductor
|
||||
$ cd torchinductor
|
||||
$ ls
|
||||
aot_model___0_debug.log model__0_forward_1.0
|
||||
$ cd model__0_forward_1.0
|
||||
|
|
|
|||
23
test/inductor/aot/cpp/CMakeLists.txt
Normal file
23
test/inductor/aot/cpp/CMakeLists.txt
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
|
||||
project(test)
|
||||
|
||||
set(Torch_DIR "../../../../torch/share/cmake/Torch")
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
add_executable(test test.cpp ${CMAKE_BINARY_DIR}/aot_inductor_output.h)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_BINARY_DIR}/aot_inductor_output.h
|
||||
COMMAND python ${CMAKE_SOURCE_DIR}/test.py
|
||||
DEPENDS ${CMAKE_SOURCE_DIR}/test.py
|
||||
)
|
||||
add_custom_target(generate_header ALL
|
||||
DEPENDS ${CMAKE_BINARY_DIR}/aot_inductor_output.h)
|
||||
|
||||
add_library(aot_inductor_output SHARED IMPORTED)
|
||||
set_property(TARGET aot_inductor_output PROPERTY
|
||||
IMPORTED_LOCATION ${CMAKE_BINARY_DIR}/aot_inductor_output.so)
|
||||
|
||||
target_link_libraries(test "${TORCH_LIBRARIES}" aot_inductor_output)
|
||||
|
||||
set_property(TARGET test PROPERTY CXX_STANDARD 17)
|
||||
41
test/inductor/aot/cpp/test.cpp
Normal file
41
test/inductor/aot/cpp/test.cpp
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
//#include <gtest/gtest.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "build/aot_inductor_output.h"
|
||||
|
||||
/*
|
||||
class Net(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = torch.ones(32, 64)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.relu(x + self.weight)
|
||||
return x
|
||||
*/
|
||||
struct Net : torch::nn::Module {
|
||||
Net() {
|
||||
weight = register_parameter("weight", torch::ones({32, 64}));
|
||||
}
|
||||
torch::Tensor forward(torch::Tensor input) {
|
||||
return torch::relu(input + weight);
|
||||
}
|
||||
torch::Tensor weight;
|
||||
};
|
||||
|
||||
int main() {
|
||||
torch::Tensor x = at::randn({32, 64});
|
||||
Net net;
|
||||
torch::Tensor results_ref = net.forward(x);
|
||||
|
||||
// TODO: we need to provide an API to concatenate args and weights
|
||||
std::vector<torch::Tensor> inputs = {x};
|
||||
for (const auto& pair : net.named_parameters()) {
|
||||
inputs.push_back(pair.value());
|
||||
}
|
||||
torch::Tensor results_opt = aot_inductor_entry(inputs);
|
||||
|
||||
assert(torch::allclose(results_ref, results_opt));
|
||||
printf("PASS\n");
|
||||
return 0;
|
||||
}
|
||||
22
test/inductor/aot/cpp/test.py
Normal file
22
test/inductor/aot/cpp/test.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import torch
|
||||
import torch._dynamo
|
||||
import torch._inductor
|
||||
import torch._inductor.config
|
||||
|
||||
torch._inductor.config.aot_codegen_output_prefix = "aot_inductor_output"
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = torch.ones(32, 64)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.relu(x + self.weight)
|
||||
return x
|
||||
|
||||
|
||||
inp = torch.randn((32, 64), device="cpu")
|
||||
module, _ = torch._dynamo.export(Net(), inp)
|
||||
so_path = torch._inductor.aot_compile(module, [inp])
|
||||
print(so_path)
|
||||
8
test/inductor/aot/cpp/test.sh
Executable file
8
test/inductor/aot/cpp/test.sh
Executable file
|
|
@ -0,0 +1,8 @@
|
|||
#!/bin/bash
|
||||
set -euxo pipefail
|
||||
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake ..
|
||||
make
|
||||
./test
|
||||
|
|
@ -27,6 +27,27 @@ def compile(
|
|||
return compile_fx(gm, example_inputs, config_patches=options)
|
||||
|
||||
|
||||
def aot_compile(
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor],
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
|
||||
|
||||
Args:
|
||||
gm: The FX graph to compile.
|
||||
example_inputs: List of tensor inputs.
|
||||
options: Optional dict of config options. See `torch._inductor.config`.
|
||||
|
||||
Returns:
|
||||
Path to the generated shared library
|
||||
"""
|
||||
from .compile_fx import compile_fx
|
||||
|
||||
return compile_fx(gm, example_inputs, config_patches=options, aot_mode=True)()
|
||||
|
||||
|
||||
def list_mode_options(mode: str = None) -> Dict[str, Any]:
|
||||
r"""Returns a dictionary describing the optimizations that each of the available
|
||||
modes passed to `torch.compile()` performs.
|
||||
|
|
|
|||
|
|
@ -534,6 +534,52 @@ def cpp_compile_command(
|
|||
).strip()
|
||||
|
||||
|
||||
class AotCodeCache:
|
||||
cache = dict()
|
||||
clear = staticmethod(cache.clear)
|
||||
|
||||
@classmethod
|
||||
def compile(cls, source_code):
|
||||
from .codegen.wrapper import CppWrapperCodeGen
|
||||
|
||||
# TODO: update cpp_compile_command for different platforms
|
||||
picked_vec_isa = pick_vec_isa()
|
||||
key, input_path = write(
|
||||
source_code,
|
||||
"cpp",
|
||||
code_hash(repr(cpp_compile_command("i", "o", vec_isa=picked_vec_isa))),
|
||||
)
|
||||
if key not in cls.cache:
|
||||
from filelock import FileLock
|
||||
|
||||
lock_dir = get_lock_dir()
|
||||
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
||||
with lock:
|
||||
output_so = (
|
||||
os.path.join(os.getcwd(), f"{config.aot_codegen_output_prefix}.so")
|
||||
if config.aot_codegen_output_prefix
|
||||
else f"{input_path[:-3]}.so"
|
||||
)
|
||||
|
||||
output_header = f"{output_so[:-3]}.h"
|
||||
with open(output_header, "w") as header_file:
|
||||
header_file.writelines("#include <torch/torch.h>\n\n")
|
||||
header_file.writelines(f"{CppWrapperCodeGen.decl_str};\n")
|
||||
|
||||
log.info(f"AOT-Inductor compiles code into: {output_so}")
|
||||
if not os.path.exists(output_so):
|
||||
cmd = cpp_compile_command(
|
||||
input=input_path, output=output_so, vec_isa=picked_vec_isa
|
||||
).split(" ")
|
||||
try:
|
||||
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise exc.CppCompileError(cmd, e.output) from e
|
||||
|
||||
cls.cache[key] = output_so
|
||||
return cls.cache[key]
|
||||
|
||||
|
||||
class CppCodeCache:
|
||||
cache = dict()
|
||||
clear = staticmethod(cache.clear)
|
||||
|
|
|
|||
|
|
@ -2266,7 +2266,12 @@ class KernelGroup:
|
|||
)
|
||||
if enable_kernel_profile:
|
||||
code.writelines(["#include <ATen/record_function.h>"])
|
||||
code.writelines([cpp_prefix(), "" f'extern "C" void kernel({arg_defs})'])
|
||||
kernel_decl_name = kernel_name if V.graph.aot_mode else "kernel"
|
||||
|
||||
if not V.graph.aot_mode or self.count == 1:
|
||||
code.writeline(cpp_prefix())
|
||||
|
||||
code.writeline(f'extern "C" void {kernel_decl_name}({arg_defs})')
|
||||
with code.indent():
|
||||
if enable_kernel_profile:
|
||||
graph_id = V.graph.graph_id
|
||||
|
|
@ -2281,9 +2286,12 @@ class KernelGroup:
|
|||
code.splice(self.loops_code)
|
||||
|
||||
codecache_def = IndentedBuffer()
|
||||
codecache_def.writeline("async_compile.cpp('''")
|
||||
codecache_def.splice(code)
|
||||
codecache_def.writeline("''')")
|
||||
if V.graph.aot_mode:
|
||||
codecache_def.splice(code)
|
||||
else:
|
||||
codecache_def.writeline("async_compile.cpp('''")
|
||||
codecache_def.splice(code)
|
||||
codecache_def.writeline("''')")
|
||||
|
||||
codecache_str = codecache_def.getvalue()
|
||||
# TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
|
||||
|
|
|
|||
|
|
@ -275,6 +275,33 @@ class WrapperCodeGen(CodeGen):
|
|||
self.wrapper_call = IndentedBuffer()
|
||||
self.kernels = {}
|
||||
self.lines = []
|
||||
|
||||
self.set_header()
|
||||
self.write_prefix()
|
||||
|
||||
for name, value in V.graph.constants.items():
|
||||
# include a hash so our code cache gives different constants different files
|
||||
hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest()
|
||||
self.header.writeline(f"{name} = None # {hashed}")
|
||||
|
||||
self.allocated = set()
|
||||
self.freed = set()
|
||||
|
||||
# maps from reusing buffer to reused buffer
|
||||
self.reuses = dict()
|
||||
|
||||
self.write_get_cuda_stream = functools.lru_cache(None)(
|
||||
self.write_get_cuda_stream
|
||||
)
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def add_import_once(line):
|
||||
self.header.writeline(line)
|
||||
|
||||
self.add_import_once = add_import_once
|
||||
self._metas = {}
|
||||
|
||||
def set_header(self):
|
||||
self.header.splice(
|
||||
f"""
|
||||
from ctypes import c_void_p, c_long
|
||||
|
|
@ -302,30 +329,6 @@ class WrapperCodeGen(CodeGen):
|
|||
"""
|
||||
)
|
||||
|
||||
self.write_prefix()
|
||||
|
||||
for name, value in V.graph.constants.items():
|
||||
# include a hash so our code cache gives different constants different files
|
||||
hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest()
|
||||
self.header.writeline(f"{name} = None # {hashed}")
|
||||
|
||||
self.allocated = set()
|
||||
self.freed = set()
|
||||
|
||||
# maps from reusing buffer to reused buffer
|
||||
self.reuses = dict()
|
||||
|
||||
self.write_get_cuda_stream = functools.lru_cache(None)(
|
||||
self.write_get_cuda_stream
|
||||
)
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def add_import_once(line):
|
||||
self.header.writeline(line)
|
||||
|
||||
self.add_import_once = add_import_once
|
||||
self._metas = {}
|
||||
|
||||
def add_meta_once(self, meta):
|
||||
meta = repr(meta)
|
||||
if meta not in self._metas:
|
||||
|
|
@ -661,6 +664,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
"""
|
||||
|
||||
call_func_id = count()
|
||||
decl_str = None
|
||||
|
||||
def __init__(self):
|
||||
self._call_func_id = next(CppWrapperCodeGen.call_func_id)
|
||||
|
|
@ -680,7 +684,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
for x in V.graph.graph_outputs
|
||||
]
|
||||
|
||||
def write_prefix(self):
|
||||
def write_prefix_header(self):
|
||||
self.prefix.splice(
|
||||
"""
|
||||
async_compile.wait(globals())
|
||||
|
|
@ -702,21 +706,30 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
|
||||
"""
|
||||
)
|
||||
with self.wrapper_call.indent():
|
||||
inputs_len = len(V.graph.graph_inputs.keys())
|
||||
output_refs = self.get_output_refs()
|
||||
if output_refs:
|
||||
if len(output_refs) == 1:
|
||||
output_types = "at::Tensor"
|
||||
else:
|
||||
output_types = "std::vector<at::Tensor>"
|
||||
else:
|
||||
output_types = "void"
|
||||
|
||||
inputs_types = "std::vector<at::Tensor>"
|
||||
self.wrapper_call.writeline(
|
||||
f"{output_types} call_{self._call_func_id}({inputs_types} args) {{"
|
||||
)
|
||||
def call_func_name(self):
|
||||
return f"call_{self._call_func_id}"
|
||||
|
||||
def write_prefix(self):
|
||||
self.write_prefix_header()
|
||||
|
||||
inputs_len = len(V.graph.graph_inputs.keys())
|
||||
output_refs = self.get_output_refs()
|
||||
if output_refs:
|
||||
if len(output_refs) == 1:
|
||||
output_types = "at::Tensor"
|
||||
else:
|
||||
output_types = "std::vector<at::Tensor>"
|
||||
else:
|
||||
output_types = "void"
|
||||
|
||||
inputs_types = "std::vector<at::Tensor>"
|
||||
|
||||
CppWrapperCodeGen.decl_str = (
|
||||
f"{output_types} {self.call_func_name()}({inputs_types} args)"
|
||||
)
|
||||
self.prefix.splice(f"{CppWrapperCodeGen.decl_str} {{")
|
||||
with self.wrapper_call.indent():
|
||||
if inputs_len != 0:
|
||||
inputs_keys_str = ", ".join(V.graph.graph_inputs.keys())
|
||||
self.wrapper_call.writeline(f"at::Tensor {inputs_keys_str};")
|
||||
|
|
@ -778,18 +791,24 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
def wrap_kernel_call(self, name, call_args):
|
||||
return "{}({});".format(name, ", ".join(call_args))
|
||||
|
||||
def return_end_str(self):
|
||||
return "\n}\n'''\n)"
|
||||
|
||||
def generate_return(self, output_refs):
|
||||
if output_refs:
|
||||
if len(output_refs) == 1:
|
||||
self.wrapper_call.writeline("return " + output_refs[0] + "; }''' )")
|
||||
self.wrapper_call.writeline(
|
||||
f"return {output_refs[0]};{self.return_end_str()}"
|
||||
)
|
||||
else:
|
||||
self.wrapper_call.writeline(
|
||||
"return std::vector<at::Tensor>({"
|
||||
+ ", ".join(output_refs)
|
||||
+ "}); }''' )"
|
||||
+ "});"
|
||||
+ self.return_end_str()
|
||||
)
|
||||
else:
|
||||
self.wrapper_call.writeline("return; }''' )")
|
||||
self.wrapper_call.writeline(f"return;{self.return_end_str()}")
|
||||
|
||||
def generate_end(self, result):
|
||||
shared = codecache.get_shared()
|
||||
|
|
@ -839,3 +858,36 @@ class CppWrapperCodeGen(WrapperCodeGen):
|
|||
else:
|
||||
args.insert(0, f"{codegen_reference}")
|
||||
self.writeline(f"{cpp_kernel}({', '.join(args)});")
|
||||
|
||||
|
||||
class CppAotWrapperCodeGen(CppWrapperCodeGen):
|
||||
"""
|
||||
The AOT-version outer wrapper that calls the kernels in C++
|
||||
"""
|
||||
|
||||
def set_header(self):
|
||||
return
|
||||
|
||||
def write_prefix_header(self):
|
||||
self.prefix.splice("\n#include <ATen/ATen.h>")
|
||||
|
||||
def call_func_name(self):
|
||||
return "aot_inductor_entry"
|
||||
|
||||
def define_kernel(self, name: str, kernel: str):
|
||||
self.header.splice(f"\n{kernel}\n")
|
||||
|
||||
def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
|
||||
return
|
||||
|
||||
def wrap_kernel_call(self, name, call_args):
|
||||
return f"{name}({', '.join(call_args)});"
|
||||
|
||||
def return_end_str(self):
|
||||
return "\n}"
|
||||
|
||||
def generate_end(self, result):
|
||||
return
|
||||
|
||||
def add_benchmark_harness(self, output):
|
||||
return
|
||||
|
|
|
|||
|
|
@ -138,6 +138,7 @@ def compile_fx_inner(
|
|||
num_fixed=0,
|
||||
is_backward=False,
|
||||
graph_id=None,
|
||||
aot_mode=False,
|
||||
):
|
||||
if is_tf32_warning_applicable(gm):
|
||||
_warn_tf32_disabled()
|
||||
|
|
@ -174,10 +175,13 @@ def compile_fx_inner(
|
|||
shape_env=shape_env,
|
||||
num_static_inputs=num_fixed,
|
||||
graph_id=graph_id,
|
||||
aot_mode=aot_mode,
|
||||
)
|
||||
with V.set_graph_handler(graph):
|
||||
graph.run(*example_inputs)
|
||||
compiled_fn = graph.compile_to_fn()
|
||||
if aot_mode:
|
||||
return compiled_fn
|
||||
|
||||
if cudagraphs:
|
||||
complex_memory_overlap_inputs = any(
|
||||
|
|
@ -399,6 +403,7 @@ def compile_fx(
|
|||
inner_compile=compile_fx_inner,
|
||||
config_patches: Optional[Dict[str, Any]] = None,
|
||||
decompositions: Optional[Dict[OpOverload, Callable]] = None,
|
||||
aot_mode=False,
|
||||
):
|
||||
"""Main entrypoint to a compile given FX graph"""
|
||||
if config_patches:
|
||||
|
|
@ -409,7 +414,23 @@ def compile_fx(
|
|||
# need extra layer of patching as backwards is compiled out of scope
|
||||
inner_compile=config.patch(config_patches)(inner_compile),
|
||||
decompositions=decompositions,
|
||||
aot_mode=aot_mode,
|
||||
)
|
||||
|
||||
if aot_mode:
|
||||
aot_config_patches = {
|
||||
"cpp_wrapper": True,
|
||||
"debug": True,
|
||||
"triton.cudagraphs": False,
|
||||
}
|
||||
with config.patch(aot_config_patches):
|
||||
return compile_fx(
|
||||
model_,
|
||||
example_inputs_,
|
||||
inner_compile=functools.partial(inner_compile, aot_mode=aot_mode),
|
||||
decompositions=decompositions,
|
||||
)
|
||||
|
||||
recursive_compile_fx = functools.partial(
|
||||
compile_fx,
|
||||
inner_compile=inner_compile,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,9 @@ disable_progress = True
|
|||
# Whether to enable printing the source code for each future
|
||||
verbose_progress = False
|
||||
|
||||
# Name for generated .h and .so files
|
||||
aot_codegen_output_prefix = None
|
||||
|
||||
# use cpp wrapper instead of python wrapper
|
||||
cpp_wrapper = False
|
||||
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ def enable_aot_logging():
|
|||
stack.enter_context(patch("functorch.compile.config.debug_graphs", True))
|
||||
stack.enter_context(patch("functorch.compile.config.debug_joint", True))
|
||||
|
||||
path = os.path.join(get_debug_dir(), "aot_torchinductor")
|
||||
path = os.path.join(get_debug_dir(), "torchinductor")
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
|
@ -245,7 +245,7 @@ class DebugContext:
|
|||
for n in DebugContext._counter:
|
||||
dirname = os.path.join(
|
||||
get_debug_dir(),
|
||||
"aot_torchinductor",
|
||||
"torchinductor",
|
||||
f"{folder_name}.{n}",
|
||||
)
|
||||
if not os.path.exists(dirname):
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from torch.utils._mode_utils import no_dispatch
|
|||
from .._dynamo import config as dynamo_config
|
||||
|
||||
from . import config, ir
|
||||
from .codegen.wrapper import CppWrapperCodeGen, WrapperCodeGen
|
||||
from .codegen.wrapper import CppAotWrapperCodeGen, CppWrapperCodeGen, WrapperCodeGen
|
||||
from .exc import (
|
||||
LoweringException,
|
||||
MissingOperatorWithDecomp,
|
||||
|
|
@ -114,6 +114,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
shape_env=None,
|
||||
num_static_inputs=None,
|
||||
graph_id=None,
|
||||
aot_mode=False,
|
||||
):
|
||||
super().__init__(gm)
|
||||
self.extra_traceback = False # we do our own error wrapping
|
||||
|
|
@ -143,6 +144,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
self.creation_time = time.time()
|
||||
self.name = "GraphLowering"
|
||||
self._can_use_cpp_wrapper = config.cpp_wrapper
|
||||
self.aot_mode = aot_mode
|
||||
self.graph_id = graph_id
|
||||
self.scheduler = None
|
||||
self._warned_fallback = {"aten.convolution_backward"}
|
||||
|
|
@ -554,10 +556,13 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
self.check_cpp_wrapper()
|
||||
if self._can_use_cpp_wrapper:
|
||||
self.sizevars = CppSizeVarAllocator(self._shape_env)
|
||||
self.wrapper_code = CppWrapperCodeGen()
|
||||
self.wrapper_code = (
|
||||
CppAotWrapperCodeGen() if self.aot_mode else CppWrapperCodeGen()
|
||||
)
|
||||
return
|
||||
else:
|
||||
assert not self.aot_mode, "Model does not support AOT compilation"
|
||||
self.wrapper_code = WrapperCodeGen()
|
||||
return
|
||||
|
||||
def codegen(self):
|
||||
from .scheduler import Scheduler
|
||||
|
|
@ -630,7 +635,18 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
return mod
|
||||
|
||||
def compile_to_fn(self):
|
||||
return self.compile_to_module().call
|
||||
if self.aot_mode:
|
||||
from .codecache import AotCodeCache
|
||||
|
||||
code = self.codegen()
|
||||
if config.debug:
|
||||
print(code)
|
||||
|
||||
# return the generated .so file path
|
||||
output_path = AotCodeCache.compile(code)
|
||||
return lambda dummy: output_path
|
||||
else:
|
||||
return self.compile_to_module().call
|
||||
|
||||
def get_output_names(self):
|
||||
assert self.graph_outputs is not None
|
||||
|
|
|
|||
Loading…
Reference in a new issue