From 3920ce2f6ef7f93dd121f86371c1b35697e2e744 Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Mon, 21 Aug 2023 08:32:05 -0700 Subject: [PATCH] [inductor] Adjust dynamic SMEM limit when above default in AOT (#107601) Summary: When AOT Inductor runs a Triton matmul kernel (generated from the Triton mm template) on large inputs of particular shape, the `RuntimeError: CUDA driver error: 1` may happen. E.g., when `x @ y` is compiled with AOT Inductor and run on the input shapes `[10285, 96]` and `[96, 1]`. Digging deeper into the generated AOT Inductor wrapper code, we see this line: ``` launchKernel(triton_unk_fused_mm_0, 81, 1, 1, 4, 55296, kernel_args_var_0, stream); ``` `55296` is the required amount (in bytes) of dynamic shared memory. This is larger than the default dynamic shared memory on A100: `49152` bytes. In these cases, `cudaFuncSetAttribute` must be called explicitly to set the`cudaFuncAttributeMaxDynamicSharedMemorySize` attribute of the kernel before launching it. Or, because AOT Inductor wrapper relies on the CUDA Driver API, the equivalent [`cuFuncSetAttribute`](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g0e37dce0173bc883aa1e5b14dd747f26) function can be called to set the `CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES` attribute. This PR adds the above call in the AOT Inductor codegen for every case when the required amount of dynamic SMEM is > 0. The call is done *within* the `launchKernel` function, meaning that it will happen only once per kernel and not affect the subsequent AOT Inductor-compiled model performance (after the first run). P.S. One could, in principle, call the `cuFuncSetAttribute` only when the required amount of dynamic SMEM is above the default limit, but that would require detecting the default limit which is different on different devices. Assuming that the `cuFuncSetAttribute` is relatively lightweight and because it's performed only once per kernel, for simplicity, the suggestion is to call the function in every non-zero dynamic SMEM case. Test Plan: ``` $ python test/inductor/test_aot_inductor.py ... ---------------------------------------------------------------------- Ran 5 tests in 100.177s OK ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/107601 Approved by: https://github.com/jansel --- test/inductor/test_aot_inductor.py | 32 +++++++++++++++++++++++++++--- torch/_inductor/codegen/wrapper.py | 17 +++++++++++++--- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index b080c95552f..3321972ea8f 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -23,7 +23,7 @@ requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda" class AOTInductorModelRunner: @classmethod - def load(cls, model, example_inputs, example_outputs): + def load(cls, model, example_inputs, example_outputs, options=None): # AOTInductorModel relies on the caller to pass in output_tensors, # so we need to explicitly allocate output tensors here. output_tensors = [] @@ -35,6 +35,7 @@ class AOTInductorModelRunner: so_path, exported = torch._export.aot_compile( model, example_inputs, + options=options, ) # Use a utility function for easier testing @@ -60,10 +61,10 @@ class AOTInductorModelRunner: return optimized, exported, output_tensors, output_spec @classmethod - def run(cls, model, example_inputs, example_outputs): + def run(cls, model, example_inputs, example_outputs, options=None): example_outputs = copy.deepcopy(example_outputs) optimized, exported, output_tensors, output_spec = AOTInductorModelRunner.load( - model, example_inputs, example_outputs + model, example_inputs, example_outputs, options ) param_buffer_values = list(exported.state_dict.values()) flat_example_inputs = fx_pytree.tree_flatten_spec( @@ -193,6 +194,31 @@ class AotInductorTests(TestCase): self.assertTrue(bwd_seq_nr_set.issubset(fwd_seq_nr_set)) + def test_dynamic_smem_above_default_limit(self): + class Repro(torch.nn.Module): + def forward(self, x, y): + return x @ y + + model = Repro() + # on A100, the generated Triton kernel for this MM + # requires 55296 bytes of dynamic SMEM which is above + # the A100's default dynamic SMEM limit of 49152 bytes. + example_inputs = ( + torch.randn(10285, 96, device="cuda"), + torch.randn(96, 1, device="cuda"), + ) + expected = model(*example_inputs) + actual = AOTInductorModelRunner.run( + model, + example_inputs, + expected, + options={ + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + }, + ) + self.assertTrue(same(actual, expected)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0208cab2b86..fa4d382b55e 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -14,6 +14,7 @@ import torch from torch._dynamo.utils import counters, dynamo_timed from torch.fx.experimental.symbolic_shapes import SymTypes from torch.fx.node import _get_qualified_name + from .. import codecache, config, ir from ..codecache import CudaKernelParamCache from ..utils import ( @@ -1348,12 +1349,21 @@ class CudaWrapperCodeGen(CppWrapperCodeGen): } \\ } while (0) - static inline CUfunction loadKernel(const std::string &filePath, - const std::string &funcName) { + static inline CUfunction loadKernel( + const std::string &filePath, + const std::string &funcName, + int sharedMemBytes) { CUmodule mod; CUfunction func; AT_CUDA_DRIVER_CHECK_OVERRIDE(cuModuleLoad(&mod, filePath.c_str())); AT_CUDA_DRIVER_CHECK_OVERRIDE(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + AT_CUDA_DRIVER_CHECK_OVERRIDE(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )); + } return func; } @@ -1400,9 +1410,10 @@ class CudaWrapperCodeGen(CppWrapperCodeGen): cubin_path ), "cubin file should already exist at this moment" + shared_mem = params.get("shared_mem", 0) self.writeline(f"if ({name} == nullptr) {{") self.writeline( - f""" {name} = loadKernel("{cubin_path}", "{mangled_name}");""" + f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});""" ) self.writeline("}")