onnxruntime/cmake/patches/composable_kernel/Add_gfx12x_support.patch
Ted Themistokleous 572e43c5d7
[MIGraphX EP/ ROCm EP] add gfx1200, gfx1201 to CMAKE_HIP_ARCHITECTURES (#22348)
### Description
Add additonal gfx targets for AMD GPU support


### Motivation and Context
Required to integrate mainline onnxruntime support for AMD GPUs

---------

Co-authored-by: Stefan Sokolovic <stsokolo@amd.com>
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2024-10-11 17:31:36 -07:00

2280 lines
104 KiB
Diff

diff --git a/CMakeLists.txt b/CMakeLists.txt
index bc326c8b5..db5ad5052 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -117,7 +117,7 @@ else()
add_definitions(-DPROFILER_ONLY)
set(GPU_TARGETS "" CACHE STRING "" FORCE)
if(GPU_TARGETS)
- message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11")
+ message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12")
endif()
if(GPU_ARCH MATCHES "gfx90")
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a")
@@ -127,8 +127,10 @@ else()
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030")
elseif(GPU_ARCH MATCHES "gfx11")
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102")
+ elseif(GPU_ARCH MATCHES "gfx12")
+ rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201")
else()
- message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11")
+ message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12")
endif()
set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE)
endif()
diff --git a/Jenkinsfile b/Jenkinsfile
index 75800bfc9..b72e2ca4e 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){
def variant = env.STAGE_NAME
def retimage
+
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
try {
(retimage, image) = getDockerImage(conf)
@@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
pipeline {
agent none
- triggers {
- parameterizedCron(CRON_SETTINGS)
- }
options {
parallelsAlwaysFailFast()
}
diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake
index 8654170b3..42070051b 100644
--- a/cmake/EnableCompilerWarnings.cmake
+++ b/cmake/EnableCompilerWarnings.cmake
@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
- -Werror
+ -Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp
index 8c52e4f7d..f8afe8d6d 100644
--- a/example/01_gemm/gemm_wmma_fp16.cpp
+++ b/example/01_gemm/gemm_wmma_fp16.cpp
@@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
- < ALayout,
- BLayout,
- CLayout,
- ADataType,
+ < ALayout,
+ BLayout,
+ CLayout,
+ ADataType,
BDataType,
- CDataType,
- AccDataType,
- CShuffleDataType,
- AElementOp,
- BElementOp,
- CElementOp,
- GemmDefault,
+ CDataType,
+ AccDataType,
+ CShuffleDataType,
+ AElementOp,
+ BElementOp,
+ CElementOp,
+ GemmDefault,
1, // Prefetch stage
128, // BlockSize
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
- 8, // K1
+ 2, // K1
16, // MPerWmma
16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
- S<4, 32, 1>,
- S<1, 0, 2>,
- S<1, 0, 2>,
- 2,
- 8,
- 8,
- true,
- S<4, 32, 1>,
- S<1, 0, 2>,
- S<1, 0, 2>,
- 2,
- 8,
- 8,
- true,
+ S<4, 32, 1>,
+ S<1, 0, 2>,
+ S<1, 0, 2>,
+ 2,
+ 2,
+ 2,
+ true,
+ S<4, 32, 1>,
+ S<1, 0, 2>,
+ S<1, 0, 2>,
+ 2,
+ 2,
+ 2,
+ true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
- S<1, 32, 1, 4>,
+ S<1, 32, 1, 4>,
8>;
// clang-format on
diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc
index b04e4e53a..cb15186c3 100644
--- a/example/01_gemm/run_gemm_example.inc
+++ b/example/01_gemm/run_gemm_example.inc
@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 4:
- ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
+ ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
break;
case 5:
diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt
index ab19f819e..be47665a2 100644
--- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt
+++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt
@@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32)
set(target 1)
endif()
-endforeach()
\ No newline at end of file
+endforeach()
diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
index 2bbf430c4..f556be887 100644
--- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
+++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN =
2,
4,
4,
- true,
+ false,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
4,
- true,
+ false,
1,
1,
S<1, 64, 1, 2>,
diff --git a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
index 4c92c5497..fac19f8b5 100644
--- a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
+++ b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
-#define CK_MHA_USE_WAVE_8
+//#define CK_MHA_USE_WAVE_8
using DeviceMHAFactory =
std::tuple<
#ifdef CK_MHA_USE_WAVE_1
@@ -277,10 +277,10 @@ using DeviceMHAFactory =
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
- MaskingSpec>,
+ MaskingSpec>
#endif
#ifdef CK_MHA_USE_WAVE_8
- ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
+ ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
diff --git a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
index 8e037272b..d463cc871 100644
--- a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
+++ b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
-#define CK_MHA_USE_WAVE_8
+//#define CK_MHA_USE_WAVE_8
using DeviceMHAFactory =
std::tuple<
#ifdef CK_MHA_USE_WAVE_1
@@ -277,10 +277,10 @@ using DeviceMHAFactory =
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
- MaskingSpec>,
+ MaskingSpec>
#endif
#ifdef CK_MHA_USE_WAVE_8
- ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
+ ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt
index 5465adb77..7534bff3b 100644
--- a/example/CMakeLists.txt
+++ b/example/CMakeLists.txt
@@ -60,7 +60,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME)
- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
@@ -134,7 +134,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME)
- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp
index 55f562061..69a7abf62 100644
--- a/include/ck/ck.hpp
+++ b/include/ck/ck.hpp
@@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
+#if defined(__gfx1200__) || defined(__gfx1201__)
+#define __gfx12__
+#endif
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
@@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
-#elif defined(__gfx11__)
+#elif defined(__gfx11__) || defined(__gfx12__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
@@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
-#elif defined(__gfx11__)
+#elif defined(__gfx11__) || defined(__gfx12__)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
@@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_MFMA_GFX940
#endif
-// WMMA instruction
-#ifndef __HIP_DEVICE_COMPILE__ // for host code
-#define CK_USE_AMD_WMMA
-#elif defined(__gfx11__) // for GPU code
-#define CK_USE_AMD_WMMA
-#endif
-
// buffer load
#define CK_USE_AMD_BUFFER_LOAD 1
diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp
index 116bb3ea0..83af2efe8 100644
--- a/include/ck/host_utility/device_prop.hpp
+++ b/include/ck/host_utility/device_prop.hpp
@@ -84,4 +84,9 @@ inline bool is_gfx11_supported()
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
}
+inline bool is_gfx12_supported()
+{
+ return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
+}
+
} // namespace ck
diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
index f8ee283c6..7eb7d42eb 100644
--- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
@@ -13,6 +13,504 @@
namespace ck {
+#ifdef __gfx12__
+template <index_t BlockSize,
+ typename FloatA,
+ typename FloatB,
+ typename FloatAcc,
+ typename ABlockDesc,
+ typename BBlockDesc,
+ index_t MPerBlock,
+ index_t NPerBlock,
+ index_t KPerBlock,
+ index_t MPerWMMA,
+ index_t NPerWMMA,
+ index_t MRepeat,
+ index_t NRepeat,
+ index_t KPack,
+ bool AEnableLds = true,
+ bool BEnableLds = true,
+ bool TransposeC = false>
+/* Option: Read from LDS, big buffer hold all threads required data
+ * Source
+ * A: K0PerBlock x MPerBlock x K1
+ * B: K0PerBlock x NPerBlock x K1
+ * Destination
+ * C, non-transpose
+ * thread level: MRepeat x NRepeat x MAccVgprs
+ * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
+ * KPACK == WMMA_K = 16
+ *
+ * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
+ * Source:
+ * A(if skip LDS): MRepeat x KPack
+ * B(if skip LDS): NRepeat x KPack
+ * Destination
+ * C, non-transpose
+ * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
+ */
+struct BlockwiseGemmWMMA
+{
+ static constexpr auto I0 = Number<0>{};
+ static constexpr auto I1 = Number<1>{};
+ static constexpr auto I2 = Number<2>{};
+ static constexpr auto I3 = Number<3>{};
+ static constexpr auto I4 = Number<4>{};
+ static constexpr auto I5 = Number<5>{};
+ static constexpr auto WmmaK = Number<16>{};
+
+ using ThisThreadBlock = ThisThreadBlock<BlockSize>;
+
+ // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
+ static constexpr index_t WaveSize = 32;
+
+ // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
+ // When not use LDS, each Row read half of whole data from source buffer, exchange the data via
+ // permutation
+ static constexpr index_t A_KRow = 2;
+ static constexpr index_t B_KRow = 2;
+
+ static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
+ static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
+
+ static constexpr auto wmma_gemm =
+ WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
+
+ static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
+ static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
+
+ StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
+ FloatAcc,
+ MRepeat * NRepeat,
+ wmma_gemm.GetRegSizePerWmma(),
+ true>
+ c_thread_buf_;
+
+ __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
+
+ __device__ static auto GetWaveIdx()
+ {
+ const index_t thread_id = ThisThreadBlock::GetThreadId();
+
+ constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
+ make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
+ make_tuple(Sequence<0, 1, 2>{}),
+ make_tuple(Sequence<0>{}));
+
+ return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
+ }
+
+ // Default, Block buffer in LDS, thread level offset enabled
+ __device__ static auto CalculateAThreadOriginDataIndex()
+ {
+ if constexpr(AEnableLds)
+ {
+ const auto wave_idx = GetWaveIdx();
+ const auto waveId_m = wave_idx[I0];
+ const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
+
+ // |KRepeat |MRepeat|MWave |KRow |MLane |KPack
+ return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0);
+ }
+ else
+ {
+ return make_tuple(0, 0, 0, 0, 0, 0);
+ }
+ }
+
+ __device__ static auto CalculateBThreadOriginDataIndex()
+ {
+ if constexpr(BEnableLds)
+ {
+ const auto wave_idx = GetWaveIdx();
+ const auto waveId_n = wave_idx[I1];
+ const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
+
+ // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
+ return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0);
+ }
+ else
+ {
+ return make_tuple(0, 0, 0, 0, 0, 0);
+ }
+ }
+
+ template <index_t m0, index_t n0>
+ __device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
+ {
+ const auto wave_idx = GetWaveIdx();
+
+ const auto waveId_m = wave_idx[I0];
+ const auto waveId_n = wave_idx[I1];
+
+ const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
+
+ constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
+ make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
+ make_tuple(Sequence<0>{}),
+ make_tuple(Sequence<0, 1, 2>{}));
+
+ constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
+ make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
+ make_tuple(Sequence<0>{}),
+ make_tuple(Sequence<0, 1, 2>{}));
+
+ const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
+ make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
+ const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
+ make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
+
+ return make_tuple(c_thread_m, c_thread_n);
+ }
+
+ template <index_t m0, index_t n0>
+ __device__ static auto CalculateCThreadOriginDataIndex7D(Number<m0>, Number<n0>)
+ {
+ const auto wave_idx = GetWaveIdx();
+
+ const auto waveId_m = wave_idx[I0];
+ const auto waveId_n = wave_idx[I1];
+
+ const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
+
+ return make_tuple(
+ Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
+ }
+
+ using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
+ __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
+ Tuple6 b_origin = CalculateBThreadOriginDataIndex())
+ : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
+ {
+ static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
+ "wrong! Desc should be known at compile-time");
+
+ static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
+ "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
+
+ static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
+ NPerBlock % (NPerWMMA * NRepeat) == 0,
+ "wrong!");
+ }
+
+ // transposed WMMA output C' = B' * A'
+ __host__ __device__ static constexpr auto
+ GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
+ {
+ constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
+ wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
+
+ constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
+
+ return make_naive_tensor_descriptor_packed(
+ // |MRepeat |MWave |MSubGroup |NRepeat |NWave
+ // |NThreadPerSubGroup |MAccVgprs
+ make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
+ }
+
+ // Thread level, register decriptor. Vector-write
+ __host__ __device__ static constexpr auto
+ GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
+ {
+ constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
+ wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
+
+ constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
+ constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
+ return make_naive_tensor_descriptor(
+ // |MRepeat |MWave |MSubGroup |NRepeat |NWave
+ // |NThreadPerSubGroup |MAccVgprs
+ make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
+ make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
+ Number<NRepeat>{} * MAccVgprs * AccStride,
+ Number<NRepeat>{} * MAccVgprs * AccStride,
+ MAccVgprs * AccStride,
+ MAccVgprs * AccStride,
+ MAccVgprs * AccStride,
+ AccStride));
+ }
+
+ template <typename CGridDesc_M_N>
+ __host__ __device__ static constexpr auto
+ MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
+ const CGridDesc_M_N& c_grid_desc_m_n)
+ {
+ const auto M = c_grid_desc_m_n.GetLength(I0);
+ const auto N = c_grid_desc_m_n.GetLength(I1);
+
+ const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
+ transform_tensor_descriptor(
+ c_grid_desc_m_n,
+ make_tuple(
+ make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
+ make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
+ make_tuple(Sequence<0>{}, Sequence<1>{}),
+ make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
+
+ return wmma_gemm
+ .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
+ c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
+ }
+
+ // transposed WMMA output C' = B' * A'
+ __host__ __device__ static constexpr auto
+ GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
+ {
+ constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
+ make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
+ Number<MWaves>{},
+ Number<MPerWMMA>{},
+ Number<NRepeat>{},
+ Number<NWaves>{},
+ Number<NPerWMMA>{}));
+
+ return wmma_gemm
+ .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
+ c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
+ }
+
+ // Provide dimension size
+ __host__ __device__ static constexpr auto
+ GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
+ {
+ constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
+ make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
+ Number<MWaves>{},
+ Number<MPerWMMA>{},
+ Number<NRepeat>{},
+ Number<NWaves>{},
+ Number<NPerWMMA>{}));
+
+ return wmma_gemm
+ .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
+ c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
+ }
+
+ // Describe how data allocated in thread copy src buffer
+ // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
+ static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
+ static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
+
+ template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
+ __device__ void Run(const ABlockBuffer& a_block_buf,
+ const BBlockBuffer& b_block_buf,
+ CThreadBuffer& c_thread_buf) const
+ {
+ auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
+ a_thread_desc_.GetElementSpaceSize());
+ auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
+ b_thread_desc_.GetElementSpaceSize());
+
+ static_assert(KPack % (A_K1 * A_KRow) == 0, "");
+ static_assert(KPack % (B_K1 * B_KRow) == 0, "");
+
+ // basic intrinsic to determine loopover direction
+ if constexpr(MRepeat < NRepeat)
+ {
+ static_for<0, KPerBlock / KPack, 1>{}(
+ [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
+ static_for<0, MRepeat, 1>{}([&](auto m0) {
+ // read A
+ a_thread_copy_.Run(
+ a_block_desc_k0_m0_m1_m2_k1,
+ make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
+ a_block_buf,
+ a_thread_desc_,
+ make_tuple(I0, m0, I0, I0, I0, I0),
+ a_thread_buf);
+
+ static_for<0, NRepeat, 1>{}([&](auto n0) {
+ // read B
+ b_thread_copy_.Run(
+ b_block_desc_k0_n0_n1_n2_k1,
+ make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
+ b_block_buf,
+ b_thread_desc_,
+ make_tuple(I0, n0, I0, I0, I0, I0),
+ b_thread_buf);
+
+ vector_type<FloatA, KPack / A_KRow> a_thread_vec;
+ vector_type<FloatB, KPack / B_KRow> b_thread_vec;
+
+ static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
+ a_thread_vec.template AsType<FloatA>()(i) =
+ a_thread_buf[Number<a_thread_desc_.CalculateOffset(
+ make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
+ });
+
+ static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
+ b_thread_vec.template AsType<FloatB>()(i) =
+ b_thread_buf[Number<b_thread_desc_.CalculateOffset(
+ make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
+ });
+
+ using wmma_input_type_a =
+ typename vector_type<FloatA, WmmaK / A_KRow>::type;
+ using wmma_input_type_b =
+ typename vector_type<FloatB, WmmaK / B_KRow>::type;
+
+ constexpr index_t c_offset =
+ c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
+
+ wmma_gemm.template Run(
+ a_thread_vec.template AsType<wmma_input_type_a>(),
+ b_thread_vec.template AsType<wmma_input_type_b>(),
+ c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
+ });
+ });
+ });
+ }
+ else
+ {
+ static_for<0, NRepeat, 1>{}([&](auto n0) {
+ static_for<0, MRepeat, 1>{}([&](auto m0) {
+ static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
+ // k=0,kpack*1, ..
+ // read B
+ b_thread_copy_.Run(
+ b_block_desc_k0_n0_n1_n2_k1,
+ make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
+ b_block_buf,
+ b_thread_desc_,
+ make_tuple(I0, n0, I0, I0, I0, I0),
+ b_thread_buf);
+ // read A
+ a_thread_copy_.Run(
+ a_block_desc_k0_m0_m1_m2_k1,
+ make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
+ a_block_buf,
+ a_thread_desc_,
+ make_tuple(I0, m0, I0, I0, I0, I0),
+ a_thread_buf);
+
+ vector_type<FloatA, KPack / A_KRow> a_thread_vec;
+ vector_type<FloatB, KPack / B_KRow> b_thread_vec;
+
+ static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
+ a_thread_vec.template AsType<FloatA>()(i) =
+ a_thread_buf[Number<a_thread_desc_.CalculateOffset(
+ make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
+ });
+
+ static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
+ b_thread_vec.template AsType<FloatB>()(i) =
+ b_thread_buf[Number<b_thread_desc_.CalculateOffset(
+ make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
+ });
+
+ using wmma_input_type_a =
+ typename vector_type<FloatA, WmmaK / A_KRow>::type;
+ using wmma_input_type_b =
+ typename vector_type<FloatB, WmmaK / B_KRow>::type;
+
+ constexpr index_t c_offset =
+ c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
+
+ wmma_gemm.template Run(
+ a_thread_vec.template AsType<wmma_input_type_a>(),
+ b_thread_vec.template AsType<wmma_input_type_b>(),
+ c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
+ });
+ });
+ });
+ }
+ }
+
+ protected:
+ static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
+ make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
+ make_tuple(Number<A_K1>{},
+ Number<KPack / A_KRow>{},
+ Number<A_K1>{},
+ Number<A_K1>{},
+ Number<A_K1>{},
+ Number<1>{}));
+
+ static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
+ make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
+ make_tuple(Number<B_K1>{},
+ Number<KPack / B_KRow>{},
+ Number<B_K1>{},
+ Number<B_K1>{},
+ Number<B_K1>{},
+ Number<1>{}));
+
+ // C[M, N, NumRegWMMA]
+ static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
+ make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
+
+ template <bool EnableLds>
+ struct AThreadCopySelector;
+
+ template <>
+ struct AThreadCopySelector<true>
+ {
+ using type =
+ ThreadwiseTensorSliceTransfer_v4<FloatA,
+ FloatA,
+ decltype(a_block_desc_k0_m0_m1_m2_k1),
+ decltype(a_thread_desc_),
+ Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
+ Sequence<0, 1, 2, 3, 4, 5>,
+ 5,
+ A_K1,
+ A_K1>;
+ };
+
+ template <>
+ struct AThreadCopySelector<false>
+ {
+ using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
+ FloatA,
+ FloatA,
+ decltype(a_block_desc_k0_m0_m1_m2_k1),
+ decltype(a_thread_desc_),
+ tensor_operation::element_wise::PassThrough,
+ Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
+ Sequence<0, 1, 2, 3, 4, 5>,
+ 5,
+ A_K1,
+ false>;
+ };
+
+ template <bool EnableLds>
+ struct BThreadCopySelector;
+
+ template <>
+ struct BThreadCopySelector<true>
+ {
+ using type =
+ ThreadwiseTensorSliceTransfer_v4<FloatB,
+ FloatB,
+ decltype(b_block_desc_k0_n0_n1_n2_k1),
+ decltype(b_thread_desc_),
+ Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
+ Sequence<0, 1, 2, 3, 4, 5>,
+ 5,
+ B_K1,
+ B_K1>;
+ };
+
+ template <>
+ struct BThreadCopySelector<false>
+ {
+ using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
+ FloatB,
+ FloatB,
+ decltype(b_block_desc_k0_n0_n1_n2_k1),
+ decltype(b_thread_desc_),
+ tensor_operation::element_wise::PassThrough,
+ Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
+ Sequence<0, 1, 2, 3, 4, 5>,
+ 5,
+ B_K1,
+ false>;
+ };
+
+ typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
+ typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
+};
+#else
template <index_t BlockSize,
typename FloatA,
typename FloatB,
@@ -529,5 +1027,6 @@ struct BlockwiseGemmWMMA
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
};
+#endif
} // namespace ck
diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
index e5e6245cb..1f7d50429 100644
--- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
@@ -488,7 +488,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// sync point.
if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
{
+#ifdef __gfx12__
+ asm volatile("\
+ s_barrier_signal -1 \n \
+ s_barrier_wait -1 \
+ " ::);
+#else
asm volatile("s_barrier" ::);
+#endif
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
index a15759559..ab3f3856a 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
- static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
- static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
+ static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
+ static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
+
+ static constexpr auto AEnableLds_auto =
+ (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true;
+ static constexpr auto BEnableLds_auto =
+ (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
- if(ck::is_gfx11_supported())
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
}
else
{
- if(!(arg.a_kz_stride_ == 1 &&
- arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
+ if(!(arg.a_kz_stride_ == 1))
{
- printf("DeviceOp: Vector Access A-k check failure\n");
- return false;
+ index_t LastK =
+ AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6);
+ if(LastK % ABlockTransferSrcScalarPerVector == 0)
+ {
+ printf("DeviceOp: Vector Access A-k check failure\n");
+ return false;
+ }
}
}
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
index 8fd14afc0..1b487502f 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
@@ -70,8 +70,9 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
- defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
+ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
+ defined(__gfx12__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
- ck::is_gfx103_supported() || ck::is_gfx11_supported())
+ ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
bool pass = true;
pass = pass && arg.K_ % K1 == 0;
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
index f6b701ab1..102611838 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
@@ -56,7 +56,7 @@ __global__ void
bool input_permute,
bool output_permute)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off
// ***************************************************
@@ -159,6 +159,7 @@ __global__ void
ignore = O;
ignore = G0;
ignore = G1;
+ ignore = alpha;
ignore = input_permute;
ignore = output_permute;
#endif // end of if (defined(__gfx11__))
@@ -187,7 +188,7 @@ __global__ void
index_t head_size,
float alpha)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off
// ***************************************************
@@ -321,7 +322,7 @@ __global__ void
index_t head_size,
float alpha)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off
// ***************************************************
@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static bool IsSupportedArgument(const RawArg& arg)
{
- if(ck::is_gfx11_supported())
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
index 9d5b74be6..017d28641 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
@@ -601,9 +601,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return false;
}
- if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" &&
- ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" &&
- std::is_same<ADataType, double>::value)
+ if(!ck::is_lds_direct_load_supported() && std::is_same<ADataType, double>::value)
{
return false;
}
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
index b84e18130..1edae33be 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
@@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
{
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
- ck::is_gfx11_supported()))
+ ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
return false;
}
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
index bf96324d0..553143e28 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
- if(ck::is_gfx11_supported())
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>))
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
index b1784b385..eb0fb55f5 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
@@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
- ck::is_gfx11_supported())
+ ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
index 23858096d..811f1ae93 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
@@ -50,8 +50,9 @@ __global__ void
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
- defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
+ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
+ defined(__gfx12__))
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
- ck::is_gfx103_supported() || ck::is_gfx11_supported())
+ ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
index a1ef37cc8..35f1c77f8 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
- if(ck::is_gfx11_supported())
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
index 93ab8a7e1..a7cc546f5 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
- static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
- static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
- static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
-
- static constexpr auto AEnableLds_auto =
- (NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
+ static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
+ static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
+ static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
+ static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
+ static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
+
+ static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) &&
+ is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
+ ? false
+ : true;
static constexpr auto BEnableLds_auto =
- (MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
+ (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) &&
+ is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
+ ? false
+ : true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
@@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
- if(ck::is_gfx11_supported())
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>))
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
index 6f74838fb..6bb5d431c 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
// check device
- if(ck::is_gfx11_supported())
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
index bd264a3c8..7047e1bda 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
@@ -48,8 +48,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
- defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
+ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
+ defined(__gfx12__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
index 211185dfb..5738be0fb 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
// check device
- if(ck::is_gfx11_supported())
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
index 7cfbd8a8f..5d5a9de7d 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
@@ -90,8 +90,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
- defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
+ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
+ defined(__gfx12__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
@@ -666,7 +667,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
- ck::is_gfx103_supported() || ck::is_gfx11_supported()))
+ ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
return false;
}
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
index 6a4d97d7d..c65370b51 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
@@ -107,7 +107,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
- defined(__gfx11__))
+ defined(__gfx11__) || defined(__gfx12__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
@@ -602,7 +602,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
- ck::is_gfx11_supported()))
+ ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
return false;
}
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
index 24bd0f242..cfb64e0ee 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace ctc = tensor_layout::convolution;
// check device
- if(ck::is_gfx11_supported())
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
index ac392cddc..060a16d1e 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
@@ -39,8 +39,9 @@ __global__ void
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
- defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
+ defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \
+ defined(__gfx12__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
@@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
- ck::is_gfx103_supported() || ck::is_gfx11_supported())
+ ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
index 71f7ac04c..67a100a11 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
@@ -61,7 +61,7 @@ __global__ void
bool input_permute,
bool output_permute)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off
// ***************************************************
@@ -166,6 +166,7 @@ __global__ void
ignore = O;
ignore = G0;
ignore = G1;
+ ignore = alpha;
ignore = input_permute;
ignore = output_permute;
#endif // end of if (defined(__gfx11__))
@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg)
{
- if(ck::is_gfx11_supported())
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
index 4e14ed3a5..cc88c1a10 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
@@ -60,7 +60,7 @@ __global__ void
bool input_permute,
bool output_permute)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off
// ***************************************************
@@ -165,6 +165,7 @@ __global__ void
ignore = O;
ignore = G0;
ignore = G1;
+ ignore = alpha;
ignore = input_permute;
ignore = output_permute;
#endif // end of if (defined(__gfx11__))
@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg)
{
- if(ck::is_gfx11_supported())
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
index 16717ff81..1754e07e6 100644
--- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if constexpr(B0EnableLds)
{
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
- constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
- constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
+ constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
+ constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
+#ifdef __gfx12__
+ constexpr auto B_KRow = I2;
+#else
constexpr auto B_KRow = I1;
+#endif
return transform_tensor_descriptor(
B0BlockDesc_{},
- make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
+ make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if constexpr(B1EnableLds)
{
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
- constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
- constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
+ constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
+ constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
+#ifdef __gfx12__
+ constexpr auto B_LRow = I2;
+#else
constexpr auto B_LRow = I1;
+#endif
return transform_tensor_descriptor(
B1BlockDesc_{},
- make_tuple(make_unmerge_transform(make_tuple(Number<B_L0>{}, B_LRow)),
+ make_tuple(make_unmerge_transform(make_tuple(Number<B_L0 / B_LRow>{}, B_LRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_L1>{})),
diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
index 499eb7eb0..21dac6f9e 100644
--- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
+++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
@@ -50,7 +50,7 @@ __global__ void
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
+ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
+ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
+#ifdef __gfx12__
+ constexpr auto A_KRow = I2;
+#else
constexpr auto A_KRow = I1;
+#endif
return transform_tensor_descriptor(
ABlockDesc_{},
- make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
+ make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma
if constexpr(BEnableLds)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
+ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
+ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
+#ifdef __gfx12__
+ constexpr auto B_KRow = I2;
+#else
constexpr auto B_KRow = I1;
+#endif
return transform_tensor_descriptor(
BBlockDesc_{},
- make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
+ make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
index 82d010a99..fdda649ef 100644
--- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
@@ -54,7 +54,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
@@ -147,7 +147,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_etile_map)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// printf("entry kernel launch");
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
@@ -237,7 +237,7 @@ __global__ void
const CDEElementwiseOperation cde_element_op,
const Block2CTileMap block_2_ctile_map)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma
}
else
{
+ constexpr auto A_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
- constexpr auto K0PerWmma = WmmaK / 2 / K1;
+ constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{},
@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma
}
else
{
+ constexpr auto B_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
- constexpr auto K0PerWmma = WmmaK / 2 / K1;
+ constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{},
@@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
+ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
+ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
+#ifdef __gfx12__
+ constexpr auto A_KRow = I2;
+#else
constexpr auto A_KRow = I1;
+#endif
return transform_tensor_descriptor(
ABlockDesc_{},
- make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
+ make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
@@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma
if constexpr(BEnableLds)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
+ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
+ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
+#ifdef __gfx12__
+ constexpr auto B_KRow = I2;
+#else
constexpr auto B_KRow = I1;
+#endif
return transform_tensor_descriptor(
BBlockDesc_{},
- make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
+ make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
@@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
{
- constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
- constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
-
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
- Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
+ Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
I1,
- Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
+ Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
}
@@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
- const auto MBlock = M / MPerBlock;
- const auto NBlock = N / NPerBlock;
+ const auto MBlock = M / MPerBlock;
+ const auto NBlock = N / NPerBlock;
+
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
index 8e4117593..4458b9356 100644
--- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
@@ -45,7 +45,7 @@ __global__ void
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma
}
else
{
+ constexpr auto A_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
- constexpr auto K0PerWmma = WmmaK / 2 / K1;
+ constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{},
@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma
}
else
{
+
+ constexpr auto B_KRow = I2;
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
- constexpr auto K0PerWmma = WmmaK / 2 / K1;
+ constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{},
@@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
+ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
+ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
+#ifdef __gfx12__
+ constexpr auto A_KRow = I2;
+#else
constexpr auto A_KRow = I1;
+#endif
+
return transform_tensor_descriptor(
ABlockDesc_{},
- make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
+ make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
@@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma
if constexpr(BEnableLds)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
+ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
+ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
+#ifdef __gfx12__
+ constexpr auto B_KRow = I2;
+#else
constexpr auto B_KRow = I1;
+#endif
return transform_tensor_descriptor(
BBlockDesc_{},
- make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
+ make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
@@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma
c_grid_desc_m_n);
}
- using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
- remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
- CGridDesc_M_N{}))>;
- using DefaultBlock2CTileMap =
- remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
-
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
@@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma
b_block_space_size_aligned * sizeof(BDataType));
};
+ using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
+ remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
+ CGridDesc_M_N{}))>;
+ using DefaultBlock2CTileMap =
+ remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
+
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
index 6772524e0..174074990 100644
--- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
+++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
@@ -35,8 +35,9 @@ __global__ void
const Block2ETileMap block_2_tile_map,
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
- defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
+ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
+ defined(__gfx12__))
GridwiseTensorRearrangeKernel::Run(in_grid_desc,
p_in_global,
out_grid_desc,
diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
index bcce930fc..d7a6a3624 100644
--- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
@@ -1304,7 +1304,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
ElementwiseOperation element_op_;
};
-// Specilized for WMMA
+// Specilized for WMMA-Navi3
// A single Wave32 is composed by double row
// Data exchange allowed between these two rows
// This RowLane Dst buf will be filled from two Src buf
@@ -1439,4 +1439,111 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
ElementwiseOperation element_op_{};
};
+// Specilized for WMMA-Navi4
+template <typename SrcData,
+ typename DstData,
+ typename SrcDesc,
+ typename DstDesc,
+ typename ElementwiseOperation,
+ typename SliceLengths,
+ typename DimAccessOrder,
+ index_t DstVectorDim,
+ index_t DstScalarPerVector,
+ bool IntraRowSwizzlePerm,
+ typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
+ bool>::type = false>
+struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
+{
+ static constexpr index_t nDim = SliceLengths::Size();
+
+ using Index = MultiIndex<nDim>;
+
+ __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index& src_idx)
+ {
+ static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
+ "wrong! Desc need to known at compile-time");
+
+ static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
+ "wrong! Not divisible");
+ ignore = src_idx;
+ }
+
+ template <typename SrcSliceOriginIdx,
+ typename DstSliceOriginIdx,
+ typename SrcBuffer,
+ typename DstBuffer>
+ __device__ void Run(const SrcDesc&,
+ const SrcSliceOriginIdx&,
+ const SrcBuffer& src_buf,
+ const DstDesc&,
+ const DstSliceOriginIdx&,
+ DstBuffer& dst_buf) const
+ {
+ static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
+ "wrong! Desc need to known at compile-time");
+
+ static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value &&
+ is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
+ "wrong! SliceOrigin need to known at compile-time");
+
+ static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
+ "wrong! Buffer need to be StaticBuffer");
+
+ // SrcDesc and src_slice_origin_idx are known at compile-time
+ constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
+ constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
+ constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
+ constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
+
+ // scalar per access on each dim
+ constexpr auto dst_scalar_per_access = generate_sequence(
+ detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
+
+ constexpr auto dst_scalar_step_in_vector =
+ generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
+
+ using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
+ DimAccessOrder,
+ remove_cv_t<decltype(dst_scalar_per_access)>>;
+
+ static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
+ "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
+
+ constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
+
+ static_for<0, num_access, 1>{}([&](auto idx_1d) {
+ constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
+
+ // copy data from src_buf into dst_vector
+ static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
+ // src_desc error, non constexpr, caused by merge transform
+ constexpr index_t src_offset = src_desc.CalculateOffset(
+ src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
+
+ constexpr index_t dst_offset = dst_desc.CalculateOffset(
+ dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
+
+ SrcData v_this_row;
+ // int type temp value due to intrinsic requirement
+ int temp = 0;
+
+ // apply element-wise operation
+ element_op_(v_this_row, src_buf[Number<src_offset>{}]);
+
+ // apply intra-row permute.
+ if constexpr(IntraRowSwizzlePerm)
+ {
+ temp = __builtin_amdgcn_permlane16(
+ temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
+ v_this_row = type_convert_sp<SrcData>(temp);
+ }
+
+ // apply type convert
+ dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
+ });
+ });
+ }
+ ElementwiseOperation element_op_{};
+};
+
} // namespace ck
diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
index 565195f53..9a9ebf559 100644
--- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
@@ -11,12 +11,17 @@ namespace ck {
enum struct WmmaInstr
{
+ // gfx11
wmma_f32_16x16x16_f16 = 0,
wmma_f32_16x16x16_bf16,
wmma_f16_16x16x16_f16,
wmma_bf16_16x16x16_bf16,
wmma_i32_16x16x16_iu8,
- wmma_i32_16x16x16_iu4
+ wmma_i32_16x16x16_iu4,
+ // gfx12
+ wmma_f32_16x16x16_f16_gfx12,
+ wmma_f32_16x16x16_bf16_gfx12,
+ wmma_i32_16x16x16_iu8_gfx12,
};
/*
@@ -279,6 +284,122 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
}
};
+// gfx12
+
+// A-swizzled
+template <index_t WaveSize>
+struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
+ WaveSize,
+ typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
+{
+ // Absolute fixing property
+ // * Data Pixel
+ static constexpr index_t m_per_wmma = 16;
+ static constexpr index_t n_per_wmma = 16;
+ static constexpr index_t k_per_wmma = 16;
+ // static constexpr index_t src_a_data_size = 2;
+ // static constexpr index_t src_b_data_size = 2;
+ // static constexpr index_t acc_data_size = 4;
+ // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
+ static constexpr index_t acc_data_size = 4;
+ static constexpr index_t acc_pack_number = 1;
+ static constexpr index_t num_thread_per_subgroups = n_per_wmma;
+
+ // Wave mode dependent propety
+ static constexpr index_t wave_size = Number<WaveSize>{};
+ // * Fixed in Navi3x, Will be wave mode dependent on Navi4x
+ // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
+ // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
+ // * num_acc_vgprs_per_wave alone M direction
+ // * num_subgroups alone M direction
+ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
+ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
+
+ template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
+ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
+ {
+ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
+ if constexpr(wave_size == 32)
+ {
+ intrin_wmma_f32_16x16x16_f16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
+ }
+ }
+};
+
+template <index_t WaveSize>
+struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16_gfx12,
+ WaveSize,
+ typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
+{
+ // Absolute fixing property
+ static constexpr index_t m_per_wmma = 16;
+ static constexpr index_t n_per_wmma = 16;
+ static constexpr index_t k_per_wmma = 16;
+ // static constexpr index_t src_a_data_size = 2;
+ // static constexpr index_t src_b_data_size = 2;
+ static constexpr index_t acc_data_size = 4;
+ static constexpr index_t acc_pack_number = 1;
+ static constexpr index_t num_thread_per_subgroups = n_per_wmma;
+
+ // Wave mode dependent propety
+ static constexpr index_t wave_size = Number<WaveSize>{};
+ // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
+ // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
+ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
+ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
+
+ template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
+ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
+ {
+ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
+ if constexpr(wave_size == 32)
+ {
+ intrin_wmma_f32_16x16x16_bf16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
+ }
+ }
+};
+
+template <index_t WaveSize>
+struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8_gfx12,
+ WaveSize,
+ typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
+{
+ // Absolute fixing property
+ static constexpr index_t m_per_wmma = 16;
+ static constexpr index_t n_per_wmma = 16;
+ static constexpr index_t k_per_wmma = 16;
+ // static constexpr index_t src_a_data_size = 2;
+ // static constexpr index_t src_b_data_size = 2;
+ static constexpr index_t acc_data_size = 4;
+ static constexpr index_t acc_pack_number = 1;
+ static constexpr index_t num_thread_per_subgroups = n_per_wmma;
+
+ // Wave mode dependent propety
+ static constexpr index_t wave_size = Number<WaveSize>{};
+ // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
+ // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
+ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
+ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
+
+ template <index_t MPerWmma,
+ index_t NPerWmma,
+ class FloatA,
+ class FloatB,
+ class FloatC,
+ bool neg_a = false,
+ bool neg_b = false,
+ bool clamp = false>
+ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
+ {
+ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
+ if constexpr(wave_size == 32)
+ {
+ intrin_wmma_i32_16x16x16_iu8_w32_gfx12<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
+ a, b, reg_c);
+ }
+ }
+};
+
template <typename src_type_a,
typename src_type_b,
typename dst_type,
@@ -296,13 +417,21 @@ struct WmmaSelector
template <>
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{
+#ifdef __gfx12__
+ return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
+#else
return WmmaInstr::wmma_f32_16x16x16_f16;
+#endif
}
template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
{
+#ifdef __gfx12__
+ return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
+#else
return WmmaInstr::wmma_f32_16x16x16_bf16;
+#endif
}
template <>
@@ -320,8 +449,13 @@ struct WmmaSelector
template <>
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
{
+#ifdef __gfx12__
+ return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
+#else
return WmmaInstr::wmma_i32_16x16x16_iu8;
+#endif
}
+
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
@@ -502,6 +636,9 @@ struct WmmaGemm
__device__ static auto GetSubGroupId()
{
+ static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups ==
+ wmma_instr.wave_size,
+ "");
return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
}
@@ -516,12 +653,20 @@ struct WmmaGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
+#ifdef __gfx12__
+ return GetLaneIdUnderSubGroup();
+#else
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
+#endif
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
+#ifdef __gfx12__
+ return GetLaneIdUnderSubGroup();
+#else
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
+#endif
}
__device__ static CIndex GetBeginOfThreadBlk()
diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp
index 1bb0140f3..322a0f94b 100644
--- a/include/ck/utility/amd_wmma.hpp
+++ b/include/ck/utility/amd_wmma.hpp
@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
}
};
+// gfx12
+/********************************WAVE32 MODE***********************************************/
+
+#if defined(__gfx1200__) || defined(__gfx1201__)
+#define __gfx12__
+#endif
+
+// src: fp16, dst: fp32
+template <index_t MPerWave, index_t NPerWave>
+struct intrin_wmma_f32_16x16x16_f16_w32_gfx12;
+
+template <>
+struct intrin_wmma_f32_16x16x16_f16_w32_gfx12<16, 16>
+{
+ template <class FloatC>
+ __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
+ {
+ // * Inline assembly need to elimate the duplicated data load, compiler won't help you
+ // delete them.
+ // amd_assembly_wmma_f32_16x16x16_f16_w32(
+ // reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
+#if defined(__gfx12__)
+ reg_c.template AsType<float8_t>()(Number<0>{}) =
+ __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(
+ reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
+#else
+ ignore = reg_a;
+ ignore = reg_b;
+ ignore = reg_c;
+#endif
+ }
+};
+
+// src: bf16, dst: fp32
+template <index_t MPerWave, index_t NPerWave>
+struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12;
+
+template <>
+struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12<16, 16>
+{
+ template <class FloatC>
+ __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
+ {
+#if defined(__gfx12__)
+ reg_c.template AsType<float8_t>()(Number<0>{}) =
+ __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(
+ reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
+#else
+ ignore = reg_a;
+ ignore = reg_b;
+ ignore = reg_c;
+#endif
+ }
+};
+
+// src: iu8, dst: i32
+template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
+struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12;
+
+template <bool neg_a, bool neg_b, bool clamp>
+struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp>
+{
+ template <class FloatC>
+ __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
+ {
+#if defined(__gfx12__)
+ reg_c.template AsType<int32x8_t>()(Number<0>{}) =
+ __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
+ neg_a,
+ bit_cast<int32x2_t>(reg_a),
+ neg_b,
+ bit_cast<int32x2_t>(reg_b),
+ reg_c.template AsType<int32x8_t>()[Number<0>{}],
+ clamp);
+#else
+ ignore = reg_a;
+ ignore = reg_b;
+ ignore = reg_c;
+#endif
+ }
+};
+
} // namespace ck
#endif
diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp
index 93a1edefb..4df14c621 100644
--- a/include/ck/utility/data_type.hpp
+++ b/include/ck/utility/data_type.hpp
@@ -203,7 +203,7 @@ struct vector_type<T, 1>
}
};
-int static err = 0;
+__device__ int static err = 0;
template <typename T>
struct vector_type<T, 2>
{
diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp
index 4fe5e3950..d6b6eac26 100644
--- a/include/ck/utility/synchronization.hpp
+++ b/include/ck/utility/synchronization.hpp
@@ -10,12 +10,20 @@ namespace ck {
__device__ void block_sync_lds()
{
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
+#ifdef __gfx12__
+ asm volatile("\
+ s_wait_dscnt 0x0 \n \
+ s_barrier_signal -1 \n \
+ s_barrier_wait -1 \
+ " ::);
+#else
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \
// s_barrier \
// " ::);
__builtin_amdgcn_s_waitcnt(0xc07f);
__builtin_amdgcn_s_barrier();
+#endif
#else
__syncthreads();
#endif
@@ -23,11 +31,20 @@ __device__ void block_sync_lds()
__device__ void block_sync_lds_direct_load()
{
+#ifdef __gfx12__
+ asm volatile("\
+ s_wait_vmcnt 0x0 \n \
+ s_wait_dscnt 0x0 \n \
+ s_barrier_signal -1 \n \
+ s_barrier_wait -1 \
+ " ::);
+#else
asm volatile("\
s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
+#endif
}
__device__ void s_nop()
diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp
index 601aad19b..9dc2b072a 100644
--- a/include/ck_tile/core/config.hpp
+++ b/include/ck_tile/core/config.hpp
@@ -17,6 +17,9 @@
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
+#if defined(__gfx1200__) || defined(__gfx1201__)
+#define __gfx12__
+#endif
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
@@ -155,7 +158,7 @@
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
-#elif defined(__gfx11__) // for GPU code
+#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt
index 8c5f36d2e..89c9d6dc6 100644
--- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt
+++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt
@@ -52,7 +52,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach()
# Do not build WMMA instances if gfx11 targets are not on the target list
foreach(source IN LISTS ARGN)
- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
@@ -149,7 +149,7 @@ FOREACH(subdir_path ${dir_list})
message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.")
set(add_inst 0)
endif()
- if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11"))
+ if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12"))
message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.")
set(add_inst 0)
endif()
@@ -157,11 +157,11 @@ FOREACH(subdir_path ${dir_list})
message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.")
set(add_inst 0)
endif()
- if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9"))
+ if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9"))
message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.")
set(add_inst 0)
endif()
- if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS))
+ if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS))
message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.")
set(add_inst 0)
endif()
diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt
index 1cfcbfff6..a9557a9b9 100644
--- a/profiler/src/CMakeLists.txt
+++ b/profiler/src/CMakeLists.txt
@@ -58,7 +58,7 @@ if(GPU_TARGETS MATCHES "gfx9")
endif()
-if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9")
+if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
endif()
@@ -133,7 +133,7 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
endif()
-if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11")
+if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
endif()
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
index 25c63ac7f..2a7c52b58 100644
--- a/test/CMakeLists.txt
+++ b/test/CMakeLists.txt
@@ -53,7 +53,7 @@ function(add_test_executable TEST_NAME)
endif()
endforeach()
foreach(source IN LISTS ARGN)
- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma")
+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma")
message("removing wmma test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
@@ -118,7 +118,7 @@ function(add_gtest_executable TEST_NAME)
endif()
endforeach()
foreach(source IN LISTS ARGN)
- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma")
+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma")
message("removing wmma test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
index 1c8082645..21f49ec0f 100644
--- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
@@ -55,7 +55,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
}
}
- if(ck::is_gfx11_supported())
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
// on gfx11 only support for 3d is implemented
if constexpr(NDimSpatial{} != 3)
diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp
index 49782bce6..d9ec94771 100644
--- a/test/wmma_op/wmma_op_util.hpp
+++ b/test/wmma_op/wmma_op_util.hpp
@@ -140,10 +140,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele];
}
+#ifdef __gfx12__
+ asm volatile("\
+ s_wait_dscnt 0x0 \n \
+ s_barrier_signal -1 \n \
+ s_barrier_wait -1 \
+ " ::);
+#else
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
+#endif
for(int ele = 0; ele < 16; ++ele)
{
@@ -155,10 +163,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8];
}
+#ifdef __gfx12__
+ asm volatile("\
+ s_wait_dscnt 0x0 \n \
+ s_barrier_signal -1 \n \
+ s_barrier_wait -1 \
+ " ::);
+#else
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
+#endif
// sync threads, similar to mma_sync
// __syncthreads();