mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update
[ghstack-poisoned]
This commit is contained in:
parent
5b34e6afea
commit
7c2cd8ea25
3 changed files with 0 additions and 139 deletions
11
1
11
1
|
|
@ -1,11 +0,0 @@
|
|||
ERROR: file or directory not found: 2
|
||||
|
||||
============================= test session starts ==============================
|
||||
platform linux -- Python 3.11.0, pytest-7.4.0, pluggy-1.5.0
|
||||
rootdir: /data/users/vasiliy/pytorch
|
||||
configfile: pytest.ini
|
||||
plugins: hypothesis-6.124.3, typeguard-4.3.0
|
||||
collected 0 items
|
||||
Running 0 items in this shard
|
||||
|
||||
============================ no tests ran in 0.05s =============================
|
||||
|
|
@ -1444,9 +1444,6 @@ void scaled_gemm(
|
|||
const auto scaleType = CUDA_R_32F;
|
||||
const float alpha_val = 1.0;
|
||||
const float beta_val = 0.0;
|
||||
if (scale_dtype == DataType::UFP8){
|
||||
// k = k * 2;
|
||||
}
|
||||
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
|
||||
|
|
|
|||
|
|
@ -843,131 +843,6 @@ class TestFP8MatmulCuda(TestCase):
|
|||
)
|
||||
torch.testing.assert_close(C, C_ref, atol=0, rtol=0)
|
||||
|
||||
# TODO(before land): real skip condition, should only run on CUDA 12.8+ with CUDA capability 10.0+
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
def test_blockwise_nvfp4(self) -> None:
|
||||
|
||||
# inspiration: https://github.com/pytorch/ao/pull/1625
|
||||
|
||||
# not for land, need to move to in-core dtypes
|
||||
from enum import IntEnum
|
||||
class DataType(IntEnum):
|
||||
DEFAULT = 0
|
||||
E8M0 = 1
|
||||
FP4 = 2
|
||||
UFP8 = 3
|
||||
|
||||
device = "cuda"
|
||||
# TODO other shapes
|
||||
M, K, N = 128, 128, 128
|
||||
BLOCK_SIZE = 16
|
||||
# torch.set_printoptions(profile='full', linewidth=320)
|
||||
torch.set_printoptions(edgeitems=7, linewidth=100)
|
||||
|
||||
assert M == K == N
|
||||
# A_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
|
||||
# B_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
|
||||
A_ref = torch.zeros(M, K, device=device, dtype=torch.bfloat16)
|
||||
# B_ref = torch.zeros(N, K, device=device, dtype=torch.bfloat16)
|
||||
# A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
||||
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
||||
|
||||
A_ref[0][0:64 + 4] = 1
|
||||
# A_ref[1][0] = 1
|
||||
print('A_ref', A_ref)
|
||||
|
||||
# probe the kernel
|
||||
if False:
|
||||
idx = M-1
|
||||
idx = 1
|
||||
A_ref[idx][idx] = 1
|
||||
B_ref[idx][idx] = 1
|
||||
|
||||
C_ref = A_ref @ B_ref
|
||||
print('C_ref', C_ref)
|
||||
|
||||
A = A_ref.to(torch.uint8)
|
||||
B = B_ref.to(torch.uint8)
|
||||
|
||||
# super hacky "cast" to fp4_e2m1_x1:
|
||||
# * we only have 1s and 0s in uint8
|
||||
# * 1 in uint8 is 0b00000001
|
||||
# * 1 in fp4_e2m1_x1 is 0b00000010
|
||||
# * just replace the values
|
||||
# TODO(before land): use a generic cast and test numerics more
|
||||
# comprehensively
|
||||
A[A == 1] = 2
|
||||
B[B == 1] = 2
|
||||
|
||||
# now, pack fp4_e2m1_x1 into fp4_e2m1_x2
|
||||
A = pack_uint4(A)
|
||||
B = pack_uint4(B)
|
||||
B = B.t()
|
||||
if False:
|
||||
vala = 0b00100001
|
||||
valb = 0b00101001
|
||||
A[0][0] = 0b00100001
|
||||
A[0][1] = 0b00100001
|
||||
A[1][0] = 0b00100001
|
||||
A[1][1] = 0b00100001
|
||||
B[0][0] = 0b00100001
|
||||
B[0][1] = 0b00100001
|
||||
B[1][0] = 0b00100001
|
||||
B[1][1] = 0b00100001
|
||||
|
||||
print('A', A, A.shape)
|
||||
print('B after t')
|
||||
print(B, B.shape)
|
||||
|
||||
# TODO more scale correctness testing
|
||||
A_scale = torch.full((M, K // BLOCK_SIZE), 1, device=device, dtype=torch.float8_e4m3fn).view(torch.uint8)
|
||||
B_scale = torch.full((N, K // BLOCK_SIZE), 1, device=device, dtype=torch.float8_e4m3fn).view(torch.uint8)
|
||||
# A_scale = torch.full((M, 256), 1, device=device, dtype=torch.float8_e4m3fn).view(torch.uint8)
|
||||
# B_scale = torch.full((N, 256), 1, device=device, dtype=torch.float8_e4m3fn).view(torch.uint8)
|
||||
|
||||
# convert to swizzled format
|
||||
A_scale = to_blocked(A_scale)
|
||||
B_scale = to_blocked(B_scale)
|
||||
|
||||
# https://docs.nvidia.com/cuda/cublas/index.html?highlight=blocked#d-block-quantization
|
||||
# e2m1 max = 6
|
||||
# e4m3 max = 448
|
||||
# calculation max = 128
|
||||
scale_result = torch.tensor(6.0 * 448.0 / 128.0, device=device)
|
||||
print(scale_result)
|
||||
|
||||
# TODO sweep fast_accum
|
||||
|
||||
C = torch._scaled_mm(
|
||||
A,
|
||||
B,
|
||||
# scales are switched
|
||||
B_scale,
|
||||
A_scale,
|
||||
bias=None,
|
||||
# scale_result=scale_result,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=False,
|
||||
a_dtype=DataType.FP4,
|
||||
b_dtype=DataType.FP4,
|
||||
scale_dtype=DataType.UFP8,
|
||||
)
|
||||
print('C', C, C.shape)
|
||||
print('A_max', torch.max(A))
|
||||
print('C_max', torch.max(C))
|
||||
|
||||
# get the indices of the set element
|
||||
# print('in_idx', (idx, idx), 'ref_idx', (C_ref == torch.max(C_ref)).nonzero(), 'result_idx', (C == torch.max(C)).nonzero())
|
||||
|
||||
if False:
|
||||
print('C_ref')
|
||||
print(C_ref, C_ref.shape)
|
||||
print('C')
|
||||
print(C, C.shape)
|
||||
print(C.max(dim=0))
|
||||
print(C.max(dim=1))
|
||||
# torch.testing.assert_close(C, C_ref, atol=0, rtol=0)
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
||||
|
|
|
|||
Loading…
Reference in a new issue