mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[AOTI] Improve the two-pass wrapper codegen (#114067)
Summary: For the second-pass, we don't have to rerun the whole inductor flow again. This PR moves that second-pass to the codegen time. This change not only speeds up the compilation, but also removes kernel scheduling inconsistency between the two passes. Another future improvement is to make the second-pass reuse the scheduler and do the wrapper codegen only. This is a copy of https://github.com/pytorch/pytorch/pull/113762 to land in github first. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114067 Approved by: https://github.com/chenyang78
This commit is contained in:
parent
226384b460
commit
5a96a42cea
5 changed files with 106 additions and 106 deletions
|
|
@ -18,6 +18,7 @@ from torch._inductor.utils import aot_inductor_launcher, cache_dir
|
|||
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_quantization import skip_if_no_torchvision
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_CI,
|
||||
|
|
@ -1060,6 +1061,42 @@ class AOTInductorTestsTemplate:
|
|||
example_inputs = (torch.randn(3, 10, device=self.device),)
|
||||
self.check_model(Model(), example_inputs)
|
||||
|
||||
@skip_if_no_torchvision
|
||||
def test_missing_cubin(self):
|
||||
from torchvision.models.resnet import Bottleneck, ResNet
|
||||
|
||||
class Model(ResNet):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
replace_stride_with_dilation=[False, False, True],
|
||||
norm_layer=None,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
f1 = x
|
||||
x = self.maxpool(x)
|
||||
x = self.layer1(x)
|
||||
f2 = x
|
||||
x = self.layer2(x)
|
||||
f3 = x
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
f4 = x
|
||||
return [f1, f2, f3, f4]
|
||||
|
||||
# Call eval() here so that batch_norm won't update the running stats
|
||||
# Use float64 to avoid numeric difference failure
|
||||
model = Model().to(device=self.device, dtype=torch.float64).eval()
|
||||
example_inputs = (
|
||||
torch.randn(4, 3, 64, 64, device=self.device, dtype=torch.float64),
|
||||
)
|
||||
self.check_model(model, example_inputs)
|
||||
|
||||
@common_utils.parametrize("grid_type", [1, 2, 3])
|
||||
@common_utils.parametrize("num_dims", [1, 2])
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
|
|
@ -1194,6 +1231,8 @@ copy_tests(
|
|||
# TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally,
|
||||
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
|
||||
"test_freezing": TestFailure(("abi_compatible_cpu",), is_skip=True),
|
||||
# Need to support convolution
|
||||
"test_missing_cubin": TestFailure(("abi_compatible_cpu",)),
|
||||
"test_normal_functional": TestFailure(("abi_compatible_cpu",)),
|
||||
"test_poi_multiple_dynamic": TestFailure(("abi_compatible_cpu",)),
|
||||
# There is a double-free issue which will be fixed in another PR
|
||||
|
|
@ -1219,6 +1258,8 @@ copy_tests(
|
|||
# test_failures, xfail by default, set is_skip=True to skip
|
||||
{
|
||||
"test_dup_unbacked_sym_decl": TestFailure(("abi_compatible_cuda",)),
|
||||
# Need to support convolution
|
||||
"test_missing_cubin": TestFailure(("abi_compatible_cuda",)),
|
||||
"test_normal_functional": TestFailure(("abi_compatible_cuda",)),
|
||||
# There is a double-free issue which will be fixed in another PR
|
||||
"test_repeat_output": TestFailure(("abi_compatible_cuda",), is_skip=True),
|
||||
|
|
|
|||
|
|
@ -50,10 +50,6 @@ class TestPatternMatcher(TestCase):
|
|||
if len(codes) == 1:
|
||||
codes = codes[0]
|
||||
torch.testing.assert_close(actual, expected)
|
||||
if inductor_config.cpp_wrapper:
|
||||
# CPP wrapper runs everything twice, so we'll match the pattern twice
|
||||
expected_matches *= 2
|
||||
expected_nodes *= 2
|
||||
|
||||
self.assertEqual(
|
||||
counters["inductor"]["pattern_matcher_count"], expected_matches
|
||||
|
|
@ -519,13 +515,6 @@ class TestPatternMatcher(TestCase):
|
|||
self.common(fn, args, 2, 5)
|
||||
|
||||
def test_cat_slice_cat(self):
|
||||
def check_counter(counter, expected):
|
||||
if not inductor_config.cpp_wrapper:
|
||||
self.assertEqual(counter, expected)
|
||||
else:
|
||||
# cpp_wrapper for the CUDA backend runs two passes
|
||||
self.assertEqual(counter, 2 * expected)
|
||||
|
||||
def fn(a, b):
|
||||
cat_1 = torch.ops.aten.cat.default([a, b], 1)
|
||||
slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
|
||||
|
|
@ -548,8 +537,8 @@ class TestPatternMatcher(TestCase):
|
|||
torch.testing.assert_close(actual, expected)
|
||||
# We don't recompile for dynamic-shape cases.
|
||||
if dynamo_config.assume_static_by_default:
|
||||
check_counter(counters["inductor"]["pattern_matcher_count"], 1)
|
||||
check_counter(counters["inductor"]["pattern_matcher_nodes"], 3)
|
||||
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
|
||||
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3)
|
||||
|
||||
# Verify we fallback to non-optimal path for negative `end`.
|
||||
def fn(a, b):
|
||||
|
|
|
|||
|
|
@ -44,13 +44,6 @@ def patches(fn):
|
|||
|
||||
|
||||
class TestSelectAlgorithm(TestCase):
|
||||
def check_counter(self, counter, expected):
|
||||
if not inductor_config.cpp_wrapper:
|
||||
self.assertEqual(counter, expected)
|
||||
else:
|
||||
# cpp_wrapper for the CUDA backend runs two passes
|
||||
self.assertEqual(counter, 2 * expected)
|
||||
|
||||
@expectedFailureDynamicWrapper
|
||||
@patches
|
||||
def test_linear_relu(self):
|
||||
|
|
@ -64,7 +57,7 @@ class TestSelectAlgorithm(TestCase):
|
|||
torch.randn(1, 16, device="cuda"),
|
||||
)
|
||||
# Autotuning checks correctness of each version
|
||||
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
# It would be nice to assert this got fused into a single kernel, but that
|
||||
# only happens if we select a triton template (and not aten).
|
||||
|
||||
|
|
@ -82,7 +75,7 @@ class TestSelectAlgorithm(TestCase):
|
|||
)
|
||||
|
||||
foo(*inps)
|
||||
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
@patch.object(select_algorithm, "VERIFY", dict(atol=5e-2, rtol=5e-2))
|
||||
@patches
|
||||
|
|
@ -112,7 +105,7 @@ class TestSelectAlgorithm(TestCase):
|
|||
torch.randn(8, 32, device="cuda"),
|
||||
torch.randn(32, 8, device="cuda"),
|
||||
)
|
||||
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
@patches
|
||||
def test__int_mm(self):
|
||||
|
|
@ -206,7 +199,7 @@ class TestSelectAlgorithm(TestCase):
|
|||
torch.randn(512, 512, device="cuda"),
|
||||
)
|
||||
# Autotuning checks correctness of each version
|
||||
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
@patches
|
||||
def test_mm_dup_args(self):
|
||||
|
|
@ -215,7 +208,7 @@ class TestSelectAlgorithm(TestCase):
|
|||
return torch.mm(a, a)
|
||||
|
||||
foo(torch.randn(32, 32, device="cuda"))
|
||||
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
@patches
|
||||
def test_mm_dup_args_view(self):
|
||||
|
|
@ -226,7 +219,7 @@ class TestSelectAlgorithm(TestCase):
|
|||
return torch.mm(q, k.transpose(0, 1))
|
||||
|
||||
foo(torch.randn(64, 64, device="cuda"))
|
||||
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
@skipIfRocm
|
||||
@expectedFailureDynamicWrapper
|
||||
|
|
@ -252,7 +245,7 @@ class TestSelectAlgorithm(TestCase):
|
|||
torch.randn(34, device="cuda"),
|
||||
)
|
||||
# Autotuning checks correctness of each version
|
||||
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
@skipIfRocm
|
||||
@patches
|
||||
|
|
|
|||
|
|
@ -51,7 +51,6 @@ from .fx_passes.post_grad import post_grad_passes, view_to_reshape
|
|||
from .fx_passes.pre_grad import pre_grad_passes
|
||||
from .graph import GraphLowering
|
||||
from .ir import ExternKernelNode
|
||||
from .pattern_matcher import clone_graph
|
||||
from .utils import get_dtype_size, has_incompatible_cudagraph_ops
|
||||
from .virtualized import V
|
||||
|
||||
|
|
@ -217,79 +216,6 @@ def count_bytes_inner(
|
|||
return make_boxed_func(gm.forward)
|
||||
|
||||
|
||||
def inner_compile_with_cpp_wrapper(inner_compile: Callable[..., Any]):
|
||||
@functools.wraps(inner_compile)
|
||||
def wrapper(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], **kwargs):
|
||||
"""
|
||||
Compile into cpp wrapper:
|
||||
For CPU, this is currently done in one pass.
|
||||
For GPU, this is done in two passes: JIT-compile the model with python wrapper code
|
||||
and run it to generate autotuned kernel binaries in the first pass; and then generate
|
||||
cpp wrapper code and compile it to a dynamic library in the second pass.
|
||||
"""
|
||||
devices = (
|
||||
{t.device.type for t in gm.parameters()}
|
||||
| {t.device.type for t in gm.buffers()}
|
||||
| {t.device.type for t in example_inputs if isinstance(t, torch.Tensor)}
|
||||
)
|
||||
|
||||
if "cuda" not in devices:
|
||||
kwargs_patched = {**kwargs, "cpp_wrapper": True}
|
||||
return inner_compile(gm, example_inputs, **kwargs_patched)
|
||||
else:
|
||||
with config.patch(
|
||||
{
|
||||
"triton.store_cubin": True,
|
||||
}
|
||||
):
|
||||
# first pass with regular python wrapper code
|
||||
kwargs_patched = {
|
||||
**kwargs,
|
||||
"cpp_wrapper": False,
|
||||
}
|
||||
# clone_graph(gm) makes sure no graph modification from the first pass will
|
||||
# leak to the second pass. It does increase memory pressure, but the problem
|
||||
# can be alleviated once we have parameters as FakeTensor.
|
||||
|
||||
compiled = inner_compile(
|
||||
clone_graph(gm), example_inputs, **kwargs_patched
|
||||
)
|
||||
|
||||
def materialize(x):
|
||||
if isinstance(x, (torch.SymInt, torch.SymFloat)):
|
||||
# Need concrete value to run dynamic shapes and tune the result
|
||||
return x.node.hint
|
||||
else:
|
||||
assert not isinstance(x, FakeTensor)
|
||||
return x
|
||||
|
||||
if tracing_context := torch._guards.TracingContext.try_get():
|
||||
if tracing_context.output_strides:
|
||||
tracing_context.output_strides.clear()
|
||||
|
||||
params_flat = [
|
||||
param
|
||||
for param in tracing_context.params_flat # type: ignore[union-attr]
|
||||
if param is not None
|
||||
]
|
||||
real_inputs = [
|
||||
materialize(x) for x in (params_flat + V.real_inputs)
|
||||
]
|
||||
else:
|
||||
real_inputs = [materialize(x) for x in V.real_inputs]
|
||||
|
||||
with torch.utils._python_dispatch._disable_current_modes():
|
||||
compiled(real_inputs)
|
||||
|
||||
del real_inputs
|
||||
|
||||
# second pass
|
||||
kwargs_patched = {**kwargs, "cpp_wrapper": True}
|
||||
return inner_compile(gm, example_inputs, **kwargs_patched)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def fake_tensor_prop(
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor],
|
||||
|
|
@ -592,6 +518,10 @@ def fx_codegen_and_compile(
|
|||
with V.set_fake_mode(fake_mode):
|
||||
graph = GraphLowering(
|
||||
gm,
|
||||
# example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.
|
||||
# For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,
|
||||
# we currently use fake tensors and defake them later.
|
||||
example_inputs=V.real_inputs if is_inference else example_inputs,
|
||||
shape_env=shape_env,
|
||||
num_static_inputs=num_fixed,
|
||||
graph_id=graph_id,
|
||||
|
|
@ -1033,6 +963,7 @@ def compile_fx(
|
|||
"cpp_wrapper": False,
|
||||
"triton.autotune_cublasLt": False,
|
||||
"triton.cudagraphs": False,
|
||||
"triton.store_cubin": True,
|
||||
}
|
||||
), V.set_real_inputs(example_inputs_):
|
||||
inputs_ = example_inputs_
|
||||
|
|
@ -1055,7 +986,7 @@ def compile_fx(
|
|||
return compile_fx(
|
||||
model_,
|
||||
inputs_,
|
||||
inner_compile=inner_compile_with_cpp_wrapper(inner_compile),
|
||||
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
|
||||
decompositions=decompositions,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@ import torch
|
|||
import torch._logging
|
||||
import torch.fx
|
||||
from torch._decomp import get_decompositions
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from torch._dynamo.utils import defake, dynamo_timed
|
||||
from torch._logging import LazyString
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
|
||||
from torch.fx.experimental.symbolic_shapes import has_free_symbols, ShapeEnv, SymTypes
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
|
|
@ -164,6 +165,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
def __init__(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: Optional[List[torch.Tensor]] = None,
|
||||
shape_env=None,
|
||||
num_static_inputs=None,
|
||||
graph_id=None,
|
||||
|
|
@ -176,6 +178,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
):
|
||||
super().__init__(gm)
|
||||
|
||||
self.example_inputs = example_inputs
|
||||
self.layout_opt = (
|
||||
layout_opt if layout_opt is not None else self.decide_layout_opt(gm)
|
||||
)
|
||||
|
|
@ -921,6 +924,46 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported"
|
||||
self.wrapper_code = wrapper_code_gen_cls()
|
||||
|
||||
def codegen_with_cpp_wrapper(self):
|
||||
"""
|
||||
For CPU, the cpp wrapper codegen is done in one pass.
|
||||
For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python
|
||||
wrapper code and run it to generate autotuned kernel binaries in the first pass; and then
|
||||
generate cpp wrapper code and compile it to a dynamic library in the second pass.
|
||||
"""
|
||||
if "cuda" in self.device_types:
|
||||
# first pass
|
||||
self.cpp_wrapper = False
|
||||
compiled = self.compile_to_module().call
|
||||
|
||||
def materialize(x):
|
||||
if isinstance(x, (torch.SymInt, torch.SymFloat)):
|
||||
# Need concrete value to run dynamic shapes and tune the result
|
||||
return x.node.hint
|
||||
elif isinstance(x, FakeTensor):
|
||||
return defake(x)
|
||||
else:
|
||||
assert isinstance(
|
||||
x, torch.Tensor
|
||||
), "Unknown type when creating real inputs"
|
||||
return x
|
||||
|
||||
with torch.utils._python_dispatch._disable_current_modes():
|
||||
assert self.example_inputs is not None
|
||||
real_inputs = [materialize(x) for x in self.example_inputs]
|
||||
compiled(real_inputs)
|
||||
del real_inputs
|
||||
|
||||
# second pass
|
||||
# TODO: reuse self.scheduler from the first pass to speed up the second pass
|
||||
self.cpp_wrapper = True
|
||||
self.removed_buffers.clear()
|
||||
self.inplaced_to_remove.clear()
|
||||
return self.codegen()
|
||||
else:
|
||||
# cpu
|
||||
return self.codegen()
|
||||
|
||||
def codegen(self):
|
||||
from .scheduler import Scheduler
|
||||
|
||||
|
|
@ -952,7 +995,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
def compile_to_module(self):
|
||||
from .codecache import PyCodeCache
|
||||
|
||||
code, linemap = self.codegen()
|
||||
code, linemap = (
|
||||
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
|
||||
)
|
||||
linemap = [(line_no, node.stack_trace) for line_no, node in linemap]
|
||||
key, path = PyCodeCache.write(code)
|
||||
mod = PyCodeCache.load_by_key_path(
|
||||
|
|
@ -975,10 +1020,11 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
return mod
|
||||
|
||||
def compile_to_fn(self):
|
||||
if self.aot_mode and self.cpp_wrapper:
|
||||
if self.aot_mode:
|
||||
from .codecache import AotCodeCache
|
||||
|
||||
code, linemap = self.codegen()
|
||||
assert self.cpp_wrapper, "AOT mode only supports C++ wrapper"
|
||||
code, linemap = self.codegen_with_cpp_wrapper()
|
||||
output_code_log.debug("Output code: \n%s", code)
|
||||
|
||||
serialized_extern_kernel_nodes = None
|
||||
|
|
|
|||
Loading…
Reference in a new issue