pytorch/test/test_sparse_semi_structured.py
Jesse Cai a8e354a9a0 [sparse][semi-structured] enable fp32 support, separate sparse and dense constraints (#115550)
Summary:

Both cuSPASRELt and CUTLASS support 1:2 semi-structured sparsity for
fp32, which this PR enables.(thanks @alexsamardzic).

Furthermore, this PR also updates the sparse_config to take into account
the different shape constraints for sparse and dense matrices.

Technically, cuSPARSELt supports smaller sparse matrix constraints as it
seens to pad to the CUTLASS constraints under the hood. However, in
practice small sparse matrices are not commonly used and we care more
about the dense constraints for LLM inference.

For now, we keep the CUTLASS constraints in place for both cuSPARSELt
and CUTLASS tensors

This PR also reconnects the _FUSE_TRANSPOSE flag for cuSPARSELt tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115550
Approved by: https://github.com/cpuhrsch
2023-12-15 02:28:17 +00:00

675 lines
29 KiB
Python

# Owner(s): ["module: sparse"]
import itertools
import random
import unittest
import torch
from torch import nn
from torch.sparse.semi_structured import (
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG,
SparseSemiStructuredTensor,
to_sparse_semi_structured,
)
from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
)
from torch.testing._internal.common_dtype import all_types_and_complex
import torch._dynamo.test_case
from torch.testing._internal.common_utils import (
parametrize,
run_tests,
subtest,
TestCase,
TEST_WITH_ROCM,
IS_WINDOWS,
)
from torch.utils._triton import has_triton
CUSPARSELT_NUM_ALG_IDS = 4
SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys()
SEMI_STRUCTURED_SUPPORTED_BACKENDS = []
_IS_SM8X = False
if torch.cuda.is_available():
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cutlass")
# check if cslt is available for now using this:
# TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available()
try:
torch._cslt_compress(torch.ones(128, 256).cuda())
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cusparselt")
except Exception:
pass
def rand_sparse_semi_structured_mask(
r, c, dtype=torch.float16, device="cuda", choice=None
):
"""
This function returns a 1:2 sparse matrix of size (r, c).
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
"""
choices = [[0, 1], [1, 0]]
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
return (
torch.tensor(mask_entries, dtype=dtype, device=device)
.reshape(r, c)
.contiguous()
)
def rand_sparse_semi_structured(r, c, dtype, device, pattern='2by4', choice=None):
if pattern == '2by4':
choices = [
[1, 1, 0, 0],
[1, 0, 1, 0],
[1, 0, 0, 1],
[0, 1, 1, 0],
[0, 1, 0, 1],
[0, 0, 1, 1]
]
mask_entries = [choice or random.choice(choices) for i in range(r * c // 4)]
elif pattern == '1by2':
choices = [
[0, 1],
[1, 0]
]
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
else:
assert(false)
mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device)
dense = make_tensor(r, c, dtype=dtype, device=device)
dense[dense == 0] = 1 # To prevent zeros except where mask applied.
dense = dense.masked_fill(~mask, 0)
return dense
def rand_sparse_semi_structured_all_patterns(r, c, dtype, device, pattern='2by4'):
if pattern == '2by4':
choices = [
[[0, 0, 0, 0], [0, 0, 1, 1]],
[[0, 0, 0, 1], [0, 0, 1, 1]],
[[0, 0, 1, 0], [0, 0, 1, 1]],
[[0, 0, 1, 1], [0, 0, 1, 1]],
[[0, 1, 0, 0], [0, 1, 0, 1]],
[[0, 1, 0, 1], [0, 1, 0, 1]],
[[0, 1, 1, 0], [0, 1, 1, 0]],
[[0, 1, 1, 1], [0, 1, 1, 0]],
[[1, 0, 0, 0], [1, 0, 0, 1]],
[[1, 0, 0, 1], [1, 0, 0, 1]],
[[1, 0, 1, 0], [1, 0, 1, 0]],
[[1, 0, 1, 1], [1, 0, 1, 0]],
[[1, 1, 0, 0], [1, 1, 0, 0]],
[[1, 1, 0, 1], [1, 1, 0, 0]],
[[1, 1, 1, 0], [1, 0, 1, 0]],
[[1, 1, 1, 1], [1, 0, 1, 0]],
]
mask_rows = [random.randint(0, len(choices) - 1) for i in range(r * c // 4)]
else:
assert(false)
COL_INV, COL_VAL = 0, 1
mask_entries_inv = [choices[i][COL_INV] for i in mask_rows]
mask_entries_val = [choices[i][COL_VAL] for i in mask_rows]
mask_inv = torch.tensor(mask_entries_inv, dtype=torch.bool).view(r, c).to(device)
mask_val = torch.tensor(mask_entries_val, dtype=torch.bool).view(r, c).to(device)
dense = make_tensor(r, c, dtype=dtype, device=device)
dense[dense == 0] = 1 # To prevent zeros except where mask below applied.
dense_inv = dense.masked_fill(~mask_inv, 0)
dense_val = dense_inv.masked_fill(~mask_val, 0)
return dense_inv, dense_val
class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
def setUp(self):
if not _IS_SM8X:
self.skipTest('Only runs on SM80')
super().setUp()
def tearDown(self):
super().tearDown()
@staticmethod
def _test_mlp_contiguous_relu_compile(backend, dense_input_shape):
"""
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
We expect:
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_linear` + `aten.contiguous()`
(2) Inductor should fuse the .contiguous() call into the relu
"""
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(128, 128)
def forward(self, x):
x = self.linear(x)
x = x.contiguous()
return torch.nn.functional.relu(x)
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
input = torch.rand(dense_input_shape, device="cuda").half()
model = Model().eval().cuda().half()
mod_linear = model.linear
m, n = mod_linear.weight.shape
mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda()
# set masked weight
mod_linear.weight = nn.Parameter(mod_linear.weight * mask)
dense_result = model(input)
mod_linear.weight = nn.Parameter(to_sparse_semi_structured(mod_linear.weight))
sparse_result = model(input)
model = torch.compile(model, backend="inductor", fullgraph=True)
sparse_compile_result = model(input)
# test that sparse_compile_result and dense_result are numerically close
assert torch.allclose(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3)
# assert sparse and sparse_compile have the same strides,
# as meta registrations may return contiguous tensors when the output is transposed
# https://github.com/pytorch/pytorch/pull/114477
assert sparse_result.stride() == sparse_compile_result.stride()
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
def test_mlp_contiguous_relu_compile_cusparselt(self):
"""
test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile
"""
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cusparselt", dense_input_shape)
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
def test_mlp_contiguous_relu_compile_cutlass(self):
"""
test for CUTLASS meta registrations (_sparse_semi_structured_linear) + torch.compile
"""
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape)
class TestSparseSemiStructured(TestCase):
def setUp(self):
if not _IS_SM8X:
self.skipTest('Only runs on SM80')
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_to_sparse_semi_structured(self, dtype, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)
assert A.shape == A_sparse.shape
assert A.device == A_sparse.device
assert A.dtype == A_sparse.dtype
assert isinstance(A, torch.Tensor)
assert isinstance(A_sparse, SparseSemiStructuredTensor)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8:
# This should fail
if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B)
else:
with self.assertRaisesRegex(RuntimeError,
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
sparse_result = torch.mm(A_sparse, B)
else:
dense_result = torch.mm(A, B)
sparse_result = torch.mm(A_sparse, B)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16
and will throw an error for int8 + padding
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8 and dense_input_shape in {(1, 128)}:
# padding with int8 throws an error because transposing B yields a contiguous output
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B.t())
else:
with self.assertRaisesRegex(RuntimeError,
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
sparse_result = torch.mm(A_sparse, B.t())
elif dtype is torch.int8:
# test transpose
# NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior.
# CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
else:
# test transpose
dense_result = torch.mm(A, B.t())
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend):
"""
Ensure torch.mm(A_sparse.t(), B) throws error
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
with self.assertRaisesRegex(
NotImplementedError,
r"arg0: SparseSemiStructuredTensor\(.*transposed=True",
):
torch.mm(A_sparse.t(), B)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A, B_sparse.t()) is correct
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
B_sparse = to_sparse_semi_structured(B)
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8:
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
sparse_result = torch.mm(A, B_sparse.t())
else:
dense_result = torch.mm(A, B.t())
sparse_result = torch.mm(A, B_sparse.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A, B_sparse) throws error
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
B_sparse = to_sparse_semi_structured(B)
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
with self.assertRaisesRegex(
NotImplementedError,
r"arg1: SparseSemiStructuredTensor\(.*transposed=False",
):
sparse_result = torch.mm(A, B_sparse)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
@parametrize("inference_mode", [subtest(True), subtest(False)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_linear(self, dense_input_shape, inference_mode, device, backend):
"""
Test nn.Linear has the same numerics
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
input = torch.rand((dense_input_shape), device=device).half()
model = nn.Linear(128, 256).to(device).half()
m, n = model.weight.shape
mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
# set masked weight
model.weight = nn.Parameter(model.weight * mask)
dense_result = model(input)
model.weight = nn.Parameter(to_sparse_semi_structured(model.weight))
if inference_mode:
with torch.inference_mode():
sparse_result = model(input)
else:
sparse_result = model(input)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mlp(self, device, dense_input_shape, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
input = torch.rand(dense_input_shape, device=device).half()
model = (
nn.Sequential(
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.half()
.to(device)
)
for i in range(2):
m, n = model[i].weight.shape
mask = rand_sparse_semi_structured_mask(
m, n, device=device, dtype=torch.bool
)
# set masked weight
model[i].weight = nn.Parameter(model[i].weight * mask)
dense_result = model(input)
for i in range(2):
model[i].weight = nn.Parameter(to_sparse_semi_structured(model[i].weight))
sparse_result = model(input)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_values(self, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(128, 128)
A_sparse = to_sparse_semi_structured(A)
assert A_sparse.values().shape == (128, 64)
assert (A_sparse.values() == 1).all()
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_indices(self, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(128, 128)
A_sparse = to_sparse_semi_structured(A)
assert A_sparse.indices().shape == (128, 8)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_min_sparse_shape(self, dtype, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
config = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[dtype]
A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
if dtype == torch.int8:
dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
# int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R
B_t = B.t().contiguous()
sparse_res = torch.mm(A_sparse, B_t.t())
else:
dense_res = torch.mm(A, B)
sparse_res = torch.mm(A_sparse, B)
assert torch.allclose(sparse_res, dense_res, rtol=1e-3, atol=1e-3)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_unsupported_shape(self, dtype, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(2, 2, dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"):
A_sparse = to_sparse_semi_structured(A)
@dtypes(*all_types_and_complex())
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_unsupported_dtype(self, dtype, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device)
if dtype not in SEMI_STRUCTURED_SUPPORTED_DTYPES:
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"):
A_sparse = to_sparse_semi_structured(A)
else:
A_sparse = to_sparse_semi_structured(A)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_unsupported_dim(self, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = torch.rand(128, 128, 128, device=device, dtype=torch.float16)
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"):
A_sparse = to_sparse_semi_structured(A)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@parametrize("backend", ["cutlass"])
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_linear_cutlass(self, device, dtype, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol):
pattern = '2by4' if dtype != torch.float32 else '1by2'
weight = rand_sparse_semi_structured(m, k, dtype, device, pattern=pattern)
input = make_tensor((*batch_shape, n, k), dtype=dtype, device=device)
bias = make_tensor((m,), dtype=dtype_out, device=device) if add_bias else None
dtype_dense = torch.float32
input_dense = input.to(dtype_dense)
weight_dense = weight.to(dtype_dense)
bias_dense = bias.to(dtype_dense) if add_bias else None
output0 = torch.nn.functional.linear(input_dense, weight_dense, bias=bias_dense)
if activation == "relu":
relu = torch.nn.ReLU()
output0 = relu(output0)
elif activation == "silu":
silu = torch.nn.SiLU()
output0 = silu(output0)
compressed = to_sparse_semi_structured(weight)
weight_sparse = compressed.values()
meta = compressed.indices()
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation)
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
if dtype == torch.float32:
# Inputs are converted to TF32 internally for sparse GEMM,
# so make dense GEMM to do the same for matching results.
orig = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = True
batch_shapes = [[], [3], [3, 1]]
dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
activations = [None, "relu", "silu"]
rtol, atol = 1e-3, 1e-3
if dtype == torch.bfloat16:
rtol, atol = 5e-3, 5e-3
elif dtype == torch.float32:
rtol, atol = 1e-3, 5e-1
for batch_shape, m, n, k, add_bias, activation in \
itertools.product(batch_shapes, range(3), range(3), range(3), (False, True), activations):
if activation == "silu" and dtype == torch.int8:
continue # SiLU not supported for integer inputs
m = 2 ** m * 32
n = 2 ** n * 32
k = 2 ** k * 128
run_test(batch_shape, m, n, k, device, dtype, dtype_out[dtype], add_bias, activation, rtol, atol)
if dtype == torch.float32:
torch.backends.cuda.matmul.allow_tf32 = orig
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
@parametrize("backend", ["cutlass"])
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_conversions(self, device, dtype, backend):
if dtype == torch.float32:
return
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
def run_test(r, c, device, dtype):
pattern = '2by4' if dtype != torch.float32 else '1by2'
dense_ref = rand_sparse_semi_structured(r, c, dtype, device, pattern=pattern)
compressed = to_sparse_semi_structured(dense_ref)
# The torch.ops.aten._to_sparse_semi_structured operator
# uses CUTLASS to perform conversion from given dense
# matrix to the pair of corresponding sparse and metadata
# matrices, with the later used here as a reference to
# compare the metadata matrix produced by conversion
# performed by SparseSemiStructuredTensor class
# constructor against.
_, meta_ref = torch.ops.aten._to_sparse_semi_structured(dense_ref)
meta = compressed.indices()
torch.testing.assert_close(meta, meta_ref, rtol=0, atol=0)
dense = compressed.to_dense()
torch.testing.assert_close(dense, dense_ref, rtol=0, atol=0)
shapes = [[32, 128], [32, 256], [64, 128], [64, 256]]
for r, c in shapes:
run_test(r, c, device, dtype)
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
@parametrize("backend", ["cutlass"])
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_conversions_all_patterns(self, device, dtype, backend):
if dtype == torch.float32:
return
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
r, c = 32, 128
dense_inv, dense_val = rand_sparse_semi_structured_all_patterns(r, c, dtype, device)
compressed = to_sparse_semi_structured(dense_inv)
dense = compressed.to_dense()
torch.testing.assert_close(dense, dense_val, rtol=0, atol=0)
class TestCUSPARSELT(TestCase):
"""
This contains cuSPARSELt specific tests.
"""
def setUp(self):
if not _IS_SM8X:
self.skipTest('Only runs on SM80')
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
self.skipTest('cuSPARSELt not enabled')
else:
SparseSemiStructuredTensor._FORCE_CUTLASS = False
@parametrize("dense_input_shape", [(128, 128)])
def test_cslt_sparse_mm_int8_in_fp16_out(self, dense_input_shape, device):
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
A_compressed = torch._cslt_compress(A)
B = torch.rand(dense_input_shape, device=device).to(torch.int8)
dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=torch.float16)
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=torch.float16)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@dtypes(torch.float16, torch.bfloat16)
def test_cslt_sparse_mm_alpha(self, dtype, device):
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda()
B = torch.ones((256, 128), device=device).to(dtype)
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
A_compressed = torch._cslt_compress(A)
sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha)
alpha_scaled = torch.stack([alpha] * 128).t()
dense_result = alpha_scaled * torch.mm(A.to(torch.float32), B.to(torch.float32))
dense_result = dense_result.to(dtype)
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
def test_cslt_sparse_mm_alpha_int8_in_f16_out(self, device):
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
B = torch.ones((128, 256), device=device).to(torch.int8).t()
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
A_compressed = torch._cslt_compress(A)
sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=torch.float16).cpu()
alpha_scaled = torch.stack([alpha] * 128).t()
dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int32).cpu(), B.to(torch.int32).cpu())
dense_result = dense_result.to(torch.float16)
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
@parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS))
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_cslt_sparse_mm_alg_id(self, device, dtype, alg_id):
# alg_id=3 not supported for float32 dtype
if dtype == torch.float32 and alg_id == 3:
return
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A_compressed = torch._cslt_compress(A)
B = torch.ones((128, 128), device=device).to(dtype)
A_compressed = torch._cslt_compress(A)
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
dense_result = dense_result.to(dtype)
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_cslt_sparse_mm_search(self, device, dtype):
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A_compressed = torch._cslt_compress(A)
B = torch.ones((128, 128), device=device).to(dtype)
A_compressed = torch._cslt_compress(A)
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
# for cuSPARSELt v0.4.0 there is a bug where although there are 5 alg_ids, we run into an error
# when setting using the last one (4)
# in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
instantiate_device_type_tests(TestCUSPARSELT, globals(), only_for="cuda")
if __name__ == "__main__":
run_tests()