From 4eac0db3afecf93063bf3d4ff289763453aaf89f Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 5 Jan 2023 17:53:30 +0800 Subject: [PATCH] [ROCm] Add GemmFastGelu CK implementation (#13759) ### Description Add GemmFastGelu CK implementation. TODO 1. The performance of CK GemmFastGelu in ORT is not good as using CK directly, still need to investigate the reason and improve the CK in ORT. `GemmFastGeluUnfused float16 NN m=49152 n=3072 k=768 2298.8064 us 100.89 tflops` `withbias DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8, Default> LoopScheduler: Default, PipelineVersion: v1 float16 NN m=49152 n=3072 k=768 2401.9799 us 96.56 tflops` ### Motivation and Context Co-authored-by: peixuanzuo --- cmake/external/composable_kernel.cmake | 2 +- cmake/onnxruntime_kernel_explorer.cmake | 7 +- cmake/onnxruntime_providers.cmake | 9 +- .../composable_kernel/Fix_Clang_Build.patch | 20 +-- .../contrib_ops/rocm/bert/gemm_fast_gelu.cc | 20 ++- .../rocm/bert/gemm_fast_gelu_ck.cuh | 133 ++++++++++++++++ .../rocm/bert/gemm_fast_gelu_common.h | 48 ++++++ .../rocm/bert/gemm_fast_gelu_impl.cu | 104 +++++++------ .../rocm/bert/gemm_fast_gelu_impl.h | 48 +++--- .../rocm/bert/gemm_fast_gelu_tunable.cuh | 73 +++++++++ .../rocm/bert/gemm_fast_gelu_tunable_op.h | 80 ---------- .../kernels/gemm_fast_gelu_test.py | 108 ++++++++----- .../kernel_explorer/kernels/gemm_test.py | 9 +- .../kernels/rocm/gemm_fast_gelu.cc | 22 +++ .../kernels/rocm/gemm_fast_gelu.cu | 146 ------------------ .../kernels/rocm/gemm_fast_gelu_ck.cu | 126 +++++++++++++++ .../kernels/rocm/gemm_fast_gelu_ck.h | 14 ++ .../kernels/rocm/gemm_fast_gelu_tunable.cu | 107 +++++++++++++ .../kernels/rocm/gemm_fast_gelu_tunable.h | 14 ++ .../kernels/rocm/gemm_fast_gelu_unfused.cu | 99 ++++++++++++ .../kernels/rocm/gemm_fast_gelu_unfused.h | 14 ++ .../tools/kernel_explorer/kernels/utils.py | 7 + .../test/contrib_ops/gemm_fastgelu_op_test.cc | 34 ++-- 23 files changed, 858 insertions(+), 386 deletions(-) create mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh create mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cc delete mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cu create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.cu create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h diff --git a/cmake/external/composable_kernel.cmake b/cmake/external/composable_kernel.cmake index fc20520b83..8ca9c1516c 100644 --- a/cmake/external/composable_kernel.cmake +++ b/cmake/external/composable_kernel.cmake @@ -1,5 +1,5 @@ set(composable_kernel_URL https://github.com/ROCmSoftwarePlatform/composable_kernel.git) -set(composable_kernel_TAG 8ee36118be9b19b15c2471bffeeeb624afb14044) # 2022-11-01 00:24:25 +0800 +set(composable_kernel_TAG 0345963eef4f92e9c5eab608bb8557b5463a1dcb) # 2022-12-15 15:07:24 -0600 set(PATCH ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Fix_Clang_Build.patch) diff --git a/cmake/onnxruntime_kernel_explorer.cmake b/cmake/onnxruntime_kernel_explorer.cmake index 67ca3885a3..e525e7005b 100644 --- a/cmake/onnxruntime_kernel_explorer.cmake +++ b/cmake/onnxruntime_kernel_explorer.cmake @@ -59,12 +59,7 @@ elseif (onnxruntime_USE_ROCM) auto_set_source_files_hip_language(${kernel_explorer_kernel_srcs} ${kernel_explorer_rocm_kernel_srcs}) target_sources(kernel_explorer PRIVATE ${kernel_explorer_rocm_kernel_srcs}) target_compile_definitions(kernel_explorer PRIVATE __HIP_PLATFORM_AMD__=1 __HIP_PLATFORM_HCC__=1) - target_link_libraries(kernel_explorer PRIVATE - onnxruntime_composable_kernel_includes - # Currently we shall not use composablekernels::device_operations, the target includes all conv dependencies, which - # are extremely slow to compile. Instead, we only link all gemm related objects. See the following link on updating. - # https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/85978e0201/library/src/tensor_operation_instance/gpu/CMakeLists.txt#L33-L54 - device_gemm_instance) + target_link_libraries(kernel_explorer PRIVATE onnxruntime_composable_kernel_includes) endif() add_dependencies(kernel_explorer onnxruntime_pybind11_state) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 19c4e4a6d6..08c955ddf7 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1436,7 +1436,14 @@ if (onnxruntime_USE_ROCM) endif() include(composable_kernel) - target_link_libraries(onnxruntime_providers_rocm PRIVATE onnxruntime_composable_kernel_includes device_gemm_instance) + target_link_libraries(onnxruntime_providers_rocm PRIVATE + onnxruntime_composable_kernel_includes + # Currently we shall not use composablekernels::device_operations, the target includes all conv dependencies, which + # are extremely slow to compile. Instead, we only link all gemm related objects. See the following link on updating. + # https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/85978e0201/library/src/tensor_operation_instance/gpu/CMakeLists.txt#L33-L54 + device_gemm_instance + device_gemm_add_fastgelu_instance + device_gemm_fastgelu_instance) if(UNIX) set_property(TARGET onnxruntime_providers_rocm APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/rocm/version_script.lds -Xlinker --gc-sections") diff --git a/cmake/patches/composable_kernel/Fix_Clang_Build.patch b/cmake/patches/composable_kernel/Fix_Clang_Build.patch index 937a739c7b..7ee4c8bfaf 100644 --- a/cmake/patches/composable_kernel/Fix_Clang_Build.patch +++ b/cmake/patches/composable_kernel/Fix_Clang_Build.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index 5655ba17..1252d1ea 100644 +index f861e302..f0b6bcea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ @@ -48,7 +48,7 @@ index 5655ba17..1252d1ea 100644 ## tidy include(EnableCompilerWarnings) -@@ -263,9 +240,6 @@ rocm_package_setup_component(tests +@@ -273,9 +250,6 @@ rocm_package_setup_component(profiler ) add_subdirectory(library) @@ -58,7 +58,7 @@ index 5655ba17..1252d1ea 100644 #Create an interface target for the include only files and call it "composablekernels" include(CMakePackageConfigHelpers) -@@ -291,11 +265,3 @@ rocm_install(FILES +@@ -301,11 +275,3 @@ rocm_install(FILES set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") set(CPACK_RPM_PACKAGE_LICENSE "MIT") @@ -70,20 +70,6 @@ index 5655ba17..1252d1ea 100644 - LDCONFIG - HEADER_ONLY -) -diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp -index 92018aac..2ada620c 100644 ---- a/include/ck/ck.hpp -+++ b/include/ck/ck.hpp -@@ -126,7 +126,9 @@ - #define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1 - - // experimental feature: optimize for inter-wave scheduling policy -+#ifndef CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING - #define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING 0 -+#endif - #define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS 1 - - // hack: have underlying assumption that need to be satsified, otherwise it's a bug diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index c206c4dc..e45fac9d 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc index 106bf15cc3..453c82a2ed 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc @@ -3,16 +3,17 @@ #include "contrib_ops/rocm/bert/gemm_fast_gelu.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" #include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/rocm/rocm_common.h" +using onnxruntime::rocm::ToHipType; + namespace onnxruntime { namespace contrib { namespace rocm { -using onnxruntime::rocm::ToHipType; - #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ GemmFastGelu, \ @@ -54,17 +55,20 @@ Status GemmFastGelu::ComputeInternal(OpKernelContext* ctx) const { const HipT alpha = ToHipType::FromFloat(1.0f); const HipT beta = ToHipType::FromFloat(0.0f); - return LaunchGemmFastGeluKernel( + using onnxruntime::rocm::tunable::blas::BlasOp; + + return blas::row_major::GemmFastGelu( IsTunableOpEnabled(), Stream(ctx), GetRocblasHandle(ctx), - transa, transb, - static_cast(helper.M()), static_cast(helper.N()), static_cast(helper.K()), + transa ? BlasOp::Trans : BlasOp::NonTrans, + transb ? BlasOp::Trans : BlasOp::NonTrans, + helper.M(), helper.N(), helper.K(), alpha, - reinterpret_cast(X->Data()), static_cast(helper.Lda(transa)), - reinterpret_cast(W->Data()), static_cast(helper.Ldb(transb)), + reinterpret_cast(X->Data()), helper.Lda(transa), + reinterpret_cast(W->Data()), helper.Ldb(transb), (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, beta, - reinterpret_cast(Y->MutableData()), static_cast(helper.Ldc())); + reinterpret_cast(Y->MutableData()), helper.Ldc()); } } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh new file mode 100644 index 0000000000..77c138a30b --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" + +using onnxruntime::rocm::ToHipType; +using onnxruntime::rocm::tunable::Op; + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace blas { +namespace internal { + +template +struct DataTypeAdaptor { + using type = T; +}; + +template <> +struct DataTypeAdaptor { + using type = ck::half_t; +}; + +template <> +struct DataTypeAdaptor { + using type = ck::bhalf16_t; +}; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using Nop = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +template +auto GetCKGemmAddFastGeluTypeStringAndOps() { + using CKDataType = typename DataTypeAdaptor::type; + using DeviceGemmAddFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, BLayout, ck::Tuple, Row, + CKDataType, CKDataType, ck::Tuple, CKDataType, + Nop, Nop, AddFastGelu>; + using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; + + std::vector>>> ret; + for (auto&& impl : InstanceFactory::GetInstances()) { + auto type_string = onnxruntime::MakeString("withbias ", impl->GetTypeString()); + + auto invoker = impl->MakeInvokerPointer(); + auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { + auto one = ToHipType::FromFloat(1.0f); + auto zero = ToHipType::FromFloat(0.0f); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->alpha != one || params->beta != zero || params->bias == nullptr, + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr", params->Signature()); + + auto nop = Nop{}; + auto addfastgelu = AddFastGelu{}; + auto arg = impl->MakeArgumentPointer(params->a, params->b, std::array{params->bias}, params->c, + params->m, params->n, params->k, + params->lda, params->ldb, std::array{0}, params->ldc, + nop, nop, addfastgelu); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support ", params->Signature()); + invoker->Run(arg.get(), StreamConfig{params->stream}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); + } + return ret; +} + + +template +auto GetCKGemmFastGeluTypeStringAndOps() { + using CKDataType = typename DataTypeAdaptor::type; + using DeviceGemmFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, BLayout, ck::Tuple<>, Row, + CKDataType, CKDataType, ck::Tuple<>, CKDataType, + Nop, Nop, FastGelu>; + using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; + + std::vector>>> ret; + for (auto&& impl : InstanceFactory::GetInstances()) { + auto type_string = onnxruntime::MakeString("nobias ", impl->GetTypeString()); + auto invoker = impl->MakeInvokerPointer(); + auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { + auto one = ToHipType::FromFloat(1.0f); + auto zero = ToHipType::FromFloat(0.0f); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->alpha != one || params->beta != zero || params->bias != nullptr, + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr", params->Signature()); + + auto nop = Nop{}; + auto fastgelu = FastGelu{}; + auto arg = impl->MakeArgumentPointer(params->a, params->b, + {}, + params->c, + params->m, params->n, params->k, + params->lda, params->ldb, + {}, + params->ldc, + nop, nop, fastgelu); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support ", params->Signature()); + invoker->Run(arg.get(), StreamConfig{params->stream}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); + } + return ret; +} + +} // namespace internal +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h new file mode 100644 index 0000000000..9bb0951101 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/tunable/gemm_common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +using onnxruntime::rocm::tunable::blas::BlasOp; +using onnxruntime::rocm::tunable::blas::BlasOpToString; + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace blas { + +template +struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams { + std::string Signature() const override { + bool has_bias = (nullptr != bias) ? 0 : 1; + return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, '_', has_bias); + } + rocblas_handle handle; + BlasOp opa; + BlasOp opb; + int64_t m; + int64_t n; + int64_t k; + T alpha; + const T* a; + int64_t lda; + const T* b; + int64_t ldb; + const T* bias; + T beta; + T* c; + int64_t ldc; + bool tuning{false}; +}; + +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu index 1317a2ccbd..039573e585 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu @@ -1,79 +1,95 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#define _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES #include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" -#include +#include +#include -#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h" -#include "core/providers/rocm/tunable/gemm_common.h" - -using onnxruntime::rocm::tunable::blas::BlasOp; +#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh" +#include "core/providers/rocm/shared_inc/fpgeneric.h" namespace onnxruntime { namespace contrib { namespace rocm { +namespace blas { -// See it as row-major -template -Status LaunchGemmFastGeluKernel(bool tuning, - hipStream_t stream, - rocblas_handle handle, - bool transa, - bool transb, - int64_t m, - int64_t n, - int64_t k, - const T alpha, - const T* a, - int64_t lda, - const T* b, - int64_t ldb, - const T* bias, - const T beta, - T* c, - int64_t ldc) { +namespace row_major { + +template +inline GEMMFASTGELU(T, ScalarT) { GemmFastGeluParams params; - params.tuning = tuning; params.stream = stream; params.handle = handle; - params.opa = transa ? BlasOp::Trans : BlasOp::NonTrans; - params.opb = transb ? BlasOp::Trans : BlasOp::NonTrans; + params.opa = opa; + params.opb = opb; params.m = m; params.n = n; params.k = k; - params.alpha = alpha; + if constexpr (!std::is_same_v && std::is_same_v) { + params.alpha = ToHipType::FromFloat(std::forward(alpha)); + } else { + params.alpha = alpha; + } params.a = a; params.lda = lda; params.b = b; params.ldb = ldb; params.bias = bias; - params.beta = beta; + if constexpr (!std::is_same_v && std::is_same_v) { + params.beta = ToHipType::FromFloat(std::forward(beta)); + } else { + params.beta = beta; + } params.c = c; params.ldc = ldc; - if (tuning) { - static GemmFastGeluTunableOp op; - op.EnableTuning(); - return op(¶ms); + if (tunable) { + params.tuning = true; + if (opa == BlasOp::N && opb == BlasOp::N) { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + gemm_fast_gelu.EnableTuning(); + return gemm_fast_gelu(¶ms); + } else if (opa == BlasOp::T && opb == BlasOp::N) { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + gemm_fast_gelu.EnableTuning(); + return gemm_fast_gelu(¶ms); + } else if (opa == BlasOp::N && opb == BlasOp::T) { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + gemm_fast_gelu.EnableTuning(); + return gemm_fast_gelu(¶ms); + } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + gemm_fast_gelu.EnableTuning(); + return gemm_fast_gelu(¶ms); + } } - return GemmFastGeluUnfused(¶ms); + return internal::GemmFastGeluUnfused(¶ms); } -#define SPECIALIZED_IMPL(T) \ - template Status LaunchGemmFastGeluKernel(bool tuning, \ - hipStream_t stream, rocblas_handle handle, \ - bool transa, bool transb, \ - int64_t m, int64_t n, int64_t k, const T alpha, \ - const T* a, int64_t lda, const T* b, int64_t ldb, \ - const T* bias, const T beta, T* c, int64_t ldc); +#define CALL_GEMMFASTGELU(T, ScalarT) \ + GemmFastGelu(tunable, stream, handle, \ + opa, opb, \ + m, n, k, \ + alpha, a, lda, b, ldb, bias, \ + beta, c, ldc) -SPECIALIZED_IMPL(float) -SPECIALIZED_IMPL(half) -SPECIALIZED_IMPL(BFloat16) +// clang-format off +GEMMFASTGELU(float, float ) { return CALL_GEMMFASTGELU(float, float ); } +GEMMFASTGELU(half, half ) { return CALL_GEMMFASTGELU(half, half ); } +GEMMFASTGELU(BFloat16, BFloat16) { return CALL_GEMMFASTGELU(BFloat16, BFloat16); } +GEMMFASTGELU(half, float ) { return CALL_GEMMFASTGELU(half, float ); } +GEMMFASTGELU(BFloat16, float ) { return CALL_GEMMFASTGELU(BFloat16, float ); } +// clang-format on +#undef CALL_GEMMFASTGELU + +} // namespace row_major + +} // namespace blas } // namespace rocm } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h index 765a0c96a9..3daeea07b6 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h @@ -3,34 +3,38 @@ #pragma once -#include - -#include "core/common/common.h" -#include "core/providers/rocm/rocm_common.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "core/common/status.h" +#include "core/framework/float16.h" namespace onnxruntime { namespace contrib { namespace rocm { +namespace blas { -template -Status LaunchGemmFastGeluKernel(bool tuning, - hipStream_t stream, - rocblas_handle handle, - bool transa, - bool transb, - int64_t m, - int64_t n, - int64_t k, - const T alpha, - const T* a, - int64_t lda, - const T* b, - int64_t ldb, - const T* bias, - const T beta, - T* c, - int64_t ldc); +#define GEMMFASTGELU(T, ScalarT) \ + common::Status GemmFastGelu( \ + bool tunable, hipStream_t stream, rocblas_handle handle, \ + BlasOp opa, BlasOp opb, \ + std::int64_t m, std::int64_t n, std::int64_t k, \ + ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \ + const T* bias, ScalarT beta, T* c, std::int64_t ldc) +namespace row_major { + +GEMMFASTGELU(float, float); +GEMMFASTGELU(half, half); +GEMMFASTGELU(BFloat16, BFloat16); +GEMMFASTGELU(half, float); +GEMMFASTGELU(BFloat16, float); + +} // namespace row_major + +} // namespace blas } // namespace rocm } // namespace contrib } // namespace onnxruntime + +#ifndef _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES +#undef GEMMFASTGELU +#endif diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh new file mode 100644 index 0000000000..de0e11b52a --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "contrib_ops/rocm/bert/fast_gelu_impl.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "core/providers/rocm/tunable/gemm.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib{ +namespace rocm { +namespace blas { +namespace internal { + +template +Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { + namespace column_major = onnxruntime::rocm::tunable::blas::column_major; + ORT_RETURN_IF_ERROR(column_major::Gemm(params->tuning, params->stream, params->handle, + params->opb, params->opa, + params->n, params->m, params->k, + params->alpha, params->b, params->ldb, params->a, params->lda, + params->beta, params->c, params->ldc)); + + int64_t fast_gelu_input_length = params->m * params->n; + int64_t bias_length = (params->bias != nullptr) ? params->n : 0; + + // Because of GemmFastGeluUnfused is a combination of GemmOp and FastGeluOp, FastGeluOp in this combination is + // an inplace computation. + // 1. If we call GemmFastGeluUnfused directly with enabled tuning, it may cause the input buffer of FastGelu been + // updated accumulatedly and result in incorrect result finally. This only happens if the tuning's FindFastest is invoked. + // 2. It's safe to call GemmFastGeluUnfused with disabled tuning, FastGelu only run once and produce correct result. + // 3. It's safe to call GemmFastGeluUnfused as part of GemmFastGeluTunableOp with enable tuning, GemmTunableOp and + // FastGeluTunableOp will do tune in first warmup step separately during GemmFastGeluUnfused profiling process. + // After that, the call to GemmFastGeluUnfused not invoke tuning's FindFastest of FastGelu. + // + // Note: If any change cause directly usage of GemmFastGeluUnfused, add PreTuning() and PostTuning() in FastGeluTunableOp + // to protect original input value. + return onnxruntime::contrib::rocm::LaunchFastGeluKernel(params->tuning, + params->stream, + static_cast(fast_gelu_input_length), + static_cast(bias_length), + params->c, + params->bias, + params->c); +} + +template +class GemmFastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp> { + public: + GemmFastGeluTunableOp() { + this->RegisterOp(GemmFastGeluUnfused); + for (auto&& [_, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } + for (auto&& [_, op] : GetCKGemmFastGeluTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } + } +}; + +} // namespace internal +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h deleted file mode 100644 index 6d050fe3fc..0000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "contrib_ops/rocm/bert/fast_gelu_impl.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/gemm_common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::tunable::blas::BlasOp; -using onnxruntime::rocm::tunable::blas::BlasOpToString; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams { - std::string Signature() const override { - return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k); - } - rocblas_handle handle; - BlasOp opa; - BlasOp opb; - int64_t m; - int64_t n; - int64_t k; - T alpha; - const T* a; - int64_t lda; - const T* b; - int64_t ldb; - const T* bias; - T beta; - T* c; - int64_t ldc; - bool tuning{false}; -}; - -template -Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { - namespace column_major = onnxruntime::rocm::tunable::blas::column_major; - if (column_major::Gemm(params->tuning, params->stream, params->handle, - params->opb, params->opa, - params->n, params->m, params->k, - params->alpha, params->b, params->ldb, params->a, params->lda, - params->beta, params->c, params->ldc) != Status::OK()) { - return Status(common::ONNXRUNTIME, common::FAIL, "GemmFastGelu call column_major::Gemm failed"); - } - - int64_t fast_gelu_input_length = params->m * params->n; - int64_t bias_length = (params->bias != nullptr) ? params->n : 0; - - // inplace computation - return LaunchFastGeluKernel(params->tuning, - params->stream, - static_cast(fast_gelu_input_length), - static_cast(bias_length), - params->c, - params->bias, - params->c); -} - -template -class GemmFastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp> { - public: - GemmFastGeluTunableOp() { - this->RegisterOp(GemmFastGeluUnfused); - this->SetDefaultId(0); - } -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py index 6a21da480b..164d07621f 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py @@ -3,22 +3,14 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -import re import sys +from dataclasses import dataclass from itertools import product import kernel_explorer as ke import numpy as np import pytest -from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix - - -def dtype_to_funcs(dtype): - type_map = { - "float16": list(filter(lambda x: re.search("GemmFastGelu.*_half", x), dir(ke))), - "float32": list(filter(lambda x: re.search("GemmFastGelu.*_float", x), dir(ke))), - } - return type_map[dtype] +from utils import dtype_to_suffix, get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix def fast_gelu(x, bias): @@ -28,7 +20,7 @@ def fast_gelu(x, bias): # TODO The test method needs update. -def _test_gemmfastgelu(func, dtype: str, m: int, n: int, k: int, transa=False, transb=False): +def _test_gemmfastgelu(my_func, dtype: str, m: int, n: int, k: int, transa=False, transb=False): assert dtype in ["float16", "float32"] a_shape = (k, m) if transa else (m, k) @@ -60,17 +52,17 @@ def _test_gemmfastgelu(func, dtype: str, m: int, n: int, k: int, transa=False, t ldb = b_shape[1] alpha = 1.0 beta = 0.0 - my_func = getattr(ke, func) my_op = my_func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, dev_bias, beta, dev_c, n) - if my_op.IsSupported(): + print(f"dtype={dtype} {transab_to_suffix((transa, transb))} m={m:<5} n={n:<5} k={k:<5} bound: {max(bound, 1e-2)}") + + for impl in my_op.ListOps(): + if not my_op.SelectOp(impl): + continue + my_op.Run() dev_c.UpdateHostNumpyArray() - print( - f"{func:<50} : dtype={dtype} {transab_to_suffix((transa, transb))} m={m:<5} n={n:<5} k={k:<5} bound: {bound}" - ) - np.testing.assert_allclose(my_c, ref_c, rtol=max(bound, 1e-2)) @@ -81,12 +73,45 @@ all_transabs = list(product([True, False], repeat=2)) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transab", all_transabs) -def test_gemmfastgelu_bert_cases(dtype, size, transab): - for func in dtype_to_funcs(dtype): - _test_gemmfastgelu(func, dtype, *size, *transab) +def test_gemmfastgelu_unfused_bert_cases(dtype, size, transab): + _test_gemmfastgelu(getattr(ke, "GemmFastGeluUnfused_" + dtype_to_suffix(dtype)), dtype, *size, *transab) -def profile_gemmfastgelu_func(func, dtype: str, m: int, n: int, k: int, transa: bool, transb: bool): +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) +@pytest.mark.parametrize("transab", all_transabs) +def test_gemmfastgelu_tunable_bert_cases(dtype, size, transab): + wrapper_name = "GemmFastGeluTunable_{}_{}".format(dtype_to_suffix(dtype), transab_to_suffix(transab)) + _test_gemmfastgelu(getattr(ke, wrapper_name), dtype, *size, *transab) + + +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) +@pytest.mark.parametrize("transab", all_transabs) +def test_gemmfastgelu_ck_bert_cases(dtype, size, transab): + wrapper_name = "CKGemmFastGelu_{}_{}".format(dtype_to_suffix(dtype), transab_to_suffix(transab)) + _test_gemmfastgelu(getattr(ke, wrapper_name), dtype, *size, *transab) + + +@dataclass +class GemmFastGeluMetric(ke.ComputeMetric): + transa: bool + transb: bool + m: int + n: int + k: int + + def report(self): + prefix = f"{self.name:<50} {self.dtype} {transab_to_suffix((self.transa, self.transb))} " + if self.duration > 0: + return ( + prefix + + f"m={self.m:<4} n={self.n:<4} k={self.k:<4} {self.duration:>8.4f} us {self.tflops:>5.2f} tflops" + ) + return prefix + "not supported" + + +def profile_gemmfastgelu_func(my_func, dtype: str, m: int, n: int, k: int, transa: bool, transb: bool): a_shape = (k, m) if transa else (m, k) b_shape = (n, k) if transb else (k, n) @@ -107,34 +132,36 @@ def profile_gemmfastgelu_func(func, dtype: str, m: int, n: int, k: int, transa: ldb = b_shape[1] alpha = 1.0 beta = 0.0 - my_func = getattr(ke, func) my_op = my_func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, dev_bias, beta, dev_c, n) - if my_op.IsSupported(): - my_op.Run() - dev_c.UpdateHostNumpyArray() - - time_ms = my_op.Profile() - time_us = time_ms * 1000 + for impl in my_op.ListOps(): + duration_ms = -1 + if my_op.SelectOp(impl): + duration_ms = my_op.Profile() # only counts gemm tflops because fastgelu is low order term (7 * n). - tflops = (m * k * n * 2) / (time_ms * 1e-3) / 1e12 - print( - f"{func:<50} {dtype} {transab_to_suffix((transa, transb))}", - f"m={m:<4} n={n:<4} k={k:<4} {time_us:>8.4f} us {tflops:>5.2f} tflops", + floating_point_operations = m * k * n * 2 + + ke.report(GemmFastGeluMetric(impl, dtype, duration_ms, floating_point_operations, transa, transb, m, n, k)) + + +def profile_with_args(transa, transb, dtype, m, n, k, sort): + dtype_suffix = "_" + dtype_to_suffix(dtype) + transab_suffix = "_" + transab_to_suffix((transa, transb)) + with ke.benchmark(sort): + profile_gemmfastgelu_func(getattr(ke, "GemmFastGeluUnfused" + dtype_suffix), dtype, m, n, k, transa, transb) + profile_gemmfastgelu_func( + getattr(ke, "CKGemmFastGelu" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb + ) + profile_gemmfastgelu_func( + getattr(ke, "GemmFastGeluTunable" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb ) - - -def profile_with_args(transa, transb, dtype, m, n, k): - for func in dtype_to_funcs(dtype): - profile_gemmfastgelu_func(func, dtype, m, n, k, transa, transb) def profile(): for dtype in dtypes: for m, n, k in get_gemm_bert_sizes(full=True): - profile_with_args(False, False, dtype, m, n, k) + profile_with_args(False, False, dtype, m, n, k, True) print() - print() if __name__ == "__main__": @@ -148,8 +175,9 @@ if __name__ == "__main__": group.add_argument("m", type=int) group.add_argument("n", type=int) group.add_argument("k", type=int) + group.add_argument("--sort", action="store_true") if len(sys.argv) == 1: profile() else: args = parser.parse_args() - profile_with_args(args.transa == "T", args.transb == "T", args.dtype, args.m, args.n, args.k) + profile_with_args(args.transa == "T", args.transb == "T", args.dtype, args.m, args.n, args.k, args.sort) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py index 456d1b598b..da3357947b 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py @@ -10,14 +10,7 @@ from itertools import product import kernel_explorer as ke import numpy as np import pytest -from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix - - -def dtype_to_suffix(dtype): - return { - "float32": "float", - "float16": "half", - }[dtype] +from utils import dtype_to_suffix, get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix def _test_gemm(func, dtype: str, m: int, n: int, k: int, transa=False, transb=False): diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cc b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cc new file mode 100644 index 0000000000..f494af834d --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cc @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h" + +#include + +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h" +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h" +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h" + +namespace py = pybind11; + +namespace onnxruntime { + +void InitGemmFastGelu(py::module mod) { + InitGemmFastGeluUnfused(mod); + InitGemmFastGeluTunable(mod); + InitComposableKernelGemmFastGelu(mod); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cu deleted file mode 100644 index 37b7230257..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cu +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h" - -#include -#include - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -using onnxruntime::rocm::tunable::blas::BlasOp; - -namespace py = pybind11; - -namespace onnxruntime { - -template -class GemmFastGeluUnfused : public IKernelExplorer { - public: - GemmFastGeluUnfused(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, - DeviceArray& b, int64_t ldb, - DeviceArray& bias, - double beta, - DeviceArray& c, int64_t ldc) : params_{} { - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - params_.tuning = true; - params_.stream = Stream(); - params_.handle = rocblas_handle_; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = alpha; - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.bias = static_cast(bias.ptr()); - params_.beta = beta; - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - } - - ~GemmFastGeluUnfused() { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - rocblas_handle_ = nullptr; - } - - void Run() override { - ORT_THROW_IF_ERROR((contrib::rocm::GemmFastGeluUnfused(¶ms_))); - } - - bool IsSupported() { - Status status = contrib::rocm::GemmFastGeluUnfused(¶ms_); - return status.IsOK(); - } - - private: - using ParamsT = contrib::rocm::GemmFastGeluParams; - ParamsT params_{}; - rocblas_handle rocblas_handle_; -}; - -template -class GemmFastGeluTunableOp : public IKernelExplorer { - public: - GemmFastGeluTunableOp(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, - DeviceArray& b, int64_t ldb, - DeviceArray& bias, - double beta, - DeviceArray& c, int64_t ldc) : params_{} { - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - params_.tuning = true; - params_.stream = Stream(); - params_.handle = rocblas_handle_; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = alpha; - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.bias = static_cast(bias.ptr()); - params_.beta = beta; - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - - op_.EnableTuning(); - } - - ~GemmFastGeluTunableOp() { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - rocblas_handle_ = nullptr; - } - - void Run() override { - ORT_THROW_IF_ERROR((op_(¶ms_))); - } - - bool IsSupported() { - Status status = op_(¶ms_); - return status.IsOK(); - } - - private: - using ParamsT = contrib::rocm::GemmFastGeluParams; - ParamsT params_{}; - rocblas_handle rocblas_handle_; - contrib::rocm::GemmFastGeluTunableOp op_{}; -}; - -#define REGISTER_OP(name, type) \ - py::class_>(m, #name "_" #type) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Run", &name::Run) \ - .def("Profile", &name::Profile) \ - .def("IsSupported", &name::IsSupported); - -void InitGemmFastGelu(py::module m) { - REGISTER_OP(GemmFastGeluUnfused, float) - REGISTER_OP(GemmFastGeluUnfused, half) - - REGISTER_OP(GemmFastGeluTunableOp, float) - REGISTER_OP(GemmFastGeluTunableOp, half) -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu new file mode 100644 index 0000000000..d9e9d5cdea --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh" +#include "python/tools/kernel_explorer/device_array.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" + +using namespace onnxruntime::contrib::rocm::blas; +using namespace onnxruntime::contrib::rocm::blas::internal; + +namespace py = pybind11; + +namespace onnxruntime { + +template +class CKGemmFastGelu : public IKernelExplorer { + public: + CKGemmFastGelu(BlasOp opa, BlasOp opb, + int64_t m, int64_t n, int64_t k, + double alpha, + DeviceArray& a, int64_t lda, + DeviceArray& b, int64_t ldb, + DeviceArray& bias, + double beta, + DeviceArray& c, int64_t ldc) + : params_{} { + auto supports_a = opa == BlasOp::N ? std::is_same_v : std::is_same_v; + auto supports_b = opb == BlasOp::N ? std::is_same_v : std::is_same_v; + ORT_ENFORCE(supports_a && supports_b); + + params_.stream = Stream(); + // rocblas handle is not used for ck + params_.handle = nullptr; + params_.opa = opa; + params_.opb = opb; + params_.m = m; + params_.n = n; + params_.k = k; + params_.alpha = alpha; + params_.a = static_cast(a.ptr()); + params_.lda = lda; + params_.b = static_cast(b.ptr()); + params_.ldb = ldb; + params_.bias = static_cast(bias.ptr()); + params_.beta = beta; + params_.c = static_cast(c.ptr()); + params_.ldc = ldc; + + for (auto&& [type_string, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { + type_strings_.emplace_back(std::move(type_string)); + ops_.emplace_back(std::move(op)); + } + for (auto&& [type_string, op] : GetCKGemmFastGeluTypeStringAndOps()) { + type_strings_.emplace_back(std::move(type_string)); + ops_.emplace_back(std::move(op)); + } + } + + void Run() override { + ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); + } + + std::vector ListOps() const { + return type_strings_; + } + + bool SelectOp(const std::string& name) { + for (size_t i = 0; i < ops_.size(); i++) { + if (type_strings_[i] == name) { + selected_op_ = i; + Status status = ops_[i](¶ms_); + return status.IsOK(); + } + } + + ORT_THROW("Cannot find implementation ", name); + } + + private: + using ParamsT = GemmFastGeluParams; + using OpT = rocm::tunable::Op; + ParamsT params_; + std::vector ops_; + std::vector type_strings_; + size_t selected_op_{}; +}; + +#define REGISTER_OP(type, alayout, blayout, layout_string) \ + py::class_>(m, "CKGemmFastGelu_" #type "_" layout_string) \ + .def(py::init()) \ + .def("SetRepeats", &CKGemmFastGelu::SetRepeats) \ + .def("Profile", &CKGemmFastGelu::Profile) \ + .def("Run", &CKGemmFastGelu::Run) \ + .def("ListOps", &CKGemmFastGelu::ListOps) \ + .def("SelectOp", &CKGemmFastGelu::SelectOp); + +#define REGISTER_OP_FOR_ALL_TRANSAB(type) \ + REGISTER_OP(type, Row, Row, "NN"); \ + REGISTER_OP(type, Row, Col, "NT"); \ + REGISTER_OP(type, Col, Row, "TN"); \ + REGISTER_OP(type, Col, Col, "TT"); + +void InitComposableKernelGemmFastGelu(py::module m) { + REGISTER_OP_FOR_ALL_TRANSAB(float); + REGISTER_OP_FOR_ALL_TRANSAB(half); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h new file mode 100644 index 0000000000..13a22fae97 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace py = pybind11; + +namespace onnxruntime { + +void InitComposableKernelGemmFastGelu(py::module mod); + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu new file mode 100644 index 0000000000..11c4d17394 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h" + +#include + +#include +#include + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh" +#include "python/tools/kernel_explorer/device_array.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" + +using namespace onnxruntime::contrib::rocm::blas; +using namespace onnxruntime::contrib::rocm::blas::internal; + +namespace py = pybind11; + +namespace onnxruntime { +template +class GemmFastGeluTunable : public IKernelExplorer { + public: + GemmFastGeluTunable(BlasOp opa, BlasOp opb, + int64_t m, int64_t n, int64_t k, + double alpha, + DeviceArray& a, int64_t lda, + DeviceArray& b, int64_t ldb, + DeviceArray& bias, + double beta, + DeviceArray& c, int64_t ldc) : params_{} { + ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); + params_.tuning = true; + params_.stream = Stream(); + params_.handle = rocblas_handle_; + params_.opa = opa; + params_.opb = opb; + params_.m = m; + params_.n = n; + params_.k = k; + params_.alpha = alpha; + params_.a = static_cast(a.ptr()); + params_.lda = lda; + params_.b = static_cast(b.ptr()); + params_.ldb = ldb; + params_.bias = static_cast(bias.ptr()); + params_.beta = beta; + params_.c = static_cast(c.ptr()); + params_.ldc = ldc; + + op_.EnableTuning(); + } + + ~GemmFastGeluTunable() { + ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); + rocblas_handle_ = nullptr; + } + + void Run() override { + ORT_THROW_IF_ERROR((op_(¶ms_))); + } + + std::vector ListOps() const { + return {"GemmFastGeluTunable"}; + } + + bool SelectOp(const std::string& name) { + return name == "GemmFastGeluTunable"; + } + + private: + using ParamsT = GemmFastGeluParams; + ParamsT params_{}; + rocblas_handle rocblas_handle_; + GemmFastGeluTunableOp op_{}; +}; + +#define REGISTER_OP(type, alayout, blayout, layout_string) \ + py::class_>(m, "GemmFastGeluTunable_" #type "_" layout_string) \ + .def(py::init()) \ + .def("SetRepeats", &GemmFastGeluTunable::SetRepeats) \ + .def("Profile", &GemmFastGeluTunable::Profile) \ + .def("Run", &GemmFastGeluTunable::Run) \ + .def("ListOps", &GemmFastGeluTunable::ListOps) \ + .def("SelectOp", &GemmFastGeluTunable::SelectOp); + +#define REGISTER_OP_FOR_ALL_TRANSAB(type) \ + REGISTER_OP(type, Row, Row, "NN"); \ + REGISTER_OP(type, Row, Col, "NT"); \ + REGISTER_OP(type, Col, Row, "TN"); \ + REGISTER_OP(type, Col, Col, "TT"); + +void InitGemmFastGeluTunable(py::module m) { + REGISTER_OP_FOR_ALL_TRANSAB(float); + REGISTER_OP_FOR_ALL_TRANSAB(half); +} + +#undef REGISTER_OP + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h new file mode 100644 index 0000000000..f67b4950a4 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace py = pybind11; + +namespace onnxruntime { + +void InitGemmFastGeluTunable(py::module mod); + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.cu new file mode 100644 index 0000000000..aa218afa38 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.cu @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h" + +#include + +#include +#include + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh" +#include "python/tools/kernel_explorer/device_array.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" + +using namespace onnxruntime::contrib::rocm::blas; +using namespace onnxruntime::contrib::rocm::blas::internal; + +namespace py = pybind11; + +namespace onnxruntime { + +template +class GemmFastGeluUnfused : public IKernelExplorer { + public: + GemmFastGeluUnfused(BlasOp opa, BlasOp opb, + int64_t m, int64_t n, int64_t k, + double alpha, + DeviceArray& a, int64_t lda, + DeviceArray& b, int64_t ldb, + DeviceArray& bias, + double beta, + DeviceArray& c, int64_t ldc) : params_{} { + ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); + params_.tuning = true; + params_.stream = Stream(); + params_.handle = rocblas_handle_; + params_.opa = opa; + params_.opb = opb; + params_.m = m; + params_.n = n; + params_.k = k; + params_.alpha = alpha; + params_.a = static_cast(a.ptr()); + params_.lda = lda; + params_.b = static_cast(b.ptr()); + params_.ldb = ldb; + params_.bias = static_cast(bias.ptr()); + params_.beta = beta; + params_.c = static_cast(c.ptr()); + params_.ldc = ldc; + } + + ~GemmFastGeluUnfused() { + ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); + rocblas_handle_ = nullptr; + } + + void Run() override { + ORT_THROW_IF_ERROR((contrib::rocm::blas::internal::GemmFastGeluUnfused(¶ms_))); + } + + std::vector ListOps() const { + return {"GemmFastGeluUnfused"}; + } + + bool SelectOp(const std::string& name) { + Status status = contrib::rocm::blas::internal::GemmFastGeluUnfused(¶ms_); + return status.IsOK() && name == "GemmFastGeluUnfused"; + } + + private: + using ParamsT = GemmFastGeluParams; + ParamsT params_{}; + rocblas_handle rocblas_handle_; +}; + +#define REGISTER_OP(type) \ + py::class_>(m, "GemmFastGeluUnfused_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &GemmFastGeluUnfused::SetRepeats) \ + .def("Run", &GemmFastGeluUnfused::Run) \ + .def("Profile", &GemmFastGeluUnfused::Profile) \ + .def("ListOps", &GemmFastGeluUnfused::ListOps) \ + .def("SelectOp", &GemmFastGeluUnfused::SelectOp); + +void InitGemmFastGeluUnfused(py::module m) { + REGISTER_OP(float) + REGISTER_OP(half) +} + +#undef REGISTER_OP + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h new file mode 100644 index 0000000000..96ea8b8360 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace py = pybind11; + +namespace onnxruntime { + +void InitGemmFastGeluUnfused(py::module mod); + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/utils.py b/onnxruntime/python/tools/kernel_explorer/kernels/utils.py index 30c19ee72c..26a5dce735 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/utils.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/utils.py @@ -26,6 +26,13 @@ def transab_to_suffix(transab): }[tuple(transab)] +def dtype_to_suffix(dtype): + return { + "float32": "float", + "float16": "half", + }[dtype] + + def get_gemm_bound(dtype: str, a: np.ndarray, b: np.ndarray, c: np.ndarray, transa: bool, transb: bool): k = b.shape[1] if transb else b.shape[0] # The machine epsilon, unit roundoff, the smallest positive floating point number n such that the floating point diff --git a/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc index eb277bdebc..a24f3b6b44 100644 --- a/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc @@ -8,12 +8,25 @@ #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/providers/run_options_config_keys.h" namespace onnxruntime { namespace test { namespace gemmfastgelu { #if defined(USE_ROCM) +namespace { + +const onnxruntime::RunOptions run_options = []() { + onnxruntime::RunOptions options{}; + ORT_THROW_IF_ERROR(options.config_options.AddConfigEntry(kOpTesterRunOptionsConfigTestTunableOp, "true")); + return options; +}(); + +const constexpr auto run_with_tunable_op = &run_options; + +} // namespace + static void RunGemmFastGeluGpuTest(const std::vector& input_data, const std::vector& weight_data, const std::vector& bias_data, const std::vector& output_data, const std::vector& input_dims, const std::vector& weight_dims, @@ -37,11 +50,9 @@ static void RunGemmFastGeluGpuTest(const std::vector& input_data, const s tester.AddOutput("Y", output_dims, output_data); } - std::vector> execution_providers; - execution_providers.push_back(DefaultRocmExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + tester.Config(run_with_tunable_op) + .RunWithConfig(); } -#endif TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat32) { int batch_size = 1; @@ -71,11 +82,10 @@ TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat32) { std::vector weight_dims = {hidden_size, dense_size}; std::vector bias_dims = {dense_size}; std::vector output_dims = {batch_size, sequence_length, dense_size}; -#if defined(USE_ROCM) + RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data, input_dims, weight_dims, bias_dims, output_dims, false); -#endif } TEST(GemmFastGeluTest, GemmFastGeluWithBiasFloat32) { @@ -107,15 +117,12 @@ TEST(GemmFastGeluTest, GemmFastGeluWithBiasFloat32) { std::vector weight_dims = {hidden_size, dense_size}; std::vector bias_dims = {dense_size}; std::vector output_dims = {batch_size, sequence_length, dense_size}; -#if defined(USE_ROCM) + RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data, input_dims, weight_dims, bias_dims, output_dims, true); -#endif } -// CUDA and ROCm only for Float16 and BFloat16 type. -#if defined(USE_ROCM) TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat16) { int batch_size = 1; int sequence_length = 2; @@ -144,6 +151,7 @@ TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat16) { std::vector weight_dims = {hidden_size, dense_size}; std::vector bias_dims = {dense_size}; std::vector output_dims = {batch_size, sequence_length, dense_size}; + RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data, input_dims, weight_dims, bias_dims, output_dims, false); @@ -178,6 +186,7 @@ TEST(GemmFastGeluTest, GemmFastGeluWithBiasFloat16) { std::vector weight_dims = {hidden_size, dense_size}; std::vector bias_dims = {dense_size}; std::vector output_dims = {batch_size, sequence_length, dense_size}; + RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data, input_dims, weight_dims, bias_dims, output_dims, true); @@ -225,9 +234,8 @@ TEST(GemmFastGeluTest, GemmFastGeluWithBias_bfloat16) { tester.AddInput("bias", bias_dims, f_B); tester.AddOutput("Y", output_dims, f_Y); - std::vector> execution_providers; - execution_providers.push_back(DefaultRocmExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + tester.Config(run_with_tunable_op) + .RunWithConfig(); } #endif