[reland2][inductor] Add an AOT compilation mode for Inductor CPP backend (#96520)

Summary: This is a reland of https://github.com/pytorch/pytorch/pull/94822.
Solved the long compilation issue for inductor cpp tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96520
Approved by: https://github.com/huydhn, https://github.com/malfet
This commit is contained in:
Bin Bao 2023-03-13 14:20:29 +00:00 committed by PyTorch MergeBot
parent 178d2a38e0
commit f03db8d6cb
13 changed files with 317 additions and 56 deletions

View file

@ -484,7 +484,7 @@ Inside this directory, each run will have a separate folder named with the times
$ ls
run_2023_03_01_08_20_52_143510-pid_180167
In the run folder there will be a torchdynamo directory which contains debug logs, and an aot_torchinductor
In the run folder there will be a torchdynamo directory which contains debug logs, and an torchinductor
folder which contains a subfolder for each compiled kernel with inductor debug artifacts.
::
@ -492,13 +492,13 @@ folder which contains a subfolder for each compiled kernel with inductor debug a
$ cd
run_2023_03_01_08_20_52_143510-pid_180167
$ ls
aot_torchinductor torchdynamo
torchinductor torchdynamo
Moving further into the aot_torchinductor directory, the \*.log files are logs from the aot autograd phase of compilation, model__0_forward_1.0 contains the inductor debug artifacts.
Moving further into the torchinductor directory, the \*.log files are logs from the aot autograd phase of compilation, model__0_forward_1.0 contains the inductor debug artifacts.
::
$ cd aot_torchinductor
$ cd torchinductor
$ ls
aot_model___0_debug.log model__0_forward_1.0
$ cd model__0_forward_1.0

View file

@ -0,0 +1,23 @@
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(test)
set(Torch_DIR "../../../../torch/share/cmake/Torch")
find_package(Torch REQUIRED)
add_executable(test test.cpp ${CMAKE_BINARY_DIR}/aot_inductor_output.h)
add_custom_command(
OUTPUT ${CMAKE_BINARY_DIR}/aot_inductor_output.h
COMMAND python ${CMAKE_SOURCE_DIR}/test.py
DEPENDS ${CMAKE_SOURCE_DIR}/test.py
)
add_custom_target(generate_header ALL
DEPENDS ${CMAKE_BINARY_DIR}/aot_inductor_output.h)
add_library(aot_inductor_output SHARED IMPORTED)
set_property(TARGET aot_inductor_output PROPERTY
IMPORTED_LOCATION ${CMAKE_BINARY_DIR}/aot_inductor_output.so)
target_link_libraries(test "${TORCH_LIBRARIES}" aot_inductor_output)
set_property(TARGET test PROPERTY CXX_STANDARD 17)

View file

@ -0,0 +1,41 @@
//#include <gtest/gtest.h>
#include <iostream>
#include "build/aot_inductor_output.h"
/*
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.ones(32, 64)
def forward(self, x):
x = torch.relu(x + self.weight)
return x
*/
struct Net : torch::nn::Module {
Net() {
weight = register_parameter("weight", torch::ones({32, 64}));
}
torch::Tensor forward(torch::Tensor input) {
return torch::relu(input + weight);
}
torch::Tensor weight;
};
int main() {
torch::Tensor x = at::randn({32, 64});
Net net;
torch::Tensor results_ref = net.forward(x);
// TODO: we need to provide an API to concatenate args and weights
std::vector<torch::Tensor> inputs = {x};
for (const auto& pair : net.named_parameters()) {
inputs.push_back(pair.value());
}
torch::Tensor results_opt = aot_inductor_entry(inputs);
assert(torch::allclose(results_ref, results_opt));
printf("PASS\n");
return 0;
}

View file

@ -0,0 +1,22 @@
import torch
import torch._dynamo
import torch._inductor
import torch._inductor.config
torch._inductor.config.aot_codegen_output_prefix = "aot_inductor_output"
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.ones(32, 64)
def forward(self, x):
x = torch.relu(x + self.weight)
return x
inp = torch.randn((32, 64), device="cpu")
module, _ = torch._dynamo.export(Net(), inp)
so_path = torch._inductor.aot_compile(module, [inp])
print(so_path)

8
test/inductor/aot/cpp/test.sh Executable file
View file

@ -0,0 +1,8 @@
#!/bin/bash
set -euxo pipefail
mkdir -p build
cd build
cmake ..
make
./test

View file

@ -27,6 +27,27 @@ def compile(
return compile_fx(gm, example_inputs, config_patches=options)
def aot_compile(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
options: Optional[Dict[str, Any]] = None,
) -> str:
"""
Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
Args:
gm: The FX graph to compile.
example_inputs: List of tensor inputs.
options: Optional dict of config options. See `torch._inductor.config`.
Returns:
Path to the generated shared library
"""
from .compile_fx import compile_fx
return compile_fx(gm, example_inputs, config_patches=options, aot_mode=True)()
def list_mode_options(mode: str = None) -> Dict[str, Any]:
r"""Returns a dictionary describing the optimizations that each of the available
modes passed to `torch.compile()` performs.

View file

@ -534,6 +534,52 @@ def cpp_compile_command(
).strip()
class AotCodeCache:
cache = dict()
clear = staticmethod(cache.clear)
@classmethod
def compile(cls, source_code):
from .codegen.wrapper import CppWrapperCodeGen
# TODO: update cpp_compile_command for different platforms
picked_vec_isa = pick_vec_isa()
key, input_path = write(
source_code,
"cpp",
code_hash(repr(cpp_compile_command("i", "o", vec_isa=picked_vec_isa))),
)
if key not in cls.cache:
from filelock import FileLock
lock_dir = get_lock_dir()
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
with lock:
output_so = (
os.path.join(os.getcwd(), f"{config.aot_codegen_output_prefix}.so")
if config.aot_codegen_output_prefix
else f"{input_path[:-3]}.so"
)
output_header = f"{output_so[:-3]}.h"
with open(output_header, "w") as header_file:
header_file.writelines("#include <torch/torch.h>\n\n")
header_file.writelines(f"{CppWrapperCodeGen.decl_str};\n")
log.info(f"AOT-Inductor compiles code into: {output_so}")
if not os.path.exists(output_so):
cmd = cpp_compile_command(
input=input_path, output=output_so, vec_isa=picked_vec_isa
).split(" ")
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
raise exc.CppCompileError(cmd, e.output) from e
cls.cache[key] = output_so
return cls.cache[key]
class CppCodeCache:
cache = dict()
clear = staticmethod(cache.clear)

View file

@ -2266,7 +2266,12 @@ class KernelGroup:
)
if enable_kernel_profile:
code.writelines(["#include <ATen/record_function.h>"])
code.writelines([cpp_prefix(), "" f'extern "C" void kernel({arg_defs})'])
kernel_decl_name = kernel_name if V.graph.aot_mode else "kernel"
if not V.graph.aot_mode or self.count == 1:
code.writeline(cpp_prefix())
code.writeline(f'extern "C" void {kernel_decl_name}({arg_defs})')
with code.indent():
if enable_kernel_profile:
graph_id = V.graph.graph_id
@ -2281,9 +2286,12 @@ class KernelGroup:
code.splice(self.loops_code)
codecache_def = IndentedBuffer()
codecache_def.writeline("async_compile.cpp('''")
codecache_def.splice(code)
codecache_def.writeline("''')")
if V.graph.aot_mode:
codecache_def.splice(code)
else:
codecache_def.writeline("async_compile.cpp('''")
codecache_def.splice(code)
codecache_def.writeline("''')")
codecache_str = codecache_def.getvalue()
# TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does

View file

@ -275,6 +275,33 @@ class WrapperCodeGen(CodeGen):
self.wrapper_call = IndentedBuffer()
self.kernels = {}
self.lines = []
self.set_header()
self.write_prefix()
for name, value in V.graph.constants.items():
# include a hash so our code cache gives different constants different files
hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest()
self.header.writeline(f"{name} = None # {hashed}")
self.allocated = set()
self.freed = set()
# maps from reusing buffer to reused buffer
self.reuses = dict()
self.write_get_cuda_stream = functools.lru_cache(None)(
self.write_get_cuda_stream
)
@functools.lru_cache(None)
def add_import_once(line):
self.header.writeline(line)
self.add_import_once = add_import_once
self._metas = {}
def set_header(self):
self.header.splice(
f"""
from ctypes import c_void_p, c_long
@ -302,30 +329,6 @@ class WrapperCodeGen(CodeGen):
"""
)
self.write_prefix()
for name, value in V.graph.constants.items():
# include a hash so our code cache gives different constants different files
hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest()
self.header.writeline(f"{name} = None # {hashed}")
self.allocated = set()
self.freed = set()
# maps from reusing buffer to reused buffer
self.reuses = dict()
self.write_get_cuda_stream = functools.lru_cache(None)(
self.write_get_cuda_stream
)
@functools.lru_cache(None)
def add_import_once(line):
self.header.writeline(line)
self.add_import_once = add_import_once
self._metas = {}
def add_meta_once(self, meta):
meta = repr(meta)
if meta not in self._metas:
@ -661,6 +664,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
"""
call_func_id = count()
decl_str = None
def __init__(self):
self._call_func_id = next(CppWrapperCodeGen.call_func_id)
@ -680,7 +684,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
for x in V.graph.graph_outputs
]
def write_prefix(self):
def write_prefix_header(self):
self.prefix.splice(
"""
async_compile.wait(globals())
@ -702,21 +706,30 @@ class CppWrapperCodeGen(WrapperCodeGen):
"""
)
with self.wrapper_call.indent():
inputs_len = len(V.graph.graph_inputs.keys())
output_refs = self.get_output_refs()
if output_refs:
if len(output_refs) == 1:
output_types = "at::Tensor"
else:
output_types = "std::vector<at::Tensor>"
else:
output_types = "void"
inputs_types = "std::vector<at::Tensor>"
self.wrapper_call.writeline(
f"{output_types} call_{self._call_func_id}({inputs_types} args) {{"
)
def call_func_name(self):
return f"call_{self._call_func_id}"
def write_prefix(self):
self.write_prefix_header()
inputs_len = len(V.graph.graph_inputs.keys())
output_refs = self.get_output_refs()
if output_refs:
if len(output_refs) == 1:
output_types = "at::Tensor"
else:
output_types = "std::vector<at::Tensor>"
else:
output_types = "void"
inputs_types = "std::vector<at::Tensor>"
CppWrapperCodeGen.decl_str = (
f"{output_types} {self.call_func_name()}({inputs_types} args)"
)
self.prefix.splice(f"{CppWrapperCodeGen.decl_str} {{")
with self.wrapper_call.indent():
if inputs_len != 0:
inputs_keys_str = ", ".join(V.graph.graph_inputs.keys())
self.wrapper_call.writeline(f"at::Tensor {inputs_keys_str};")
@ -778,18 +791,24 @@ class CppWrapperCodeGen(WrapperCodeGen):
def wrap_kernel_call(self, name, call_args):
return "{}({});".format(name, ", ".join(call_args))
def return_end_str(self):
return "\n}\n'''\n)"
def generate_return(self, output_refs):
if output_refs:
if len(output_refs) == 1:
self.wrapper_call.writeline("return " + output_refs[0] + "; }''' )")
self.wrapper_call.writeline(
f"return {output_refs[0]};{self.return_end_str()}"
)
else:
self.wrapper_call.writeline(
"return std::vector<at::Tensor>({"
+ ", ".join(output_refs)
+ "}); }''' )"
+ "});"
+ self.return_end_str()
)
else:
self.wrapper_call.writeline("return; }''' )")
self.wrapper_call.writeline(f"return;{self.return_end_str()}")
def generate_end(self, result):
shared = codecache.get_shared()
@ -839,3 +858,36 @@ class CppWrapperCodeGen(WrapperCodeGen):
else:
args.insert(0, f"{codegen_reference}")
self.writeline(f"{cpp_kernel}({', '.join(args)});")
class CppAotWrapperCodeGen(CppWrapperCodeGen):
"""
The AOT-version outer wrapper that calls the kernels in C++
"""
def set_header(self):
return
def write_prefix_header(self):
self.prefix.splice("\n#include <ATen/ATen.h>")
def call_func_name(self):
return "aot_inductor_entry"
def define_kernel(self, name: str, kernel: str):
self.header.splice(f"\n{kernel}\n")
def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
return
def wrap_kernel_call(self, name, call_args):
return f"{name}({', '.join(call_args)});"
def return_end_str(self):
return "\n}"
def generate_end(self, result):
return
def add_benchmark_harness(self, output):
return

View file

@ -138,6 +138,7 @@ def compile_fx_inner(
num_fixed=0,
is_backward=False,
graph_id=None,
aot_mode=False,
):
if is_tf32_warning_applicable(gm):
_warn_tf32_disabled()
@ -174,10 +175,13 @@ def compile_fx_inner(
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
aot_mode=aot_mode,
)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
compiled_fn = graph.compile_to_fn()
if aot_mode:
return compiled_fn
if cudagraphs:
complex_memory_overlap_inputs = any(
@ -399,6 +403,7 @@ def compile_fx(
inner_compile=compile_fx_inner,
config_patches: Optional[Dict[str, Any]] = None,
decompositions: Optional[Dict[OpOverload, Callable]] = None,
aot_mode=False,
):
"""Main entrypoint to a compile given FX graph"""
if config_patches:
@ -409,7 +414,23 @@ def compile_fx(
# need extra layer of patching as backwards is compiled out of scope
inner_compile=config.patch(config_patches)(inner_compile),
decompositions=decompositions,
aot_mode=aot_mode,
)
if aot_mode:
aot_config_patches = {
"cpp_wrapper": True,
"debug": True,
"triton.cudagraphs": False,
}
with config.patch(aot_config_patches):
return compile_fx(
model_,
example_inputs_,
inner_compile=functools.partial(inner_compile, aot_mode=aot_mode),
decompositions=decompositions,
)
recursive_compile_fx = functools.partial(
compile_fx,
inner_compile=inner_compile,

View file

@ -12,6 +12,9 @@ disable_progress = True
# Whether to enable printing the source code for each future
verbose_progress = False
# Name for generated .h and .so files
aot_codegen_output_prefix = None
# use cpp wrapper instead of python wrapper
cpp_wrapper = False

View file

@ -206,7 +206,7 @@ def enable_aot_logging():
stack.enter_context(patch("functorch.compile.config.debug_graphs", True))
stack.enter_context(patch("functorch.compile.config.debug_joint", True))
path = os.path.join(get_debug_dir(), "aot_torchinductor")
path = os.path.join(get_debug_dir(), "torchinductor")
if not os.path.exists(path):
os.makedirs(path)
@ -245,7 +245,7 @@ class DebugContext:
for n in DebugContext._counter:
dirname = os.path.join(
get_debug_dir(),
"aot_torchinductor",
"torchinductor",
f"{folder_name}.{n}",
)
if not os.path.exists(dirname):

View file

@ -23,7 +23,7 @@ from torch.utils._mode_utils import no_dispatch
from .._dynamo import config as dynamo_config
from . import config, ir
from .codegen.wrapper import CppWrapperCodeGen, WrapperCodeGen
from .codegen.wrapper import CppAotWrapperCodeGen, CppWrapperCodeGen, WrapperCodeGen
from .exc import (
LoweringException,
MissingOperatorWithDecomp,
@ -114,6 +114,7 @@ class GraphLowering(torch.fx.Interpreter):
shape_env=None,
num_static_inputs=None,
graph_id=None,
aot_mode=False,
):
super().__init__(gm)
self.extra_traceback = False # we do our own error wrapping
@ -143,6 +144,7 @@ class GraphLowering(torch.fx.Interpreter):
self.creation_time = time.time()
self.name = "GraphLowering"
self._can_use_cpp_wrapper = config.cpp_wrapper
self.aot_mode = aot_mode
self.graph_id = graph_id
self.scheduler = None
self._warned_fallback = {"aten.convolution_backward"}
@ -554,10 +556,13 @@ class GraphLowering(torch.fx.Interpreter):
self.check_cpp_wrapper()
if self._can_use_cpp_wrapper:
self.sizevars = CppSizeVarAllocator(self._shape_env)
self.wrapper_code = CppWrapperCodeGen()
self.wrapper_code = (
CppAotWrapperCodeGen() if self.aot_mode else CppWrapperCodeGen()
)
return
else:
assert not self.aot_mode, "Model does not support AOT compilation"
self.wrapper_code = WrapperCodeGen()
return
def codegen(self):
from .scheduler import Scheduler
@ -630,7 +635,18 @@ class GraphLowering(torch.fx.Interpreter):
return mod
def compile_to_fn(self):
return self.compile_to_module().call
if self.aot_mode:
from .codecache import AotCodeCache
code = self.codegen()
if config.debug:
print(code)
# return the generated .so file path
output_path = AotCodeCache.compile(code)
return lambda dummy: output_path
else:
return self.compile_to_module().call
def get_output_names(self):
assert self.graph_outputs is not None