[AOTI] Improve the two-pass wrapper codegen (#114067)

Summary: For the second-pass, we don't have to rerun the whole inductor flow again. This PR moves that second-pass to the codegen time. This change not only speeds up the compilation, but also removes kernel scheduling inconsistency between the two passes. Another future improvement is to make the second-pass reuse the scheduler and do the wrapper codegen only.

This is a copy of https://github.com/pytorch/pytorch/pull/113762 to land in github first.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114067
Approved by: https://github.com/chenyang78
This commit is contained in:
Bin Bao 2023-11-19 09:05:46 -08:00 committed by PyTorch MergeBot
parent 226384b460
commit 5a96a42cea
5 changed files with 106 additions and 106 deletions

View file

@ -18,6 +18,7 @@ from torch._inductor.utils import aot_inductor_launcher, cache_dir
from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_quantization import skip_if_no_torchvision
from torch.testing._internal.common_utils import (
IS_CI,
@ -1060,6 +1061,42 @@ class AOTInductorTestsTemplate:
example_inputs = (torch.randn(3, 10, device=self.device),)
self.check_model(Model(), example_inputs)
@skip_if_no_torchvision
def test_missing_cubin(self):
from torchvision.models.resnet import Bottleneck, ResNet
class Model(ResNet):
def __init__(self):
super().__init__(
block=Bottleneck,
layers=[3, 4, 6, 3],
replace_stride_with_dilation=[False, False, True],
norm_layer=None,
)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
f1 = x
x = self.maxpool(x)
x = self.layer1(x)
f2 = x
x = self.layer2(x)
f3 = x
x = self.layer3(x)
x = self.layer4(x)
f4 = x
return [f1, f2, f3, f4]
# Call eval() here so that batch_norm won't update the running stats
# Use float64 to avoid numeric difference failure
model = Model().to(device=self.device, dtype=torch.float64).eval()
example_inputs = (
torch.randn(4, 3, 64, 64, device=self.device, dtype=torch.float64),
)
self.check_model(model, example_inputs)
@common_utils.parametrize("grid_type", [1, 2, 3])
@common_utils.parametrize("num_dims", [1, 2])
@common_utils.parametrize("dynamic", [False, True])
@ -1194,6 +1231,8 @@ copy_tests(
# TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally,
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
"test_freezing": TestFailure(("abi_compatible_cpu",), is_skip=True),
# Need to support convolution
"test_missing_cubin": TestFailure(("abi_compatible_cpu",)),
"test_normal_functional": TestFailure(("abi_compatible_cpu",)),
"test_poi_multiple_dynamic": TestFailure(("abi_compatible_cpu",)),
# There is a double-free issue which will be fixed in another PR
@ -1219,6 +1258,8 @@ copy_tests(
# test_failures, xfail by default, set is_skip=True to skip
{
"test_dup_unbacked_sym_decl": TestFailure(("abi_compatible_cuda",)),
# Need to support convolution
"test_missing_cubin": TestFailure(("abi_compatible_cuda",)),
"test_normal_functional": TestFailure(("abi_compatible_cuda",)),
# There is a double-free issue which will be fixed in another PR
"test_repeat_output": TestFailure(("abi_compatible_cuda",), is_skip=True),

View file

@ -50,10 +50,6 @@ class TestPatternMatcher(TestCase):
if len(codes) == 1:
codes = codes[0]
torch.testing.assert_close(actual, expected)
if inductor_config.cpp_wrapper:
# CPP wrapper runs everything twice, so we'll match the pattern twice
expected_matches *= 2
expected_nodes *= 2
self.assertEqual(
counters["inductor"]["pattern_matcher_count"], expected_matches
@ -519,13 +515,6 @@ class TestPatternMatcher(TestCase):
self.common(fn, args, 2, 5)
def test_cat_slice_cat(self):
def check_counter(counter, expected):
if not inductor_config.cpp_wrapper:
self.assertEqual(counter, expected)
else:
# cpp_wrapper for the CUDA backend runs two passes
self.assertEqual(counter, 2 * expected)
def fn(a, b):
cat_1 = torch.ops.aten.cat.default([a, b], 1)
slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
@ -548,8 +537,8 @@ class TestPatternMatcher(TestCase):
torch.testing.assert_close(actual, expected)
# We don't recompile for dynamic-shape cases.
if dynamo_config.assume_static_by_default:
check_counter(counters["inductor"]["pattern_matcher_count"], 1)
check_counter(counters["inductor"]["pattern_matcher_nodes"], 3)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3)
# Verify we fallback to non-optimal path for negative `end`.
def fn(a, b):

View file

@ -44,13 +44,6 @@ def patches(fn):
class TestSelectAlgorithm(TestCase):
def check_counter(self, counter, expected):
if not inductor_config.cpp_wrapper:
self.assertEqual(counter, expected)
else:
# cpp_wrapper for the CUDA backend runs two passes
self.assertEqual(counter, 2 * expected)
@expectedFailureDynamicWrapper
@patches
def test_linear_relu(self):
@ -64,7 +57,7 @@ class TestSelectAlgorithm(TestCase):
torch.randn(1, 16, device="cuda"),
)
# Autotuning checks correctness of each version
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
# It would be nice to assert this got fused into a single kernel, but that
# only happens if we select a triton template (and not aten).
@ -82,7 +75,7 @@ class TestSelectAlgorithm(TestCase):
)
foo(*inps)
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patch.object(select_algorithm, "VERIFY", dict(atol=5e-2, rtol=5e-2))
@patches
@ -112,7 +105,7 @@ class TestSelectAlgorithm(TestCase):
torch.randn(8, 32, device="cuda"),
torch.randn(32, 8, device="cuda"),
)
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test__int_mm(self):
@ -206,7 +199,7 @@ class TestSelectAlgorithm(TestCase):
torch.randn(512, 512, device="cuda"),
)
# Autotuning checks correctness of each version
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm_dup_args(self):
@ -215,7 +208,7 @@ class TestSelectAlgorithm(TestCase):
return torch.mm(a, a)
foo(torch.randn(32, 32, device="cuda"))
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@patches
def test_mm_dup_args_view(self):
@ -226,7 +219,7 @@ class TestSelectAlgorithm(TestCase):
return torch.mm(q, k.transpose(0, 1))
foo(torch.randn(64, 64, device="cuda"))
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@skipIfRocm
@expectedFailureDynamicWrapper
@ -252,7 +245,7 @@ class TestSelectAlgorithm(TestCase):
torch.randn(34, device="cuda"),
)
# Autotuning checks correctness of each version
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@skipIfRocm
@patches

View file

@ -51,7 +51,6 @@ from .fx_passes.post_grad import post_grad_passes, view_to_reshape
from .fx_passes.pre_grad import pre_grad_passes
from .graph import GraphLowering
from .ir import ExternKernelNode
from .pattern_matcher import clone_graph
from .utils import get_dtype_size, has_incompatible_cudagraph_ops
from .virtualized import V
@ -217,79 +216,6 @@ def count_bytes_inner(
return make_boxed_func(gm.forward)
def inner_compile_with_cpp_wrapper(inner_compile: Callable[..., Any]):
@functools.wraps(inner_compile)
def wrapper(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], **kwargs):
"""
Compile into cpp wrapper:
For CPU, this is currently done in one pass.
For GPU, this is done in two passes: JIT-compile the model with python wrapper code
and run it to generate autotuned kernel binaries in the first pass; and then generate
cpp wrapper code and compile it to a dynamic library in the second pass.
"""
devices = (
{t.device.type for t in gm.parameters()}
| {t.device.type for t in gm.buffers()}
| {t.device.type for t in example_inputs if isinstance(t, torch.Tensor)}
)
if "cuda" not in devices:
kwargs_patched = {**kwargs, "cpp_wrapper": True}
return inner_compile(gm, example_inputs, **kwargs_patched)
else:
with config.patch(
{
"triton.store_cubin": True,
}
):
# first pass with regular python wrapper code
kwargs_patched = {
**kwargs,
"cpp_wrapper": False,
}
# clone_graph(gm) makes sure no graph modification from the first pass will
# leak to the second pass. It does increase memory pressure, but the problem
# can be alleviated once we have parameters as FakeTensor.
compiled = inner_compile(
clone_graph(gm), example_inputs, **kwargs_patched
)
def materialize(x):
if isinstance(x, (torch.SymInt, torch.SymFloat)):
# Need concrete value to run dynamic shapes and tune the result
return x.node.hint
else:
assert not isinstance(x, FakeTensor)
return x
if tracing_context := torch._guards.TracingContext.try_get():
if tracing_context.output_strides:
tracing_context.output_strides.clear()
params_flat = [
param
for param in tracing_context.params_flat # type: ignore[union-attr]
if param is not None
]
real_inputs = [
materialize(x) for x in (params_flat + V.real_inputs)
]
else:
real_inputs = [materialize(x) for x in V.real_inputs]
with torch.utils._python_dispatch._disable_current_modes():
compiled(real_inputs)
del real_inputs
# second pass
kwargs_patched = {**kwargs, "cpp_wrapper": True}
return inner_compile(gm, example_inputs, **kwargs_patched)
return wrapper
def fake_tensor_prop(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
@ -592,6 +518,10 @@ def fx_codegen_and_compile(
with V.set_fake_mode(fake_mode):
graph = GraphLowering(
gm,
# example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.
# For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,
# we currently use fake tensors and defake them later.
example_inputs=V.real_inputs if is_inference else example_inputs,
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
@ -1033,6 +963,7 @@ def compile_fx(
"cpp_wrapper": False,
"triton.autotune_cublasLt": False,
"triton.cudagraphs": False,
"triton.store_cubin": True,
}
), V.set_real_inputs(example_inputs_):
inputs_ = example_inputs_
@ -1055,7 +986,7 @@ def compile_fx(
return compile_fx(
model_,
inputs_,
inner_compile=inner_compile_with_cpp_wrapper(inner_compile),
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
decompositions=decompositions,
)

View file

@ -15,8 +15,9 @@ import torch
import torch._logging
import torch.fx
from torch._decomp import get_decompositions
from torch._dynamo.utils import dynamo_timed
from torch._dynamo.utils import defake, dynamo_timed
from torch._logging import LazyString
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
from torch.fx.experimental.symbolic_shapes import has_free_symbols, ShapeEnv, SymTypes
from torch.utils._mode_utils import no_dispatch
@ -164,6 +165,7 @@ class GraphLowering(torch.fx.Interpreter):
def __init__(
self,
gm: torch.fx.GraphModule,
example_inputs: Optional[List[torch.Tensor]] = None,
shape_env=None,
num_static_inputs=None,
graph_id=None,
@ -176,6 +178,7 @@ class GraphLowering(torch.fx.Interpreter):
):
super().__init__(gm)
self.example_inputs = example_inputs
self.layout_opt = (
layout_opt if layout_opt is not None else self.decide_layout_opt(gm)
)
@ -921,6 +924,46 @@ class GraphLowering(torch.fx.Interpreter):
assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported"
self.wrapper_code = wrapper_code_gen_cls()
def codegen_with_cpp_wrapper(self):
"""
For CPU, the cpp wrapper codegen is done in one pass.
For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python
wrapper code and run it to generate autotuned kernel binaries in the first pass; and then
generate cpp wrapper code and compile it to a dynamic library in the second pass.
"""
if "cuda" in self.device_types:
# first pass
self.cpp_wrapper = False
compiled = self.compile_to_module().call
def materialize(x):
if isinstance(x, (torch.SymInt, torch.SymFloat)):
# Need concrete value to run dynamic shapes and tune the result
return x.node.hint
elif isinstance(x, FakeTensor):
return defake(x)
else:
assert isinstance(
x, torch.Tensor
), "Unknown type when creating real inputs"
return x
with torch.utils._python_dispatch._disable_current_modes():
assert self.example_inputs is not None
real_inputs = [materialize(x) for x in self.example_inputs]
compiled(real_inputs)
del real_inputs
# second pass
# TODO: reuse self.scheduler from the first pass to speed up the second pass
self.cpp_wrapper = True
self.removed_buffers.clear()
self.inplaced_to_remove.clear()
return self.codegen()
else:
# cpu
return self.codegen()
def codegen(self):
from .scheduler import Scheduler
@ -952,7 +995,9 @@ class GraphLowering(torch.fx.Interpreter):
def compile_to_module(self):
from .codecache import PyCodeCache
code, linemap = self.codegen()
code, linemap = (
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
)
linemap = [(line_no, node.stack_trace) for line_no, node in linemap]
key, path = PyCodeCache.write(code)
mod = PyCodeCache.load_by_key_path(
@ -975,10 +1020,11 @@ class GraphLowering(torch.fx.Interpreter):
return mod
def compile_to_fn(self):
if self.aot_mode and self.cpp_wrapper:
if self.aot_mode:
from .codecache import AotCodeCache
code, linemap = self.codegen()
assert self.cpp_wrapper, "AOT mode only supports C++ wrapper"
code, linemap = self.codegen_with_cpp_wrapper()
output_code_log.debug("Output code: \n%s", code)
serialized_extern_kernel_nodes = None