From 42f647219a4eba55315374dd9e6bcf66f6f8cb87 Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Tue, 9 Jul 2024 19:49:11 +0000 Subject: [PATCH] [ROCm] Add int4 support (#129710) - Add AMD support for int4 kernel - Only supports CDNA2 and CDNA3 gpus for now - Uses `mfma_f32_16x16x16bf16` instruction for matrix multiply - Uses `v_and_or_b32` instruction and `__hfma2` instrinsic for unpacking bf16 values - Enable hipify for `__nv_bfloat16` and `__nv_bfloat162` data types - Enable int4 unit tests for CDNA2 and CDNA3 AMD gpus - Fix torchscript issues due to hipify for `__nv_bfloat16` type - TorchScript has its own implementation for bfloat16 type - Implemented in `__nv_bloat16` structure at [resource_strings.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h) - So, we shouldn't hipify any reference of `__nv_bfloat16` in the torchscript implementation - Hence moved the `__nv_bfloat16` direct references in `codegen.cpp` and `cuda_codegen.cpp` to `resource_strings.h` which is already exempted from hipify Fixes #124699 Fixes pytorch-labs/gpt-fast/issues/154 Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/129710 Approved by: https://github.com/malfet --- aten/src/ATen/native/cuda/int4mm.cu | 286 +++++++++++++++++- test/test_linalg.py | 8 +- torch/csrc/jit/codegen/fuser/codegen.cpp | 2 +- .../jit/codegen/fuser/cuda/resource_strings.h | 2 + torch/csrc/jit/tensorexpr/cuda_codegen.cpp | 2 +- torch/testing/_internal/common_cuda.py | 6 + torch/utils/hipify/cuda_to_hip_mappings.py | 2 + 7 files changed, 289 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/cuda/int4mm.cu b/aten/src/ATen/native/cuda/int4mm.cu index fcfcd2e5ebb..129b2798799 100644 --- a/aten/src/ATen/native/cuda/int4mm.cu +++ b/aten/src/ATen/native/cuda/int4mm.cu @@ -1,9 +1,11 @@ -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) +#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) #include #include #include +#if !defined(USE_ROCM) #include #endif +#endif #include #include #include @@ -125,9 +127,38 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) { return diff == 0 ? 0 : uint32_t(Align) - diff; } -constexpr int32_t kWarpSize = 32; +#if defined(USE_ROCM) +// TODO: Support RDNA +constexpr int32_t kWarpSize = 64; + +template +using VecT = T __attribute__((ext_vector_type(Rank))); + +static bool isCDNA2orLater(int index) { + hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index); + std::string device_arch = prop->gcnArchName; + static const std::vector archs = {"gfx90a", "gfx940", "gfx941", "gfx942"}; + for (std::string arch : archs) { + size_t substring = device_arch.find(arch); + if (substring != std::string::npos) { + return true; + } + } + return false; +} + +#else +constexpr int32_t kWarpSize = 32; +#endif + +#if defined (__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#define CDNA2_OR_LATER 1 +#else +#define CDNA2_OR_LATER 0 +#endif + +#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) // f16 vector types struct __align__(2) f16x1 { __half vals[1]; @@ -176,11 +207,19 @@ struct __align__(16) bf16x2x4 { }; struct __align__(16) bf16x2x4_u32 { +#if defined(USE_ROCM) + VecT val[2]; +#else uint32_t vals[4]; +#endif }; struct __align__(8) bf16x2x2_u32 { +#if defined(USE_ROCM) + VecT val; +#else uint32_t vals[2]; +#endif }; struct __align__(4) bf16x2x1_u32 { @@ -202,38 +241,68 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { uint32_t const source_i4s = source; // First, we extract the i4s and construct an intermediate fp16 number. +#if !defined(USE_ROCM) static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; +#endif static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; // We don't have enough mantissa to remove as much shift overhead as FP16, so // we must loop. No shift needed for first item. uint32_t i4s = source_i4s; + +#if defined(USE_ROCM) + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(h[0]) + : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); +#else asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#endif + #pragma unroll for (int ii = 1; ii < kElements / 2; ++ii) { i4s >>= 4; // or is it 8? // (i4s & 0x000f000f) | 0x43004300 +#if defined(USE_ROCM) + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(h[ii]) + : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); +#else asm volatile( "lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[ii]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#endif } // This is the BF16 {-136, -136} represented as an integer. +#if defined(USE_ROCM) +#if ROCM_VERSION >= 60200 + auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308})); + auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80})); +#else + auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308}); + auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80}); +#endif +#else static constexpr uint32_t BF16_BIAS = 0xC308C308; static constexpr uint32_t BF16_ONE = 0x3F803F80; +#endif // Finally, we construct the output numbers. #pragma unroll for (int ii = 0; ii < kElements / 2; ++ii) { // Since this section is for Ampere+, we use bf16 fma to do the bias // subtraction +#if defined(USE_ROCM) + result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS); +#else asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); +#endif } return result; @@ -254,7 +323,11 @@ enum class KReductionType { template struct ALayout_RM { static constexpr int32_t kMTileSize = 16; +#if defined(USE_ROCM) + static constexpr int32_t kNTileSize = 16; +#else static constexpr int32_t kNTileSize = 8; +#endif static constexpr int32_t kKTileSize = 16; template @@ -267,22 +340,37 @@ struct ALayout_RM { int32_t kTiles, int32_t kTileStart, int32_t laneId, - bf16x2x4_u32 out[KTilesToLoad]) { +#if defined(USE_ROCM) + bf16x2x2_u32 out[KTilesToLoad] +#else + bf16x2x4_u32 out[KTilesToLoad] +#endif + ) { +#if defined(USE_ROCM) + const auto mLane = mTile * kMTileSize + (laneId % kMTileSize); + const auto kLane = kTileStart * kKTileSize + (laneId / kMTileSize) * 4; +#else const auto mLane = mTile * kMTileSize + (laneId / 4); const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2; +#endif // access // [mTile * kMTileSize + (laneId / 4)] // [kTileStart * kKTileSize + (laneId % 4) * 2] auto aPtr = reinterpret_cast(A) + mLane * k + kLane; + bool m0InBounds = mLane < m; +#if !defined(USE_ROCM) auto aPtrPlus8Rows = aPtr + 8 * k; - bool m0InBounds = mLane < m; bool m1InBounds = (mLane + 8) < m; +#endif #pragma unroll for (int i = 0; i < KTilesToLoad; ++i) { +#if defined(USE_ROCM) + out[i].val = m0InBounds ? *((VecT *)(aPtr + i * kKTileSize)) : VecT{0, 0, 0, 0}; +#else out[i].vals[0] = m0InBounds ? *reinterpret_cast(aPtr + i * kKTileSize) : uint32_t(0); @@ -296,6 +384,7 @@ struct ALayout_RM { out[i].vals[3] = m1InBounds ? *reinterpret_cast( aPtrPlus8Rows + i * kKTileSize + 8) : uint32_t(0); +#endif } } @@ -312,6 +401,10 @@ struct ALayout_RM { static_assert(ReduceType == KReductionType::None, ""); if constexpr (ReduceType == KReductionType::None) { +#if defined(USE_ROCM) + const int outRow = mTile * kMTileSize + (laneId / kNTileSize) * 4; + const int outCol = nTile * kNTileSize + (laneId % kNTileSize); +#else // sum.x / sum.y are written at // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1] // sum.z / sum.w are written at @@ -319,10 +412,21 @@ struct ALayout_RM { // i.e., same columns, different row. const int outRow = mTile * kMTileSize + (laneId / 4); const int outCol = nTile * kNTileSize + (laneId % 4) * 2; +#endif // Pointer where sum.x / sum.y is written auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol; +#if defined(USE_ROCM) + if (outRow < m) + cPtr[0] = __float2bfloat16(out.x); + if ((outRow + 1) < m) + cPtr[n] = __float2bfloat16(out.y); + if ((outRow + 2) < m) + cPtr[2*n] = __float2bfloat16(out.z); + if ((outRow + 3) < m) + cPtr[3*n] = __float2bfloat16(out.w); +#else auto v01 = __float22bfloat162_rn(float2{out.x, out.y}); auto v23 = __float22bfloat162_rn(float2{out.z, out.w}); @@ -334,6 +438,7 @@ struct ALayout_RM { if (outRow + 8 < m) { *reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23; } +#endif } } }; @@ -342,15 +447,19 @@ template struct BLayout_TC_int4 { static constexpr int32_t kInnerKTiles = InnerKTiles; static constexpr int32_t kMTileSize = 16; +#if defined(USE_ROCM) + static constexpr int32_t kNTileSize = 16; +#else static constexpr int32_t kNTileSize = 8; +#endif static constexpr int32_t kKTileSize = 16; template static __device__ void load( // type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2] - // n / 8: n-tiles (n8) - // k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16) - // 32: value per warp lane + // n-tiles: n / 8 for NV, n /16 for AMD + // k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16 for NV, m16n16k16 for AMD) + // value per warp lane: 32 for NV, 64 for AMD // (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile. // 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest // value) 4 k-tiles packed is a uint32x2 (64 bits) 8 k-tiles packed is a @@ -423,7 +532,11 @@ struct BLayout_TC_int4 { __nv_bfloat162 qScaleAndZero[kNumQGroups]; { +#if defined(USE_ROCM) + int32_t laneN = nTile * kNTileSize + (laneId % kNTileSize); +#else int32_t laneN = nTile * kNTileSize + (laneId / 4); +#endif int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize; int32_t n = nTiles * kNTileSize; @@ -514,9 +627,15 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( int32_t nTiles, int32_t kTiles) { constexpr int32_t kMTileSize = 16; +#if defined(USE_ROCM) + constexpr int32_t kNTileSize = 16; +#else constexpr int32_t kNTileSize = 8; +#endif constexpr int32_t kKTileSize = 16; +#if !defined(USE_ROCM) || CDNA2_OR_LATER + static_assert( ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize && ALayout::kKTileSize == kKTileSize, @@ -550,7 +669,11 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( int32_t mTile = blockIdx.z; int32_t nTile = blockIdx.y; +#if defined(USE_ROCM) + VecT c{0.0f, 0.0f, 0.0f, 0.0f}; +#else float4 c{0.0f, 0.0f, 0.0f, 0.0f}; +#endif // First, handle whole multiples of KTilesPerIteration auto kTilesLimit = roundDown(kTiles, KTilesPerIteration); @@ -562,7 +685,11 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( // // Load data from A // +#if defined(USE_ROCM) + bf16x2x2_u32 a[KTilesPerIteration]; +#else bf16x2x4_u32 a[KTilesPerIteration]; +#endif ALayout::template load( A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a); @@ -596,15 +723,29 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( // We don't simply accumulate into `c` as this creates a too-strong // execution dependency. Instead, we only periodically accumulate into // `c` +#if defined(USE_ROCM) + VecT cTmp[2]; +#else float4 cTmp[2]; +#endif #pragma unroll for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + cTmp[k] = VecT{0.0f, 0.0f, 0.0f, 0.0f}; +#else cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; +#endif } #pragma unroll for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + a[i * kInnerKTiles + j * 2 + k].val, + b[i][(j * 2 + k) / 2].val[((j * 2 + k) % 2)], + cTmp[k], 0, 0, 0); +#else asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" @@ -622,14 +763,22 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( "f"(cTmp[k].y), "f"(cTmp[k].z), "f"(cTmp[k].w)); +#endif } #pragma unroll for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + c[0] += cTmp[k][0]; + c[1] += cTmp[k][1]; + c[2] += cTmp[k][2]; + c[3] += cTmp[k][3]; +#else c.x += cTmp[k].x; c.y += cTmp[k].y; c.z += cTmp[k].z; c.w += cTmp[k].w; +#endif } } } @@ -646,7 +795,11 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( // If we have any remainder k-tiles, some warps will handle them, processing // kInnerKTiles k-tiles at a time if (kTileBaseRemaining < kTiles) { +#if defined(USE_ROCM) + bf16x2x2_u32 a[kInnerKTiles]; +#else bf16x2x4_u32 a[kInnerKTiles]; +#endif ALayout::template load( A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, a); @@ -668,15 +821,29 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( // We don't simply accumulate into `c` as this creates a too-strong // execution dependency. Instead, we only periodically accumulate into // `c` +#if defined(USE_ROCM) + VecT cTmp[2]; +#else float4 cTmp[2]; +#endif #pragma unroll for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + cTmp[k] = VecT{0.0f, 0.0f, 0.0f, 0.0f}; +#else cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; +#endif } #pragma unroll for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + a[j * 2 + k].val, + b[0][(j * 2 + k) / 2].val[((j * 2 + k) % 2)], + cTmp[k], 0, 0, 0); +#else asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" @@ -691,14 +858,22 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( "f"(cTmp[k].y), "f"(cTmp[k].z), "f"(cTmp[k].w)); +#endif } #pragma unroll for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + c[0] += cTmp[k][0]; + c[1] += cTmp[k][1]; + c[2] += cTmp[k][2]; + c[3] += cTmp[k][3]; +#else c.x += cTmp[k].x; c.y += cTmp[k].y; c.z += cTmp[k].z; c.w += cTmp[k].w; +#endif } } } @@ -711,7 +886,14 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( // FIXME: this likely doesn't need to be a true reduction tree, can just be a // serial sum, maybe (unless nvcc/ptxas goes back to its old ways) // smem_sum[warpId][laneId] = TreeReduce4::reduce(c); +#if defined(USE_ROCM) + smem_sum[warpId][laneId].x = c[0]; + smem_sum[warpId][laneId].y = c[1]; + smem_sum[warpId][laneId].z = c[2]; + smem_sum[warpId][laneId].w = c[3]; +#else smem_sum[warpId][laneId] = c; +#endif __syncthreads(); @@ -741,6 +923,9 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( laneId, sum_f32); } +#else + printf("__builtin_amdgcn_mfma_f32_16x16x16bf16_1k is only supported on AMD gpu arch greater than or equal to CDNA2\n"); +#endif } @@ -798,7 +983,12 @@ void launch_tinygemm_kernel( cudaFuncAttributes funcAttr; C10_CUDA_CHECK(cudaFuncGetAttributes( &funcAttr, - func)); +#if defined(USE_ROCM) + (void *)func +#else + func +#endif + )); } // FIXME: parallelize better, smem staging etc? @@ -813,7 +1003,11 @@ __global__ void matrix_to_m16n8k16_Bint4_layout( // innermost k-tiles that we can use is 2. static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), ""); +#if defined(USE_ROCM) + constexpr int32_t kNTileSize = 16; +#else constexpr int32_t kNTileSize = 8; +#endif constexpr int32_t kKTileSize = 16; // gridDim.x corresponds to the number of k-tiles divided by InnerKTiles @@ -825,13 +1019,30 @@ __global__ void matrix_to_m16n8k16_Bint4_layout( #pragma unroll for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) { // n dimension that this lane loads from +#if defined(USE_ROCM) + auto n0 = nTile * kNTileSize + (t % kNTileSize); +#else auto n0 = nTile * kNTileSize + (t / 4); +#endif bool n0Valid = n0 < in.size(0); int32_t ks[8]; auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize; + +#if defined(USE_ROCM) + ks[0] = kBase0 + (t / kNTileSize) * 4; + ks[1] = ks[0] + 1; + ks[2] = ks[0] + 2; + ks[3] = ks[0] + 3; + + auto kBase1 = kBase0 + kKTileSize; + ks[4] = kBase1 + (t / kNTileSize) * 4; + ks[5] = ks[4] + 1; + ks[6] = ks[4] + 2; + ks[7] = ks[4] + 3; +#else ks[0] = kBase0 + (t % 4) * 2; ks[1] = ks[0] + 1; ks[2] = ks[0] + 8; @@ -842,6 +1053,7 @@ __global__ void matrix_to_m16n8k16_Bint4_layout( ks[5] = ks[4] + 1; ks[6] = ks[4] + 8; ks[7] = ks[4] + 8 + 1; +#endif auto pIn = &in[n0][0]; @@ -855,7 +1067,19 @@ __global__ void matrix_to_m16n8k16_Bint4_layout( (v[6] << 12) | (v[4] << 8) | (v[2] << 4) | v[0]; // inner k-tiles pack two at a time +#if defined(USE_ROCM) + // The output tensor shape is [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2], which is specific to Nvidia + // But AMD needs [ceil(n / 16)][ceil(k / (InnerKTiles * 16))][64][InnerKTiles / 2] + // So construct the pointer accordingly + auto bPtr = out.data() + + ((nTile * out.size(1) * kWarpSize * (InnerKTiles / 2)) + + (kOuterTile * kWarpSize * (InnerKTiles / 2)) + + (t * (InnerKTiles / 2)) + + (innerKTile / 2)); + *bPtr = pack; +#else out[nTile][kOuterTile][t][innerKTile / 2] = pack; +#endif } } @@ -872,16 +1096,30 @@ at::Tensor _weight_int4pack_mm_cuda( TORCH_CHECK( A.device() == B.device() && A.device() == qScaleAndZeros.device()); +#if defined(USE_ROCM) + if (!isCDNA2orLater(A.device().index())) { + TORCH_CHECK(false, "_weight_int4pack_mm_cuda is only supported on AMD gpu arch greater than or equal to CDNA2"); + } +#endif + constexpr int32_t kMTileSize = 16; +#if defined(USE_ROCM) + constexpr int32_t kNTileSize = 16; +#else constexpr int32_t kNTileSize = 8; +#endif constexpr int32_t kKTileSize = 16; // row major layout auto m = A.size(0); auto mTiles = divUp(m, kMTileSize); + // To convert the nTiles from tensor storage layout to the actual matrix core layout + constexpr int32_t kNTileSizeTensor = 8; + auto nTileScaleFactor = (kNTileSize / kNTileSizeTensor); + // tensor core layout - auto nTiles = B.size(0); + auto nTiles = (B.size(0) / nTileScaleFactor); auto n = nTiles * kNTileSize; // row major layout @@ -904,7 +1142,7 @@ at::Tensor _weight_int4pack_mm_cuda( TORCH_CHECK(B.is_contiguous()); TORCH_CHECK(B.dim() == 4); TORCH_CHECK(B.size(1) == k / (B_innerKTiles * kKTileSize)); - TORCH_CHECK(B.size(2) == kWarpSize); + TORCH_CHECK(B.size(2) == 32); // Validate the scale and zero point tensor for dequantization // These are the only versions handled at the moment @@ -924,7 +1162,7 @@ at::Tensor _weight_int4pack_mm_cuda( auto C_final = at::empty( {m, n}, at::TensorOptions().dtype(at::kBFloat16).device(A.device())); -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) +#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) auto stream = at::cuda::getCurrentCUDAStream(); #define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \ do { \ @@ -1053,10 +1291,27 @@ at::Tensor _convert_weight_to_int4pack_cuda( // which is the maximum vectorized load/store size TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8); +#if defined(USE_ROCM) + if (!isCDNA2orLater(in.device().index())) { + TORCH_CHECK(false, "_convert_weight_to_int4pack_cuda is only supported on AMD gpu arch greater than or equal to CDNA2"); + } +#endif + +#if defined(USE_ROCM) + constexpr int32_t kNTileSize = 16; +#else constexpr int32_t kNTileSize = 8; +#endif constexpr int32_t kKTileSize = 16; + // GPT-FAST assumes nTileSize of 8 for quantized weight tensor. + // See https://github.com/pytorch-labs/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py#L510 + // Torch dynamo also requires the torch ops has the same output shape for each device. + // See https://github.com/pytorch/pytorch/blob/ec284d3a74ec1863685febd53687d491fd99a161/torch/_meta_registrations.py#L3263 + constexpr int32_t kNTileSizeTensor = 8; + auto nTiles = divUp(in.size(0), kNTileSize); + auto nTilesTensor = divUp(in.size(0), kNTileSizeTensor); // k-tiles are packed back to back in the innermost dimension in order to // allow for 4/8/16 byte loads @@ -1066,11 +1321,14 @@ at::Tensor _convert_weight_to_int4pack_cuda( // each block handles `innerKTiles` k-tiles. // 2 k-tiles are a single int32 + // + // We use the same shape for AMD gpus also to match the GPT-FAST spec. + // Will index it correctly when dereferencing the quantized weight tensor pointer. auto out = at::empty( - {nTiles, kSuperTiles, 32, innerKTiles / 2}, + {nTilesTensor, kSuperTiles, 32, innerKTiles / 2}, at::TensorOptions().dtype(at::kInt).device(in.device())); -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) +#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) auto stream = at::cuda::getCurrentCUDAStream(); dim3 grid(kSuperTiles, nTiles); diff --git a/test/test_linalg.py b/test/test_linalg.py index 81db475f1e3..e0ad1b2ede6 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -31,7 +31,7 @@ from torch.testing._internal.common_dtype import ( floating_and_complex_types_and, floating_types_and, complex_types, ) from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \ - _get_torch_cuda_version + _get_torch_cuda_version, CDNA2OrLater from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel from torch.testing._internal.common_mkldnn import bf32_on_and_off from torch.distributions.binomial import Binomial @@ -6127,7 +6127,8 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: self.skipTest("requires SM80 or later") if TEST_WITH_ROCM: - self.skipTest("_int4_mm not compiled for ROCM") + if not CDNA2OrLater(): + self.skipTest("_int4_mm is supported only for CDNA2 or later") q_group = 32 inner_k_tiles = 2 @@ -6175,7 +6176,8 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: self.skipTest("requires SM80 or later") if TEST_WITH_ROCM: - self.skipTest("_int4_mm not compiled for ROCM") + if not CDNA2OrLater(): + self.skipTest("_int4_mm is supported only for CDNA2 or later") q_group = 32 inner_k_tiles = 2 diff --git a/torch/csrc/jit/codegen/fuser/codegen.cpp b/torch/csrc/jit/codegen/fuser/codegen.cpp index a2d26979c1e..940444f4ce7 100644 --- a/torch/csrc/jit/codegen/fuser/codegen.cpp +++ b/torch/csrc/jit/codegen/fuser/codegen.cpp @@ -66,7 +66,7 @@ static const char* scalarTypeName(const at::ScalarType type) { return "half"; } if (type == at::ScalarType::BFloat16) { - return "__nv_bfloat16"; + return cuda::bfloat16_type_string; } switch (type) { diff --git a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h index 0eb7299223a..e6114f818e3 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h +++ b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h @@ -13,6 +13,8 @@ tensor as input. Correct code for this case is generated, however, nvrtc does not know how to handle int*_t integer types, so typedefs help it handle those cases*/ +static constexpr auto bfloat16_type_string = "__nv_bfloat16"; + #if defined(USE_ROCM) static auto type_declarations_template = at::jit::CodeTemplate(R"( ${HalfHeader} diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp index 602bc49302c..d8f8f1e5796 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp @@ -70,7 +70,7 @@ std::string CudaPrinter::dtypeToCppString(const Dtype& dtype) { case ScalarType::Half: return "half"; case ScalarType::BFloat16: - return "__nv_bfloat16"; + return fuser::cuda::bfloat16_type_string; case ScalarType::Char: return "char"; case ScalarType::Byte: diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 7be663e2171..01eeac86ae1 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -33,6 +33,12 @@ SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)]) +def CDNA2OrLater(): + if TEST_WITH_ROCM: + gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName + return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx940", "gfx941", "gfx942"}) + return False + def evaluate_gfx_arch_exact(matching_arch): if not torch.cuda.is_available(): return False diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index 976e12e42d3..034418afa46 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -537,6 +537,8 @@ CUDA_TYPE_NAME_MAP = collections.OrderedDict( ("CUuuid", ("hipUUID", CONV_TYPE, API_RUNTIME)), ("cudaGraph_t", ("hipGraph_t", CONV_TYPE, API_RAND)), ("cudaGraphExec_t", ("hipGraphExec_t", CONV_TYPE, API_RAND)), + ("__nv_bfloat16", ("__hip_bfloat16", CONV_TYPE, API_RUNTIME)), + ("__nv_bfloat162", ("__hip_bfloat162", CONV_TYPE, API_RUNTIME)), ] )