[ghstack-poisoned]
This commit is contained in:
vasiliy 2025-02-08 21:57:55 -08:00
parent 5b34e6afea
commit 7c2cd8ea25
3 changed files with 0 additions and 139 deletions

11
1
View file

@ -1,11 +0,0 @@
ERROR: file or directory not found: 2
============================= test session starts ==============================
platform linux -- Python 3.11.0, pytest-7.4.0, pluggy-1.5.0
rootdir: /data/users/vasiliy/pytorch
configfile: pytest.ini
plugins: hypothesis-6.124.3, typeguard-4.3.0
collected 0 items
Running 0 items in this shard
============================ no tests ran in 0.05s =============================

View file

@ -1444,9 +1444,6 @@ void scaled_gemm(
const auto scaleType = CUDA_R_32F;
const float alpha_val = 1.0;
const float beta_val = 0.0;
if (scale_dtype == DataType::UFP8){
// k = k * 2;
}
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));

View file

@ -843,131 +843,6 @@ class TestFP8MatmulCuda(TestCase):
)
torch.testing.assert_close(C, C_ref, atol=0, rtol=0)
# TODO(before land): real skip condition, should only run on CUDA 12.8+ with CUDA capability 10.0+
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_blockwise_nvfp4(self) -> None:
# inspiration: https://github.com/pytorch/ao/pull/1625
# not for land, need to move to in-core dtypes
from enum import IntEnum
class DataType(IntEnum):
DEFAULT = 0
E8M0 = 1
FP4 = 2
UFP8 = 3
device = "cuda"
# TODO other shapes
M, K, N = 128, 128, 128
BLOCK_SIZE = 16
# torch.set_printoptions(profile='full', linewidth=320)
torch.set_printoptions(edgeitems=7, linewidth=100)
assert M == K == N
# A_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
# B_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
A_ref = torch.zeros(M, K, device=device, dtype=torch.bfloat16)
# B_ref = torch.zeros(N, K, device=device, dtype=torch.bfloat16)
# A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
A_ref[0][0:64 + 4] = 1
# A_ref[1][0] = 1
print('A_ref', A_ref)
# probe the kernel
if False:
idx = M-1
idx = 1
A_ref[idx][idx] = 1
B_ref[idx][idx] = 1
C_ref = A_ref @ B_ref
print('C_ref', C_ref)
A = A_ref.to(torch.uint8)
B = B_ref.to(torch.uint8)
# super hacky "cast" to fp4_e2m1_x1:
# * we only have 1s and 0s in uint8
# * 1 in uint8 is 0b00000001
# * 1 in fp4_e2m1_x1 is 0b00000010
# * just replace the values
# TODO(before land): use a generic cast and test numerics more
# comprehensively
A[A == 1] = 2
B[B == 1] = 2
# now, pack fp4_e2m1_x1 into fp4_e2m1_x2
A = pack_uint4(A)
B = pack_uint4(B)
B = B.t()
if False:
vala = 0b00100001
valb = 0b00101001
A[0][0] = 0b00100001
A[0][1] = 0b00100001
A[1][0] = 0b00100001
A[1][1] = 0b00100001
B[0][0] = 0b00100001
B[0][1] = 0b00100001
B[1][0] = 0b00100001
B[1][1] = 0b00100001
print('A', A, A.shape)
print('B after t')
print(B, B.shape)
# TODO more scale correctness testing
A_scale = torch.full((M, K // BLOCK_SIZE), 1, device=device, dtype=torch.float8_e4m3fn).view(torch.uint8)
B_scale = torch.full((N, K // BLOCK_SIZE), 1, device=device, dtype=torch.float8_e4m3fn).view(torch.uint8)
# A_scale = torch.full((M, 256), 1, device=device, dtype=torch.float8_e4m3fn).view(torch.uint8)
# B_scale = torch.full((N, 256), 1, device=device, dtype=torch.float8_e4m3fn).view(torch.uint8)
# convert to swizzled format
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
# https://docs.nvidia.com/cuda/cublas/index.html?highlight=blocked#d-block-quantization
# e2m1 max = 6
# e4m3 max = 448
# calculation max = 128
scale_result = torch.tensor(6.0 * 448.0 / 128.0, device=device)
print(scale_result)
# TODO sweep fast_accum
C = torch._scaled_mm(
A,
B,
# scales are switched
B_scale,
A_scale,
bias=None,
# scale_result=scale_result,
out_dtype=torch.bfloat16,
use_fast_accum=False,
a_dtype=DataType.FP4,
b_dtype=DataType.FP4,
scale_dtype=DataType.UFP8,
)
print('C', C, C.shape)
print('A_max', torch.max(A))
print('C_max', torch.max(C))
# get the indices of the set element
# print('in_idx', (idx, idx), 'ref_idx', (C_ref == torch.max(C_ref)).nonzero(), 'result_idx', (C == torch.max(C)).nonzero())
if False:
print('C_ref')
print(C_ref, C_ref.shape)
print('C')
print(C, C.shape)
print(C.max(dim=0))
print(C.max(dim=1))
# torch.testing.assert_close(C, C_ref, atol=0, rtol=0)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")