Adding lowering to persistent-tma device kernel for _scaled_mm (#142045)

# Summary
This PR adds an alternative triton lowering for _scaled_mm. This uses an updated mm template that utilizes persistent scheduling + TMAs on A and B matrices.

Limitations:
* This implementations does not work with Bias values: 0602676c8d/torch/_inductor/kernel/mm_scaled.py (L106) Plan is to remove this work around and enforce that both scaling + bias is properly done as epilogues onto the existing templates
* K dim must be 32 or greater for these to take effect
* Gated by a config flag ( currently defaults to Off, maybe should be on)

## Testing
We dont have any tests exercising this code in CI/CD but I updated the relevant tests in test_fp8 and they are all green:
<img width="1680" alt="Screenshot 2024-12-05 at 7 24 07 PM" src="https://github.com/user-attachments/assets/9c520541-d97a-416f-9af7-e68b366ec90f">

## Follow Ups
* Work to update the base mm triton templates and utilize the same template from mm/addmm/scaled_mm w/ respective epilogues
* Tuning on Persistent kernel configs. I found ones that work for my problem shapes but need to do some more NCU work

### Some profiling code I was using

Code I am using to iterate w/
```Python
import torch
from dataclasses import dataclass
from jsonargparse import CLI
import logging
from pathlib import Path

from transformer_nuggets.utils.benchmark import ProfileConfig, profile_function
from torchao.float8.inference import (
    addmm_float8_unwrapped_inference,
    preprocess_data,
    Float8MMConfig,
)
from transformer_nuggets.fp8.fp8_matmul import (
    matmul_persistent,
    matmul_tma_persistent,
    matmul_device_tma_persistent,
)
from enum import Enum

logging.getLogger("transformer_nuggets").setLevel(logging.INFO)

class FP8Kernel(Enum):
    PERSISTENT = "Persistent"
    PERSISTENT_TMA = "Persistent-TMA"
    DEVICE_TMA = "Device-TMA"
    SCALED_MM = "Scaled-MM"

class ScalingStrategy(Enum):
    PER_TENSOR = "PerTensor"
    PER_ROW = "PerRow"

@dataclass(frozen=True)
class ExperimentConfig:
    M: int
    K: int
    N: int
    scaling_strategy: ScalingStrategy
    fp8_kernel: FP8Kernel
    compile: bool

def get_fp8_matmul(
    A: torch.Tensor,
    B: torch.Tensor,
    scaling_strategy: ScalingStrategy,
    fp8_kernel: FP8Kernel,
):
    A_fp8 = A.to(torch.float8_e4m3fn)
    B_fp8 = B.to(torch.float8_e4m3fn)
    A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True))

    if scaling_strategy == ScalingStrategy.PER_TENSOR:
        a_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
        b_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
    elif scaling_strategy == ScalingStrategy.PER_ROW:
        a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32)
        b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T
    else:
        raise ValueError(f"Invalid scaling strategy: {scaling_strategy}")

    assert fp8_kernel == FP8Kernel.SCALED_MM
    return lambda: addmm_float8_unwrapped_inference(
        A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
    )

def run_matmul(config: ExperimentConfig):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16)
    B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16)

    fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel)

    if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
        fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs")

    _ = fp8_matmul()

    return

def main():
    torch.random.manual_seed(123)

    # Define your experiment configuration here
    config = ExperimentConfig(
        M=8192,
        K=8192,
        N=8192,
        scaling_strategy=ScalingStrategy.PER_TENSOR,
        fp8_kernel=FP8Kernel.SCALED_MM,
        compile=True,
    )

    run_matmul(config)

if __name__ == "__main__":
    CLI(main)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142045
Approved by: https://github.com/eellison
This commit is contained in:
drisspg 2024-12-07 09:13:13 -08:00 committed by PyTorch MergeBot
parent 29e985b7b0
commit 75e72e1408
6 changed files with 443 additions and 53 deletions

View file

@ -5,7 +5,7 @@ import unittest
import torch
from torch import Tensor
from torch._inductor import utils
from torch._inductor import config, utils
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM90OrLater
from torch.testing._internal.common_utils import (
@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import (
TEST_WITH_ROCM,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.utils._triton import has_triton_tma_device
torch.set_float32_matmul_precision("high")
@ -414,8 +415,16 @@ class TestFP8Lowering(TestCase):
@parametrize("shape", ("16,16,32", "1024,1024,512"))
@parametrize("has_bias", (False, True))
@parametrize("use_fast_accum", (False, True))
@parametrize(
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
)
def test_tensorwise_scaling(
self, dtype: torch.dtype, shape: str, has_bias: bool, use_fast_accum: bool
self,
dtype: torch.dtype,
shape: str,
has_bias: bool,
use_fast_accum: bool,
persistent_matmul: bool,
):
if dtype is torch.float32 and has_bias:
self.skipTest("bias is not supported when output dtype is float32")
@ -459,28 +468,36 @@ class TestFP8Lowering(TestCase):
w_inverse_scale,
bias,
)
linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
y_compiled = linear_compiled(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
# depending on the kernel config (BLOCK_M size, etc) selected during Inductor
# autotuning for the compiled case, the results can be different because of
# the way blocks of results are accumulated (float addition not associative), so
# setting a small absolute tolerance in these tests
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
with config.patch({"triton.enable_persistent_tma_matmul": True}):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
y_compiled = linear_compiled(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
# depending on the kernel config (BLOCK_M size, etc) selected during Inductor
# autotuning for the compiled case, the results can be different because of
# the way blocks of results are accumulated (float addition not associative), so
# setting a small absolute tolerance in these tests
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize("shape", ("16,16,32", "1024,1024,512"))
@parametrize("has_bias", (False, True))
@parametrize("use_fast_accum", (False, True))
def test_rowwise_scaling(self, shape: str, has_bias: bool, use_fast_accum: bool):
@parametrize(
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
)
def test_rowwise_scaling(
self, shape: str, has_bias: bool, use_fast_accum: bool, persistent_matmul: bool
):
# Only bf16 output type is supported for row-wise scaling, not fp32
dtype: torch.dtype = torch.bfloat16
device = "cuda"
@ -521,7 +538,10 @@ class TestFP8Lowering(TestCase):
w_inverse_scale,
bias,
)
linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
with config.patch({"triton.enable_persistent_tma_matmul": True}):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
y_compiled = linear_compiled(
x_fp8,
x_inverse_scale,
@ -538,7 +558,12 @@ class TestFP8Lowering(TestCase):
@parametrize("M", (1, 3, 33, 257, 1024))
@parametrize("K", (16, 1024))
@parametrize("N", (16, 2048))
def test_tensorwise_scaling_acceptable_input_dims(self, M: int, K: int, N: int):
@parametrize(
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
)
def test_tensorwise_scaling_acceptable_input_dims(
self, M: int, K: int, N: int, persistent_matmul: bool
):
# alignment requirements: K and N divisible by 16
dtype: torch.dtype = torch.bfloat16
use_fast_accum = True
@ -571,14 +596,17 @@ class TestFP8Lowering(TestCase):
w_inverse_scale,
bias,
)
linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
y_compiled = linear_compiled(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
with config.patch({"triton.enable_persistent_tma_matmul": True}):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
y_compiled = linear_compiled(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07)
@ -588,7 +616,12 @@ class TestFP8Lowering(TestCase):
@parametrize("M", (1, 3, 33, 257, 1024))
@parametrize("K", (16, 1024))
@parametrize("N", (16, 2048))
def test_rowwise_scaling_acceptable_input_dims(self, M: int, K: int, N: int):
@parametrize(
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
)
def test_rowwise_scaling_acceptable_input_dims(
self, M: int, K: int, N: int, persistent_matmul: bool
):
dtype: torch.dtype = torch.bfloat16
use_fast_accum = True
device = "cuda"
@ -622,14 +655,17 @@ class TestFP8Lowering(TestCase):
w_inverse_scale,
bias,
)
linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
y_compiled = linear_compiled(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
with config.patch({"triton.enable_persistent_tma_matmul": True}):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
y_compiled = linear_compiled(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07)

View file

@ -1027,6 +1027,12 @@ class triton:
# Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental)
codegen_upcast_to_fp32 = True
# Whether persistent matmul kernels should be enabled this flag only has effect when on h100
# with a verison of triton new enough to support TMA
enable_persistent_tma_matmul = (
os.environ.get("ENABLE_PERSISTENT_TMA_MATMUL", "0") == "1"
)
class aot_inductor:
# AOTInductor output path

View file

@ -226,6 +226,7 @@ def mark_nodes_dislike_padding(
ops_dislike_padding = {
aten.convolution,
aten.convolution_backward,
aten._scaled_mm,
}
# what's a better way to collect the reduction ops?
ops_like_padding = {

View file

@ -2,7 +2,7 @@
import functools
import itertools
import logging
from typing import cast, Sequence, Tuple
from typing import Any, cast, Dict, Sequence, Tuple
import sympy
@ -216,6 +216,19 @@ mixed_mm_kernel_configs = (
else mm_kernel_configs
)
persistent_mm_kernel_configs = [
{"config": (128, 128, 64, 3, 8), "cond": True},
{"config": (128, 128, 128, 3, 8), "cond": True},
{"config": (128, 128, 128, 4, 8), "cond": True},
{"config": (128, 128, 128, 4, 4), "cond": True},
{"config": (128, 128, 128, 3, 4), "cond": True},
{"config": (128, 128, 128, 5, 4), "cond": True},
{"config": (128, 128, 128, 5, 8), "cond": True},
{"config": (128, 128, 128, 6, 8), "cond": True},
{"config": (128, 128, 64, 4, 8), "cond": True},
]
scaled_mm_kernel_configs = [
{"config": (128, 256, 32, 3, 8), "cond": True},
{"config": (256, 128, 32, 3, 8), "cond": True},
@ -344,6 +357,12 @@ scaled_mm_platform_configs = tuple(
if config["cond"]
)
persistent_mm_platform_configs = tuple(
cast(Tuple[int, int, int, int, int], config["config"])
for config in persistent_mm_kernel_configs
if config["cond"]
)
# On ROCm convert num_stages to improve performance
if torch.version.hip:
mm_platform_configs = build_rocm_gemm_configs(mm_platform_configs)
@ -377,6 +396,10 @@ scaled_mm_configs = functools.partial(
configs=scaled_mm_platform_configs,
)
persistent_mm_configs = functools.partial(
filtered_configs, configs=persistent_mm_platform_configs
)
def mm_grid(m, n, meta):
"""
@ -385,6 +408,15 @@ def mm_grid(m, n, meta):
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
def persistent_grid(M: int, N: int, meta: Dict[str, Any]):
"""Defines the grid for persistent kernels."""
return (
min(meta["NUM_SMS"], cdiv(M, meta["BLOCK_M"]) * cdiv(N, meta["BLOCK_N"])),
1,
1,
)
def acc_type(dtype):
if dtype in (torch.float16, torch.bfloat16):
return "tl.float32"

View file

@ -5,9 +5,12 @@ import sympy
import torch
from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
from torch.utils._triton import has_triton_tma_device
from .. import config as inductor_config
from ..ir import ChoiceCaller, Layout, StorageBox, TensorBox
from ..codegen.common import WorkspaceArg, WorkspaceZeroMode
from ..config import triton as triton_config
from ..ir import _IntLike, ChoiceCaller, Layout, StorageBox, TensorBox
from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering
from ..select_algorithm import (
autotune_select_algorithm,
@ -17,12 +20,194 @@ from ..select_algorithm import (
TritonTemplate,
)
from ..utils import use_aten_gemm_kernels, use_ck_gemm_template, use_triton_template
from .mm_common import _is_static_problem, mm_args, mm_grid, scaled_mm_configs
from .mm_common import (
_is_static_problem,
mm_args,
mm_grid,
persistent_grid,
persistent_mm_configs,
scaled_mm_configs,
)
_TMA_SIZE = 128
log = logging.getLogger(__name__)
aten = torch.ops.aten
load_scales = r"""
@triton.jit
def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr):
if SCALING_ROWWISE:
# For row-wise scaling, we'll return the pointers
return a_scale_ptr, b_scale_ptr
else:
# For per-tensor scaling, we'll load the scalar values
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr)
return a_scale, b_scale
"""
apply_scaling = r"""
@triton.jit
def apply_scaling(
accumulator,
a_scale,
b_scale,
SCALING_ROWWISE: tl.constexpr,
offs_cm,
offs_cn,
M,
N,
stride_a_scale_m,
stride_b_scale_n,
):
if SCALING_ROWWISE:
# For row-wise scaling, we need to load the scales for each row/column
a_scales = tl.load(
a_scale + (offs_cm * stride_a_scale_m),
mask=offs_cm < M,
other=0.0,
)
b_scales = tl.load(
b_scale + (offs_cn * stride_b_scale_n),
mask=offs_cn < N,
other=0.0,
)
acc_scale = a_scales[:, None] * b_scales[None, :]
else:
# For per-tensor scaling, we can directly use the loaded scalar values
acc_scale = a_scale * b_scale
return accumulator * acc_scale
"""
device_tma = r"""
{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}
M = {{size("A", 0)}}
N = {{size("B", 1)}}
K = {{size("A", 1)}}
if M * N == 0:
# early exit due to zero-size input(s)
return
stride_am = {{stride("A", 0)}}
stride_ak = {{stride("A", 1)}}
stride_bk = {{stride("B", 0)}}
stride_bn = {{stride("B", 1)}}
if SCALING_ROWWISE:
stride_a_scale_m = 1
stride_b_scale_n = 1
else:
stride_a_scale_m = 0
stride_b_scale_n = 0
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
k_tiles = tl.cdiv(K, BLOCK_K)
num_tiles = num_pid_m * num_pid_n
workspace_base = ws_ptr + start_pid * 3 * TMA_SIZE
a_desc_ptr = workspace_base
b_desc_ptr = workspace_base + TMA_SIZE
c_desc_ptr = workspace_base + 2 * TMA_SIZE
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=a_desc_ptr,
global_address=A,
load_size=[BLOCK_M, BLOCK_K],
global_size=[M, K],
element_ty=A.dtype.element_ty,
)
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=b_desc_ptr,
global_address=B,
load_size=[BLOCK_N, BLOCK_K],
global_size=[N, K],
element_ty=B.dtype.element_ty,
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1
tile_id = start_pid - NUM_SMS
ki = -1
pid_m = 0
pid_n = 0
offs_am = 0
offs_bn = 0
num_pid_in_group = GROUP_M * num_pid_n
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE)
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_M
offs_bn = pid_n * BLOCK_N
offs_k = ki * BLOCK_K
a = tl._experimental_descriptor_load(
a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty
)
b = tl._experimental_descriptor_load(
b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty
)
if USE_FAST_ACCUM:
accumulator = tl.dot(a, b.T, accumulator)
else:
accumulator += tl.dot(a, b.T)
if ki == k_tiles - 1:
# Apply inverse scaling
offs_cm = offs_am + tl.arange(0, BLOCK_M)
offs_cn = offs_bn + tl.arange(0, BLOCK_N)
# Apply scaling
accumulator = apply_scaling(
accumulator,
a_scale,
b_scale,
SCALING_ROWWISE,
offs_cm,
offs_cn,
M,
N,
stride_a_scale_m,
stride_b_scale_n,
)
idx_m = offs_cm[:, None]
idx_n = offs_cn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
{{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}}
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
"""
scaled_mm_device_tma_template = TritonTemplate(
name="scaled_mm_device_tma",
grid=persistent_grid,
source=device_tma + load_scales + apply_scaling,
)
scaled_mm_template = TritonTemplate(
name="scaled_mm",
@ -206,6 +391,66 @@ def are_compatible_scales(size_a: Sequence[int], size_b: Sequence[int]) -> bool:
return False
def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None:
def is_row_major(stride: Sequence[_IntLike]) -> bool:
return stride[1] == 1
def is_col_major(stride: Sequence[_IntLike]) -> bool:
return stride[0] == 1
def has_zero_dim(size: Sequence[_IntLike]) -> bool:
return bool(size[0] == 0 or size[1] == 0)
# Check mat_a (self) stride requirements
torch._check(
is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()),
lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}",
)
# Check mat_b stride requirements
torch._check(
is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()),
lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}",
)
def scaled_mm_options_device_tma( # type: ignore[no-untyped-def]
config, # triton.Config
sym_m: sympy.core.numbers.Integer,
sym_n: sympy.core.numbers.Integer,
sym_k: sympy.core.numbers.Integer,
layout: Layout,
scale_a: StorageBox,
scale_b: StorageBox,
use_fast_accum: bool,
b_prologue_cast_type: Optional[str] = None,
) -> Dict[str, Any]:
even_k_symbolic = (
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
)
size_a, size_b = scale_a.get_size(), scale_b.get_size()
assert are_compatible_scales(size_a, size_b), (
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
return dict(
GROUP_M=8,
EVEN_K=even_k_symbolic,
ACC_TYPE="tl.float32",
B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
USE_FAST_ACCUM=use_fast_accum,
num_stages=config.num_stages,
num_warps=config.num_warps,
# tensor-wise scaling if scalar scales
SCALING_ROWWISE=len(scale_a.get_size()) == 2,
TMA_SIZE=_TMA_SIZE,
NUM_SMS=NUM_SMS,
**config.kwargs,
)
def scaled_mm_options( # type: ignore[no-untyped-def]
config, # triton.Config
sym_m: sympy.core.numbers.Integer,
@ -243,6 +488,33 @@ def scaled_mm_options( # type: ignore[no-untyped-def]
add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides)
def get_workspace_size(
num_sms: int, TMA_SIZE: int = _TMA_SIZE, NUM_TMA_DESCRIPTORS: int = 3
) -> int:
"""Device side TMA requires a workspace buffer to be allocated in global memory."""
return num_sms * NUM_TMA_DESCRIPTORS * TMA_SIZE
def get_workspace_arg(num_sms: int, device: torch.device) -> WorkspaceArg:
"""Builds and returns a WorkspaceArg for the device side TMA workspace buffer."""
size = get_workspace_size(num_sms)
zero_mode = WorkspaceZeroMode.from_bool(False)
return WorkspaceArg(
count=size,
zero_mode=zero_mode,
device=device,
outer_name=WorkspaceArg.unique_name(),
)
def use_persistent_tma(k: sympy.core.numbers.Integer, has_bias: bool) -> bool:
available = has_triton_tma_device() and triton_config.enable_persistent_tma_matmul
# _determine_swizzle_mode_2d requires BLOCK_K to be at least 32 contiguous bytes
# When K is 16, BLOCK_K = 16 and is not valid
min_k = k >= 32
return available and min_k and not has_bias
@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc]
def tuned_scaled_mm(
mat_a: TensorBox,
@ -258,6 +530,9 @@ def tuned_scaled_mm(
m, n, k, layout, mat_a, mat_b = mm_args(
mat_a, mat_b, layout=layout, out_dtype=out_dtype
)
check_supported_striding(mat_a, mat_b)
scale_a, scale_b = realize_inputs(scale_a, scale_b)
input_nodes: Tuple[Any, ...]
@ -279,20 +554,37 @@ def tuned_scaled_mm(
choices.append(aten_choice)
static_shape, is_nonzero = _is_static_problem(layout)
if is_nonzero and use_triton_template(layout, enable_float8=True):
for config in scaled_mm_configs(m, n, k):
if k == 16 and config.kwargs["BLOCK_M"] >= 64:
continue # Triton crashes in this case
kwargs = scaled_mm_options(
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
)
# possibly appends a TritonTemplateCaller to choices
triton_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
**kwargs,
)
if use_persistent_tma(k, bias is not None):
for config in persistent_mm_configs(m, n, k):
kwargs = scaled_mm_options_device_tma(
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
)
input_nodes = (mat_a, mat_b, scale_a, scale_b)
scaled_mm_device_tma_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
workspace_arg=get_workspace_arg(
kwargs["NUM_SMS"], mat_a.get_device()
),
**kwargs,
)
else:
for config in scaled_mm_configs(m, n, k):
if k == 16 and config.kwargs["BLOCK_M"] >= 64:
continue # Triton crashes in this case
kwargs = scaled_mm_options(
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
)
# possibly appends a TritonTemplateCaller to choices
triton_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
**kwargs,
)
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes)

View file

@ -38,6 +38,29 @@ def has_triton_tma():
return False
@functools.lru_cache(None)
def has_triton_tma_device():
if has_triton_package():
import torch
if (
torch.cuda.is_available()
and torch.cuda.get_device_capability() >= (9, 0)
and not torch.version.hip
):
try:
from triton.language.extra.cuda import ( # noqa: F401
experimental_device_tensormap_create1d,
experimental_device_tensormap_create2d,
)
return True
except ImportError:
pass
return False
@functools.lru_cache(None)
def has_triton() -> bool:
if not has_triton_package():