[inductor] Fix AOTInductor (#99203)

Summary: Fix the broken AOTInductor flow and add a smoketest on CI.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99203
Approved by: https://github.com/jansel
This commit is contained in:
Bin Bao 2023-04-24 00:58:21 +00:00 committed by PyTorch MergeBot
parent 3afa60bf0f
commit e43918b93a
20 changed files with 188 additions and 216 deletions

View file

@ -251,6 +251,11 @@ test_inductor() {
python test/run_test.py --inductor --include test_modules test_ops test_ops_gradients test_torch --verbose
# Do not add --inductor for the following inductor unit tests, otherwise we will fail because of nested dynamo state
python test/run_test.py --include inductor/test_torchinductor inductor/test_torchinductor_opinfo --verbose
# docker build uses bdist_wheel which does not work with test_aot_inductor
# TODO: need a faster way to build
BUILD_AOT_INDUCTOR_TEST=1 python setup.py develop
LD_LIBRARY_PATH="$TORCH_LIB_DIR $TORCH_BIN_DIR"/test_aot_inductor
}
# "Global" flags for inductor benchmarking controlled by TEST_CONFIG
@ -551,6 +556,7 @@ test_libtorch() {
# TODO: Consider to run static_runtime_test from $TORCH_BIN_DIR (may need modify build script)
"$BUILD_BIN_DIR"/static_runtime_test --gtest_output=xml:$TEST_REPORTS_DIR/static_runtime_test.xml
fi
assert_git_not_dirty
fi
}

View file

@ -178,6 +178,7 @@ cmake_dependent_option(
CAFFE2_USE_MSVC_STATIC_RUNTIME "Using MSVC static runtime libraries" ON
"NOT BUILD_SHARED_LIBS" OFF)
option(BUILD_TEST "Build C++ test binaries (need gtest and gbenchmark)" OFF)
option(BUILD_AOT_INDUCTOR_TEST "Build C++ test binaries for aot-inductor" OFF)
option(BUILD_STATIC_RUNTIME_BENCHMARK "Build C++ binaries for static runtime benchmarks (need gbenchmark)" OFF)
option(BUILD_TENSOREXPR_BENCHMARK "Build C++ binaries for tensorexpr benchmarks (need gbenchmark)" OFF)
option(BUILD_MOBILE_BENCHMARK "Build C++ test binaries for mobile (ARM) targets(need gtest and gbenchmark)" OFF)

View file

@ -1174,6 +1174,11 @@ if(BUILD_TEST)
add_subdirectory(${TORCH_ROOT}/test/cpp/lazy
${CMAKE_BINARY_DIR}/test_lazy)
endif()
if(BUILD_AOT_INDUCTOR_TEST)
add_subdirectory(
${TORCH_ROOT}/test/cpp/aot_inductor
${CMAKE_BINARY_DIR}/test_aot_inductor)
endif()
endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")

View file

@ -0,0 +1,53 @@
set(AOT_INDUCTOR_TEST_ROOT ${TORCH_ROOT}/test/cpp/aot_inductor)
# Build the cpp gtest binary containing the cpp-only tests.
set(INDUCTOR_TEST_SRCS
${AOT_INDUCTOR_TEST_ROOT}/test.cpp
)
add_executable(test_aot_inductor
${TORCH_ROOT}/test/cpp/common/main.cpp
${INDUCTOR_TEST_SRCS}
)
# TODO temporary until we can delete the old gtest polyfills.
target_compile_definitions(test_aot_inductor PRIVATE USE_GTEST)
# Define a custom command to generate the library
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/libaot_inductor_output.so
COMMAND python ${AOT_INDUCTOR_TEST_ROOT}/test.py
DEPENDS ${AOT_INDUCTOR_TEST_ROOT}/test.py
)
add_custom_target(aot_inductor_output_target ALL
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libaot_inductor_output.so)
add_dependencies(test_aot_inductor aot_inductor_output_target)
target_link_libraries(test_aot_inductor PRIVATE
torch
gtest
${CMAKE_CURRENT_BINARY_DIR}/libaot_inductor_output.so
)
if(USE_CUDA)
target_link_libraries(test_aot_inductor PRIVATE
${C10_CUDA_BUILD_SHARED_LIBS}
${CUDA_LIBRARIES}
${CUDA_NVRTC_LIB}
${CUDA_CUDA_LIB}
${TORCH_CUDA_LIBRARIES}
)
target_include_directories(test_aot_inductor PRIVATE ${ATen_CUDA_INCLUDE})
target_compile_definitions(test_aot_inductor PRIVATE USE_CUDA)
endif()
if(INSTALL_TEST)
install(TARGETS test_aot_inductor DESTINATION bin)
# Install PDB files for MSVC builds
if(MSVC AND BUILD_SHARED_LIBS)
install(FILES $<TARGET_PDB_FILE:test_aot_inductor> DESTINATION bin OPTIONAL)
endif()
endif()

View file

@ -0,0 +1,46 @@
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include <torch/torch.h>
extern std::vector<at::Tensor> inductor_entry_cpp(
const std::vector<at::Tensor>& args);
namespace torch {
namespace aot_inductor {
struct Net : torch::nn::Module {
Net() : linear(register_module("linear", torch::nn::Linear(64, 10))) {}
torch::Tensor forward(torch::Tensor x, torch::Tensor y) {
return linear(torch::sin(x) + torch::cos(y));
}
torch::nn::Linear linear;
};
TEST(AotInductorTest, BasicTest) {
torch::NoGradGuard no_grad;
Net net;
net.to(torch::kCUDA);
torch::Tensor x =
at::randn({32, 64}, at::dtype(at::kFloat).device(at::kCUDA));
torch::Tensor y =
at::randn({32, 64}, at::dtype(at::kFloat).device(at::kCUDA));
torch::Tensor results_ref = net.forward(x, y);
// TODO: we need to provide an API to concatenate args and weights
std::vector<torch::Tensor> inputs;
for (const auto& pair : net.named_parameters()) {
inputs.push_back(pair.value());
}
inputs.push_back(x);
inputs.push_back(y);
auto results_opt = inductor_entry_cpp(inputs);
ASSERT_TRUE(torch::allclose(results_ref, results_opt[0]));
}
} // namespace aot_inductor
} // namespace torch

View file

@ -0,0 +1,28 @@
import shutil
import torch
import torch._dynamo
import torch._inductor
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(64, 10)
def forward(self, x, y):
return self.fc(torch.sin(x) + torch.cos(y))
x = torch.randn((32, 64), device="cuda")
y = torch.randn((32, 64), device="cuda")
with torch.no_grad():
from torch.fx.experimental.proxy_tensor import make_fx
# Using export is blocked by https://github.com/pytorch/pytorch/issues/99000
# module, _ = torch._dynamo.export(Net().cuda(), inp)
module = make_fx(Net().cuda())(x, y)
lib_path = torch._inductor.aot_compile(module, [x, y])
shutil.copy(lib_path, "libaot_inductor_output.so")

View file

@ -1,21 +0,0 @@
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(test)
set(Torch_DIR "../../../../torch/share/cmake/Torch")
find_package(Torch REQUIRED)
add_library(aot_inductor_output SHARED IMPORTED)
set_property(TARGET aot_inductor_output PROPERTY
IMPORTED_LOCATION ${CMAKE_BINARY_DIR}/aot_inductor_output.so)
add_custom_command(
OUTPUT ${CMAKE_BINARY_DIR}/aot_inductor_output.so
COMMAND python ${CMAKE_SOURCE_DIR}/test.py
DEPENDS ${CMAKE_SOURCE_DIR}/test.py
)
add_custom_target(aot_inductor_output_target ALL
DEPENDS ${CMAKE_BINARY_DIR}/aot_inductor_output.so)
add_executable(test test.cpp)
target_link_libraries(test ${TORCH_LIBRARIES} aot_inductor_output)
add_dependencies(test aot_inductor_output_target)
set_property(TARGET test PROPERTY CXX_STANDARD 17)

View file

@ -1,44 +0,0 @@
//#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include <torch/torch.h>
extern std::vector<at::Tensor> inductor_cpp_entry(const std::vector<at::Tensor>& args);
/*
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.ones(32, 64)
def forward(self, x):
x = torch.relu(x + self.weight)
return x
*/
struct Net : torch::nn::Module {
Net() {
weight = register_parameter("weight", torch::ones({32, 64}));
}
torch::Tensor forward(torch::Tensor input) {
return torch::relu(input + weight);
}
torch::Tensor weight;
};
int main() {
torch::Tensor x = at::randn({32, 64});
Net net;
torch::Tensor results_ref = net.forward(x);
// TODO: we need to provide an API to concatenate args and weights
std::vector<torch::Tensor> inputs;
for (const auto& pair : net.named_parameters()) {
inputs.push_back(pair.value());
}
inputs.push_back(x);
auto results_opt = inductor_cpp_entry(inputs);
assert(torch::allclose(results_ref, results_opt[0]));
printf("PASS\n");
return 0;
}

View file

@ -1,21 +0,0 @@
import shutil
import torch
import torch._dynamo
import torch._inductor
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.ones(32, 64)
def forward(self, x):
x = torch.relu(x + self.weight)
return x
inp = torch.randn((32, 64), device="cpu")
module, _ = torch._dynamo.export(Net(), inp)
lib_path = torch._inductor.aot_compile(module, [inp])
shutil.copy(lib_path, "aot_inductor_output.so")

View file

@ -1,9 +0,0 @@
#!/bin/bash
set -euxo pipefail
rm -rf build
mkdir -p build
cd build
cmake ..
make
./test

View file

@ -1,21 +0,0 @@
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(test)
set(Torch_DIR "../../../../torch/share/cmake/Torch")
find_package(Torch REQUIRED)
add_library(aot_inductor_output SHARED IMPORTED)
set_property(TARGET aot_inductor_output PROPERTY
IMPORTED_LOCATION ${CMAKE_BINARY_DIR}/aot_inductor_output.so)
add_custom_command(
OUTPUT ${CMAKE_BINARY_DIR}/aot_inductor_output.so
COMMAND python ${CMAKE_SOURCE_DIR}/test.py
DEPENDS ${CMAKE_SOURCE_DIR}/test.py
)
add_custom_target(aot_inductor_output_target ALL
DEPENDS ${CMAKE_BINARY_DIR}/aot_inductor_output.so)
add_executable(test test.cpp)
target_link_libraries(test ${TORCH_LIBRARIES} aot_inductor_output)
add_dependencies(test aot_inductor_output_target)
set_property(TARGET test PROPERTY CXX_STANDARD 17)

View file

@ -1,46 +0,0 @@
//#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include <torch/torch.h>
extern std::vector<at::Tensor> inductor_cpp_entry(const std::vector<at::Tensor>& args);
/*
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.ones(32, 64)
def forward(self, x):
x = torch.relu(x + self.weight)
return x
*/
struct Net : torch::nn::Module {
Net() {
weight = register_parameter("weight", torch::ones({32, 64}, at::TensorOptions(at::kCUDA).dtype(at::ScalarType::Float)));
}
torch::Tensor forward(torch::Tensor input) {
return torch::relu(input + weight);
}
torch::Tensor weight;
};
int main() {
torch::Tensor x = at::randn({32, 64}, at::dtype(at::kFloat).device(at::kCUDA));
Net net;
torch::Tensor results_ref = net.forward(x);
// TODO: we need to provide an API to concatenate args and weights
std::vector<torch::Tensor> inputs;
for (const auto& pair : net.named_parameters()) {
inputs.push_back(pair.value());
}
inputs.push_back(x);
auto results_opt = inductor_cpp_entry(inputs);
assert(torch::allclose(results_ref, results_opt[0]));
printf("PASS\n");
return 0;
}

View file

@ -1,22 +0,0 @@
import shutil
import torch
import torch._dynamo
import torch._inductor
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.ones((32, 64), device="cuda")
def forward(self, x):
x = torch.relu(x + self.weight)
return x
inp = torch.randn((32, 64), device="cuda")
module, _ = torch._dynamo.export(Net().cuda(), inp)
lib_path = torch._inductor.aot_compile(module, [inp])
shutil.copy(lib_path, "aot_inductor_output.so")

View file

@ -1,9 +0,0 @@
#!/bin/bash
set -euxo pipefail
rm -rf build
mkdir -p build
cd build
cmake ..
make
./test

View file

@ -45,12 +45,12 @@ def aot_compile(
"""
from .compile_fx import compile_fx_aot
compiled = compile_fx_aot(
result = compile_fx_aot(
gm,
example_inputs,
config_patches=options,
)
lib_path = compiled()
)()
lib_path = result[0] if isinstance(result, tuple) else result
return lib_path

View file

@ -604,7 +604,7 @@ class AotCodeCache:
clear = staticmethod(cache.clear)
@classmethod
def compile(cls, source_code, cuda):
def compile(cls, graph, source_code, cuda):
# TODO: update cpp_compile_command for different platforms
picked_vec_isa = invalid_vec_isa if cuda else pick_vec_isa()
key, input_path = write(
@ -635,7 +635,11 @@ class AotCodeCache:
cls.cache[key] = output_so
return cls.cache[key]
def wrapper_call(*args):
assert len(graph.graph_outputs) > 0
return cls.cache[key], *(None for i in range(len(graph.graph_outputs) - 1))
return wrapper_call
class CppCodeCache:

View file

@ -722,7 +722,7 @@ class CppWrapperCodeGen(WrapperCodeGen):
self.extern_call_ops = set()
self.size = "sizes()"
self.stride = "strides()"
self.call_func_name = "inductor_cpp_entry"
self.call_func_name = "inductor_entry_cpp"
self.cuda = False
def seed(self):
@ -737,7 +737,13 @@ class CppWrapperCodeGen(WrapperCodeGen):
def write_header(self):
if V.graph.aot_mode:
self.header.splice("\n#include <ATen/ATen.h>")
self.header.splice(
"""
/* AOTInductor generated code */
#include <ATen/ScalarOps.h>
"""
)
else:
self.header.splice(
"""
@ -881,6 +887,11 @@ class CppWrapperCodeGen(WrapperCodeGen):
args.insert(0, f"{codegen_reference}")
self.writeline(self.wrap_kernel_call(kernel, args))
def add_benchmark_harness(self, output):
if V.graph.aot_mode:
return
super().add_benchmark_harness(output)
def codegen_sizevar(self, x: Expr) -> str:
from .cpp import cexpr
@ -972,7 +983,6 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
def write_prefix(self):
self.prefix.splice(
"""
#include <ATen/ATen.h>
#include <c10/util/Exception.h>
#include <c10/cuda/CUDAGuard.h>

View file

@ -516,6 +516,7 @@ def compile_fx_with_cpp_wrapper(
example_inputs: List[torch.Tensor],
inner_compile,
decompositions: Optional[Dict[OpOverload, Callable]] = None,
aot_mode=False,
):
"""
Compile into cpp wrapper:
@ -536,7 +537,9 @@ def compile_fx_with_cpp_wrapper(
return compile_fx(
module,
example_inputs,
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
inner_compile=functools.partial(
inner_compile, cpp_wrapper=True, aot_mode=aot_mode
),
decompositions=decompositions,
)
else:
@ -557,7 +560,9 @@ def compile_fx_with_cpp_wrapper(
compiled = compile_fx(
module_copy,
inputs_copy,
inner_compile=functools.partial(inner_compile, cpp_wrapper=False),
inner_compile=functools.partial(
inner_compile, cpp_wrapper=False, aot_mode=False
),
decompositions=decompositions,
)
if fake_mode:
@ -580,7 +585,9 @@ def compile_fx_with_cpp_wrapper(
return compile_fx(
module,
example_inputs,
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
inner_compile=functools.partial(
inner_compile, cpp_wrapper=True, aot_mode=aot_mode
),
decompositions=decompositions,
)
@ -592,12 +599,17 @@ def compile_fx_aot(
config_patches: Optional[Dict[str, Any]] = None,
decompositions: Optional[Dict[OpOverload, Callable]] = None,
):
return compile_fx(
model_,
example_inputs_,
inner_compile=functools.partial(inner_compile, aot_mode=True),
config_patches=config_patches,
decompositions=decompositions,
if config_patches:
with config.patch(config_patches):
return compile_fx_aot(
model_,
example_inputs_,
# need extra layer of patching as backwards is compiled out of scope
inner_compile=config.patch(config_patches)(inner_compile),
decompositions=decompositions,
)
return compile_fx_with_cpp_wrapper(
model_, example_inputs_, inner_compile, decompositions, aot_mode=True
)

View file

@ -691,10 +691,9 @@ class GraphLowering(torch.fx.Interpreter):
code, linemap = self.codegen()
output_code_log.debug("Output code: \n%s", code)
libpath = AotCodeCache.compile(
code, cuda=(self.get_single_device() == "cuda")
return AotCodeCache.compile(
self, code, cuda=(self.get_single_device() == "cuda")
)
return lambda dummy: libpath
else:
return self.compile_to_module().call

View file

@ -3148,9 +3148,10 @@ class MultiOutputLayout(IRNode):
class MultiOutput(ExternKernel):
def codegen(self, wrapper):
wrapper.writeline(
f"{self.get_name()} = {self.inputs[0].get_name()}{self.index}"
)
line = V.graph.wrapper_code.declare
line += f"{self.get_name()} = {self.inputs[0].get_name()}{self.index}"
line += V.graph.wrapper_code.ending
wrapper.writeline(line)
self.codegen_size_asserts(wrapper)
def __init__(self, layout, input, index: str):