mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
### Description Add additonal gfx targets for AMD GPU support ### Motivation and Context Required to integrate mainline onnxruntime support for AMD GPUs --------- Co-authored-by: Stefan Sokolovic <stsokolo@amd.com> Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2280 lines
104 KiB
Diff
2280 lines
104 KiB
Diff
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
|
index bc326c8b5..db5ad5052 100644
|
|
--- a/CMakeLists.txt
|
|
+++ b/CMakeLists.txt
|
|
@@ -117,7 +117,7 @@ else()
|
|
add_definitions(-DPROFILER_ONLY)
|
|
set(GPU_TARGETS "" CACHE STRING "" FORCE)
|
|
if(GPU_TARGETS)
|
|
- message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11")
|
|
+ message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12")
|
|
endif()
|
|
if(GPU_ARCH MATCHES "gfx90")
|
|
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a")
|
|
@@ -127,8 +127,10 @@ else()
|
|
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030")
|
|
elseif(GPU_ARCH MATCHES "gfx11")
|
|
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102")
|
|
+ elseif(GPU_ARCH MATCHES "gfx12")
|
|
+ rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201")
|
|
else()
|
|
- message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11")
|
|
+ message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12")
|
|
endif()
|
|
set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE)
|
|
endif()
|
|
diff --git a/Jenkinsfile b/Jenkinsfile
|
|
index 75800bfc9..b72e2ca4e 100644
|
|
--- a/Jenkinsfile
|
|
+++ b/Jenkinsfile
|
|
@@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){
|
|
|
|
def variant = env.STAGE_NAME
|
|
def retimage
|
|
+
|
|
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
|
try {
|
|
(retimage, image) = getDockerImage(conf)
|
|
@@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
|
|
|
|
pipeline {
|
|
agent none
|
|
- triggers {
|
|
- parameterizedCron(CRON_SETTINGS)
|
|
- }
|
|
options {
|
|
parallelsAlwaysFailFast()
|
|
}
|
|
diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake
|
|
index 8654170b3..42070051b 100644
|
|
--- a/cmake/EnableCompilerWarnings.cmake
|
|
+++ b/cmake/EnableCompilerWarnings.cmake
|
|
@@ -66,7 +66,7 @@ else()
|
|
-Wunreachable-code
|
|
-Wunused
|
|
-Wno-reserved-identifier
|
|
- -Werror
|
|
+ -Werror
|
|
-Wno-option-ignored
|
|
-Wsign-compare
|
|
-Wno-extra-semi-stmt
|
|
diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp
|
|
index 8c52e4f7d..f8afe8d6d 100644
|
|
--- a/example/01_gemm/gemm_wmma_fp16.cpp
|
|
+++ b/example/01_gemm/gemm_wmma_fp16.cpp
|
|
@@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
|
|
|
|
// clang-format off
|
|
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
|
|
- < ALayout,
|
|
- BLayout,
|
|
- CLayout,
|
|
- ADataType,
|
|
+ < ALayout,
|
|
+ BLayout,
|
|
+ CLayout,
|
|
+ ADataType,
|
|
BDataType,
|
|
- CDataType,
|
|
- AccDataType,
|
|
- CShuffleDataType,
|
|
- AElementOp,
|
|
- BElementOp,
|
|
- CElementOp,
|
|
- GemmDefault,
|
|
+ CDataType,
|
|
+ AccDataType,
|
|
+ CShuffleDataType,
|
|
+ AElementOp,
|
|
+ BElementOp,
|
|
+ CElementOp,
|
|
+ GemmDefault,
|
|
1, // Prefetch stage
|
|
128, // BlockSize
|
|
64, // MPerBlock
|
|
128, // NPerBlock
|
|
64, // KPerBlock
|
|
- 8, // K1
|
|
+ 2, // K1
|
|
16, // MPerWmma
|
|
16, // NPerWmma
|
|
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
|
|
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
|
|
- S<4, 32, 1>,
|
|
- S<1, 0, 2>,
|
|
- S<1, 0, 2>,
|
|
- 2,
|
|
- 8,
|
|
- 8,
|
|
- true,
|
|
- S<4, 32, 1>,
|
|
- S<1, 0, 2>,
|
|
- S<1, 0, 2>,
|
|
- 2,
|
|
- 8,
|
|
- 8,
|
|
- true,
|
|
+ S<4, 32, 1>,
|
|
+ S<1, 0, 2>,
|
|
+ S<1, 0, 2>,
|
|
+ 2,
|
|
+ 2,
|
|
+ 2,
|
|
+ true,
|
|
+ S<4, 32, 1>,
|
|
+ S<1, 0, 2>,
|
|
+ S<1, 0, 2>,
|
|
+ 2,
|
|
+ 2,
|
|
+ 2,
|
|
+ true,
|
|
1, // C shuffle (M Repeat) Per store
|
|
1, // C shuffle (N Repeat) Per store
|
|
- S<1, 32, 1, 4>,
|
|
+ S<1, 32, 1, 4>,
|
|
8>;
|
|
// clang-format on
|
|
|
|
diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc
|
|
index b04e4e53a..cb15186c3 100644
|
|
--- a/example/01_gemm/run_gemm_example.inc
|
|
+++ b/example/01_gemm/run_gemm_example.inc
|
|
@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
|
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
|
|
break;
|
|
case 4:
|
|
- ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
|
|
+ ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
|
|
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
|
|
break;
|
|
case 5:
|
|
diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt
|
|
index ab19f819e..be47665a2 100644
|
|
--- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt
|
|
+++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt
|
|
@@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
|
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32)
|
|
set(target 1)
|
|
endif()
|
|
-endforeach()
|
|
\ No newline at end of file
|
|
+endforeach()
|
|
diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
|
|
index 2bbf430c4..f556be887 100644
|
|
--- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
|
|
+++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
|
|
@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN =
|
|
2,
|
|
4,
|
|
4,
|
|
- true,
|
|
+ false,
|
|
S<4, 32, 1>,
|
|
S<1, 0, 2>,
|
|
S<1, 0, 2>,
|
|
2,
|
|
4,
|
|
4,
|
|
- true,
|
|
+ false,
|
|
1,
|
|
1,
|
|
S<1, 64, 1, 2>,
|
|
diff --git a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
|
|
index 4c92c5497..fac19f8b5 100644
|
|
--- a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
|
|
+++ b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
|
|
@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
|
|
#define CK_MHA_USE_WAVE_1
|
|
#define CK_MHA_USE_WAVE_2
|
|
#define CK_MHA_USE_WAVE_4
|
|
-#define CK_MHA_USE_WAVE_8
|
|
+//#define CK_MHA_USE_WAVE_8
|
|
using DeviceMHAFactory =
|
|
std::tuple<
|
|
#ifdef CK_MHA_USE_WAVE_1
|
|
@@ -277,10 +277,10 @@ using DeviceMHAFactory =
|
|
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
|
|
// CShuffleBlockTransfer MN
|
|
1, 1, S<1, 64, 1, 2>, 8,
|
|
- MaskingSpec>,
|
|
+ MaskingSpec>
|
|
#endif
|
|
#ifdef CK_MHA_USE_WAVE_8
|
|
- ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
|
|
+ ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
|
|
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
|
|
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
|
|
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
|
diff --git a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
|
|
index 8e037272b..d463cc871 100644
|
|
--- a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
|
|
+++ b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
|
|
@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
|
|
#define CK_MHA_USE_WAVE_1
|
|
#define CK_MHA_USE_WAVE_2
|
|
#define CK_MHA_USE_WAVE_4
|
|
-#define CK_MHA_USE_WAVE_8
|
|
+//#define CK_MHA_USE_WAVE_8
|
|
using DeviceMHAFactory =
|
|
std::tuple<
|
|
#ifdef CK_MHA_USE_WAVE_1
|
|
@@ -277,10 +277,10 @@ using DeviceMHAFactory =
|
|
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
|
|
// CShuffleBlockTransfer MN
|
|
1, 1, S<1, 64, 1, 2>, 8,
|
|
- MaskingSpec>,
|
|
+ MaskingSpec>
|
|
#endif
|
|
#ifdef CK_MHA_USE_WAVE_8
|
|
- ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
|
|
+ ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
|
|
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
|
|
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
|
|
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
|
diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt
|
|
index 5465adb77..7534bff3b 100644
|
|
--- a/example/CMakeLists.txt
|
|
+++ b/example/CMakeLists.txt
|
|
@@ -60,7 +60,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
|
endforeach()
|
|
#Do not build any WMMA examples if gfx11 targets are not on the list
|
|
foreach(source IN LISTS FILE_NAME)
|
|
- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
|
+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
|
|
message("removing wmma example ${source} ")
|
|
list(REMOVE_ITEM FILE_NAME "${source}")
|
|
endif()
|
|
@@ -134,7 +134,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
|
endforeach()
|
|
#Do not build any WMMA examples if gfx11 targets are not on the list
|
|
foreach(source IN LISTS FILE_NAME)
|
|
- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
|
+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
|
|
message("removing wmma example ${source} ")
|
|
list(REMOVE_ITEM FILE_NAME "${source}")
|
|
endif()
|
|
diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp
|
|
index 55f562061..69a7abf62 100644
|
|
--- a/include/ck/ck.hpp
|
|
+++ b/include/ck/ck.hpp
|
|
@@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
|
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
|
|
#define __gfx11__
|
|
#endif
|
|
+#if defined(__gfx1200__) || defined(__gfx1201__)
|
|
+#define __gfx12__
|
|
+#endif
|
|
|
|
// buffer resource
|
|
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
|
@@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
|
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
|
#elif defined(__gfx103__)
|
|
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
|
-#elif defined(__gfx11__)
|
|
+#elif defined(__gfx11__) || defined(__gfx12__)
|
|
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
|
|
#endif
|
|
|
|
@@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
|
#define CK_USE_AMD_V_FMAC_F32
|
|
#define CK_USE_AMD_V_DOT2_F32_F16
|
|
#define CK_USE_AMD_V_DOT4_I32_I8
|
|
-#elif defined(__gfx11__)
|
|
+#elif defined(__gfx11__) || defined(__gfx12__)
|
|
#define CK_USE_AMD_V_FMAC_F32
|
|
#define CK_USE_AMD_V_DOT2_F32_F16
|
|
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
|
|
@@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
|
#define CK_USE_AMD_MFMA_GFX940
|
|
#endif
|
|
|
|
-// WMMA instruction
|
|
-#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
|
-#define CK_USE_AMD_WMMA
|
|
-#elif defined(__gfx11__) // for GPU code
|
|
-#define CK_USE_AMD_WMMA
|
|
-#endif
|
|
-
|
|
// buffer load
|
|
#define CK_USE_AMD_BUFFER_LOAD 1
|
|
|
|
diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp
|
|
index 116bb3ea0..83af2efe8 100644
|
|
--- a/include/ck/host_utility/device_prop.hpp
|
|
+++ b/include/ck/host_utility/device_prop.hpp
|
|
@@ -84,4 +84,9 @@ inline bool is_gfx11_supported()
|
|
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
|
|
}
|
|
|
|
+inline bool is_gfx12_supported()
|
|
+{
|
|
+ return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
|
|
+}
|
|
+
|
|
} // namespace ck
|
|
diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
|
|
index f8ee283c6..7eb7d42eb 100644
|
|
--- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
|
|
@@ -13,6 +13,504 @@
|
|
|
|
namespace ck {
|
|
|
|
+#ifdef __gfx12__
|
|
+template <index_t BlockSize,
|
|
+ typename FloatA,
|
|
+ typename FloatB,
|
|
+ typename FloatAcc,
|
|
+ typename ABlockDesc,
|
|
+ typename BBlockDesc,
|
|
+ index_t MPerBlock,
|
|
+ index_t NPerBlock,
|
|
+ index_t KPerBlock,
|
|
+ index_t MPerWMMA,
|
|
+ index_t NPerWMMA,
|
|
+ index_t MRepeat,
|
|
+ index_t NRepeat,
|
|
+ index_t KPack,
|
|
+ bool AEnableLds = true,
|
|
+ bool BEnableLds = true,
|
|
+ bool TransposeC = false>
|
|
+/* Option: Read from LDS, big buffer hold all threads required data
|
|
+ * Source
|
|
+ * A: K0PerBlock x MPerBlock x K1
|
|
+ * B: K0PerBlock x NPerBlock x K1
|
|
+ * Destination
|
|
+ * C, non-transpose
|
|
+ * thread level: MRepeat x NRepeat x MAccVgprs
|
|
+ * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
|
|
+ * KPACK == WMMA_K = 16
|
|
+ *
|
|
+ * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
|
|
+ * Source:
|
|
+ * A(if skip LDS): MRepeat x KPack
|
|
+ * B(if skip LDS): NRepeat x KPack
|
|
+ * Destination
|
|
+ * C, non-transpose
|
|
+ * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
|
|
+ */
|
|
+struct BlockwiseGemmWMMA
|
|
+{
|
|
+ static constexpr auto I0 = Number<0>{};
|
|
+ static constexpr auto I1 = Number<1>{};
|
|
+ static constexpr auto I2 = Number<2>{};
|
|
+ static constexpr auto I3 = Number<3>{};
|
|
+ static constexpr auto I4 = Number<4>{};
|
|
+ static constexpr auto I5 = Number<5>{};
|
|
+ static constexpr auto WmmaK = Number<16>{};
|
|
+
|
|
+ using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
|
+
|
|
+ // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
|
|
+ static constexpr index_t WaveSize = 32;
|
|
+
|
|
+ // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
|
|
+ // When not use LDS, each Row read half of whole data from source buffer, exchange the data via
|
|
+ // permutation
|
|
+ static constexpr index_t A_KRow = 2;
|
|
+ static constexpr index_t B_KRow = 2;
|
|
+
|
|
+ static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
|
|
+ static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
|
|
+
|
|
+ static constexpr auto wmma_gemm =
|
|
+ WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
|
|
+
|
|
+ static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
|
|
+ static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
|
|
+
|
|
+ StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
|
+ FloatAcc,
|
|
+ MRepeat * NRepeat,
|
|
+ wmma_gemm.GetRegSizePerWmma(),
|
|
+ true>
|
|
+ c_thread_buf_;
|
|
+
|
|
+ __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
|
+
|
|
+ __device__ static auto GetWaveIdx()
|
|
+ {
|
|
+ const index_t thread_id = ThisThreadBlock::GetThreadId();
|
|
+
|
|
+ constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
|
+ make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
|
+ make_tuple(Sequence<0, 1, 2>{}),
|
|
+ make_tuple(Sequence<0>{}));
|
|
+
|
|
+ return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
|
+ }
|
|
+
|
|
+ // Default, Block buffer in LDS, thread level offset enabled
|
|
+ __device__ static auto CalculateAThreadOriginDataIndex()
|
|
+ {
|
|
+ if constexpr(AEnableLds)
|
|
+ {
|
|
+ const auto wave_idx = GetWaveIdx();
|
|
+ const auto waveId_m = wave_idx[I0];
|
|
+ const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
|
|
+
|
|
+ // |KRepeat |MRepeat|MWave |KRow |MLane |KPack
|
|
+ return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0);
|
|
+ }
|
|
+ else
|
|
+ {
|
|
+ return make_tuple(0, 0, 0, 0, 0, 0);
|
|
+ }
|
|
+ }
|
|
+
|
|
+ __device__ static auto CalculateBThreadOriginDataIndex()
|
|
+ {
|
|
+ if constexpr(BEnableLds)
|
|
+ {
|
|
+ const auto wave_idx = GetWaveIdx();
|
|
+ const auto waveId_n = wave_idx[I1];
|
|
+ const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
|
|
+
|
|
+ // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
|
|
+ return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0);
|
|
+ }
|
|
+ else
|
|
+ {
|
|
+ return make_tuple(0, 0, 0, 0, 0, 0);
|
|
+ }
|
|
+ }
|
|
+
|
|
+ template <index_t m0, index_t n0>
|
|
+ __device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
|
|
+ {
|
|
+ const auto wave_idx = GetWaveIdx();
|
|
+
|
|
+ const auto waveId_m = wave_idx[I0];
|
|
+ const auto waveId_n = wave_idx[I1];
|
|
+
|
|
+ const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
|
|
+
|
|
+ constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
|
|
+ make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
|
|
+ make_tuple(Sequence<0>{}),
|
|
+ make_tuple(Sequence<0, 1, 2>{}));
|
|
+
|
|
+ constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
|
|
+ make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
|
|
+ make_tuple(Sequence<0>{}),
|
|
+ make_tuple(Sequence<0, 1, 2>{}));
|
|
+
|
|
+ const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
|
|
+ make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
|
|
+ const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
|
|
+ make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
|
|
+
|
|
+ return make_tuple(c_thread_m, c_thread_n);
|
|
+ }
|
|
+
|
|
+ template <index_t m0, index_t n0>
|
|
+ __device__ static auto CalculateCThreadOriginDataIndex7D(Number<m0>, Number<n0>)
|
|
+ {
|
|
+ const auto wave_idx = GetWaveIdx();
|
|
+
|
|
+ const auto waveId_m = wave_idx[I0];
|
|
+ const auto waveId_n = wave_idx[I1];
|
|
+
|
|
+ const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
|
|
+
|
|
+ return make_tuple(
|
|
+ Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
|
|
+ }
|
|
+
|
|
+ using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
|
|
+ __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
|
|
+ Tuple6 b_origin = CalculateBThreadOriginDataIndex())
|
|
+ : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
|
|
+ {
|
|
+ static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
|
|
+ "wrong! Desc should be known at compile-time");
|
|
+
|
|
+ static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
|
+ "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
|
|
+
|
|
+ static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
|
|
+ NPerBlock % (NPerWMMA * NRepeat) == 0,
|
|
+ "wrong!");
|
|
+ }
|
|
+
|
|
+ // transposed WMMA output C' = B' * A'
|
|
+ __host__ __device__ static constexpr auto
|
|
+ GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
|
|
+ {
|
|
+ constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
|
|
+ wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
|
|
+
|
|
+ constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
|
|
+
|
|
+ return make_naive_tensor_descriptor_packed(
|
|
+ // |MRepeat |MWave |MSubGroup |NRepeat |NWave
|
|
+ // |NThreadPerSubGroup |MAccVgprs
|
|
+ make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
|
|
+ }
|
|
+
|
|
+ // Thread level, register decriptor. Vector-write
|
|
+ __host__ __device__ static constexpr auto
|
|
+ GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
|
+ {
|
|
+ constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
|
|
+ wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
|
|
+
|
|
+ constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
|
|
+ constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
|
|
+ return make_naive_tensor_descriptor(
|
|
+ // |MRepeat |MWave |MSubGroup |NRepeat |NWave
|
|
+ // |NThreadPerSubGroup |MAccVgprs
|
|
+ make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
|
|
+ make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
|
|
+ Number<NRepeat>{} * MAccVgprs * AccStride,
|
|
+ Number<NRepeat>{} * MAccVgprs * AccStride,
|
|
+ MAccVgprs * AccStride,
|
|
+ MAccVgprs * AccStride,
|
|
+ MAccVgprs * AccStride,
|
|
+ AccStride));
|
|
+ }
|
|
+
|
|
+ template <typename CGridDesc_M_N>
|
|
+ __host__ __device__ static constexpr auto
|
|
+ MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
|
|
+ const CGridDesc_M_N& c_grid_desc_m_n)
|
|
+ {
|
|
+ const auto M = c_grid_desc_m_n.GetLength(I0);
|
|
+ const auto N = c_grid_desc_m_n.GetLength(I1);
|
|
+
|
|
+ const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
|
|
+ transform_tensor_descriptor(
|
|
+ c_grid_desc_m_n,
|
|
+ make_tuple(
|
|
+ make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
|
|
+ make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
|
|
+ make_tuple(Sequence<0>{}, Sequence<1>{}),
|
|
+ make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
|
+
|
|
+ return wmma_gemm
|
|
+ .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
|
|
+ c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
|
|
+ }
|
|
+
|
|
+ // transposed WMMA output C' = B' * A'
|
|
+ __host__ __device__ static constexpr auto
|
|
+ GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
|
|
+ {
|
|
+ constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
|
|
+ make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
|
+ Number<MWaves>{},
|
|
+ Number<MPerWMMA>{},
|
|
+ Number<NRepeat>{},
|
|
+ Number<NWaves>{},
|
|
+ Number<NPerWMMA>{}));
|
|
+
|
|
+ return wmma_gemm
|
|
+ .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
|
|
+ c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
|
|
+ }
|
|
+
|
|
+ // Provide dimension size
|
|
+ __host__ __device__ static constexpr auto
|
|
+ GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
|
+ {
|
|
+ constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
|
|
+ make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
|
+ Number<MWaves>{},
|
|
+ Number<MPerWMMA>{},
|
|
+ Number<NRepeat>{},
|
|
+ Number<NWaves>{},
|
|
+ Number<NPerWMMA>{}));
|
|
+
|
|
+ return wmma_gemm
|
|
+ .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
|
|
+ c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
|
|
+ }
|
|
+
|
|
+ // Describe how data allocated in thread copy src buffer
|
|
+ // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
|
|
+ static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
|
|
+ static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
|
|
+
|
|
+ template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
|
+ __device__ void Run(const ABlockBuffer& a_block_buf,
|
|
+ const BBlockBuffer& b_block_buf,
|
|
+ CThreadBuffer& c_thread_buf) const
|
|
+ {
|
|
+ auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
|
|
+ a_thread_desc_.GetElementSpaceSize());
|
|
+ auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
|
+ b_thread_desc_.GetElementSpaceSize());
|
|
+
|
|
+ static_assert(KPack % (A_K1 * A_KRow) == 0, "");
|
|
+ static_assert(KPack % (B_K1 * B_KRow) == 0, "");
|
|
+
|
|
+ // basic intrinsic to determine loopover direction
|
|
+ if constexpr(MRepeat < NRepeat)
|
|
+ {
|
|
+ static_for<0, KPerBlock / KPack, 1>{}(
|
|
+ [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
|
|
+ static_for<0, MRepeat, 1>{}([&](auto m0) {
|
|
+ // read A
|
|
+ a_thread_copy_.Run(
|
|
+ a_block_desc_k0_m0_m1_m2_k1,
|
|
+ make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
|
|
+ a_block_buf,
|
|
+ a_thread_desc_,
|
|
+ make_tuple(I0, m0, I0, I0, I0, I0),
|
|
+ a_thread_buf);
|
|
+
|
|
+ static_for<0, NRepeat, 1>{}([&](auto n0) {
|
|
+ // read B
|
|
+ b_thread_copy_.Run(
|
|
+ b_block_desc_k0_n0_n1_n2_k1,
|
|
+ make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
|
|
+ b_block_buf,
|
|
+ b_thread_desc_,
|
|
+ make_tuple(I0, n0, I0, I0, I0, I0),
|
|
+ b_thread_buf);
|
|
+
|
|
+ vector_type<FloatA, KPack / A_KRow> a_thread_vec;
|
|
+ vector_type<FloatB, KPack / B_KRow> b_thread_vec;
|
|
+
|
|
+ static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
|
|
+ a_thread_vec.template AsType<FloatA>()(i) =
|
|
+ a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
|
+ make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
|
|
+ });
|
|
+
|
|
+ static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
|
|
+ b_thread_vec.template AsType<FloatB>()(i) =
|
|
+ b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
|
+ make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
|
|
+ });
|
|
+
|
|
+ using wmma_input_type_a =
|
|
+ typename vector_type<FloatA, WmmaK / A_KRow>::type;
|
|
+ using wmma_input_type_b =
|
|
+ typename vector_type<FloatB, WmmaK / B_KRow>::type;
|
|
+
|
|
+ constexpr index_t c_offset =
|
|
+ c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
|
+
|
|
+ wmma_gemm.template Run(
|
|
+ a_thread_vec.template AsType<wmma_input_type_a>(),
|
|
+ b_thread_vec.template AsType<wmma_input_type_b>(),
|
|
+ c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
|
+ });
|
|
+ });
|
|
+ });
|
|
+ }
|
|
+ else
|
|
+ {
|
|
+ static_for<0, NRepeat, 1>{}([&](auto n0) {
|
|
+ static_for<0, MRepeat, 1>{}([&](auto m0) {
|
|
+ static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
|
|
+ // k=0,kpack*1, ..
|
|
+ // read B
|
|
+ b_thread_copy_.Run(
|
|
+ b_block_desc_k0_n0_n1_n2_k1,
|
|
+ make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
|
|
+ b_block_buf,
|
|
+ b_thread_desc_,
|
|
+ make_tuple(I0, n0, I0, I0, I0, I0),
|
|
+ b_thread_buf);
|
|
+ // read A
|
|
+ a_thread_copy_.Run(
|
|
+ a_block_desc_k0_m0_m1_m2_k1,
|
|
+ make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
|
|
+ a_block_buf,
|
|
+ a_thread_desc_,
|
|
+ make_tuple(I0, m0, I0, I0, I0, I0),
|
|
+ a_thread_buf);
|
|
+
|
|
+ vector_type<FloatA, KPack / A_KRow> a_thread_vec;
|
|
+ vector_type<FloatB, KPack / B_KRow> b_thread_vec;
|
|
+
|
|
+ static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
|
|
+ a_thread_vec.template AsType<FloatA>()(i) =
|
|
+ a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
|
+ make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
|
|
+ });
|
|
+
|
|
+ static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
|
|
+ b_thread_vec.template AsType<FloatB>()(i) =
|
|
+ b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
|
+ make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
|
|
+ });
|
|
+
|
|
+ using wmma_input_type_a =
|
|
+ typename vector_type<FloatA, WmmaK / A_KRow>::type;
|
|
+ using wmma_input_type_b =
|
|
+ typename vector_type<FloatB, WmmaK / B_KRow>::type;
|
|
+
|
|
+ constexpr index_t c_offset =
|
|
+ c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
|
+
|
|
+ wmma_gemm.template Run(
|
|
+ a_thread_vec.template AsType<wmma_input_type_a>(),
|
|
+ b_thread_vec.template AsType<wmma_input_type_b>(),
|
|
+ c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
|
+ });
|
|
+ });
|
|
+ });
|
|
+ }
|
|
+ }
|
|
+
|
|
+ protected:
|
|
+ static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
|
|
+ make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
|
|
+ make_tuple(Number<A_K1>{},
|
|
+ Number<KPack / A_KRow>{},
|
|
+ Number<A_K1>{},
|
|
+ Number<A_K1>{},
|
|
+ Number<A_K1>{},
|
|
+ Number<1>{}));
|
|
+
|
|
+ static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
|
|
+ make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
|
|
+ make_tuple(Number<B_K1>{},
|
|
+ Number<KPack / B_KRow>{},
|
|
+ Number<B_K1>{},
|
|
+ Number<B_K1>{},
|
|
+ Number<B_K1>{},
|
|
+ Number<1>{}));
|
|
+
|
|
+ // C[M, N, NumRegWMMA]
|
|
+ static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
|
+ make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
|
|
+
|
|
+ template <bool EnableLds>
|
|
+ struct AThreadCopySelector;
|
|
+
|
|
+ template <>
|
|
+ struct AThreadCopySelector<true>
|
|
+ {
|
|
+ using type =
|
|
+ ThreadwiseTensorSliceTransfer_v4<FloatA,
|
|
+ FloatA,
|
|
+ decltype(a_block_desc_k0_m0_m1_m2_k1),
|
|
+ decltype(a_thread_desc_),
|
|
+ Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
|
|
+ Sequence<0, 1, 2, 3, 4, 5>,
|
|
+ 5,
|
|
+ A_K1,
|
|
+ A_K1>;
|
|
+ };
|
|
+
|
|
+ template <>
|
|
+ struct AThreadCopySelector<false>
|
|
+ {
|
|
+ using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
|
|
+ FloatA,
|
|
+ FloatA,
|
|
+ decltype(a_block_desc_k0_m0_m1_m2_k1),
|
|
+ decltype(a_thread_desc_),
|
|
+ tensor_operation::element_wise::PassThrough,
|
|
+ Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
|
|
+ Sequence<0, 1, 2, 3, 4, 5>,
|
|
+ 5,
|
|
+ A_K1,
|
|
+ false>;
|
|
+ };
|
|
+
|
|
+ template <bool EnableLds>
|
|
+ struct BThreadCopySelector;
|
|
+
|
|
+ template <>
|
|
+ struct BThreadCopySelector<true>
|
|
+ {
|
|
+ using type =
|
|
+ ThreadwiseTensorSliceTransfer_v4<FloatB,
|
|
+ FloatB,
|
|
+ decltype(b_block_desc_k0_n0_n1_n2_k1),
|
|
+ decltype(b_thread_desc_),
|
|
+ Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
|
|
+ Sequence<0, 1, 2, 3, 4, 5>,
|
|
+ 5,
|
|
+ B_K1,
|
|
+ B_K1>;
|
|
+ };
|
|
+
|
|
+ template <>
|
|
+ struct BThreadCopySelector<false>
|
|
+ {
|
|
+ using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
|
|
+ FloatB,
|
|
+ FloatB,
|
|
+ decltype(b_block_desc_k0_n0_n1_n2_k1),
|
|
+ decltype(b_thread_desc_),
|
|
+ tensor_operation::element_wise::PassThrough,
|
|
+ Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
|
|
+ Sequence<0, 1, 2, 3, 4, 5>,
|
|
+ 5,
|
|
+ B_K1,
|
|
+ false>;
|
|
+ };
|
|
+
|
|
+ typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
|
|
+ typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
|
|
+};
|
|
+#else
|
|
template <index_t BlockSize,
|
|
typename FloatA,
|
|
typename FloatB,
|
|
@@ -529,5 +1027,6 @@ struct BlockwiseGemmWMMA
|
|
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
|
|
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
|
|
};
|
|
+#endif
|
|
|
|
} // namespace ck
|
|
diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
|
|
index e5e6245cb..1f7d50429 100644
|
|
--- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
|
|
@@ -488,7 +488,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
|
// sync point.
|
|
if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
|
|
{
|
|
+#ifdef __gfx12__
|
|
+ asm volatile("\
|
|
+ s_barrier_signal -1 \n \
|
|
+ s_barrier_wait -1 \
|
|
+ " ::);
|
|
+#else
|
|
asm volatile("s_barrier" ::);
|
|
+#endif
|
|
__builtin_amdgcn_sched_barrier(0);
|
|
}
|
|
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
|
|
index a15759559..ab3f3856a 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
|
|
@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
|
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
|
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
|
|
|
- static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
|
|
- static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
|
|
+ static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
|
|
+ static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
|
|
+
|
|
+ static constexpr auto AEnableLds_auto =
|
|
+ (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true;
|
|
+ static constexpr auto BEnableLds_auto =
|
|
+ (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true;
|
|
|
|
// If true, LDS is used unconditionally
|
|
static constexpr auto AEnableLds_manu = false;
|
|
@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
|
|
|
static bool IsSupportedArgument(const Argument& arg)
|
|
{
|
|
- if(ck::is_gfx11_supported())
|
|
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
|
{
|
|
@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
|
}
|
|
else
|
|
{
|
|
- if(!(arg.a_kz_stride_ == 1 &&
|
|
- arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
|
|
+ if(!(arg.a_kz_stride_ == 1))
|
|
{
|
|
- printf("DeviceOp: Vector Access A-k check failure\n");
|
|
- return false;
|
|
+ index_t LastK =
|
|
+ AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6);
|
|
+ if(LastK % ABlockTransferSrcScalarPerVector == 0)
|
|
+ {
|
|
+ printf("DeviceOp: Vector Access A-k check failure\n");
|
|
+ return false;
|
|
+ }
|
|
}
|
|
}
|
|
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
|
|
index 8fd14afc0..1b487502f 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
|
|
@@ -70,8 +70,9 @@ __global__ void
|
|
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
|
const Block2CTileMap block_2_ctile_map)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
|
- defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
|
+ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
|
|
+ defined(__gfx12__))
|
|
|
|
const index_t num_blocks_per_batch =
|
|
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
|
@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
|
|
static bool IsSupportedArgument(const Argument& arg)
|
|
{
|
|
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
|
- ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
|
+ ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
bool pass = true;
|
|
pass = pass && arg.K_ % K1 == 0;
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
|
|
index f6b701ab1..102611838 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
|
|
@@ -56,7 +56,7 @@ __global__ void
|
|
bool input_permute,
|
|
bool output_permute)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
|
|
|
// clang-format off
|
|
// ***************************************************
|
|
@@ -159,6 +159,7 @@ __global__ void
|
|
ignore = O;
|
|
ignore = G0;
|
|
ignore = G1;
|
|
+ ignore = alpha;
|
|
ignore = input_permute;
|
|
ignore = output_permute;
|
|
#endif // end of if (defined(__gfx11__))
|
|
@@ -187,7 +188,7 @@ __global__ void
|
|
index_t head_size,
|
|
float alpha)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
|
|
|
// clang-format off
|
|
// ***************************************************
|
|
@@ -321,7 +322,7 @@ __global__ void
|
|
index_t head_size,
|
|
float alpha)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
|
|
|
// clang-format off
|
|
// ***************************************************
|
|
@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
|
|
|
|
static bool IsSupportedArgument(const RawArg& arg)
|
|
{
|
|
- if(ck::is_gfx11_supported())
|
|
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
|
{
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
|
|
index 9d5b74be6..017d28641 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
|
|
@@ -601,9 +601,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
|
return false;
|
|
}
|
|
|
|
- if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" &&
|
|
- ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" &&
|
|
- std::is_same<ADataType, double>::value)
|
|
+ if(!ck::is_lds_direct_load_supported() && std::is_same<ADataType, double>::value)
|
|
{
|
|
return false;
|
|
}
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
|
|
index b84e18130..1edae33be 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
|
|
@@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
|
|
{
|
|
// check device
|
|
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
|
- ck::is_gfx11_supported()))
|
|
+ ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
|
{
|
|
return false;
|
|
}
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
|
|
index bf96324d0..553143e28 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
|
|
@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
|
|
|
|
static bool IsSupportedArgument(const Argument& arg)
|
|
{
|
|
- if(ck::is_gfx11_supported())
|
|
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
|
is_same_v<AccDataType, int32_t>))
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
|
|
index b1784b385..eb0fb55f5 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
|
|
@@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
|
}
|
|
|
|
if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
|
- ck::is_gfx11_supported())
|
|
+ ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
return GridwiseGemm::CheckValidity(
|
|
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
|
|
index 23858096d..811f1ae93 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
|
|
@@ -50,8 +50,9 @@ __global__ void
|
|
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
|
|
const Block2CTileMap block_2_ctile_map)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
|
- defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
|
+ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
|
|
+ defined(__gfx12__))
|
|
|
|
constexpr index_t shared_block_size =
|
|
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
|
|
@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
|
|
static bool IsSupportedArgument(const Argument& arg)
|
|
{
|
|
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
|
- ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
|
+ ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
return GridwiseGemm::CheckValidity(
|
|
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
|
|
index a1ef37cc8..35f1c77f8 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
|
|
@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
|
|
|
static bool IsSupportedArgument(const Argument& arg)
|
|
{
|
|
- if(ck::is_gfx11_supported())
|
|
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
|
{
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
|
|
index 93ab8a7e1..a7cc546f5 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
|
|
@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
|
// K1 = Max Vector Access Pixels
|
|
static constexpr auto K1Number = Number<K1>{};
|
|
|
|
- static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
|
- static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
|
- static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
|
-
|
|
- static constexpr auto AEnableLds_auto =
|
|
- (NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
|
|
+ static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
|
+ static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
|
+ static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
|
+ static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
|
|
+ static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
|
|
+
|
|
+ static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) &&
|
|
+ is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
|
+ ? false
|
|
+ : true;
|
|
static constexpr auto BEnableLds_auto =
|
|
- (MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
|
|
+ (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) &&
|
|
+ is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
|
+ ? false
|
|
+ : true;
|
|
|
|
// If true, LDS is used unconditionally
|
|
static constexpr auto AEnableLds_manu = false;
|
|
@@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
|
|
|
static bool IsSupportedArgument(const Argument& arg)
|
|
{
|
|
- if(ck::is_gfx11_supported())
|
|
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
|
is_same_v<AccDataType, int32_t>))
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
|
|
index 6f74838fb..6bb5d431c 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
|
|
@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
|
static bool IsSupportedArgument(const Argument& arg)
|
|
{
|
|
// check device
|
|
- if(ck::is_gfx11_supported())
|
|
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
|
{
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
|
|
index bd264a3c8..7047e1bda 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
|
|
@@ -48,8 +48,9 @@ __global__ void
|
|
const Block2CTileMap block_2_ctile_map,
|
|
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
|
- defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
|
+ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
|
|
+ defined(__gfx12__))
|
|
const index_t num_blocks_per_batch =
|
|
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
|
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
|
|
index 211185dfb..5738be0fb 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
|
|
@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
|
static bool IsSupportedArgument(const Argument& arg)
|
|
{
|
|
// check device
|
|
- if(ck::is_gfx11_supported())
|
|
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
|
{
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
|
|
index 7cfbd8a8f..5d5a9de7d 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
|
|
@@ -90,8 +90,9 @@ __global__ void
|
|
const Block2CTileMap block_2_ctile_map,
|
|
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
|
- defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
|
+ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
|
|
+ defined(__gfx12__))
|
|
// offset base pointer for each work-group
|
|
const index_t num_blocks_per_batch =
|
|
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
|
@@ -666,7 +667,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
|
|
|
// check device
|
|
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
|
- ck::is_gfx103_supported() || ck::is_gfx11_supported()))
|
|
+ ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
|
{
|
|
return false;
|
|
}
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
|
|
index 6a4d97d7d..c65370b51 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
|
|
@@ -107,7 +107,7 @@ __global__ void
|
|
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
|
{
|
|
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
|
- defined(__gfx11__))
|
|
+ defined(__gfx11__) || defined(__gfx12__))
|
|
// offset base pointer for each work-group
|
|
const index_t num_blocks_per_batch =
|
|
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
|
@@ -602,7 +602,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
|
|
|
// check device
|
|
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
|
- ck::is_gfx11_supported()))
|
|
+ ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
|
{
|
|
return false;
|
|
}
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
|
|
index 24bd0f242..cfb64e0ee 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
|
|
@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
|
namespace ctc = tensor_layout::convolution;
|
|
|
|
// check device
|
|
- if(ck::is_gfx11_supported())
|
|
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
|
{
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
|
|
index ac392cddc..060a16d1e 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
|
|
@@ -39,8 +39,9 @@ __global__ void
|
|
const BElementwiseOperation b_element_op,
|
|
const CDEElementwiseOperation cde_element_op)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
|
- defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
|
+ defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \
|
|
+ defined(__gfx12__))
|
|
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
|
|
|
const index_t block_id = get_block_1d_id();
|
|
@@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
|
}
|
|
|
|
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
|
- ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
|
+ ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
|
{
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
|
|
index 71f7ac04c..67a100a11 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
|
|
@@ -61,7 +61,7 @@ __global__ void
|
|
bool input_permute,
|
|
bool output_permute)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
|
|
|
// clang-format off
|
|
// ***************************************************
|
|
@@ -166,6 +166,7 @@ __global__ void
|
|
ignore = O;
|
|
ignore = G0;
|
|
ignore = G1;
|
|
+ ignore = alpha;
|
|
ignore = input_permute;
|
|
ignore = output_permute;
|
|
#endif // end of if (defined(__gfx11__))
|
|
@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
|
|
|
|
static bool IsSupportedArgument(const RawArg& arg)
|
|
{
|
|
- if(ck::is_gfx11_supported())
|
|
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
|
{
|
|
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
|
|
index 4e14ed3a5..cc88c1a10 100644
|
|
--- a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
|
|
@@ -60,7 +60,7 @@ __global__ void
|
|
bool input_permute,
|
|
bool output_permute)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
|
|
|
// clang-format off
|
|
// ***************************************************
|
|
@@ -165,6 +165,7 @@ __global__ void
|
|
ignore = O;
|
|
ignore = G0;
|
|
ignore = G1;
|
|
+ ignore = alpha;
|
|
ignore = input_permute;
|
|
ignore = output_permute;
|
|
#endif // end of if (defined(__gfx11__))
|
|
@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
|
|
|
|
static bool IsSupportedArgument(const RawArg& arg)
|
|
{
|
|
- if(ck::is_gfx11_supported())
|
|
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
|
{
|
|
diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
|
|
index 16717ff81..1754e07e6 100644
|
|
--- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
|
|
@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
|
|
if constexpr(B0EnableLds)
|
|
{
|
|
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
|
|
- constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
|
|
- constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
|
|
+ constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
|
|
+ constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
|
|
+#ifdef __gfx12__
|
|
+ constexpr auto B_KRow = I2;
|
|
+#else
|
|
constexpr auto B_KRow = I1;
|
|
+#endif
|
|
return transform_tensor_descriptor(
|
|
B0BlockDesc_{},
|
|
- make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
|
+ make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
|
|
make_unmerge_transform(make_tuple(
|
|
Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
|
|
make_pass_through_transform(Number<B_K1>{})),
|
|
@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
|
|
if constexpr(B1EnableLds)
|
|
{
|
|
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
|
|
- constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
|
|
- constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
|
|
+ constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
|
|
+ constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
|
|
+#ifdef __gfx12__
|
|
+ constexpr auto B_LRow = I2;
|
|
+#else
|
|
constexpr auto B_LRow = I1;
|
|
+#endif
|
|
return transform_tensor_descriptor(
|
|
B1BlockDesc_{},
|
|
- make_tuple(make_unmerge_transform(make_tuple(Number<B_L0>{}, B_LRow)),
|
|
+ make_tuple(make_unmerge_transform(make_tuple(Number<B_L0 / B_LRow>{}, B_LRow)),
|
|
make_unmerge_transform(make_tuple(
|
|
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
|
make_pass_through_transform(Number<B_L1>{})),
|
|
diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
|
|
index 499eb7eb0..21dac6f9e 100644
|
|
--- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
|
|
@@ -50,7 +50,7 @@ __global__ void
|
|
const CElementwiseOperation c_element_op,
|
|
const Block2CTileMap block_2_ctile_map)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
|
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
|
|
|
|
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
|
@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma
|
|
if constexpr(AEnableLds)
|
|
{
|
|
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
|
|
- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
|
- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
|
+ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
|
+ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
|
+#ifdef __gfx12__
|
|
+ constexpr auto A_KRow = I2;
|
|
+#else
|
|
constexpr auto A_KRow = I1;
|
|
+#endif
|
|
return transform_tensor_descriptor(
|
|
ABlockDesc_{},
|
|
- make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
|
|
+ make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
|
|
make_unmerge_transform(make_tuple(
|
|
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
|
|
make_pass_through_transform(Number<A_K1>{})),
|
|
@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma
|
|
if constexpr(BEnableLds)
|
|
{
|
|
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
|
|
- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
|
- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
|
+ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
|
+ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
|
+#ifdef __gfx12__
|
|
+ constexpr auto B_KRow = I2;
|
|
+#else
|
|
constexpr auto B_KRow = I1;
|
|
+#endif
|
|
return transform_tensor_descriptor(
|
|
BBlockDesc_{},
|
|
- make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
|
+ make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
|
|
make_unmerge_transform(make_tuple(
|
|
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
|
make_pass_through_transform(Number<B_K1>{})),
|
|
diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
|
|
index 82d010a99..fdda649ef 100644
|
|
--- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
|
|
@@ -54,7 +54,7 @@ __global__ void
|
|
const Block2CTileMap block_2_ctile_map,
|
|
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
|
// offset base pointer for each work-group
|
|
const index_t num_blocks_per_batch =
|
|
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
|
@@ -147,7 +147,7 @@ __global__ void
|
|
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
|
const Block2CTileMap block_2_etile_map)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
|
// printf("entry kernel launch");
|
|
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
|
|
|
|
@@ -237,7 +237,7 @@ __global__ void
|
|
const CDEElementwiseOperation cde_element_op,
|
|
const Block2CTileMap block_2_ctile_map)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
|
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
|
|
|
|
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
|
|
@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma
|
|
}
|
|
else
|
|
{
|
|
+ constexpr auto A_KRow = I2;
|
|
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
|
- constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
|
+ constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
|
|
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
|
return make_naive_tensor_descriptor(
|
|
make_tuple(Number<KWmmaPerblock>{},
|
|
@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma
|
|
}
|
|
else
|
|
{
|
|
+ constexpr auto B_KRow = I2;
|
|
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
|
- constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
|
+ constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
|
|
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
|
return make_naive_tensor_descriptor(
|
|
make_tuple(Number<KWmmaPerblock>{},
|
|
@@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma
|
|
if constexpr(AEnableLds)
|
|
{
|
|
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
|
|
- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
|
- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
|
+ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
|
+ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
|
+#ifdef __gfx12__
|
|
+ constexpr auto A_KRow = I2;
|
|
+#else
|
|
constexpr auto A_KRow = I1;
|
|
+#endif
|
|
return transform_tensor_descriptor(
|
|
ABlockDesc_{},
|
|
- make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
|
|
+ make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
|
|
make_unmerge_transform(make_tuple(
|
|
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
|
|
make_pass_through_transform(Number<A_K1>{})),
|
|
@@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma
|
|
if constexpr(BEnableLds)
|
|
{
|
|
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
|
|
- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
|
- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
|
+ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
|
+ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
|
+#ifdef __gfx12__
|
|
+ constexpr auto B_KRow = I2;
|
|
+#else
|
|
constexpr auto B_KRow = I1;
|
|
+#endif
|
|
return transform_tensor_descriptor(
|
|
BBlockDesc_{},
|
|
- make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
|
+ make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
|
|
make_unmerge_transform(make_tuple(
|
|
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
|
make_pass_through_transform(Number<B_K1>{})),
|
|
@@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma
|
|
// *Caution Here repeat is shuffle repeat
|
|
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
|
{
|
|
- constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
|
|
- constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
|
|
-
|
|
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
|
make_naive_tensor_descriptor_packed(
|
|
make_tuple(I1,
|
|
- Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
|
|
+ Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
|
|
I1,
|
|
- Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
|
|
+ Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
|
|
|
|
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
|
}
|
|
@@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma
|
|
const auto M = e_grid_desc_m_n.GetLength(I0);
|
|
const auto N = e_grid_desc_m_n.GetLength(I1);
|
|
|
|
- const auto MBlock = M / MPerBlock;
|
|
- const auto NBlock = N / NPerBlock;
|
|
+ const auto MBlock = M / MPerBlock;
|
|
+ const auto NBlock = N / NPerBlock;
|
|
+
|
|
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
|
e_grid_desc_m_n,
|
|
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
|
diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
|
|
index 8e4117593..4458b9356 100644
|
|
--- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
|
|
@@ -45,7 +45,7 @@ __global__ void
|
|
const CElementwiseOperation c_element_op,
|
|
const Block2CTileMap block_2_ctile_map)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
|
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
|
|
|
|
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
|
@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma
|
|
}
|
|
else
|
|
{
|
|
+ constexpr auto A_KRow = I2;
|
|
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
|
- constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
|
+ constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
|
|
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
|
return make_naive_tensor_descriptor(
|
|
make_tuple(Number<KWmmaPerblock>{},
|
|
@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma
|
|
}
|
|
else
|
|
{
|
|
+
|
|
+ constexpr auto B_KRow = I2;
|
|
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
|
- constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
|
+ constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
|
|
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
|
return make_naive_tensor_descriptor(
|
|
make_tuple(Number<KWmmaPerblock>{},
|
|
@@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma
|
|
if constexpr(AEnableLds)
|
|
{
|
|
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
|
|
- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
|
- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
|
+ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
|
+ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
|
+#ifdef __gfx12__
|
|
+ constexpr auto A_KRow = I2;
|
|
+#else
|
|
constexpr auto A_KRow = I1;
|
|
+#endif
|
|
+
|
|
return transform_tensor_descriptor(
|
|
ABlockDesc_{},
|
|
- make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
|
|
+ make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
|
|
make_unmerge_transform(make_tuple(
|
|
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
|
|
make_pass_through_transform(Number<A_K1>{})),
|
|
@@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma
|
|
if constexpr(BEnableLds)
|
|
{
|
|
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
|
|
- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
|
- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
|
+ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
|
+ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
|
+#ifdef __gfx12__
|
|
+ constexpr auto B_KRow = I2;
|
|
+#else
|
|
constexpr auto B_KRow = I1;
|
|
+#endif
|
|
return transform_tensor_descriptor(
|
|
BBlockDesc_{},
|
|
- make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
|
+ make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
|
|
make_unmerge_transform(make_tuple(
|
|
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
|
make_pass_through_transform(Number<B_K1>{})),
|
|
@@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma
|
|
c_grid_desc_m_n);
|
|
}
|
|
|
|
- using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
|
- remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
|
- CGridDesc_M_N{}))>;
|
|
- using DefaultBlock2CTileMap =
|
|
- remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
|
-
|
|
struct SharedMemTrait
|
|
{
|
|
// LDS allocation for A and B: be careful of alignment
|
|
@@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma
|
|
b_block_space_size_aligned * sizeof(BDataType));
|
|
};
|
|
|
|
+ using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
|
+ remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
|
+ CGridDesc_M_N{}))>;
|
|
+ using DefaultBlock2CTileMap =
|
|
+ remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
|
+
|
|
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
|
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
|
|
const BDataType* __restrict__ p_b_grid,
|
|
diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
|
|
index 6772524e0..174074990 100644
|
|
--- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
|
|
@@ -35,8 +35,9 @@ __global__ void
|
|
const Block2ETileMap block_2_tile_map,
|
|
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
|
|
{
|
|
-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
|
- defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
|
|
+#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
|
+ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
|
|
+ defined(__gfx12__))
|
|
GridwiseTensorRearrangeKernel::Run(in_grid_desc,
|
|
p_in_global,
|
|
out_grid_desc,
|
|
diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
|
|
index bcce930fc..d7a6a3624 100644
|
|
--- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
|
|
@@ -1304,7 +1304,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
|
|
ElementwiseOperation element_op_;
|
|
};
|
|
|
|
-// Specilized for WMMA
|
|
+// Specilized for WMMA-Navi3
|
|
// A single Wave32 is composed by double row
|
|
// Data exchange allowed between these two rows
|
|
// This RowLane Dst buf will be filled from two Src buf
|
|
@@ -1439,4 +1439,111 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
|
|
ElementwiseOperation element_op_{};
|
|
};
|
|
|
|
+// Specilized for WMMA-Navi4
|
|
+template <typename SrcData,
|
|
+ typename DstData,
|
|
+ typename SrcDesc,
|
|
+ typename DstDesc,
|
|
+ typename ElementwiseOperation,
|
|
+ typename SliceLengths,
|
|
+ typename DimAccessOrder,
|
|
+ index_t DstVectorDim,
|
|
+ index_t DstScalarPerVector,
|
|
+ bool IntraRowSwizzlePerm,
|
|
+ typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
|
+ bool>::type = false>
|
|
+struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
|
|
+{
|
|
+ static constexpr index_t nDim = SliceLengths::Size();
|
|
+
|
|
+ using Index = MultiIndex<nDim>;
|
|
+
|
|
+ __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index& src_idx)
|
|
+ {
|
|
+ static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
|
+ "wrong! Desc need to known at compile-time");
|
|
+
|
|
+ static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
|
|
+ "wrong! Not divisible");
|
|
+ ignore = src_idx;
|
|
+ }
|
|
+
|
|
+ template <typename SrcSliceOriginIdx,
|
|
+ typename DstSliceOriginIdx,
|
|
+ typename SrcBuffer,
|
|
+ typename DstBuffer>
|
|
+ __device__ void Run(const SrcDesc&,
|
|
+ const SrcSliceOriginIdx&,
|
|
+ const SrcBuffer& src_buf,
|
|
+ const DstDesc&,
|
|
+ const DstSliceOriginIdx&,
|
|
+ DstBuffer& dst_buf) const
|
|
+ {
|
|
+ static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
|
+ "wrong! Desc need to known at compile-time");
|
|
+
|
|
+ static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value &&
|
|
+ is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
|
|
+ "wrong! SliceOrigin need to known at compile-time");
|
|
+
|
|
+ static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
|
|
+ "wrong! Buffer need to be StaticBuffer");
|
|
+
|
|
+ // SrcDesc and src_slice_origin_idx are known at compile-time
|
|
+ constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
|
+ constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
|
|
+ constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
|
+ constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
|
|
+
|
|
+ // scalar per access on each dim
|
|
+ constexpr auto dst_scalar_per_access = generate_sequence(
|
|
+ detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
|
+
|
|
+ constexpr auto dst_scalar_step_in_vector =
|
|
+ generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
|
+
|
|
+ using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
|
+ DimAccessOrder,
|
|
+ remove_cv_t<decltype(dst_scalar_per_access)>>;
|
|
+
|
|
+ static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
|
|
+ "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
|
|
+
|
|
+ constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
|
+
|
|
+ static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
|
+ constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
|
+
|
|
+ // copy data from src_buf into dst_vector
|
|
+ static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
|
+ // src_desc error, non constexpr, caused by merge transform
|
|
+ constexpr index_t src_offset = src_desc.CalculateOffset(
|
|
+ src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
|
+
|
|
+ constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
|
+ dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
|
+
|
|
+ SrcData v_this_row;
|
|
+ // int type temp value due to intrinsic requirement
|
|
+ int temp = 0;
|
|
+
|
|
+ // apply element-wise operation
|
|
+ element_op_(v_this_row, src_buf[Number<src_offset>{}]);
|
|
+
|
|
+ // apply intra-row permute.
|
|
+ if constexpr(IntraRowSwizzlePerm)
|
|
+ {
|
|
+ temp = __builtin_amdgcn_permlane16(
|
|
+ temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
|
|
+ v_this_row = type_convert_sp<SrcData>(temp);
|
|
+ }
|
|
+
|
|
+ // apply type convert
|
|
+ dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
|
|
+ });
|
|
+ });
|
|
+ }
|
|
+ ElementwiseOperation element_op_{};
|
|
+};
|
|
+
|
|
} // namespace ck
|
|
diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
|
|
index 565195f53..9a9ebf559 100644
|
|
--- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
|
|
+++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
|
|
@@ -11,12 +11,17 @@ namespace ck {
|
|
|
|
enum struct WmmaInstr
|
|
{
|
|
+ // gfx11
|
|
wmma_f32_16x16x16_f16 = 0,
|
|
wmma_f32_16x16x16_bf16,
|
|
wmma_f16_16x16x16_f16,
|
|
wmma_bf16_16x16x16_bf16,
|
|
wmma_i32_16x16x16_iu8,
|
|
- wmma_i32_16x16x16_iu4
|
|
+ wmma_i32_16x16x16_iu4,
|
|
+ // gfx12
|
|
+ wmma_f32_16x16x16_f16_gfx12,
|
|
+ wmma_f32_16x16x16_bf16_gfx12,
|
|
+ wmma_i32_16x16x16_iu8_gfx12,
|
|
};
|
|
|
|
/*
|
|
@@ -279,6 +284,122 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
|
}
|
|
};
|
|
|
|
+// gfx12
|
|
+
|
|
+// A-swizzled
|
|
+template <index_t WaveSize>
|
|
+struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
|
|
+ WaveSize,
|
|
+ typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
|
+{
|
|
+ // Absolute fixing property
|
|
+ // * Data Pixel
|
|
+ static constexpr index_t m_per_wmma = 16;
|
|
+ static constexpr index_t n_per_wmma = 16;
|
|
+ static constexpr index_t k_per_wmma = 16;
|
|
+ // static constexpr index_t src_a_data_size = 2;
|
|
+ // static constexpr index_t src_b_data_size = 2;
|
|
+ // static constexpr index_t acc_data_size = 4;
|
|
+ // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
|
|
+ static constexpr index_t acc_data_size = 4;
|
|
+ static constexpr index_t acc_pack_number = 1;
|
|
+ static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
|
+
|
|
+ // Wave mode dependent propety
|
|
+ static constexpr index_t wave_size = Number<WaveSize>{};
|
|
+ // * Fixed in Navi3x, Will be wave mode dependent on Navi4x
|
|
+ // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
|
|
+ // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
|
|
+ // * num_acc_vgprs_per_wave alone M direction
|
|
+ // * num_subgroups alone M direction
|
|
+ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
|
|
+ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
|
+
|
|
+ template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
|
+ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
|
+ {
|
|
+ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
|
|
+ if constexpr(wave_size == 32)
|
|
+ {
|
|
+ intrin_wmma_f32_16x16x16_f16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
|
+ }
|
|
+ }
|
|
+};
|
|
+
|
|
+template <index_t WaveSize>
|
|
+struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16_gfx12,
|
|
+ WaveSize,
|
|
+ typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
|
+{
|
|
+ // Absolute fixing property
|
|
+ static constexpr index_t m_per_wmma = 16;
|
|
+ static constexpr index_t n_per_wmma = 16;
|
|
+ static constexpr index_t k_per_wmma = 16;
|
|
+ // static constexpr index_t src_a_data_size = 2;
|
|
+ // static constexpr index_t src_b_data_size = 2;
|
|
+ static constexpr index_t acc_data_size = 4;
|
|
+ static constexpr index_t acc_pack_number = 1;
|
|
+ static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
|
+
|
|
+ // Wave mode dependent propety
|
|
+ static constexpr index_t wave_size = Number<WaveSize>{};
|
|
+ // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
|
+ // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
|
+ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
|
|
+ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
|
+
|
|
+ template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
|
+ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
|
+ {
|
|
+ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
|
|
+ if constexpr(wave_size == 32)
|
|
+ {
|
|
+ intrin_wmma_f32_16x16x16_bf16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
|
+ }
|
|
+ }
|
|
+};
|
|
+
|
|
+template <index_t WaveSize>
|
|
+struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8_gfx12,
|
|
+ WaveSize,
|
|
+ typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
|
+{
|
|
+ // Absolute fixing property
|
|
+ static constexpr index_t m_per_wmma = 16;
|
|
+ static constexpr index_t n_per_wmma = 16;
|
|
+ static constexpr index_t k_per_wmma = 16;
|
|
+ // static constexpr index_t src_a_data_size = 2;
|
|
+ // static constexpr index_t src_b_data_size = 2;
|
|
+ static constexpr index_t acc_data_size = 4;
|
|
+ static constexpr index_t acc_pack_number = 1;
|
|
+ static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
|
+
|
|
+ // Wave mode dependent propety
|
|
+ static constexpr index_t wave_size = Number<WaveSize>{};
|
|
+ // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
|
+ // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
|
+ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
|
|
+ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
|
+
|
|
+ template <index_t MPerWmma,
|
|
+ index_t NPerWmma,
|
|
+ class FloatA,
|
|
+ class FloatB,
|
|
+ class FloatC,
|
|
+ bool neg_a = false,
|
|
+ bool neg_b = false,
|
|
+ bool clamp = false>
|
|
+ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
|
+ {
|
|
+ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
|
|
+ if constexpr(wave_size == 32)
|
|
+ {
|
|
+ intrin_wmma_i32_16x16x16_iu8_w32_gfx12<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
|
|
+ a, b, reg_c);
|
|
+ }
|
|
+ }
|
|
+};
|
|
+
|
|
template <typename src_type_a,
|
|
typename src_type_b,
|
|
typename dst_type,
|
|
@@ -296,13 +417,21 @@ struct WmmaSelector
|
|
template <>
|
|
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
|
|
{
|
|
+#ifdef __gfx12__
|
|
+ return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
|
|
+#else
|
|
return WmmaInstr::wmma_f32_16x16x16_f16;
|
|
+#endif
|
|
}
|
|
|
|
template <>
|
|
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
|
|
{
|
|
+#ifdef __gfx12__
|
|
+ return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
|
|
+#else
|
|
return WmmaInstr::wmma_f32_16x16x16_bf16;
|
|
+#endif
|
|
}
|
|
|
|
template <>
|
|
@@ -320,8 +449,13 @@ struct WmmaSelector
|
|
template <>
|
|
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
|
|
{
|
|
+#ifdef __gfx12__
|
|
+ return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
|
|
+#else
|
|
return WmmaInstr::wmma_i32_16x16x16_iu8;
|
|
+#endif
|
|
}
|
|
+
|
|
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
|
template <>
|
|
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
|
|
@@ -502,6 +636,9 @@ struct WmmaGemm
|
|
|
|
__device__ static auto GetSubGroupId()
|
|
{
|
|
+ static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups ==
|
|
+ wmma_instr.wave_size,
|
|
+ "");
|
|
return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
|
|
}
|
|
|
|
@@ -516,12 +653,20 @@ struct WmmaGemm
|
|
|
|
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
|
|
{
|
|
+#ifdef __gfx12__
|
|
+ return GetLaneIdUnderSubGroup();
|
|
+#else
|
|
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
|
|
+#endif
|
|
}
|
|
|
|
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
|
|
{
|
|
+#ifdef __gfx12__
|
|
+ return GetLaneIdUnderSubGroup();
|
|
+#else
|
|
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
|
|
+#endif
|
|
}
|
|
|
|
__device__ static CIndex GetBeginOfThreadBlk()
|
|
diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp
|
|
index 1bb0140f3..322a0f94b 100644
|
|
--- a/include/ck/utility/amd_wmma.hpp
|
|
+++ b/include/ck/utility/amd_wmma.hpp
|
|
@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
|
|
}
|
|
};
|
|
|
|
+// gfx12
|
|
+/********************************WAVE32 MODE***********************************************/
|
|
+
|
|
+#if defined(__gfx1200__) || defined(__gfx1201__)
|
|
+#define __gfx12__
|
|
+#endif
|
|
+
|
|
+// src: fp16, dst: fp32
|
|
+template <index_t MPerWave, index_t NPerWave>
|
|
+struct intrin_wmma_f32_16x16x16_f16_w32_gfx12;
|
|
+
|
|
+template <>
|
|
+struct intrin_wmma_f32_16x16x16_f16_w32_gfx12<16, 16>
|
|
+{
|
|
+ template <class FloatC>
|
|
+ __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
|
|
+ {
|
|
+ // * Inline assembly need to elimate the duplicated data load, compiler won't help you
|
|
+ // delete them.
|
|
+ // amd_assembly_wmma_f32_16x16x16_f16_w32(
|
|
+ // reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
|
|
+#if defined(__gfx12__)
|
|
+ reg_c.template AsType<float8_t>()(Number<0>{}) =
|
|
+ __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(
|
|
+ reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
|
|
+#else
|
|
+ ignore = reg_a;
|
|
+ ignore = reg_b;
|
|
+ ignore = reg_c;
|
|
+#endif
|
|
+ }
|
|
+};
|
|
+
|
|
+// src: bf16, dst: fp32
|
|
+template <index_t MPerWave, index_t NPerWave>
|
|
+struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12;
|
|
+
|
|
+template <>
|
|
+struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12<16, 16>
|
|
+{
|
|
+ template <class FloatC>
|
|
+ __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
|
|
+ {
|
|
+#if defined(__gfx12__)
|
|
+ reg_c.template AsType<float8_t>()(Number<0>{}) =
|
|
+ __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(
|
|
+ reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
|
|
+#else
|
|
+ ignore = reg_a;
|
|
+ ignore = reg_b;
|
|
+ ignore = reg_c;
|
|
+#endif
|
|
+ }
|
|
+};
|
|
+
|
|
+// src: iu8, dst: i32
|
|
+template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
|
|
+struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12;
|
|
+
|
|
+template <bool neg_a, bool neg_b, bool clamp>
|
|
+struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp>
|
|
+{
|
|
+ template <class FloatC>
|
|
+ __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
|
|
+ {
|
|
+#if defined(__gfx12__)
|
|
+ reg_c.template AsType<int32x8_t>()(Number<0>{}) =
|
|
+ __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
|
+ neg_a,
|
|
+ bit_cast<int32x2_t>(reg_a),
|
|
+ neg_b,
|
|
+ bit_cast<int32x2_t>(reg_b),
|
|
+ reg_c.template AsType<int32x8_t>()[Number<0>{}],
|
|
+ clamp);
|
|
+#else
|
|
+ ignore = reg_a;
|
|
+ ignore = reg_b;
|
|
+ ignore = reg_c;
|
|
+#endif
|
|
+ }
|
|
+};
|
|
+
|
|
} // namespace ck
|
|
#endif
|
|
diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp
|
|
index 93a1edefb..4df14c621 100644
|
|
--- a/include/ck/utility/data_type.hpp
|
|
+++ b/include/ck/utility/data_type.hpp
|
|
@@ -203,7 +203,7 @@ struct vector_type<T, 1>
|
|
}
|
|
};
|
|
|
|
-int static err = 0;
|
|
+__device__ int static err = 0;
|
|
template <typename T>
|
|
struct vector_type<T, 2>
|
|
{
|
|
diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp
|
|
index 4fe5e3950..d6b6eac26 100644
|
|
--- a/include/ck/utility/synchronization.hpp
|
|
+++ b/include/ck/utility/synchronization.hpp
|
|
@@ -10,12 +10,20 @@ namespace ck {
|
|
__device__ void block_sync_lds()
|
|
{
|
|
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
|
+#ifdef __gfx12__
|
|
+ asm volatile("\
|
|
+ s_wait_dscnt 0x0 \n \
|
|
+ s_barrier_signal -1 \n \
|
|
+ s_barrier_wait -1 \
|
|
+ " ::);
|
|
+#else
|
|
// asm volatile("\
|
|
// s_waitcnt lgkmcnt(0) \n \
|
|
// s_barrier \
|
|
// " ::);
|
|
__builtin_amdgcn_s_waitcnt(0xc07f);
|
|
__builtin_amdgcn_s_barrier();
|
|
+#endif
|
|
#else
|
|
__syncthreads();
|
|
#endif
|
|
@@ -23,11 +31,20 @@ __device__ void block_sync_lds()
|
|
|
|
__device__ void block_sync_lds_direct_load()
|
|
{
|
|
+#ifdef __gfx12__
|
|
+ asm volatile("\
|
|
+ s_wait_vmcnt 0x0 \n \
|
|
+ s_wait_dscnt 0x0 \n \
|
|
+ s_barrier_signal -1 \n \
|
|
+ s_barrier_wait -1 \
|
|
+ " ::);
|
|
+#else
|
|
asm volatile("\
|
|
s_waitcnt vmcnt(0) \n \
|
|
s_waitcnt lgkmcnt(0) \n \
|
|
s_barrier \
|
|
" ::);
|
|
+#endif
|
|
}
|
|
|
|
__device__ void s_nop()
|
|
diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp
|
|
index 601aad19b..9dc2b072a 100644
|
|
--- a/include/ck_tile/core/config.hpp
|
|
+++ b/include/ck_tile/core/config.hpp
|
|
@@ -17,6 +17,9 @@
|
|
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
|
|
#define __gfx11__
|
|
#endif
|
|
+#if defined(__gfx1200__) || defined(__gfx1201__)
|
|
+#define __gfx12__
|
|
+#endif
|
|
|
|
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
|
|
#include "hip/hip_runtime.h"
|
|
@@ -155,7 +158,7 @@
|
|
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
|
#elif defined(__gfx103__) // for GPU code
|
|
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
|
-#elif defined(__gfx11__) // for GPU code
|
|
+#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code
|
|
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
|
|
#endif
|
|
|
|
diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt
|
|
index 8c5f36d2e..89c9d6dc6 100644
|
|
--- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt
|
|
+++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt
|
|
@@ -52,7 +52,7 @@ function(add_instance_library INSTANCE_NAME)
|
|
endforeach()
|
|
# Do not build WMMA instances if gfx11 targets are not on the target list
|
|
foreach(source IN LISTS ARGN)
|
|
- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
|
+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
|
|
message("removing wmma instance ${source} ")
|
|
list(REMOVE_ITEM ARGN "${source}")
|
|
endif()
|
|
@@ -149,7 +149,7 @@ FOREACH(subdir_path ${dir_list})
|
|
message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.")
|
|
set(add_inst 0)
|
|
endif()
|
|
- if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11"))
|
|
+ if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12"))
|
|
message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.")
|
|
set(add_inst 0)
|
|
endif()
|
|
@@ -157,11 +157,11 @@ FOREACH(subdir_path ${dir_list})
|
|
message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.")
|
|
set(add_inst 0)
|
|
endif()
|
|
- if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9"))
|
|
+ if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9"))
|
|
message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.")
|
|
set(add_inst 0)
|
|
endif()
|
|
- if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS))
|
|
+ if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS))
|
|
message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.")
|
|
set(add_inst 0)
|
|
endif()
|
|
diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt
|
|
index 1cfcbfff6..a9557a9b9 100644
|
|
--- a/profiler/src/CMakeLists.txt
|
|
+++ b/profiler/src/CMakeLists.txt
|
|
@@ -58,7 +58,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
|
|
|
endif()
|
|
|
|
-if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9")
|
|
+if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9")
|
|
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
|
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
|
|
endif()
|
|
@@ -133,7 +133,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
|
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
|
|
endif()
|
|
|
|
-if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11")
|
|
+if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
|
|
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
|
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
|
|
endif()
|
|
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
|
|
index 25c63ac7f..2a7c52b58 100644
|
|
--- a/test/CMakeLists.txt
|
|
+++ b/test/CMakeLists.txt
|
|
@@ -53,7 +53,7 @@ function(add_test_executable TEST_NAME)
|
|
endif()
|
|
endforeach()
|
|
foreach(source IN LISTS ARGN)
|
|
- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma")
|
|
+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma")
|
|
message("removing wmma test ${source} ")
|
|
list(REMOVE_ITEM ARGN "${source}")
|
|
endif()
|
|
@@ -118,7 +118,7 @@ function(add_gtest_executable TEST_NAME)
|
|
endif()
|
|
endforeach()
|
|
foreach(source IN LISTS ARGN)
|
|
- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma")
|
|
+ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma")
|
|
message("removing wmma test ${source} ")
|
|
list(REMOVE_ITEM ARGN "${source}")
|
|
endif()
|
|
diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
|
|
index 1c8082645..21f49ec0f 100644
|
|
--- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
|
|
+++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
|
|
@@ -55,7 +55,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
|
}
|
|
}
|
|
|
|
- if(ck::is_gfx11_supported())
|
|
+ if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
|
{
|
|
// on gfx11 only support for 3d is implemented
|
|
if constexpr(NDimSpatial{} != 3)
|
|
diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp
|
|
index 49782bce6..d9ec94771 100644
|
|
--- a/test/wmma_op/wmma_op_util.hpp
|
|
+++ b/test/wmma_op/wmma_op_util.hpp
|
|
@@ -140,10 +140,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
|
|
p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele];
|
|
}
|
|
|
|
+#ifdef __gfx12__
|
|
+ asm volatile("\
|
|
+ s_wait_dscnt 0x0 \n \
|
|
+ s_barrier_signal -1 \n \
|
|
+ s_barrier_wait -1 \
|
|
+ " ::);
|
|
+#else
|
|
asm volatile("\
|
|
s_waitcnt lgkmcnt(0) \n \
|
|
s_barrier \
|
|
" ::);
|
|
+#endif
|
|
|
|
for(int ele = 0; ele < 16; ++ele)
|
|
{
|
|
@@ -155,10 +163,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
|
|
a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8];
|
|
}
|
|
|
|
+#ifdef __gfx12__
|
|
+ asm volatile("\
|
|
+ s_wait_dscnt 0x0 \n \
|
|
+ s_barrier_signal -1 \n \
|
|
+ s_barrier_wait -1 \
|
|
+ " ::);
|
|
+#else
|
|
asm volatile("\
|
|
s_waitcnt lgkmcnt(0) \n \
|
|
s_barrier \
|
|
" ::);
|
|
+#endif
|
|
|
|
// sync threads, similar to mma_sync
|
|
// __syncthreads();
|