mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Adding lowering to persistent-tma device kernel for _scaled_mm (#142045)
# Summary
This PR adds an alternative triton lowering for _scaled_mm. This uses an updated mm template that utilizes persistent scheduling + TMAs on A and B matrices.
Limitations:
* This implementations does not work with Bias values: 0602676c8d/torch/_inductor/kernel/mm_scaled.py (L106) Plan is to remove this work around and enforce that both scaling + bias is properly done as epilogues onto the existing templates
* K dim must be 32 or greater for these to take effect
* Gated by a config flag ( currently defaults to Off, maybe should be on)
## Testing
We dont have any tests exercising this code in CI/CD but I updated the relevant tests in test_fp8 and they are all green:
<img width="1680" alt="Screenshot 2024-12-05 at 7 24 07 PM" src="https://github.com/user-attachments/assets/9c520541-d97a-416f-9af7-e68b366ec90f">
## Follow Ups
* Work to update the base mm triton templates and utilize the same template from mm/addmm/scaled_mm w/ respective epilogues
* Tuning on Persistent kernel configs. I found ones that work for my problem shapes but need to do some more NCU work
### Some profiling code I was using
Code I am using to iterate w/
```Python
import torch
from dataclasses import dataclass
from jsonargparse import CLI
import logging
from pathlib import Path
from transformer_nuggets.utils.benchmark import ProfileConfig, profile_function
from torchao.float8.inference import (
addmm_float8_unwrapped_inference,
preprocess_data,
Float8MMConfig,
)
from transformer_nuggets.fp8.fp8_matmul import (
matmul_persistent,
matmul_tma_persistent,
matmul_device_tma_persistent,
)
from enum import Enum
logging.getLogger("transformer_nuggets").setLevel(logging.INFO)
class FP8Kernel(Enum):
PERSISTENT = "Persistent"
PERSISTENT_TMA = "Persistent-TMA"
DEVICE_TMA = "Device-TMA"
SCALED_MM = "Scaled-MM"
class ScalingStrategy(Enum):
PER_TENSOR = "PerTensor"
PER_ROW = "PerRow"
@dataclass(frozen=True)
class ExperimentConfig:
M: int
K: int
N: int
scaling_strategy: ScalingStrategy
fp8_kernel: FP8Kernel
compile: bool
def get_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
scaling_strategy: ScalingStrategy,
fp8_kernel: FP8Kernel,
):
A_fp8 = A.to(torch.float8_e4m3fn)
B_fp8 = B.to(torch.float8_e4m3fn)
A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True))
if scaling_strategy == ScalingStrategy.PER_TENSOR:
a_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
b_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
elif scaling_strategy == ScalingStrategy.PER_ROW:
a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32)
b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T
else:
raise ValueError(f"Invalid scaling strategy: {scaling_strategy}")
assert fp8_kernel == FP8Kernel.SCALED_MM
return lambda: addmm_float8_unwrapped_inference(
A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
)
def run_matmul(config: ExperimentConfig):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16)
B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16)
fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel)
if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs")
_ = fp8_matmul()
return
def main():
torch.random.manual_seed(123)
# Define your experiment configuration here
config = ExperimentConfig(
M=8192,
K=8192,
N=8192,
scaling_strategy=ScalingStrategy.PER_TENSOR,
fp8_kernel=FP8Kernel.SCALED_MM,
compile=True,
)
run_matmul(config)
if __name__ == "__main__":
CLI(main)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142045
Approved by: https://github.com/eellison
This commit is contained in:
parent
29e985b7b0
commit
75e72e1408
6 changed files with 443 additions and 53 deletions
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._inductor import utils
|
||||
from torch._inductor import config, utils
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM90OrLater
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
from torch.utils._triton import has_triton_tma_device
|
||||
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
|
@ -414,8 +415,16 @@ class TestFP8Lowering(TestCase):
|
|||
@parametrize("shape", ("16,16,32", "1024,1024,512"))
|
||||
@parametrize("has_bias", (False, True))
|
||||
@parametrize("use_fast_accum", (False, True))
|
||||
@parametrize(
|
||||
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
|
||||
)
|
||||
def test_tensorwise_scaling(
|
||||
self, dtype: torch.dtype, shape: str, has_bias: bool, use_fast_accum: bool
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
shape: str,
|
||||
has_bias: bool,
|
||||
use_fast_accum: bool,
|
||||
persistent_matmul: bool,
|
||||
):
|
||||
if dtype is torch.float32 and has_bias:
|
||||
self.skipTest("bias is not supported when output dtype is float32")
|
||||
|
|
@ -459,28 +468,36 @@ class TestFP8Lowering(TestCase):
|
|||
w_inverse_scale,
|
||||
bias,
|
||||
)
|
||||
linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
|
||||
y_compiled = linear_compiled(
|
||||
x_fp8,
|
||||
x_inverse_scale,
|
||||
w_t_fp8,
|
||||
w_inverse_scale,
|
||||
bias,
|
||||
)
|
||||
self.assertEqual(y_eager.dtype, dtype)
|
||||
self.assertEqual(y_compiled.dtype, dtype)
|
||||
# depending on the kernel config (BLOCK_M size, etc) selected during Inductor
|
||||
# autotuning for the compiled case, the results can be different because of
|
||||
# the way blocks of results are accumulated (float addition not associative), so
|
||||
# setting a small absolute tolerance in these tests
|
||||
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
|
||||
with config.patch({"triton.enable_persistent_tma_matmul": True}):
|
||||
linear_compiled = torch.compile(
|
||||
linear, backend="inductor", mode="max-autotune"
|
||||
)
|
||||
y_compiled = linear_compiled(
|
||||
x_fp8,
|
||||
x_inverse_scale,
|
||||
w_t_fp8,
|
||||
w_inverse_scale,
|
||||
bias,
|
||||
)
|
||||
self.assertEqual(y_eager.dtype, dtype)
|
||||
self.assertEqual(y_compiled.dtype, dtype)
|
||||
# depending on the kernel config (BLOCK_M size, etc) selected during Inductor
|
||||
# autotuning for the compiled case, the results can be different because of
|
||||
# the way blocks of results are accumulated (float addition not associative), so
|
||||
# setting a small absolute tolerance in these tests
|
||||
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
|
||||
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
|
||||
@parametrize("shape", ("16,16,32", "1024,1024,512"))
|
||||
@parametrize("has_bias", (False, True))
|
||||
@parametrize("use_fast_accum", (False, True))
|
||||
def test_rowwise_scaling(self, shape: str, has_bias: bool, use_fast_accum: bool):
|
||||
@parametrize(
|
||||
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
|
||||
)
|
||||
def test_rowwise_scaling(
|
||||
self, shape: str, has_bias: bool, use_fast_accum: bool, persistent_matmul: bool
|
||||
):
|
||||
# Only bf16 output type is supported for row-wise scaling, not fp32
|
||||
dtype: torch.dtype = torch.bfloat16
|
||||
device = "cuda"
|
||||
|
|
@ -521,7 +538,10 @@ class TestFP8Lowering(TestCase):
|
|||
w_inverse_scale,
|
||||
bias,
|
||||
)
|
||||
linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
|
||||
with config.patch({"triton.enable_persistent_tma_matmul": True}):
|
||||
linear_compiled = torch.compile(
|
||||
linear, backend="inductor", mode="max-autotune"
|
||||
)
|
||||
y_compiled = linear_compiled(
|
||||
x_fp8,
|
||||
x_inverse_scale,
|
||||
|
|
@ -538,7 +558,12 @@ class TestFP8Lowering(TestCase):
|
|||
@parametrize("M", (1, 3, 33, 257, 1024))
|
||||
@parametrize("K", (16, 1024))
|
||||
@parametrize("N", (16, 2048))
|
||||
def test_tensorwise_scaling_acceptable_input_dims(self, M: int, K: int, N: int):
|
||||
@parametrize(
|
||||
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
|
||||
)
|
||||
def test_tensorwise_scaling_acceptable_input_dims(
|
||||
self, M: int, K: int, N: int, persistent_matmul: bool
|
||||
):
|
||||
# alignment requirements: K and N divisible by 16
|
||||
dtype: torch.dtype = torch.bfloat16
|
||||
use_fast_accum = True
|
||||
|
|
@ -571,14 +596,17 @@ class TestFP8Lowering(TestCase):
|
|||
w_inverse_scale,
|
||||
bias,
|
||||
)
|
||||
linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
|
||||
y_compiled = linear_compiled(
|
||||
x_fp8,
|
||||
x_inverse_scale,
|
||||
w_t_fp8,
|
||||
w_inverse_scale,
|
||||
bias,
|
||||
)
|
||||
with config.patch({"triton.enable_persistent_tma_matmul": True}):
|
||||
linear_compiled = torch.compile(
|
||||
linear, backend="inductor", mode="max-autotune"
|
||||
)
|
||||
y_compiled = linear_compiled(
|
||||
x_fp8,
|
||||
x_inverse_scale,
|
||||
w_t_fp8,
|
||||
w_inverse_scale,
|
||||
bias,
|
||||
)
|
||||
self.assertEqual(y_eager.dtype, dtype)
|
||||
self.assertEqual(y_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07)
|
||||
|
|
@ -588,7 +616,12 @@ class TestFP8Lowering(TestCase):
|
|||
@parametrize("M", (1, 3, 33, 257, 1024))
|
||||
@parametrize("K", (16, 1024))
|
||||
@parametrize("N", (16, 2048))
|
||||
def test_rowwise_scaling_acceptable_input_dims(self, M: int, K: int, N: int):
|
||||
@parametrize(
|
||||
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
|
||||
)
|
||||
def test_rowwise_scaling_acceptable_input_dims(
|
||||
self, M: int, K: int, N: int, persistent_matmul: bool
|
||||
):
|
||||
dtype: torch.dtype = torch.bfloat16
|
||||
use_fast_accum = True
|
||||
device = "cuda"
|
||||
|
|
@ -622,14 +655,17 @@ class TestFP8Lowering(TestCase):
|
|||
w_inverse_scale,
|
||||
bias,
|
||||
)
|
||||
linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
|
||||
y_compiled = linear_compiled(
|
||||
x_fp8,
|
||||
x_inverse_scale,
|
||||
w_t_fp8,
|
||||
w_inverse_scale,
|
||||
bias,
|
||||
)
|
||||
with config.patch({"triton.enable_persistent_tma_matmul": True}):
|
||||
linear_compiled = torch.compile(
|
||||
linear, backend="inductor", mode="max-autotune"
|
||||
)
|
||||
y_compiled = linear_compiled(
|
||||
x_fp8,
|
||||
x_inverse_scale,
|
||||
w_t_fp8,
|
||||
w_inverse_scale,
|
||||
bias,
|
||||
)
|
||||
self.assertEqual(y_eager.dtype, dtype)
|
||||
self.assertEqual(y_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07)
|
||||
|
|
|
|||
|
|
@ -1027,6 +1027,12 @@ class triton:
|
|||
# Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental)
|
||||
codegen_upcast_to_fp32 = True
|
||||
|
||||
# Whether persistent matmul kernels should be enabled this flag only has effect when on h100
|
||||
# with a verison of triton new enough to support TMA
|
||||
enable_persistent_tma_matmul = (
|
||||
os.environ.get("ENABLE_PERSISTENT_TMA_MATMUL", "0") == "1"
|
||||
)
|
||||
|
||||
|
||||
class aot_inductor:
|
||||
# AOTInductor output path
|
||||
|
|
|
|||
|
|
@ -226,6 +226,7 @@ def mark_nodes_dislike_padding(
|
|||
ops_dislike_padding = {
|
||||
aten.convolution,
|
||||
aten.convolution_backward,
|
||||
aten._scaled_mm,
|
||||
}
|
||||
# what's a better way to collect the reduction ops?
|
||||
ops_like_padding = {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
from typing import cast, Sequence, Tuple
|
||||
from typing import Any, cast, Dict, Sequence, Tuple
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
@ -216,6 +216,19 @@ mixed_mm_kernel_configs = (
|
|||
else mm_kernel_configs
|
||||
)
|
||||
|
||||
persistent_mm_kernel_configs = [
|
||||
{"config": (128, 128, 64, 3, 8), "cond": True},
|
||||
{"config": (128, 128, 128, 3, 8), "cond": True},
|
||||
{"config": (128, 128, 128, 4, 8), "cond": True},
|
||||
{"config": (128, 128, 128, 4, 4), "cond": True},
|
||||
{"config": (128, 128, 128, 3, 4), "cond": True},
|
||||
{"config": (128, 128, 128, 5, 4), "cond": True},
|
||||
{"config": (128, 128, 128, 5, 8), "cond": True},
|
||||
{"config": (128, 128, 128, 6, 8), "cond": True},
|
||||
{"config": (128, 128, 64, 4, 8), "cond": True},
|
||||
]
|
||||
|
||||
|
||||
scaled_mm_kernel_configs = [
|
||||
{"config": (128, 256, 32, 3, 8), "cond": True},
|
||||
{"config": (256, 128, 32, 3, 8), "cond": True},
|
||||
|
|
@ -344,6 +357,12 @@ scaled_mm_platform_configs = tuple(
|
|||
if config["cond"]
|
||||
)
|
||||
|
||||
persistent_mm_platform_configs = tuple(
|
||||
cast(Tuple[int, int, int, int, int], config["config"])
|
||||
for config in persistent_mm_kernel_configs
|
||||
if config["cond"]
|
||||
)
|
||||
|
||||
# On ROCm convert num_stages to improve performance
|
||||
if torch.version.hip:
|
||||
mm_platform_configs = build_rocm_gemm_configs(mm_platform_configs)
|
||||
|
|
@ -377,6 +396,10 @@ scaled_mm_configs = functools.partial(
|
|||
configs=scaled_mm_platform_configs,
|
||||
)
|
||||
|
||||
persistent_mm_configs = functools.partial(
|
||||
filtered_configs, configs=persistent_mm_platform_configs
|
||||
)
|
||||
|
||||
|
||||
def mm_grid(m, n, meta):
|
||||
"""
|
||||
|
|
@ -385,6 +408,15 @@ def mm_grid(m, n, meta):
|
|||
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
|
||||
|
||||
|
||||
def persistent_grid(M: int, N: int, meta: Dict[str, Any]):
|
||||
"""Defines the grid for persistent kernels."""
|
||||
return (
|
||||
min(meta["NUM_SMS"], cdiv(M, meta["BLOCK_M"]) * cdiv(N, meta["BLOCK_N"])),
|
||||
1,
|
||||
1,
|
||||
)
|
||||
|
||||
|
||||
def acc_type(dtype):
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
return "tl.float32"
|
||||
|
|
|
|||
|
|
@ -5,9 +5,12 @@ import sympy
|
|||
|
||||
import torch
|
||||
from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
|
||||
from torch.utils._triton import has_triton_tma_device
|
||||
|
||||
from .. import config as inductor_config
|
||||
from ..ir import ChoiceCaller, Layout, StorageBox, TensorBox
|
||||
from ..codegen.common import WorkspaceArg, WorkspaceZeroMode
|
||||
from ..config import triton as triton_config
|
||||
from ..ir import _IntLike, ChoiceCaller, Layout, StorageBox, TensorBox
|
||||
from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering
|
||||
from ..select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
|
|
@ -17,12 +20,194 @@ from ..select_algorithm import (
|
|||
TritonTemplate,
|
||||
)
|
||||
from ..utils import use_aten_gemm_kernels, use_ck_gemm_template, use_triton_template
|
||||
from .mm_common import _is_static_problem, mm_args, mm_grid, scaled_mm_configs
|
||||
from .mm_common import (
|
||||
_is_static_problem,
|
||||
mm_args,
|
||||
mm_grid,
|
||||
persistent_grid,
|
||||
persistent_mm_configs,
|
||||
scaled_mm_configs,
|
||||
)
|
||||
|
||||
|
||||
_TMA_SIZE = 128
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
|
||||
load_scales = r"""
|
||||
@triton.jit
|
||||
def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr):
|
||||
if SCALING_ROWWISE:
|
||||
# For row-wise scaling, we'll return the pointers
|
||||
return a_scale_ptr, b_scale_ptr
|
||||
else:
|
||||
# For per-tensor scaling, we'll load the scalar values
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr)
|
||||
return a_scale, b_scale
|
||||
"""
|
||||
|
||||
|
||||
apply_scaling = r"""
|
||||
@triton.jit
|
||||
def apply_scaling(
|
||||
accumulator,
|
||||
a_scale,
|
||||
b_scale,
|
||||
SCALING_ROWWISE: tl.constexpr,
|
||||
offs_cm,
|
||||
offs_cn,
|
||||
M,
|
||||
N,
|
||||
stride_a_scale_m,
|
||||
stride_b_scale_n,
|
||||
):
|
||||
if SCALING_ROWWISE:
|
||||
# For row-wise scaling, we need to load the scales for each row/column
|
||||
a_scales = tl.load(
|
||||
a_scale + (offs_cm * stride_a_scale_m),
|
||||
mask=offs_cm < M,
|
||||
other=0.0,
|
||||
)
|
||||
b_scales = tl.load(
|
||||
b_scale + (offs_cn * stride_b_scale_n),
|
||||
mask=offs_cn < N,
|
||||
other=0.0,
|
||||
)
|
||||
acc_scale = a_scales[:, None] * b_scales[None, :]
|
||||
else:
|
||||
# For per-tensor scaling, we can directly use the loaded scalar values
|
||||
acc_scale = a_scale * b_scale
|
||||
|
||||
return accumulator * acc_scale
|
||||
"""
|
||||
|
||||
|
||||
device_tma = r"""
|
||||
{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}
|
||||
M = {{size("A", 0)}}
|
||||
N = {{size("B", 1)}}
|
||||
K = {{size("A", 1)}}
|
||||
if M * N == 0:
|
||||
# early exit due to zero-size input(s)
|
||||
return
|
||||
|
||||
stride_am = {{stride("A", 0)}}
|
||||
stride_ak = {{stride("A", 1)}}
|
||||
stride_bk = {{stride("B", 0)}}
|
||||
stride_bn = {{stride("B", 1)}}
|
||||
|
||||
if SCALING_ROWWISE:
|
||||
stride_a_scale_m = 1
|
||||
stride_b_scale_n = 1
|
||||
else:
|
||||
stride_a_scale_m = 0
|
||||
stride_b_scale_n = 0
|
||||
|
||||
start_pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||
k_tiles = tl.cdiv(K, BLOCK_K)
|
||||
num_tiles = num_pid_m * num_pid_n
|
||||
|
||||
workspace_base = ws_ptr + start_pid * 3 * TMA_SIZE
|
||||
a_desc_ptr = workspace_base
|
||||
b_desc_ptr = workspace_base + TMA_SIZE
|
||||
c_desc_ptr = workspace_base + 2 * TMA_SIZE
|
||||
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
desc_ptr=a_desc_ptr,
|
||||
global_address=A,
|
||||
load_size=[BLOCK_M, BLOCK_K],
|
||||
global_size=[M, K],
|
||||
element_ty=A.dtype.element_ty,
|
||||
)
|
||||
triton.language.extra.cuda.experimental_device_tensormap_create2d(
|
||||
desc_ptr=b_desc_ptr,
|
||||
global_address=B,
|
||||
load_size=[BLOCK_N, BLOCK_K],
|
||||
global_size=[N, K],
|
||||
element_ty=B.dtype.element_ty,
|
||||
)
|
||||
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
|
||||
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
|
||||
|
||||
tiles_per_SM = num_tiles // NUM_SMS
|
||||
if start_pid < num_tiles % NUM_SMS:
|
||||
tiles_per_SM += 1
|
||||
|
||||
tile_id = start_pid - NUM_SMS
|
||||
ki = -1
|
||||
|
||||
pid_m = 0
|
||||
pid_n = 0
|
||||
offs_am = 0
|
||||
offs_bn = 0
|
||||
|
||||
num_pid_in_group = GROUP_M * num_pid_n
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE)
|
||||
|
||||
for _ in range(0, k_tiles * tiles_per_SM):
|
||||
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
|
||||
if ki == 0:
|
||||
tile_id += NUM_SMS
|
||||
group_id = tile_id // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
|
||||
pid_m = first_pid_m + (tile_id % group_size_m)
|
||||
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_M
|
||||
offs_bn = pid_n * BLOCK_N
|
||||
|
||||
offs_k = ki * BLOCK_K
|
||||
|
||||
a = tl._experimental_descriptor_load(
|
||||
a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty
|
||||
)
|
||||
b = tl._experimental_descriptor_load(
|
||||
b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty
|
||||
)
|
||||
if USE_FAST_ACCUM:
|
||||
accumulator = tl.dot(a, b.T, accumulator)
|
||||
else:
|
||||
accumulator += tl.dot(a, b.T)
|
||||
|
||||
if ki == k_tiles - 1:
|
||||
# Apply inverse scaling
|
||||
offs_cm = offs_am + tl.arange(0, BLOCK_M)
|
||||
offs_cn = offs_bn + tl.arange(0, BLOCK_N)
|
||||
# Apply scaling
|
||||
accumulator = apply_scaling(
|
||||
accumulator,
|
||||
a_scale,
|
||||
b_scale,
|
||||
SCALING_ROWWISE,
|
||||
offs_cm,
|
||||
offs_cn,
|
||||
M,
|
||||
N,
|
||||
stride_a_scale_m,
|
||||
stride_b_scale_n,
|
||||
)
|
||||
|
||||
idx_m = offs_cm[:, None]
|
||||
idx_n = offs_cn[None, :]
|
||||
mask = (idx_m < M) & (idx_n < N)
|
||||
# inductor generates a suffix
|
||||
{{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}}
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
"""
|
||||
|
||||
|
||||
scaled_mm_device_tma_template = TritonTemplate(
|
||||
name="scaled_mm_device_tma",
|
||||
grid=persistent_grid,
|
||||
source=device_tma + load_scales + apply_scaling,
|
||||
)
|
||||
|
||||
|
||||
scaled_mm_template = TritonTemplate(
|
||||
name="scaled_mm",
|
||||
|
|
@ -206,6 +391,66 @@ def are_compatible_scales(size_a: Sequence[int], size_b: Sequence[int]) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None:
|
||||
def is_row_major(stride: Sequence[_IntLike]) -> bool:
|
||||
return stride[1] == 1
|
||||
|
||||
def is_col_major(stride: Sequence[_IntLike]) -> bool:
|
||||
return stride[0] == 1
|
||||
|
||||
def has_zero_dim(size: Sequence[_IntLike]) -> bool:
|
||||
return bool(size[0] == 0 or size[1] == 0)
|
||||
|
||||
# Check mat_a (self) stride requirements
|
||||
torch._check(
|
||||
is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()),
|
||||
lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}",
|
||||
)
|
||||
|
||||
# Check mat_b stride requirements
|
||||
torch._check(
|
||||
is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()),
|
||||
lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}",
|
||||
)
|
||||
|
||||
|
||||
def scaled_mm_options_device_tma( # type: ignore[no-untyped-def]
|
||||
config, # triton.Config
|
||||
sym_m: sympy.core.numbers.Integer,
|
||||
sym_n: sympy.core.numbers.Integer,
|
||||
sym_k: sympy.core.numbers.Integer,
|
||||
layout: Layout,
|
||||
scale_a: StorageBox,
|
||||
scale_b: StorageBox,
|
||||
use_fast_accum: bool,
|
||||
b_prologue_cast_type: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
even_k_symbolic = (
|
||||
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
|
||||
)
|
||||
|
||||
size_a, size_b = scale_a.get_size(), scale_b.get_size()
|
||||
assert are_compatible_scales(size_a, size_b), (
|
||||
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
|
||||
f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
|
||||
)
|
||||
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
return dict(
|
||||
GROUP_M=8,
|
||||
EVEN_K=even_k_symbolic,
|
||||
ACC_TYPE="tl.float32",
|
||||
B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
|
||||
USE_FAST_ACCUM=use_fast_accum,
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
# tensor-wise scaling if scalar scales
|
||||
SCALING_ROWWISE=len(scale_a.get_size()) == 2,
|
||||
TMA_SIZE=_TMA_SIZE,
|
||||
NUM_SMS=NUM_SMS,
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
|
||||
def scaled_mm_options( # type: ignore[no-untyped-def]
|
||||
config, # triton.Config
|
||||
sym_m: sympy.core.numbers.Integer,
|
||||
|
|
@ -243,6 +488,33 @@ def scaled_mm_options( # type: ignore[no-untyped-def]
|
|||
add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides)
|
||||
|
||||
|
||||
def get_workspace_size(
|
||||
num_sms: int, TMA_SIZE: int = _TMA_SIZE, NUM_TMA_DESCRIPTORS: int = 3
|
||||
) -> int:
|
||||
"""Device side TMA requires a workspace buffer to be allocated in global memory."""
|
||||
return num_sms * NUM_TMA_DESCRIPTORS * TMA_SIZE
|
||||
|
||||
|
||||
def get_workspace_arg(num_sms: int, device: torch.device) -> WorkspaceArg:
|
||||
"""Builds and returns a WorkspaceArg for the device side TMA workspace buffer."""
|
||||
size = get_workspace_size(num_sms)
|
||||
zero_mode = WorkspaceZeroMode.from_bool(False)
|
||||
return WorkspaceArg(
|
||||
count=size,
|
||||
zero_mode=zero_mode,
|
||||
device=device,
|
||||
outer_name=WorkspaceArg.unique_name(),
|
||||
)
|
||||
|
||||
|
||||
def use_persistent_tma(k: sympy.core.numbers.Integer, has_bias: bool) -> bool:
|
||||
available = has_triton_tma_device() and triton_config.enable_persistent_tma_matmul
|
||||
# _determine_swizzle_mode_2d requires BLOCK_K to be at least 32 contiguous bytes
|
||||
# When K is 16, BLOCK_K = 16 and is not valid
|
||||
min_k = k >= 32
|
||||
return available and min_k and not has_bias
|
||||
|
||||
|
||||
@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc]
|
||||
def tuned_scaled_mm(
|
||||
mat_a: TensorBox,
|
||||
|
|
@ -258,6 +530,9 @@ def tuned_scaled_mm(
|
|||
m, n, k, layout, mat_a, mat_b = mm_args(
|
||||
mat_a, mat_b, layout=layout, out_dtype=out_dtype
|
||||
)
|
||||
|
||||
check_supported_striding(mat_a, mat_b)
|
||||
|
||||
scale_a, scale_b = realize_inputs(scale_a, scale_b)
|
||||
|
||||
input_nodes: Tuple[Any, ...]
|
||||
|
|
@ -279,20 +554,37 @@ def tuned_scaled_mm(
|
|||
choices.append(aten_choice)
|
||||
|
||||
static_shape, is_nonzero = _is_static_problem(layout)
|
||||
|
||||
if is_nonzero and use_triton_template(layout, enable_float8=True):
|
||||
for config in scaled_mm_configs(m, n, k):
|
||||
if k == 16 and config.kwargs["BLOCK_M"] >= 64:
|
||||
continue # Triton crashes in this case
|
||||
kwargs = scaled_mm_options(
|
||||
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
|
||||
)
|
||||
# possibly appends a TritonTemplateCaller to choices
|
||||
triton_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
)
|
||||
if use_persistent_tma(k, bias is not None):
|
||||
for config in persistent_mm_configs(m, n, k):
|
||||
kwargs = scaled_mm_options_device_tma(
|
||||
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
|
||||
)
|
||||
input_nodes = (mat_a, mat_b, scale_a, scale_b)
|
||||
scaled_mm_device_tma_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
workspace_arg=get_workspace_arg(
|
||||
kwargs["NUM_SMS"], mat_a.get_device()
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
for config in scaled_mm_configs(m, n, k):
|
||||
if k == 16 and config.kwargs["BLOCK_M"] >= 64:
|
||||
continue # Triton crashes in this case
|
||||
kwargs = scaled_mm_options(
|
||||
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
|
||||
)
|
||||
# possibly appends a TritonTemplateCaller to choices
|
||||
triton_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
|
||||
CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes)
|
||||
|
|
|
|||
|
|
@ -38,6 +38,29 @@ def has_triton_tma():
|
|||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_triton_tma_device():
|
||||
if has_triton_package():
|
||||
import torch
|
||||
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability() >= (9, 0)
|
||||
and not torch.version.hip
|
||||
):
|
||||
try:
|
||||
from triton.language.extra.cuda import ( # noqa: F401
|
||||
experimental_device_tensormap_create1d,
|
||||
experimental_device_tensormap_create2d,
|
||||
)
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_triton() -> bool:
|
||||
if not has_triton_package():
|
||||
|
|
|
|||
Loading…
Reference in a new issue