mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Both cuSPASRELt and CUTLASS support 1:2 semi-structured sparsity for fp32, which this PR enables.(thanks @alexsamardzic). Furthermore, this PR also updates the sparse_config to take into account the different shape constraints for sparse and dense matrices. Technically, cuSPARSELt supports smaller sparse matrix constraints as it seens to pad to the CUTLASS constraints under the hood. However, in practice small sparse matrices are not commonly used and we care more about the dense constraints for LLM inference. For now, we keep the CUTLASS constraints in place for both cuSPARSELt and CUTLASS tensors This PR also reconnects the _FUSE_TRANSPOSE flag for cuSPARSELt tensors. Test Plan: ``` python test/test_sparse_semi_structured.py ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/115550 Approved by: https://github.com/cpuhrsch
675 lines
29 KiB
Python
675 lines
29 KiB
Python
# Owner(s): ["module: sparse"]
|
|
import itertools
|
|
import random
|
|
import unittest
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from torch.sparse.semi_structured import (
|
|
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG,
|
|
SparseSemiStructuredTensor,
|
|
to_sparse_semi_structured,
|
|
)
|
|
|
|
from torch.testing import make_tensor
|
|
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
instantiate_device_type_tests,
|
|
)
|
|
|
|
from torch.testing._internal.common_dtype import all_types_and_complex
|
|
import torch._dynamo.test_case
|
|
|
|
from torch.testing._internal.common_utils import (
|
|
parametrize,
|
|
run_tests,
|
|
subtest,
|
|
TestCase,
|
|
TEST_WITH_ROCM,
|
|
IS_WINDOWS,
|
|
)
|
|
|
|
from torch.utils._triton import has_triton
|
|
|
|
CUSPARSELT_NUM_ALG_IDS = 4
|
|
|
|
SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys()
|
|
SEMI_STRUCTURED_SUPPORTED_BACKENDS = []
|
|
|
|
_IS_SM8X = False
|
|
if torch.cuda.is_available():
|
|
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
|
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cutlass")
|
|
|
|
# check if cslt is available for now using this:
|
|
# TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available()
|
|
try:
|
|
torch._cslt_compress(torch.ones(128, 256).cuda())
|
|
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cusparselt")
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
|
|
def rand_sparse_semi_structured_mask(
|
|
r, c, dtype=torch.float16, device="cuda", choice=None
|
|
):
|
|
"""
|
|
This function returns a 1:2 sparse matrix of size (r, c).
|
|
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
|
|
"""
|
|
|
|
choices = [[0, 1], [1, 0]]
|
|
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
|
|
|
|
return (
|
|
torch.tensor(mask_entries, dtype=dtype, device=device)
|
|
.reshape(r, c)
|
|
.contiguous()
|
|
)
|
|
|
|
def rand_sparse_semi_structured(r, c, dtype, device, pattern='2by4', choice=None):
|
|
if pattern == '2by4':
|
|
choices = [
|
|
[1, 1, 0, 0],
|
|
[1, 0, 1, 0],
|
|
[1, 0, 0, 1],
|
|
[0, 1, 1, 0],
|
|
[0, 1, 0, 1],
|
|
[0, 0, 1, 1]
|
|
]
|
|
mask_entries = [choice or random.choice(choices) for i in range(r * c // 4)]
|
|
elif pattern == '1by2':
|
|
choices = [
|
|
[0, 1],
|
|
[1, 0]
|
|
]
|
|
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
|
|
else:
|
|
assert(false)
|
|
mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device)
|
|
dense = make_tensor(r, c, dtype=dtype, device=device)
|
|
dense[dense == 0] = 1 # To prevent zeros except where mask applied.
|
|
dense = dense.masked_fill(~mask, 0)
|
|
return dense
|
|
|
|
def rand_sparse_semi_structured_all_patterns(r, c, dtype, device, pattern='2by4'):
|
|
if pattern == '2by4':
|
|
choices = [
|
|
[[0, 0, 0, 0], [0, 0, 1, 1]],
|
|
[[0, 0, 0, 1], [0, 0, 1, 1]],
|
|
[[0, 0, 1, 0], [0, 0, 1, 1]],
|
|
[[0, 0, 1, 1], [0, 0, 1, 1]],
|
|
[[0, 1, 0, 0], [0, 1, 0, 1]],
|
|
[[0, 1, 0, 1], [0, 1, 0, 1]],
|
|
[[0, 1, 1, 0], [0, 1, 1, 0]],
|
|
[[0, 1, 1, 1], [0, 1, 1, 0]],
|
|
[[1, 0, 0, 0], [1, 0, 0, 1]],
|
|
[[1, 0, 0, 1], [1, 0, 0, 1]],
|
|
[[1, 0, 1, 0], [1, 0, 1, 0]],
|
|
[[1, 0, 1, 1], [1, 0, 1, 0]],
|
|
[[1, 1, 0, 0], [1, 1, 0, 0]],
|
|
[[1, 1, 0, 1], [1, 1, 0, 0]],
|
|
[[1, 1, 1, 0], [1, 0, 1, 0]],
|
|
[[1, 1, 1, 1], [1, 0, 1, 0]],
|
|
]
|
|
mask_rows = [random.randint(0, len(choices) - 1) for i in range(r * c // 4)]
|
|
else:
|
|
assert(false)
|
|
COL_INV, COL_VAL = 0, 1
|
|
mask_entries_inv = [choices[i][COL_INV] for i in mask_rows]
|
|
mask_entries_val = [choices[i][COL_VAL] for i in mask_rows]
|
|
mask_inv = torch.tensor(mask_entries_inv, dtype=torch.bool).view(r, c).to(device)
|
|
mask_val = torch.tensor(mask_entries_val, dtype=torch.bool).view(r, c).to(device)
|
|
dense = make_tensor(r, c, dtype=dtype, device=device)
|
|
dense[dense == 0] = 1 # To prevent zeros except where mask below applied.
|
|
dense_inv = dense.masked_fill(~mask_inv, 0)
|
|
dense_val = dense_inv.masked_fill(~mask_val, 0)
|
|
return dense_inv, dense_val
|
|
|
|
|
|
class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
|
|
|
def setUp(self):
|
|
if not _IS_SM8X:
|
|
self.skipTest('Only runs on SM80')
|
|
super().setUp()
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
|
|
@staticmethod
|
|
def _test_mlp_contiguous_relu_compile(backend, dense_input_shape):
|
|
"""
|
|
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
|
|
We expect:
|
|
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_linear` + `aten.contiguous()`
|
|
(2) Inductor should fuse the .contiguous() call into the relu
|
|
"""
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = nn.Linear(128, 128)
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = x.contiguous()
|
|
return torch.nn.functional.relu(x)
|
|
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
|
|
|
|
input = torch.rand(dense_input_shape, device="cuda").half()
|
|
model = Model().eval().cuda().half()
|
|
mod_linear = model.linear
|
|
m, n = mod_linear.weight.shape
|
|
mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda()
|
|
# set masked weight
|
|
mod_linear.weight = nn.Parameter(mod_linear.weight * mask)
|
|
|
|
dense_result = model(input)
|
|
mod_linear.weight = nn.Parameter(to_sparse_semi_structured(mod_linear.weight))
|
|
sparse_result = model(input)
|
|
|
|
model = torch.compile(model, backend="inductor", fullgraph=True)
|
|
sparse_compile_result = model(input)
|
|
|
|
# test that sparse_compile_result and dense_result are numerically close
|
|
assert torch.allclose(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3)
|
|
# assert sparse and sparse_compile have the same strides,
|
|
# as meta registrations may return contiguous tensors when the output is transposed
|
|
# https://github.com/pytorch/pytorch/pull/114477
|
|
assert sparse_result.stride() == sparse_compile_result.stride()
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
|
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
|
def test_mlp_contiguous_relu_compile_cusparselt(self):
|
|
"""
|
|
test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile
|
|
"""
|
|
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
|
|
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cusparselt", dense_input_shape)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
|
def test_mlp_contiguous_relu_compile_cutlass(self):
|
|
"""
|
|
test for CUTLASS meta registrations (_sparse_semi_structured_linear) + torch.compile
|
|
"""
|
|
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
|
|
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape)
|
|
|
|
|
|
class TestSparseSemiStructured(TestCase):
|
|
|
|
def setUp(self):
|
|
if not _IS_SM8X:
|
|
self.skipTest('Only runs on SM80')
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_to_sparse_semi_structured(self, dtype, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
|
|
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
assert A.shape == A_sparse.shape
|
|
assert A.device == A_sparse.device
|
|
assert A.dtype == A_sparse.dtype
|
|
|
|
assert isinstance(A, torch.Tensor)
|
|
assert isinstance(A_sparse, SparseSemiStructuredTensor)
|
|
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
@parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)])
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
|
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
|
|
|
|
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
|
if dtype is torch.int8:
|
|
# This should fail
|
|
if backend == "cutlass":
|
|
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
|
|
sparse_result = torch.mm(A_sparse, B)
|
|
else:
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
|
|
sparse_result = torch.mm(A_sparse, B)
|
|
else:
|
|
dense_result = torch.mm(A, B)
|
|
sparse_result = torch.mm(A_sparse, B)
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16
|
|
and will throw an error for int8 + padding
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
|
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
|
|
|
|
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
|
if dtype is torch.int8 and dense_input_shape in {(1, 128)}:
|
|
# padding with int8 throws an error because transposing B yields a contiguous output
|
|
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
|
|
if backend == "cutlass":
|
|
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
|
|
sparse_result = torch.mm(A_sparse, B.t())
|
|
else:
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
|
|
sparse_result = torch.mm(A_sparse, B.t())
|
|
elif dtype is torch.int8:
|
|
# test transpose
|
|
# NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior.
|
|
# CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor
|
|
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
|
|
sparse_result = torch.mm(A_sparse, B.t())
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
else:
|
|
# test transpose
|
|
dense_result = torch.mm(A, B.t())
|
|
sparse_result = torch.mm(A_sparse, B.t())
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend):
|
|
"""
|
|
Ensure torch.mm(A_sparse.t(), B) throws error
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
r"arg0: SparseSemiStructuredTensor\(.*transposed=True",
|
|
):
|
|
torch.mm(A_sparse.t(), B)
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A, B_sparse.t()) is correct
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
|
B_sparse = to_sparse_semi_structured(B)
|
|
|
|
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
|
|
|
|
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
|
if dtype is torch.int8:
|
|
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
|
|
sparse_result = torch.mm(A, B_sparse.t())
|
|
else:
|
|
dense_result = torch.mm(A, B.t())
|
|
sparse_result = torch.mm(A, B_sparse.t())
|
|
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A, B_sparse) throws error
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
|
B_sparse = to_sparse_semi_structured(B)
|
|
|
|
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
r"arg1: SparseSemiStructuredTensor\(.*transposed=False",
|
|
):
|
|
sparse_result = torch.mm(A, B_sparse)
|
|
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
|
|
@parametrize("inference_mode", [subtest(True), subtest(False)])
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_linear(self, dense_input_shape, inference_mode, device, backend):
|
|
"""
|
|
Test nn.Linear has the same numerics
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
input = torch.rand((dense_input_shape), device=device).half()
|
|
model = nn.Linear(128, 256).to(device).half()
|
|
m, n = model.weight.shape
|
|
mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
|
|
# set masked weight
|
|
model.weight = nn.Parameter(model.weight * mask)
|
|
|
|
dense_result = model(input)
|
|
|
|
model.weight = nn.Parameter(to_sparse_semi_structured(model.weight))
|
|
|
|
if inference_mode:
|
|
with torch.inference_mode():
|
|
sparse_result = model(input)
|
|
else:
|
|
sparse_result = model(input)
|
|
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_mlp(self, device, dense_input_shape, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
|
|
input = torch.rand(dense_input_shape, device=device).half()
|
|
model = (
|
|
nn.Sequential(
|
|
nn.Linear(128, 256),
|
|
nn.Linear(256, 128),
|
|
)
|
|
.half()
|
|
.to(device)
|
|
)
|
|
|
|
for i in range(2):
|
|
m, n = model[i].weight.shape
|
|
mask = rand_sparse_semi_structured_mask(
|
|
m, n, device=device, dtype=torch.bool
|
|
)
|
|
# set masked weight
|
|
model[i].weight = nn.Parameter(model[i].weight * mask)
|
|
|
|
dense_result = model(input)
|
|
|
|
for i in range(2):
|
|
model[i].weight = nn.Parameter(to_sparse_semi_structured(model[i].weight))
|
|
|
|
sparse_result = model(input)
|
|
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_values(self, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(128, 128)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
assert A_sparse.values().shape == (128, 64)
|
|
assert (A_sparse.values() == 1).all()
|
|
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_indices(self, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(128, 128)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
assert A_sparse.indices().shape == (128, 8)
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_min_sparse_shape(self, dtype, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
config = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[dtype]
|
|
A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
|
|
if dtype == torch.int8:
|
|
dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
|
|
# int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R
|
|
B_t = B.t().contiguous()
|
|
sparse_res = torch.mm(A_sparse, B_t.t())
|
|
else:
|
|
dense_res = torch.mm(A, B)
|
|
sparse_res = torch.mm(A_sparse, B)
|
|
assert torch.allclose(sparse_res, dense_res, rtol=1e-3, atol=1e-3)
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_unsupported_shape(self, dtype, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(2, 2, dtype=dtype, device=device)
|
|
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"):
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
@dtypes(*all_types_and_complex())
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_unsupported_dtype(self, dtype, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device)
|
|
|
|
if dtype not in SEMI_STRUCTURED_SUPPORTED_DTYPES:
|
|
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"):
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
else:
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
def test_unsupported_dim(self, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = torch.rand(128, 128, 128, device=device, dtype=torch.float16)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"):
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
|
@parametrize("backend", ["cutlass"])
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
def test_linear_cutlass(self, device, dtype, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
|
|
def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol):
|
|
pattern = '2by4' if dtype != torch.float32 else '1by2'
|
|
weight = rand_sparse_semi_structured(m, k, dtype, device, pattern=pattern)
|
|
input = make_tensor((*batch_shape, n, k), dtype=dtype, device=device)
|
|
bias = make_tensor((m,), dtype=dtype_out, device=device) if add_bias else None
|
|
|
|
dtype_dense = torch.float32
|
|
input_dense = input.to(dtype_dense)
|
|
weight_dense = weight.to(dtype_dense)
|
|
bias_dense = bias.to(dtype_dense) if add_bias else None
|
|
output0 = torch.nn.functional.linear(input_dense, weight_dense, bias=bias_dense)
|
|
if activation == "relu":
|
|
relu = torch.nn.ReLU()
|
|
output0 = relu(output0)
|
|
elif activation == "silu":
|
|
silu = torch.nn.SiLU()
|
|
output0 = silu(output0)
|
|
|
|
compressed = to_sparse_semi_structured(weight)
|
|
|
|
weight_sparse = compressed.values()
|
|
meta = compressed.indices()
|
|
|
|
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation)
|
|
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
|
|
|
|
if dtype == torch.float32:
|
|
# Inputs are converted to TF32 internally for sparse GEMM,
|
|
# so make dense GEMM to do the same for matching results.
|
|
orig = torch.backends.cuda.matmul.allow_tf32
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
batch_shapes = [[], [3], [3, 1]]
|
|
dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
|
|
activations = [None, "relu", "silu"]
|
|
rtol, atol = 1e-3, 1e-3
|
|
if dtype == torch.bfloat16:
|
|
rtol, atol = 5e-3, 5e-3
|
|
elif dtype == torch.float32:
|
|
rtol, atol = 1e-3, 5e-1
|
|
for batch_shape, m, n, k, add_bias, activation in \
|
|
itertools.product(batch_shapes, range(3), range(3), range(3), (False, True), activations):
|
|
if activation == "silu" and dtype == torch.int8:
|
|
continue # SiLU not supported for integer inputs
|
|
|
|
m = 2 ** m * 32
|
|
n = 2 ** n * 32
|
|
k = 2 ** k * 128
|
|
run_test(batch_shape, m, n, k, device, dtype, dtype_out[dtype], add_bias, activation, rtol, atol)
|
|
|
|
if dtype == torch.float32:
|
|
torch.backends.cuda.matmul.allow_tf32 = orig
|
|
|
|
|
|
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
|
|
@parametrize("backend", ["cutlass"])
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
def test_conversions(self, device, dtype, backend):
|
|
if dtype == torch.float32:
|
|
return
|
|
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
|
|
def run_test(r, c, device, dtype):
|
|
pattern = '2by4' if dtype != torch.float32 else '1by2'
|
|
dense_ref = rand_sparse_semi_structured(r, c, dtype, device, pattern=pattern)
|
|
|
|
compressed = to_sparse_semi_structured(dense_ref)
|
|
|
|
# The torch.ops.aten._to_sparse_semi_structured operator
|
|
# uses CUTLASS to perform conversion from given dense
|
|
# matrix to the pair of corresponding sparse and metadata
|
|
# matrices, with the later used here as a reference to
|
|
# compare the metadata matrix produced by conversion
|
|
# performed by SparseSemiStructuredTensor class
|
|
# constructor against.
|
|
_, meta_ref = torch.ops.aten._to_sparse_semi_structured(dense_ref)
|
|
|
|
meta = compressed.indices()
|
|
torch.testing.assert_close(meta, meta_ref, rtol=0, atol=0)
|
|
|
|
dense = compressed.to_dense()
|
|
torch.testing.assert_close(dense, dense_ref, rtol=0, atol=0)
|
|
|
|
shapes = [[32, 128], [32, 256], [64, 128], [64, 256]]
|
|
for r, c in shapes:
|
|
run_test(r, c, device, dtype)
|
|
|
|
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
|
|
@parametrize("backend", ["cutlass"])
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
def test_conversions_all_patterns(self, device, dtype, backend):
|
|
if dtype == torch.float32:
|
|
return
|
|
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
r, c = 32, 128
|
|
|
|
dense_inv, dense_val = rand_sparse_semi_structured_all_patterns(r, c, dtype, device)
|
|
|
|
compressed = to_sparse_semi_structured(dense_inv)
|
|
dense = compressed.to_dense()
|
|
|
|
torch.testing.assert_close(dense, dense_val, rtol=0, atol=0)
|
|
|
|
class TestCUSPARSELT(TestCase):
|
|
"""
|
|
This contains cuSPARSELt specific tests.
|
|
"""
|
|
|
|
def setUp(self):
|
|
if not _IS_SM8X:
|
|
self.skipTest('Only runs on SM80')
|
|
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
|
self.skipTest('cuSPARSELt not enabled')
|
|
else:
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = False
|
|
|
|
|
|
@parametrize("dense_input_shape", [(128, 128)])
|
|
def test_cslt_sparse_mm_int8_in_fp16_out(self, dense_input_shape, device):
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
|
|
A_compressed = torch._cslt_compress(A)
|
|
|
|
B = torch.rand(dense_input_shape, device=device).to(torch.int8)
|
|
|
|
dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=torch.float16)
|
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=torch.float16)
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@dtypes(torch.float16, torch.bfloat16)
|
|
def test_cslt_sparse_mm_alpha(self, dtype, device):
|
|
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda()
|
|
B = torch.ones((256, 128), device=device).to(dtype)
|
|
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
|
|
|
|
A_compressed = torch._cslt_compress(A)
|
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha)
|
|
|
|
alpha_scaled = torch.stack([alpha] * 128).t()
|
|
dense_result = alpha_scaled * torch.mm(A.to(torch.float32), B.to(torch.float32))
|
|
dense_result = dense_result.to(dtype)
|
|
|
|
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
|
|
|
def test_cslt_sparse_mm_alpha_int8_in_f16_out(self, device):
|
|
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
|
|
B = torch.ones((128, 256), device=device).to(torch.int8).t()
|
|
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
|
|
|
|
A_compressed = torch._cslt_compress(A)
|
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=torch.float16).cpu()
|
|
|
|
alpha_scaled = torch.stack([alpha] * 128).t()
|
|
dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int32).cpu(), B.to(torch.int32).cpu())
|
|
dense_result = dense_result.to(torch.float16)
|
|
|
|
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS))
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
def test_cslt_sparse_mm_alg_id(self, device, dtype, alg_id):
|
|
# alg_id=3 not supported for float32 dtype
|
|
if dtype == torch.float32 and alg_id == 3:
|
|
return
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
|
A_compressed = torch._cslt_compress(A)
|
|
B = torch.ones((128, 128), device=device).to(dtype)
|
|
|
|
A_compressed = torch._cslt_compress(A)
|
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
|
|
|
|
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
|
dense_result = dense_result.to(dtype)
|
|
|
|
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
def test_cslt_sparse_mm_search(self, device, dtype):
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
|
A_compressed = torch._cslt_compress(A)
|
|
B = torch.ones((128, 128), device=device).to(dtype)
|
|
|
|
A_compressed = torch._cslt_compress(A)
|
|
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
|
# for cuSPARSELt v0.4.0 there is a bug where although there are 5 alg_ids, we run into an error
|
|
# when setting using the last one (4)
|
|
# in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
|
|
assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
|
|
|
|
|
|
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
|
|
instantiate_device_type_tests(TestCUSPARSELT, globals(), only_for="cuda")
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|