diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index b1dc7431816..fa7411dd5cf 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -18,6 +18,7 @@ from torch._inductor.utils import aot_inductor_launcher, cache_dir from torch.testing import FileCheck from torch.testing._internal import common_utils +from torch.testing._internal.common_quantization import skip_if_no_torchvision from torch.testing._internal.common_utils import ( IS_CI, @@ -1060,6 +1061,42 @@ class AOTInductorTestsTemplate: example_inputs = (torch.randn(3, 10, device=self.device),) self.check_model(Model(), example_inputs) + @skip_if_no_torchvision + def test_missing_cubin(self): + from torchvision.models.resnet import Bottleneck, ResNet + + class Model(ResNet): + def __init__(self): + super().__init__( + block=Bottleneck, + layers=[3, 4, 6, 3], + replace_stride_with_dilation=[False, False, True], + norm_layer=None, + ) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + f1 = x + x = self.maxpool(x) + x = self.layer1(x) + f2 = x + x = self.layer2(x) + f3 = x + x = self.layer3(x) + x = self.layer4(x) + f4 = x + return [f1, f2, f3, f4] + + # Call eval() here so that batch_norm won't update the running stats + # Use float64 to avoid numeric difference failure + model = Model().to(device=self.device, dtype=torch.float64).eval() + example_inputs = ( + torch.randn(4, 3, 64, 64, device=self.device, dtype=torch.float64), + ) + self.check_model(model, example_inputs) + @common_utils.parametrize("grid_type", [1, 2, 3]) @common_utils.parametrize("num_dims", [1, 2]) @common_utils.parametrize("dynamic", [False, True]) @@ -1194,6 +1231,8 @@ copy_tests( # TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally, # NotImplementedError: Cannot access storage of OpaqueTensorImpl "test_freezing": TestFailure(("abi_compatible_cpu",), is_skip=True), + # Need to support convolution + "test_missing_cubin": TestFailure(("abi_compatible_cpu",)), "test_normal_functional": TestFailure(("abi_compatible_cpu",)), "test_poi_multiple_dynamic": TestFailure(("abi_compatible_cpu",)), # There is a double-free issue which will be fixed in another PR @@ -1219,6 +1258,8 @@ copy_tests( # test_failures, xfail by default, set is_skip=True to skip { "test_dup_unbacked_sym_decl": TestFailure(("abi_compatible_cuda",)), + # Need to support convolution + "test_missing_cubin": TestFailure(("abi_compatible_cuda",)), "test_normal_functional": TestFailure(("abi_compatible_cuda",)), # There is a double-free issue which will be fixed in another PR "test_repeat_output": TestFailure(("abi_compatible_cuda",), is_skip=True), diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 726cfcea4ed..6de2d05a481 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -50,10 +50,6 @@ class TestPatternMatcher(TestCase): if len(codes) == 1: codes = codes[0] torch.testing.assert_close(actual, expected) - if inductor_config.cpp_wrapper: - # CPP wrapper runs everything twice, so we'll match the pattern twice - expected_matches *= 2 - expected_nodes *= 2 self.assertEqual( counters["inductor"]["pattern_matcher_count"], expected_matches @@ -519,13 +515,6 @@ class TestPatternMatcher(TestCase): self.common(fn, args, 2, 5) def test_cat_slice_cat(self): - def check_counter(counter, expected): - if not inductor_config.cpp_wrapper: - self.assertEqual(counter, expected) - else: - # cpp_wrapper for the CUDA backend runs two passes - self.assertEqual(counter, 2 * expected) - def fn(a, b): cat_1 = torch.ops.aten.cat.default([a, b], 1) slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807) @@ -548,8 +537,8 @@ class TestPatternMatcher(TestCase): torch.testing.assert_close(actual, expected) # We don't recompile for dynamic-shape cases. if dynamo_config.assume_static_by_default: - check_counter(counters["inductor"]["pattern_matcher_count"], 1) - check_counter(counters["inductor"]["pattern_matcher_nodes"], 3) + self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1) + self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3) # Verify we fallback to non-optimal path for negative `end`. def fn(a, b): diff --git a/test/inductor/test_select_algorithm.py b/test/inductor/test_select_algorithm.py index 28c17610610..c1ee714eb2f 100644 --- a/test/inductor/test_select_algorithm.py +++ b/test/inductor/test_select_algorithm.py @@ -44,13 +44,6 @@ def patches(fn): class TestSelectAlgorithm(TestCase): - def check_counter(self, counter, expected): - if not inductor_config.cpp_wrapper: - self.assertEqual(counter, expected) - else: - # cpp_wrapper for the CUDA backend runs two passes - self.assertEqual(counter, 2 * expected) - @expectedFailureDynamicWrapper @patches def test_linear_relu(self): @@ -64,7 +57,7 @@ class TestSelectAlgorithm(TestCase): torch.randn(1, 16, device="cuda"), ) # Autotuning checks correctness of each version - self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) # It would be nice to assert this got fused into a single kernel, but that # only happens if we select a triton template (and not aten). @@ -82,7 +75,7 @@ class TestSelectAlgorithm(TestCase): ) foo(*inps) - self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) @patch.object(select_algorithm, "VERIFY", dict(atol=5e-2, rtol=5e-2)) @patches @@ -112,7 +105,7 @@ class TestSelectAlgorithm(TestCase): torch.randn(8, 32, device="cuda"), torch.randn(32, 8, device="cuda"), ) - self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) @patches def test__int_mm(self): @@ -206,7 +199,7 @@ class TestSelectAlgorithm(TestCase): torch.randn(512, 512, device="cuda"), ) # Autotuning checks correctness of each version - self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) @patches def test_mm_dup_args(self): @@ -215,7 +208,7 @@ class TestSelectAlgorithm(TestCase): return torch.mm(a, a) foo(torch.randn(32, 32, device="cuda")) - self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) @patches def test_mm_dup_args_view(self): @@ -226,7 +219,7 @@ class TestSelectAlgorithm(TestCase): return torch.mm(q, k.transpose(0, 1)) foo(torch.randn(64, 64, device="cuda")) - self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) @skipIfRocm @expectedFailureDynamicWrapper @@ -252,7 +245,7 @@ class TestSelectAlgorithm(TestCase): torch.randn(34, device="cuda"), ) # Autotuning checks correctness of each version - self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) @skipIfRocm @patches diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index b81088e7757..740c80de78e 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -51,7 +51,6 @@ from .fx_passes.post_grad import post_grad_passes, view_to_reshape from .fx_passes.pre_grad import pre_grad_passes from .graph import GraphLowering from .ir import ExternKernelNode -from .pattern_matcher import clone_graph from .utils import get_dtype_size, has_incompatible_cudagraph_ops from .virtualized import V @@ -217,79 +216,6 @@ def count_bytes_inner( return make_boxed_func(gm.forward) -def inner_compile_with_cpp_wrapper(inner_compile: Callable[..., Any]): - @functools.wraps(inner_compile) - def wrapper(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], **kwargs): - """ - Compile into cpp wrapper: - For CPU, this is currently done in one pass. - For GPU, this is done in two passes: JIT-compile the model with python wrapper code - and run it to generate autotuned kernel binaries in the first pass; and then generate - cpp wrapper code and compile it to a dynamic library in the second pass. - """ - devices = ( - {t.device.type for t in gm.parameters()} - | {t.device.type for t in gm.buffers()} - | {t.device.type for t in example_inputs if isinstance(t, torch.Tensor)} - ) - - if "cuda" not in devices: - kwargs_patched = {**kwargs, "cpp_wrapper": True} - return inner_compile(gm, example_inputs, **kwargs_patched) - else: - with config.patch( - { - "triton.store_cubin": True, - } - ): - # first pass with regular python wrapper code - kwargs_patched = { - **kwargs, - "cpp_wrapper": False, - } - # clone_graph(gm) makes sure no graph modification from the first pass will - # leak to the second pass. It does increase memory pressure, but the problem - # can be alleviated once we have parameters as FakeTensor. - - compiled = inner_compile( - clone_graph(gm), example_inputs, **kwargs_patched - ) - - def materialize(x): - if isinstance(x, (torch.SymInt, torch.SymFloat)): - # Need concrete value to run dynamic shapes and tune the result - return x.node.hint - else: - assert not isinstance(x, FakeTensor) - return x - - if tracing_context := torch._guards.TracingContext.try_get(): - if tracing_context.output_strides: - tracing_context.output_strides.clear() - - params_flat = [ - param - for param in tracing_context.params_flat # type: ignore[union-attr] - if param is not None - ] - real_inputs = [ - materialize(x) for x in (params_flat + V.real_inputs) - ] - else: - real_inputs = [materialize(x) for x in V.real_inputs] - - with torch.utils._python_dispatch._disable_current_modes(): - compiled(real_inputs) - - del real_inputs - - # second pass - kwargs_patched = {**kwargs, "cpp_wrapper": True} - return inner_compile(gm, example_inputs, **kwargs_patched) - - return wrapper - - def fake_tensor_prop( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], @@ -592,6 +518,10 @@ def fx_codegen_and_compile( with V.set_fake_mode(fake_mode): graph = GraphLowering( gm, + # example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning. + # For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass, + # we currently use fake tensors and defake them later. + example_inputs=V.real_inputs if is_inference else example_inputs, shape_env=shape_env, num_static_inputs=num_fixed, graph_id=graph_id, @@ -1033,6 +963,7 @@ def compile_fx( "cpp_wrapper": False, "triton.autotune_cublasLt": False, "triton.cudagraphs": False, + "triton.store_cubin": True, } ), V.set_real_inputs(example_inputs_): inputs_ = example_inputs_ @@ -1055,7 +986,7 @@ def compile_fx( return compile_fx( model_, inputs_, - inner_compile=inner_compile_with_cpp_wrapper(inner_compile), + inner_compile=functools.partial(inner_compile, cpp_wrapper=True), decompositions=decompositions, ) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 3c320766770..1027fbf7c85 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -15,8 +15,9 @@ import torch import torch._logging import torch.fx from torch._decomp import get_decompositions -from torch._dynamo.utils import dynamo_timed +from torch._dynamo.utils import defake, dynamo_timed from torch._logging import LazyString +from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.sym_node import magic_methods, method_to_operator from torch.fx.experimental.symbolic_shapes import has_free_symbols, ShapeEnv, SymTypes from torch.utils._mode_utils import no_dispatch @@ -164,6 +165,7 @@ class GraphLowering(torch.fx.Interpreter): def __init__( self, gm: torch.fx.GraphModule, + example_inputs: Optional[List[torch.Tensor]] = None, shape_env=None, num_static_inputs=None, graph_id=None, @@ -176,6 +178,7 @@ class GraphLowering(torch.fx.Interpreter): ): super().__init__(gm) + self.example_inputs = example_inputs self.layout_opt = ( layout_opt if layout_opt is not None else self.decide_layout_opt(gm) ) @@ -921,6 +924,46 @@ class GraphLowering(torch.fx.Interpreter): assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported" self.wrapper_code = wrapper_code_gen_cls() + def codegen_with_cpp_wrapper(self): + """ + For CPU, the cpp wrapper codegen is done in one pass. + For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python + wrapper code and run it to generate autotuned kernel binaries in the first pass; and then + generate cpp wrapper code and compile it to a dynamic library in the second pass. + """ + if "cuda" in self.device_types: + # first pass + self.cpp_wrapper = False + compiled = self.compile_to_module().call + + def materialize(x): + if isinstance(x, (torch.SymInt, torch.SymFloat)): + # Need concrete value to run dynamic shapes and tune the result + return x.node.hint + elif isinstance(x, FakeTensor): + return defake(x) + else: + assert isinstance( + x, torch.Tensor + ), "Unknown type when creating real inputs" + return x + + with torch.utils._python_dispatch._disable_current_modes(): + assert self.example_inputs is not None + real_inputs = [materialize(x) for x in self.example_inputs] + compiled(real_inputs) + del real_inputs + + # second pass + # TODO: reuse self.scheduler from the first pass to speed up the second pass + self.cpp_wrapper = True + self.removed_buffers.clear() + self.inplaced_to_remove.clear() + return self.codegen() + else: + # cpu + return self.codegen() + def codegen(self): from .scheduler import Scheduler @@ -952,7 +995,9 @@ class GraphLowering(torch.fx.Interpreter): def compile_to_module(self): from .codecache import PyCodeCache - code, linemap = self.codegen() + code, linemap = ( + self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() + ) linemap = [(line_no, node.stack_trace) for line_no, node in linemap] key, path = PyCodeCache.write(code) mod = PyCodeCache.load_by_key_path( @@ -975,10 +1020,11 @@ class GraphLowering(torch.fx.Interpreter): return mod def compile_to_fn(self): - if self.aot_mode and self.cpp_wrapper: + if self.aot_mode: from .codecache import AotCodeCache - code, linemap = self.codegen() + assert self.cpp_wrapper, "AOT mode only supports C++ wrapper" + code, linemap = self.codegen_with_cpp_wrapper() output_code_log.debug("Output code: \n%s", code) serialized_extern_kernel_nodes = None