diff --git a/cmake/external/composable_kernel.cmake b/cmake/external/composable_kernel.cmake index fc20520b83..8ca9c1516c 100644 --- a/cmake/external/composable_kernel.cmake +++ b/cmake/external/composable_kernel.cmake @@ -1,5 +1,5 @@ set(composable_kernel_URL https://github.com/ROCmSoftwarePlatform/composable_kernel.git) -set(composable_kernel_TAG 8ee36118be9b19b15c2471bffeeeb624afb14044) # 2022-11-01 00:24:25 +0800 +set(composable_kernel_TAG 0345963eef4f92e9c5eab608bb8557b5463a1dcb) # 2022-12-15 15:07:24 -0600 set(PATCH ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Fix_Clang_Build.patch) diff --git a/cmake/onnxruntime_kernel_explorer.cmake b/cmake/onnxruntime_kernel_explorer.cmake index 67ca3885a3..e525e7005b 100644 --- a/cmake/onnxruntime_kernel_explorer.cmake +++ b/cmake/onnxruntime_kernel_explorer.cmake @@ -59,12 +59,7 @@ elseif (onnxruntime_USE_ROCM) auto_set_source_files_hip_language(${kernel_explorer_kernel_srcs} ${kernel_explorer_rocm_kernel_srcs}) target_sources(kernel_explorer PRIVATE ${kernel_explorer_rocm_kernel_srcs}) target_compile_definitions(kernel_explorer PRIVATE __HIP_PLATFORM_AMD__=1 __HIP_PLATFORM_HCC__=1) - target_link_libraries(kernel_explorer PRIVATE - onnxruntime_composable_kernel_includes - # Currently we shall not use composablekernels::device_operations, the target includes all conv dependencies, which - # are extremely slow to compile. Instead, we only link all gemm related objects. See the following link on updating. - # https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/85978e0201/library/src/tensor_operation_instance/gpu/CMakeLists.txt#L33-L54 - device_gemm_instance) + target_link_libraries(kernel_explorer PRIVATE onnxruntime_composable_kernel_includes) endif() add_dependencies(kernel_explorer onnxruntime_pybind11_state) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 19c4e4a6d6..08c955ddf7 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1436,7 +1436,14 @@ if (onnxruntime_USE_ROCM) endif() include(composable_kernel) - target_link_libraries(onnxruntime_providers_rocm PRIVATE onnxruntime_composable_kernel_includes device_gemm_instance) + target_link_libraries(onnxruntime_providers_rocm PRIVATE + onnxruntime_composable_kernel_includes + # Currently we shall not use composablekernels::device_operations, the target includes all conv dependencies, which + # are extremely slow to compile. Instead, we only link all gemm related objects. See the following link on updating. + # https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/85978e0201/library/src/tensor_operation_instance/gpu/CMakeLists.txt#L33-L54 + device_gemm_instance + device_gemm_add_fastgelu_instance + device_gemm_fastgelu_instance) if(UNIX) set_property(TARGET onnxruntime_providers_rocm APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/rocm/version_script.lds -Xlinker --gc-sections") diff --git a/cmake/patches/composable_kernel/Fix_Clang_Build.patch b/cmake/patches/composable_kernel/Fix_Clang_Build.patch index 937a739c7b..7ee4c8bfaf 100644 --- a/cmake/patches/composable_kernel/Fix_Clang_Build.patch +++ b/cmake/patches/composable_kernel/Fix_Clang_Build.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index 5655ba17..1252d1ea 100644 +index f861e302..f0b6bcea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ @@ -48,7 +48,7 @@ index 5655ba17..1252d1ea 100644 ## tidy include(EnableCompilerWarnings) -@@ -263,9 +240,6 @@ rocm_package_setup_component(tests +@@ -273,9 +250,6 @@ rocm_package_setup_component(profiler ) add_subdirectory(library) @@ -58,7 +58,7 @@ index 5655ba17..1252d1ea 100644 #Create an interface target for the include only files and call it "composablekernels" include(CMakePackageConfigHelpers) -@@ -291,11 +265,3 @@ rocm_install(FILES +@@ -301,11 +275,3 @@ rocm_install(FILES set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") set(CPACK_RPM_PACKAGE_LICENSE "MIT") @@ -70,20 +70,6 @@ index 5655ba17..1252d1ea 100644 - LDCONFIG - HEADER_ONLY -) -diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp -index 92018aac..2ada620c 100644 ---- a/include/ck/ck.hpp -+++ b/include/ck/ck.hpp -@@ -126,7 +126,9 @@ - #define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1 - - // experimental feature: optimize for inter-wave scheduling policy -+#ifndef CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING - #define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING 0 -+#endif - #define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS 1 - - // hack: have underlying assumption that need to be satsified, otherwise it's a bug diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index c206c4dc..e45fac9d 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc index 106bf15cc3..453c82a2ed 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc @@ -3,16 +3,17 @@ #include "contrib_ops/rocm/bert/gemm_fast_gelu.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" #include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/rocm/rocm_common.h" +using onnxruntime::rocm::ToHipType; + namespace onnxruntime { namespace contrib { namespace rocm { -using onnxruntime::rocm::ToHipType; - #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ GemmFastGelu, \ @@ -54,17 +55,20 @@ Status GemmFastGelu::ComputeInternal(OpKernelContext* ctx) const { const HipT alpha = ToHipType::FromFloat(1.0f); const HipT beta = ToHipType::FromFloat(0.0f); - return LaunchGemmFastGeluKernel( + using onnxruntime::rocm::tunable::blas::BlasOp; + + return blas::row_major::GemmFastGelu( IsTunableOpEnabled(), Stream(ctx), GetRocblasHandle(ctx), - transa, transb, - static_cast(helper.M()), static_cast(helper.N()), static_cast(helper.K()), + transa ? BlasOp::Trans : BlasOp::NonTrans, + transb ? BlasOp::Trans : BlasOp::NonTrans, + helper.M(), helper.N(), helper.K(), alpha, - reinterpret_cast(X->Data()), static_cast(helper.Lda(transa)), - reinterpret_cast(W->Data()), static_cast(helper.Ldb(transb)), + reinterpret_cast(X->Data()), helper.Lda(transa), + reinterpret_cast(W->Data()), helper.Ldb(transb), (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, beta, - reinterpret_cast(Y->MutableData()), static_cast(helper.Ldc())); + reinterpret_cast(Y->MutableData()), helper.Ldc()); } } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh new file mode 100644 index 0000000000..77c138a30b --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" + +using onnxruntime::rocm::ToHipType; +using onnxruntime::rocm::tunable::Op; + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace blas { +namespace internal { + +template +struct DataTypeAdaptor { + using type = T; +}; + +template <> +struct DataTypeAdaptor { + using type = ck::half_t; +}; + +template <> +struct DataTypeAdaptor { + using type = ck::bhalf16_t; +}; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using Nop = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +template +auto GetCKGemmAddFastGeluTypeStringAndOps() { + using CKDataType = typename DataTypeAdaptor::type; + using DeviceGemmAddFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, BLayout, ck::Tuple, Row, + CKDataType, CKDataType, ck::Tuple, CKDataType, + Nop, Nop, AddFastGelu>; + using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; + + std::vector>>> ret; + for (auto&& impl : InstanceFactory::GetInstances()) { + auto type_string = onnxruntime::MakeString("withbias ", impl->GetTypeString()); + + auto invoker = impl->MakeInvokerPointer(); + auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { + auto one = ToHipType::FromFloat(1.0f); + auto zero = ToHipType::FromFloat(0.0f); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->alpha != one || params->beta != zero || params->bias == nullptr, + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr", params->Signature()); + + auto nop = Nop{}; + auto addfastgelu = AddFastGelu{}; + auto arg = impl->MakeArgumentPointer(params->a, params->b, std::array{params->bias}, params->c, + params->m, params->n, params->k, + params->lda, params->ldb, std::array{0}, params->ldc, + nop, nop, addfastgelu); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support ", params->Signature()); + invoker->Run(arg.get(), StreamConfig{params->stream}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); + } + return ret; +} + + +template +auto GetCKGemmFastGeluTypeStringAndOps() { + using CKDataType = typename DataTypeAdaptor::type; + using DeviceGemmFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, BLayout, ck::Tuple<>, Row, + CKDataType, CKDataType, ck::Tuple<>, CKDataType, + Nop, Nop, FastGelu>; + using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; + + std::vector>>> ret; + for (auto&& impl : InstanceFactory::GetInstances()) { + auto type_string = onnxruntime::MakeString("nobias ", impl->GetTypeString()); + auto invoker = impl->MakeInvokerPointer(); + auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { + auto one = ToHipType::FromFloat(1.0f); + auto zero = ToHipType::FromFloat(0.0f); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->alpha != one || params->beta != zero || params->bias != nullptr, + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr", params->Signature()); + + auto nop = Nop{}; + auto fastgelu = FastGelu{}; + auto arg = impl->MakeArgumentPointer(params->a, params->b, + {}, + params->c, + params->m, params->n, params->k, + params->lda, params->ldb, + {}, + params->ldc, + nop, nop, fastgelu); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support ", params->Signature()); + invoker->Run(arg.get(), StreamConfig{params->stream}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); + } + return ret; +} + +} // namespace internal +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h new file mode 100644 index 0000000000..9bb0951101 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/tunable/gemm_common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +using onnxruntime::rocm::tunable::blas::BlasOp; +using onnxruntime::rocm::tunable::blas::BlasOpToString; + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace blas { + +template +struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams { + std::string Signature() const override { + bool has_bias = (nullptr != bias) ? 0 : 1; + return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, '_', has_bias); + } + rocblas_handle handle; + BlasOp opa; + BlasOp opb; + int64_t m; + int64_t n; + int64_t k; + T alpha; + const T* a; + int64_t lda; + const T* b; + int64_t ldb; + const T* bias; + T beta; + T* c; + int64_t ldc; + bool tuning{false}; +}; + +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu index 1317a2ccbd..039573e585 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu @@ -1,79 +1,95 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#define _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES #include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" -#include +#include +#include -#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h" -#include "core/providers/rocm/tunable/gemm_common.h" - -using onnxruntime::rocm::tunable::blas::BlasOp; +#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh" +#include "core/providers/rocm/shared_inc/fpgeneric.h" namespace onnxruntime { namespace contrib { namespace rocm { +namespace blas { -// See it as row-major -template -Status LaunchGemmFastGeluKernel(bool tuning, - hipStream_t stream, - rocblas_handle handle, - bool transa, - bool transb, - int64_t m, - int64_t n, - int64_t k, - const T alpha, - const T* a, - int64_t lda, - const T* b, - int64_t ldb, - const T* bias, - const T beta, - T* c, - int64_t ldc) { +namespace row_major { + +template +inline GEMMFASTGELU(T, ScalarT) { GemmFastGeluParams params; - params.tuning = tuning; params.stream = stream; params.handle = handle; - params.opa = transa ? BlasOp::Trans : BlasOp::NonTrans; - params.opb = transb ? BlasOp::Trans : BlasOp::NonTrans; + params.opa = opa; + params.opb = opb; params.m = m; params.n = n; params.k = k; - params.alpha = alpha; + if constexpr (!std::is_same_v && std::is_same_v) { + params.alpha = ToHipType::FromFloat(std::forward(alpha)); + } else { + params.alpha = alpha; + } params.a = a; params.lda = lda; params.b = b; params.ldb = ldb; params.bias = bias; - params.beta = beta; + if constexpr (!std::is_same_v && std::is_same_v) { + params.beta = ToHipType::FromFloat(std::forward(beta)); + } else { + params.beta = beta; + } params.c = c; params.ldc = ldc; - if (tuning) { - static GemmFastGeluTunableOp op; - op.EnableTuning(); - return op(¶ms); + if (tunable) { + params.tuning = true; + if (opa == BlasOp::N && opb == BlasOp::N) { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + gemm_fast_gelu.EnableTuning(); + return gemm_fast_gelu(¶ms); + } else if (opa == BlasOp::T && opb == BlasOp::N) { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + gemm_fast_gelu.EnableTuning(); + return gemm_fast_gelu(¶ms); + } else if (opa == BlasOp::N && opb == BlasOp::T) { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + gemm_fast_gelu.EnableTuning(); + return gemm_fast_gelu(¶ms); + } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + gemm_fast_gelu.EnableTuning(); + return gemm_fast_gelu(¶ms); + } } - return GemmFastGeluUnfused(¶ms); + return internal::GemmFastGeluUnfused(¶ms); } -#define SPECIALIZED_IMPL(T) \ - template Status LaunchGemmFastGeluKernel(bool tuning, \ - hipStream_t stream, rocblas_handle handle, \ - bool transa, bool transb, \ - int64_t m, int64_t n, int64_t k, const T alpha, \ - const T* a, int64_t lda, const T* b, int64_t ldb, \ - const T* bias, const T beta, T* c, int64_t ldc); +#define CALL_GEMMFASTGELU(T, ScalarT) \ + GemmFastGelu(tunable, stream, handle, \ + opa, opb, \ + m, n, k, \ + alpha, a, lda, b, ldb, bias, \ + beta, c, ldc) -SPECIALIZED_IMPL(float) -SPECIALIZED_IMPL(half) -SPECIALIZED_IMPL(BFloat16) +// clang-format off +GEMMFASTGELU(float, float ) { return CALL_GEMMFASTGELU(float, float ); } +GEMMFASTGELU(half, half ) { return CALL_GEMMFASTGELU(half, half ); } +GEMMFASTGELU(BFloat16, BFloat16) { return CALL_GEMMFASTGELU(BFloat16, BFloat16); } +GEMMFASTGELU(half, float ) { return CALL_GEMMFASTGELU(half, float ); } +GEMMFASTGELU(BFloat16, float ) { return CALL_GEMMFASTGELU(BFloat16, float ); } +// clang-format on +#undef CALL_GEMMFASTGELU + +} // namespace row_major + +} // namespace blas } // namespace rocm } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h index 765a0c96a9..3daeea07b6 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h @@ -3,34 +3,38 @@ #pragma once -#include - -#include "core/common/common.h" -#include "core/providers/rocm/rocm_common.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "core/common/status.h" +#include "core/framework/float16.h" namespace onnxruntime { namespace contrib { namespace rocm { +namespace blas { -template -Status LaunchGemmFastGeluKernel(bool tuning, - hipStream_t stream, - rocblas_handle handle, - bool transa, - bool transb, - int64_t m, - int64_t n, - int64_t k, - const T alpha, - const T* a, - int64_t lda, - const T* b, - int64_t ldb, - const T* bias, - const T beta, - T* c, - int64_t ldc); +#define GEMMFASTGELU(T, ScalarT) \ + common::Status GemmFastGelu( \ + bool tunable, hipStream_t stream, rocblas_handle handle, \ + BlasOp opa, BlasOp opb, \ + std::int64_t m, std::int64_t n, std::int64_t k, \ + ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \ + const T* bias, ScalarT beta, T* c, std::int64_t ldc) +namespace row_major { + +GEMMFASTGELU(float, float); +GEMMFASTGELU(half, half); +GEMMFASTGELU(BFloat16, BFloat16); +GEMMFASTGELU(half, float); +GEMMFASTGELU(BFloat16, float); + +} // namespace row_major + +} // namespace blas } // namespace rocm } // namespace contrib } // namespace onnxruntime + +#ifndef _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES +#undef GEMMFASTGELU +#endif diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh new file mode 100644 index 0000000000..de0e11b52a --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "contrib_ops/rocm/bert/fast_gelu_impl.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "core/providers/rocm/tunable/gemm.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib{ +namespace rocm { +namespace blas { +namespace internal { + +template +Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { + namespace column_major = onnxruntime::rocm::tunable::blas::column_major; + ORT_RETURN_IF_ERROR(column_major::Gemm(params->tuning, params->stream, params->handle, + params->opb, params->opa, + params->n, params->m, params->k, + params->alpha, params->b, params->ldb, params->a, params->lda, + params->beta, params->c, params->ldc)); + + int64_t fast_gelu_input_length = params->m * params->n; + int64_t bias_length = (params->bias != nullptr) ? params->n : 0; + + // Because of GemmFastGeluUnfused is a combination of GemmOp and FastGeluOp, FastGeluOp in this combination is + // an inplace computation. + // 1. If we call GemmFastGeluUnfused directly with enabled tuning, it may cause the input buffer of FastGelu been + // updated accumulatedly and result in incorrect result finally. This only happens if the tuning's FindFastest is invoked. + // 2. It's safe to call GemmFastGeluUnfused with disabled tuning, FastGelu only run once and produce correct result. + // 3. It's safe to call GemmFastGeluUnfused as part of GemmFastGeluTunableOp with enable tuning, GemmTunableOp and + // FastGeluTunableOp will do tune in first warmup step separately during GemmFastGeluUnfused profiling process. + // After that, the call to GemmFastGeluUnfused not invoke tuning's FindFastest of FastGelu. + // + // Note: If any change cause directly usage of GemmFastGeluUnfused, add PreTuning() and PostTuning() in FastGeluTunableOp + // to protect original input value. + return onnxruntime::contrib::rocm::LaunchFastGeluKernel(params->tuning, + params->stream, + static_cast(fast_gelu_input_length), + static_cast(bias_length), + params->c, + params->bias, + params->c); +} + +template +class GemmFastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp> { + public: + GemmFastGeluTunableOp() { + this->RegisterOp(GemmFastGeluUnfused); + for (auto&& [_, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } + for (auto&& [_, op] : GetCKGemmFastGeluTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } + } +}; + +} // namespace internal +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h deleted file mode 100644 index 6d050fe3fc..0000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "contrib_ops/rocm/bert/fast_gelu_impl.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/gemm_common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::tunable::blas::BlasOp; -using onnxruntime::rocm::tunable::blas::BlasOpToString; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams { - std::string Signature() const override { - return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k); - } - rocblas_handle handle; - BlasOp opa; - BlasOp opb; - int64_t m; - int64_t n; - int64_t k; - T alpha; - const T* a; - int64_t lda; - const T* b; - int64_t ldb; - const T* bias; - T beta; - T* c; - int64_t ldc; - bool tuning{false}; -}; - -template -Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { - namespace column_major = onnxruntime::rocm::tunable::blas::column_major; - if (column_major::Gemm(params->tuning, params->stream, params->handle, - params->opb, params->opa, - params->n, params->m, params->k, - params->alpha, params->b, params->ldb, params->a, params->lda, - params->beta, params->c, params->ldc) != Status::OK()) { - return Status(common::ONNXRUNTIME, common::FAIL, "GemmFastGelu call column_major::Gemm failed"); - } - - int64_t fast_gelu_input_length = params->m * params->n; - int64_t bias_length = (params->bias != nullptr) ? params->n : 0; - - // inplace computation - return LaunchFastGeluKernel(params->tuning, - params->stream, - static_cast(fast_gelu_input_length), - static_cast(bias_length), - params->c, - params->bias, - params->c); -} - -template -class GemmFastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp> { - public: - GemmFastGeluTunableOp() { - this->RegisterOp(GemmFastGeluUnfused); - this->SetDefaultId(0); - } -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py index 6a21da480b..164d07621f 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py @@ -3,22 +3,14 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -import re import sys +from dataclasses import dataclass from itertools import product import kernel_explorer as ke import numpy as np import pytest -from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix - - -def dtype_to_funcs(dtype): - type_map = { - "float16": list(filter(lambda x: re.search("GemmFastGelu.*_half", x), dir(ke))), - "float32": list(filter(lambda x: re.search("GemmFastGelu.*_float", x), dir(ke))), - } - return type_map[dtype] +from utils import dtype_to_suffix, get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix def fast_gelu(x, bias): @@ -28,7 +20,7 @@ def fast_gelu(x, bias): # TODO The test method needs update. -def _test_gemmfastgelu(func, dtype: str, m: int, n: int, k: int, transa=False, transb=False): +def _test_gemmfastgelu(my_func, dtype: str, m: int, n: int, k: int, transa=False, transb=False): assert dtype in ["float16", "float32"] a_shape = (k, m) if transa else (m, k) @@ -60,17 +52,17 @@ def _test_gemmfastgelu(func, dtype: str, m: int, n: int, k: int, transa=False, t ldb = b_shape[1] alpha = 1.0 beta = 0.0 - my_func = getattr(ke, func) my_op = my_func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, dev_bias, beta, dev_c, n) - if my_op.IsSupported(): + print(f"dtype={dtype} {transab_to_suffix((transa, transb))} m={m:<5} n={n:<5} k={k:<5} bound: {max(bound, 1e-2)}") + + for impl in my_op.ListOps(): + if not my_op.SelectOp(impl): + continue + my_op.Run() dev_c.UpdateHostNumpyArray() - print( - f"{func:<50} : dtype={dtype} {transab_to_suffix((transa, transb))} m={m:<5} n={n:<5} k={k:<5} bound: {bound}" - ) - np.testing.assert_allclose(my_c, ref_c, rtol=max(bound, 1e-2)) @@ -81,12 +73,45 @@ all_transabs = list(product([True, False], repeat=2)) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transab", all_transabs) -def test_gemmfastgelu_bert_cases(dtype, size, transab): - for func in dtype_to_funcs(dtype): - _test_gemmfastgelu(func, dtype, *size, *transab) +def test_gemmfastgelu_unfused_bert_cases(dtype, size, transab): + _test_gemmfastgelu(getattr(ke, "GemmFastGeluUnfused_" + dtype_to_suffix(dtype)), dtype, *size, *transab) -def profile_gemmfastgelu_func(func, dtype: str, m: int, n: int, k: int, transa: bool, transb: bool): +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) +@pytest.mark.parametrize("transab", all_transabs) +def test_gemmfastgelu_tunable_bert_cases(dtype, size, transab): + wrapper_name = "GemmFastGeluTunable_{}_{}".format(dtype_to_suffix(dtype), transab_to_suffix(transab)) + _test_gemmfastgelu(getattr(ke, wrapper_name), dtype, *size, *transab) + + +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) +@pytest.mark.parametrize("transab", all_transabs) +def test_gemmfastgelu_ck_bert_cases(dtype, size, transab): + wrapper_name = "CKGemmFastGelu_{}_{}".format(dtype_to_suffix(dtype), transab_to_suffix(transab)) + _test_gemmfastgelu(getattr(ke, wrapper_name), dtype, *size, *transab) + + +@dataclass +class GemmFastGeluMetric(ke.ComputeMetric): + transa: bool + transb: bool + m: int + n: int + k: int + + def report(self): + prefix = f"{self.name:<50} {self.dtype} {transab_to_suffix((self.transa, self.transb))} " + if self.duration > 0: + return ( + prefix + + f"m={self.m:<4} n={self.n:<4} k={self.k:<4} {self.duration:>8.4f} us {self.tflops:>5.2f} tflops" + ) + return prefix + "not supported" + + +def profile_gemmfastgelu_func(my_func, dtype: str, m: int, n: int, k: int, transa: bool, transb: bool): a_shape = (k, m) if transa else (m, k) b_shape = (n, k) if transb else (k, n) @@ -107,34 +132,36 @@ def profile_gemmfastgelu_func(func, dtype: str, m: int, n: int, k: int, transa: ldb = b_shape[1] alpha = 1.0 beta = 0.0 - my_func = getattr(ke, func) my_op = my_func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, dev_bias, beta, dev_c, n) - if my_op.IsSupported(): - my_op.Run() - dev_c.UpdateHostNumpyArray() - - time_ms = my_op.Profile() - time_us = time_ms * 1000 + for impl in my_op.ListOps(): + duration_ms = -1 + if my_op.SelectOp(impl): + duration_ms = my_op.Profile() # only counts gemm tflops because fastgelu is low order term (7 * n). - tflops = (m * k * n * 2) / (time_ms * 1e-3) / 1e12 - print( - f"{func:<50} {dtype} {transab_to_suffix((transa, transb))}", - f"m={m:<4} n={n:<4} k={k:<4} {time_us:>8.4f} us {tflops:>5.2f} tflops", + floating_point_operations = m * k * n * 2 + + ke.report(GemmFastGeluMetric(impl, dtype, duration_ms, floating_point_operations, transa, transb, m, n, k)) + + +def profile_with_args(transa, transb, dtype, m, n, k, sort): + dtype_suffix = "_" + dtype_to_suffix(dtype) + transab_suffix = "_" + transab_to_suffix((transa, transb)) + with ke.benchmark(sort): + profile_gemmfastgelu_func(getattr(ke, "GemmFastGeluUnfused" + dtype_suffix), dtype, m, n, k, transa, transb) + profile_gemmfastgelu_func( + getattr(ke, "CKGemmFastGelu" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb + ) + profile_gemmfastgelu_func( + getattr(ke, "GemmFastGeluTunable" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb ) - - -def profile_with_args(transa, transb, dtype, m, n, k): - for func in dtype_to_funcs(dtype): - profile_gemmfastgelu_func(func, dtype, m, n, k, transa, transb) def profile(): for dtype in dtypes: for m, n, k in get_gemm_bert_sizes(full=True): - profile_with_args(False, False, dtype, m, n, k) + profile_with_args(False, False, dtype, m, n, k, True) print() - print() if __name__ == "__main__": @@ -148,8 +175,9 @@ if __name__ == "__main__": group.add_argument("m", type=int) group.add_argument("n", type=int) group.add_argument("k", type=int) + group.add_argument("--sort", action="store_true") if len(sys.argv) == 1: profile() else: args = parser.parse_args() - profile_with_args(args.transa == "T", args.transb == "T", args.dtype, args.m, args.n, args.k) + profile_with_args(args.transa == "T", args.transb == "T", args.dtype, args.m, args.n, args.k, args.sort) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py index 456d1b598b..da3357947b 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py @@ -10,14 +10,7 @@ from itertools import product import kernel_explorer as ke import numpy as np import pytest -from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix - - -def dtype_to_suffix(dtype): - return { - "float32": "float", - "float16": "half", - }[dtype] +from utils import dtype_to_suffix, get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix def _test_gemm(func, dtype: str, m: int, n: int, k: int, transa=False, transb=False): diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cc b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cc new file mode 100644 index 0000000000..f494af834d --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cc @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h" + +#include + +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h" +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h" +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h" + +namespace py = pybind11; + +namespace onnxruntime { + +void InitGemmFastGelu(py::module mod) { + InitGemmFastGeluUnfused(mod); + InitGemmFastGeluTunable(mod); + InitComposableKernelGemmFastGelu(mod); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cu deleted file mode 100644 index 37b7230257..0000000000 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.cu +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h" - -#include -#include - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h" -#include "python/tools/kernel_explorer/device_array.h" -#include "python/tools/kernel_explorer/kernel_explorer_interface.h" - -using onnxruntime::rocm::tunable::blas::BlasOp; - -namespace py = pybind11; - -namespace onnxruntime { - -template -class GemmFastGeluUnfused : public IKernelExplorer { - public: - GemmFastGeluUnfused(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, - DeviceArray& b, int64_t ldb, - DeviceArray& bias, - double beta, - DeviceArray& c, int64_t ldc) : params_{} { - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - params_.tuning = true; - params_.stream = Stream(); - params_.handle = rocblas_handle_; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = alpha; - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.bias = static_cast(bias.ptr()); - params_.beta = beta; - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - } - - ~GemmFastGeluUnfused() { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - rocblas_handle_ = nullptr; - } - - void Run() override { - ORT_THROW_IF_ERROR((contrib::rocm::GemmFastGeluUnfused(¶ms_))); - } - - bool IsSupported() { - Status status = contrib::rocm::GemmFastGeluUnfused(¶ms_); - return status.IsOK(); - } - - private: - using ParamsT = contrib::rocm::GemmFastGeluParams; - ParamsT params_{}; - rocblas_handle rocblas_handle_; -}; - -template -class GemmFastGeluTunableOp : public IKernelExplorer { - public: - GemmFastGeluTunableOp(BlasOp opa, BlasOp opb, - int64_t m, int64_t n, int64_t k, - double alpha, - DeviceArray& a, int64_t lda, - DeviceArray& b, int64_t ldb, - DeviceArray& bias, - double beta, - DeviceArray& c, int64_t ldc) : params_{} { - ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); - params_.tuning = true; - params_.stream = Stream(); - params_.handle = rocblas_handle_; - params_.opa = opa; - params_.opb = opb; - params_.m = m; - params_.n = n; - params_.k = k; - params_.alpha = alpha; - params_.a = static_cast(a.ptr()); - params_.lda = lda; - params_.b = static_cast(b.ptr()); - params_.ldb = ldb; - params_.bias = static_cast(bias.ptr()); - params_.beta = beta; - params_.c = static_cast(c.ptr()); - params_.ldc = ldc; - - op_.EnableTuning(); - } - - ~GemmFastGeluTunableOp() { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - rocblas_handle_ = nullptr; - } - - void Run() override { - ORT_THROW_IF_ERROR((op_(¶ms_))); - } - - bool IsSupported() { - Status status = op_(¶ms_); - return status.IsOK(); - } - - private: - using ParamsT = contrib::rocm::GemmFastGeluParams; - ParamsT params_{}; - rocblas_handle rocblas_handle_; - contrib::rocm::GemmFastGeluTunableOp op_{}; -}; - -#define REGISTER_OP(name, type) \ - py::class_>(m, #name "_" #type) \ - .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Run", &name::Run) \ - .def("Profile", &name::Profile) \ - .def("IsSupported", &name::IsSupported); - -void InitGemmFastGelu(py::module m) { - REGISTER_OP(GemmFastGeluUnfused, float) - REGISTER_OP(GemmFastGeluUnfused, half) - - REGISTER_OP(GemmFastGeluTunableOp, float) - REGISTER_OP(GemmFastGeluTunableOp, half) -} - -} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu new file mode 100644 index 0000000000..d9e9d5cdea --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.cu @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh" +#include "python/tools/kernel_explorer/device_array.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" + +using namespace onnxruntime::contrib::rocm::blas; +using namespace onnxruntime::contrib::rocm::blas::internal; + +namespace py = pybind11; + +namespace onnxruntime { + +template +class CKGemmFastGelu : public IKernelExplorer { + public: + CKGemmFastGelu(BlasOp opa, BlasOp opb, + int64_t m, int64_t n, int64_t k, + double alpha, + DeviceArray& a, int64_t lda, + DeviceArray& b, int64_t ldb, + DeviceArray& bias, + double beta, + DeviceArray& c, int64_t ldc) + : params_{} { + auto supports_a = opa == BlasOp::N ? std::is_same_v : std::is_same_v; + auto supports_b = opb == BlasOp::N ? std::is_same_v : std::is_same_v; + ORT_ENFORCE(supports_a && supports_b); + + params_.stream = Stream(); + // rocblas handle is not used for ck + params_.handle = nullptr; + params_.opa = opa; + params_.opb = opb; + params_.m = m; + params_.n = n; + params_.k = k; + params_.alpha = alpha; + params_.a = static_cast(a.ptr()); + params_.lda = lda; + params_.b = static_cast(b.ptr()); + params_.ldb = ldb; + params_.bias = static_cast(bias.ptr()); + params_.beta = beta; + params_.c = static_cast(c.ptr()); + params_.ldc = ldc; + + for (auto&& [type_string, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { + type_strings_.emplace_back(std::move(type_string)); + ops_.emplace_back(std::move(op)); + } + for (auto&& [type_string, op] : GetCKGemmFastGeluTypeStringAndOps()) { + type_strings_.emplace_back(std::move(type_string)); + ops_.emplace_back(std::move(op)); + } + } + + void Run() override { + ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); + } + + std::vector ListOps() const { + return type_strings_; + } + + bool SelectOp(const std::string& name) { + for (size_t i = 0; i < ops_.size(); i++) { + if (type_strings_[i] == name) { + selected_op_ = i; + Status status = ops_[i](¶ms_); + return status.IsOK(); + } + } + + ORT_THROW("Cannot find implementation ", name); + } + + private: + using ParamsT = GemmFastGeluParams; + using OpT = rocm::tunable::Op; + ParamsT params_; + std::vector ops_; + std::vector type_strings_; + size_t selected_op_{}; +}; + +#define REGISTER_OP(type, alayout, blayout, layout_string) \ + py::class_>(m, "CKGemmFastGelu_" #type "_" layout_string) \ + .def(py::init()) \ + .def("SetRepeats", &CKGemmFastGelu::SetRepeats) \ + .def("Profile", &CKGemmFastGelu::Profile) \ + .def("Run", &CKGemmFastGelu::Run) \ + .def("ListOps", &CKGemmFastGelu::ListOps) \ + .def("SelectOp", &CKGemmFastGelu::SelectOp); + +#define REGISTER_OP_FOR_ALL_TRANSAB(type) \ + REGISTER_OP(type, Row, Row, "NN"); \ + REGISTER_OP(type, Row, Col, "NT"); \ + REGISTER_OP(type, Col, Row, "TN"); \ + REGISTER_OP(type, Col, Col, "TT"); + +void InitComposableKernelGemmFastGelu(py::module m) { + REGISTER_OP_FOR_ALL_TRANSAB(float); + REGISTER_OP_FOR_ALL_TRANSAB(half); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h new file mode 100644 index 0000000000..13a22fae97 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace py = pybind11; + +namespace onnxruntime { + +void InitComposableKernelGemmFastGelu(py::module mod); + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu new file mode 100644 index 0000000000..11c4d17394 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h" + +#include + +#include +#include + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh" +#include "python/tools/kernel_explorer/device_array.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" + +using namespace onnxruntime::contrib::rocm::blas; +using namespace onnxruntime::contrib::rocm::blas::internal; + +namespace py = pybind11; + +namespace onnxruntime { +template +class GemmFastGeluTunable : public IKernelExplorer { + public: + GemmFastGeluTunable(BlasOp opa, BlasOp opb, + int64_t m, int64_t n, int64_t k, + double alpha, + DeviceArray& a, int64_t lda, + DeviceArray& b, int64_t ldb, + DeviceArray& bias, + double beta, + DeviceArray& c, int64_t ldc) : params_{} { + ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); + params_.tuning = true; + params_.stream = Stream(); + params_.handle = rocblas_handle_; + params_.opa = opa; + params_.opb = opb; + params_.m = m; + params_.n = n; + params_.k = k; + params_.alpha = alpha; + params_.a = static_cast(a.ptr()); + params_.lda = lda; + params_.b = static_cast(b.ptr()); + params_.ldb = ldb; + params_.bias = static_cast(bias.ptr()); + params_.beta = beta; + params_.c = static_cast(c.ptr()); + params_.ldc = ldc; + + op_.EnableTuning(); + } + + ~GemmFastGeluTunable() { + ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); + rocblas_handle_ = nullptr; + } + + void Run() override { + ORT_THROW_IF_ERROR((op_(¶ms_))); + } + + std::vector ListOps() const { + return {"GemmFastGeluTunable"}; + } + + bool SelectOp(const std::string& name) { + return name == "GemmFastGeluTunable"; + } + + private: + using ParamsT = GemmFastGeluParams; + ParamsT params_{}; + rocblas_handle rocblas_handle_; + GemmFastGeluTunableOp op_{}; +}; + +#define REGISTER_OP(type, alayout, blayout, layout_string) \ + py::class_>(m, "GemmFastGeluTunable_" #type "_" layout_string) \ + .def(py::init()) \ + .def("SetRepeats", &GemmFastGeluTunable::SetRepeats) \ + .def("Profile", &GemmFastGeluTunable::Profile) \ + .def("Run", &GemmFastGeluTunable::Run) \ + .def("ListOps", &GemmFastGeluTunable::ListOps) \ + .def("SelectOp", &GemmFastGeluTunable::SelectOp); + +#define REGISTER_OP_FOR_ALL_TRANSAB(type) \ + REGISTER_OP(type, Row, Row, "NN"); \ + REGISTER_OP(type, Row, Col, "NT"); \ + REGISTER_OP(type, Col, Row, "TN"); \ + REGISTER_OP(type, Col, Col, "TT"); + +void InitGemmFastGeluTunable(py::module m) { + REGISTER_OP_FOR_ALL_TRANSAB(float); + REGISTER_OP_FOR_ALL_TRANSAB(half); +} + +#undef REGISTER_OP + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h new file mode 100644 index 0000000000..f67b4950a4 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace py = pybind11; + +namespace onnxruntime { + +void InitGemmFastGeluTunable(py::module mod); + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.cu new file mode 100644 index 0000000000..aa218afa38 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.cu @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h" + +#include + +#include +#include + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh" +#include "python/tools/kernel_explorer/device_array.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" + +using namespace onnxruntime::contrib::rocm::blas; +using namespace onnxruntime::contrib::rocm::blas::internal; + +namespace py = pybind11; + +namespace onnxruntime { + +template +class GemmFastGeluUnfused : public IKernelExplorer { + public: + GemmFastGeluUnfused(BlasOp opa, BlasOp opb, + int64_t m, int64_t n, int64_t k, + double alpha, + DeviceArray& a, int64_t lda, + DeviceArray& b, int64_t ldb, + DeviceArray& bias, + double beta, + DeviceArray& c, int64_t ldc) : params_{} { + ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); + params_.tuning = true; + params_.stream = Stream(); + params_.handle = rocblas_handle_; + params_.opa = opa; + params_.opb = opb; + params_.m = m; + params_.n = n; + params_.k = k; + params_.alpha = alpha; + params_.a = static_cast(a.ptr()); + params_.lda = lda; + params_.b = static_cast(b.ptr()); + params_.ldb = ldb; + params_.bias = static_cast(bias.ptr()); + params_.beta = beta; + params_.c = static_cast(c.ptr()); + params_.ldc = ldc; + } + + ~GemmFastGeluUnfused() { + ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); + rocblas_handle_ = nullptr; + } + + void Run() override { + ORT_THROW_IF_ERROR((contrib::rocm::blas::internal::GemmFastGeluUnfused(¶ms_))); + } + + std::vector ListOps() const { + return {"GemmFastGeluUnfused"}; + } + + bool SelectOp(const std::string& name) { + Status status = contrib::rocm::blas::internal::GemmFastGeluUnfused(¶ms_); + return status.IsOK() && name == "GemmFastGeluUnfused"; + } + + private: + using ParamsT = GemmFastGeluParams; + ParamsT params_{}; + rocblas_handle rocblas_handle_; +}; + +#define REGISTER_OP(type) \ + py::class_>(m, "GemmFastGeluUnfused_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &GemmFastGeluUnfused::SetRepeats) \ + .def("Run", &GemmFastGeluUnfused::Run) \ + .def("Profile", &GemmFastGeluUnfused::Profile) \ + .def("ListOps", &GemmFastGeluUnfused::ListOps) \ + .def("SelectOp", &GemmFastGeluUnfused::SelectOp); + +void InitGemmFastGeluUnfused(py::module m) { + REGISTER_OP(float) + REGISTER_OP(half) +} + +#undef REGISTER_OP + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h new file mode 100644 index 0000000000..96ea8b8360 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace py = pybind11; + +namespace onnxruntime { + +void InitGemmFastGeluUnfused(py::module mod); + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/utils.py b/onnxruntime/python/tools/kernel_explorer/kernels/utils.py index 30c19ee72c..26a5dce735 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/utils.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/utils.py @@ -26,6 +26,13 @@ def transab_to_suffix(transab): }[tuple(transab)] +def dtype_to_suffix(dtype): + return { + "float32": "float", + "float16": "half", + }[dtype] + + def get_gemm_bound(dtype: str, a: np.ndarray, b: np.ndarray, c: np.ndarray, transa: bool, transb: bool): k = b.shape[1] if transb else b.shape[0] # The machine epsilon, unit roundoff, the smallest positive floating point number n such that the floating point diff --git a/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc index eb277bdebc..a24f3b6b44 100644 --- a/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc @@ -8,12 +8,25 @@ #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/providers/run_options_config_keys.h" namespace onnxruntime { namespace test { namespace gemmfastgelu { #if defined(USE_ROCM) +namespace { + +const onnxruntime::RunOptions run_options = []() { + onnxruntime::RunOptions options{}; + ORT_THROW_IF_ERROR(options.config_options.AddConfigEntry(kOpTesterRunOptionsConfigTestTunableOp, "true")); + return options; +}(); + +const constexpr auto run_with_tunable_op = &run_options; + +} // namespace + static void RunGemmFastGeluGpuTest(const std::vector& input_data, const std::vector& weight_data, const std::vector& bias_data, const std::vector& output_data, const std::vector& input_dims, const std::vector& weight_dims, @@ -37,11 +50,9 @@ static void RunGemmFastGeluGpuTest(const std::vector& input_data, const s tester.AddOutput("Y", output_dims, output_data); } - std::vector> execution_providers; - execution_providers.push_back(DefaultRocmExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + tester.Config(run_with_tunable_op) + .RunWithConfig(); } -#endif TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat32) { int batch_size = 1; @@ -71,11 +82,10 @@ TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat32) { std::vector weight_dims = {hidden_size, dense_size}; std::vector bias_dims = {dense_size}; std::vector output_dims = {batch_size, sequence_length, dense_size}; -#if defined(USE_ROCM) + RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data, input_dims, weight_dims, bias_dims, output_dims, false); -#endif } TEST(GemmFastGeluTest, GemmFastGeluWithBiasFloat32) { @@ -107,15 +117,12 @@ TEST(GemmFastGeluTest, GemmFastGeluWithBiasFloat32) { std::vector weight_dims = {hidden_size, dense_size}; std::vector bias_dims = {dense_size}; std::vector output_dims = {batch_size, sequence_length, dense_size}; -#if defined(USE_ROCM) + RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data, input_dims, weight_dims, bias_dims, output_dims, true); -#endif } -// CUDA and ROCm only for Float16 and BFloat16 type. -#if defined(USE_ROCM) TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat16) { int batch_size = 1; int sequence_length = 2; @@ -144,6 +151,7 @@ TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat16) { std::vector weight_dims = {hidden_size, dense_size}; std::vector bias_dims = {dense_size}; std::vector output_dims = {batch_size, sequence_length, dense_size}; + RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data, input_dims, weight_dims, bias_dims, output_dims, false); @@ -178,6 +186,7 @@ TEST(GemmFastGeluTest, GemmFastGeluWithBiasFloat16) { std::vector weight_dims = {hidden_size, dense_size}; std::vector bias_dims = {dense_size}; std::vector output_dims = {batch_size, sequence_length, dense_size}; + RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data, input_dims, weight_dims, bias_dims, output_dims, true); @@ -225,9 +234,8 @@ TEST(GemmFastGeluTest, GemmFastGeluWithBias_bfloat16) { tester.AddInput("bias", bias_dims, f_B); tester.AddOutput("Y", output_dims, f_Y); - std::vector> execution_providers; - execution_providers.push_back(DefaultRocmExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + tester.Config(run_with_tunable_op) + .RunWithConfig(); } #endif