mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] Adjust dynamic SMEM limit when above default in AOT (#107601)
Summary: When AOT Inductor runs a Triton matmul kernel (generated from the Triton mm template) on large inputs of particular shape, the `RuntimeError: CUDA driver error: 1` may happen. E.g., when `x @ y` is compiled with AOT Inductor and run on the input shapes `[10285, 96]` and `[96, 1]`. Digging deeper into the generated AOT Inductor wrapper code, we see this line: ``` launchKernel(triton_unk_fused_mm_0, 81, 1, 1, 4, 55296, kernel_args_var_0, stream); ``` `55296` is the required amount (in bytes) of dynamic shared memory. This is larger than the default dynamic shared memory on A100: `49152` bytes. In these cases, `cudaFuncSetAttribute` must be called explicitly to set the`cudaFuncAttributeMaxDynamicSharedMemorySize` attribute of the kernel before launching it. Or, because AOT Inductor wrapper relies on the CUDA Driver API, the equivalent [`cuFuncSetAttribute`](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g0e37dce0173bc883aa1e5b14dd747f26) function can be called to set the `CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES` attribute. This PR adds the above call in the AOT Inductor codegen for every case when the required amount of dynamic SMEM is > 0. The call is done *within* the `launchKernel` function, meaning that it will happen only once per kernel and not affect the subsequent AOT Inductor-compiled model performance (after the first run). P.S. One could, in principle, call the `cuFuncSetAttribute` only when the required amount of dynamic SMEM is above the default limit, but that would require detecting the default limit which is different on different devices. Assuming that the `cuFuncSetAttribute` is relatively lightweight and because it's performed only once per kernel, for simplicity, the suggestion is to call the function in every non-zero dynamic SMEM case. Test Plan: ``` $ python test/inductor/test_aot_inductor.py ... ---------------------------------------------------------------------- Ran 5 tests in 100.177s OK ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/107601 Approved by: https://github.com/jansel
This commit is contained in:
parent
cfd98d3c42
commit
3920ce2f6e
2 changed files with 43 additions and 6 deletions
|
|
@ -23,7 +23,7 @@ requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda"
|
|||
|
||||
class AOTInductorModelRunner:
|
||||
@classmethod
|
||||
def load(cls, model, example_inputs, example_outputs):
|
||||
def load(cls, model, example_inputs, example_outputs, options=None):
|
||||
# AOTInductorModel relies on the caller to pass in output_tensors,
|
||||
# so we need to explicitly allocate output tensors here.
|
||||
output_tensors = []
|
||||
|
|
@ -35,6 +35,7 @@ class AOTInductorModelRunner:
|
|||
so_path, exported = torch._export.aot_compile(
|
||||
model,
|
||||
example_inputs,
|
||||
options=options,
|
||||
)
|
||||
|
||||
# Use a utility function for easier testing
|
||||
|
|
@ -60,10 +61,10 @@ class AOTInductorModelRunner:
|
|||
return optimized, exported, output_tensors, output_spec
|
||||
|
||||
@classmethod
|
||||
def run(cls, model, example_inputs, example_outputs):
|
||||
def run(cls, model, example_inputs, example_outputs, options=None):
|
||||
example_outputs = copy.deepcopy(example_outputs)
|
||||
optimized, exported, output_tensors, output_spec = AOTInductorModelRunner.load(
|
||||
model, example_inputs, example_outputs
|
||||
model, example_inputs, example_outputs, options
|
||||
)
|
||||
param_buffer_values = list(exported.state_dict.values())
|
||||
flat_example_inputs = fx_pytree.tree_flatten_spec(
|
||||
|
|
@ -193,6 +194,31 @@ class AotInductorTests(TestCase):
|
|||
|
||||
self.assertTrue(bwd_seq_nr_set.issubset(fwd_seq_nr_set))
|
||||
|
||||
def test_dynamic_smem_above_default_limit(self):
|
||||
class Repro(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return x @ y
|
||||
|
||||
model = Repro()
|
||||
# on A100, the generated Triton kernel for this MM
|
||||
# requires 55296 bytes of dynamic SMEM which is above
|
||||
# the A100's default dynamic SMEM limit of 49152 bytes.
|
||||
example_inputs = (
|
||||
torch.randn(10285, 96, device="cuda"),
|
||||
torch.randn(96, 1, device="cuda"),
|
||||
)
|
||||
expected = model(*example_inputs)
|
||||
actual = AOTInductorModelRunner.run(
|
||||
model,
|
||||
example_inputs,
|
||||
expected,
|
||||
options={
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON",
|
||||
},
|
||||
)
|
||||
self.assertTrue(same(actual, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import torch
|
|||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch.fx.experimental.symbolic_shapes import SymTypes
|
||||
from torch.fx.node import _get_qualified_name
|
||||
|
||||
from .. import codecache, config, ir
|
||||
from ..codecache import CudaKernelParamCache
|
||||
from ..utils import (
|
||||
|
|
@ -1348,12 +1349,21 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
|
|||
} \\
|
||||
} while (0)
|
||||
|
||||
static inline CUfunction loadKernel(const std::string &filePath,
|
||||
const std::string &funcName) {
|
||||
static inline CUfunction loadKernel(
|
||||
const std::string &filePath,
|
||||
const std::string &funcName,
|
||||
int sharedMemBytes) {
|
||||
CUmodule mod;
|
||||
CUfunction func;
|
||||
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuModuleLoad(&mod, filePath.c_str()));
|
||||
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuModuleGetFunction(&func, mod, funcName.c_str()));
|
||||
if (sharedMemBytes > 0) {
|
||||
AT_CUDA_DRIVER_CHECK_OVERRIDE(cuFuncSetAttribute(
|
||||
func,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
sharedMemBytes
|
||||
));
|
||||
}
|
||||
return func;
|
||||
}
|
||||
|
||||
|
|
@ -1400,9 +1410,10 @@ class CudaWrapperCodeGen(CppWrapperCodeGen):
|
|||
cubin_path
|
||||
), "cubin file should already exist at this moment"
|
||||
|
||||
shared_mem = params.get("shared_mem", 0)
|
||||
self.writeline(f"if ({name} == nullptr) {{")
|
||||
self.writeline(
|
||||
f""" {name} = loadKernel("{cubin_path}", "{mangled_name}");"""
|
||||
f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});"""
|
||||
)
|
||||
self.writeline("}")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue