[ROCm] Add GemmFastGelu CK implementation (#13759)

### Description
<!-- Describe your changes. -->

Add GemmFastGelu CK implementation.

TODO 
1. The performance of CK GemmFastGelu in ORT is not good as using CK
directly, still need to investigate the reason and improve the CK in
ORT.
`GemmFastGeluUnfused float16 NN m=49152 n=3072 k=768 2298.8064 us 100.89
tflops`
`withbias DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8,
Default> LoopScheduler: Default, PipelineVersion: v1 float16 NN m=49152
n=3072 k=768 2401.9799 us 96.56 tflops`

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: peixuanzuo <peixuanzuo@linmif39a000004.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
This commit is contained in:
PeixuanZuo 2023-01-05 17:53:30 +08:00 committed by GitHub
parent 2b45410e52
commit 4eac0db3af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 858 additions and 386 deletions

View file

@ -1,5 +1,5 @@
set(composable_kernel_URL https://github.com/ROCmSoftwarePlatform/composable_kernel.git)
set(composable_kernel_TAG 8ee36118be9b19b15c2471bffeeeb624afb14044) # 2022-11-01 00:24:25 +0800
set(composable_kernel_TAG 0345963eef4f92e9c5eab608bb8557b5463a1dcb) # 2022-12-15 15:07:24 -0600
set(PATCH ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Fix_Clang_Build.patch)

View file

@ -59,12 +59,7 @@ elseif (onnxruntime_USE_ROCM)
auto_set_source_files_hip_language(${kernel_explorer_kernel_srcs} ${kernel_explorer_rocm_kernel_srcs})
target_sources(kernel_explorer PRIVATE ${kernel_explorer_rocm_kernel_srcs})
target_compile_definitions(kernel_explorer PRIVATE __HIP_PLATFORM_AMD__=1 __HIP_PLATFORM_HCC__=1)
target_link_libraries(kernel_explorer PRIVATE
onnxruntime_composable_kernel_includes
# Currently we shall not use composablekernels::device_operations, the target includes all conv dependencies, which
# are extremely slow to compile. Instead, we only link all gemm related objects. See the following link on updating.
# https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/85978e0201/library/src/tensor_operation_instance/gpu/CMakeLists.txt#L33-L54
device_gemm_instance)
target_link_libraries(kernel_explorer PRIVATE onnxruntime_composable_kernel_includes)
endif()
add_dependencies(kernel_explorer onnxruntime_pybind11_state)

View file

@ -1436,7 +1436,14 @@ if (onnxruntime_USE_ROCM)
endif()
include(composable_kernel)
target_link_libraries(onnxruntime_providers_rocm PRIVATE onnxruntime_composable_kernel_includes device_gemm_instance)
target_link_libraries(onnxruntime_providers_rocm PRIVATE
onnxruntime_composable_kernel_includes
# Currently we shall not use composablekernels::device_operations, the target includes all conv dependencies, which
# are extremely slow to compile. Instead, we only link all gemm related objects. See the following link on updating.
# https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/85978e0201/library/src/tensor_operation_instance/gpu/CMakeLists.txt#L33-L54
device_gemm_instance
device_gemm_add_fastgelu_instance
device_gemm_fastgelu_instance)
if(UNIX)
set_property(TARGET onnxruntime_providers_rocm APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/rocm/version_script.lds -Xlinker --gc-sections")

View file

@ -1,5 +1,5 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5655ba17..1252d1ea 100644
index f861e302..f0b6bcea 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,7 +1,7 @@
@ -48,7 +48,7 @@ index 5655ba17..1252d1ea 100644
## tidy
include(EnableCompilerWarnings)
@@ -263,9 +240,6 @@ rocm_package_setup_component(tests
@@ -273,9 +250,6 @@ rocm_package_setup_component(profiler
)
add_subdirectory(library)
@ -58,7 +58,7 @@ index 5655ba17..1252d1ea 100644
#Create an interface target for the include only files and call it "composablekernels"
include(CMakePackageConfigHelpers)
@@ -291,11 +265,3 @@ rocm_install(FILES
@@ -301,11 +275,3 @@ rocm_install(FILES
set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE")
set(CPACK_RPM_PACKAGE_LICENSE "MIT")
@ -70,20 +70,6 @@ index 5655ba17..1252d1ea 100644
- LDCONFIG
- HEADER_ONLY
-)
diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp
index 92018aac..2ada620c 100644
--- a/include/ck/ck.hpp
+++ b/include/ck/ck.hpp
@@ -126,7 +126,9 @@
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1
// experimental feature: optimize for inter-wave scheduling policy
+#ifndef CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING 0
+#endif
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS 1
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt
index c206c4dc..e45fac9d 100644
--- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt

View file

@ -3,16 +3,17 @@
#include "contrib_ops/rocm/bert/gemm_fast_gelu.h"
#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h"
#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/rocm/rocm_common.h"
using onnxruntime::rocm::ToHipType;
namespace onnxruntime {
namespace contrib {
namespace rocm {
using onnxruntime::rocm::ToHipType;
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GemmFastGelu, \
@ -54,17 +55,20 @@ Status GemmFastGelu<T>::ComputeInternal(OpKernelContext* ctx) const {
const HipT alpha = ToHipType<T>::FromFloat(1.0f);
const HipT beta = ToHipType<T>::FromFloat(0.0f);
return LaunchGemmFastGeluKernel<HipT>(
using onnxruntime::rocm::tunable::blas::BlasOp;
return blas::row_major::GemmFastGelu(
IsTunableOpEnabled(),
Stream(ctx), GetRocblasHandle(ctx),
transa, transb,
static_cast<int64_t>(helper.M()), static_cast<int64_t>(helper.N()), static_cast<int64_t>(helper.K()),
transa ? BlasOp::Trans : BlasOp::NonTrans,
transb ? BlasOp::Trans : BlasOp::NonTrans,
helper.M(), helper.N(), helper.K(),
alpha,
reinterpret_cast<const HipT*>(X->Data<T>()), static_cast<int64_t>(helper.Lda(transa)),
reinterpret_cast<const HipT*>(W->Data<T>()), static_cast<int64_t>(helper.Ldb(transb)),
reinterpret_cast<const HipT*>(X->Data<T>()), helper.Lda(transa),
reinterpret_cast<const HipT*>(W->Data<T>()), helper.Ldb(transb),
(nullptr != bias) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr,
beta,
reinterpret_cast<HipT*>(Y->MutableData<T>()), static_cast<int64_t>(helper.Ldc()));
reinterpret_cast<HipT*>(Y->MutableData<T>()), helper.Ldc());
}
} // namespace rocm

View file

@ -0,0 +1,133 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <utility>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h"
using onnxruntime::rocm::ToHipType;
using onnxruntime::rocm::tunable::Op;
namespace onnxruntime {
namespace contrib {
namespace rocm {
namespace blas {
namespace internal {
template <typename T>
struct DataTypeAdaptor {
using type = T;
};
template <>
struct DataTypeAdaptor<half> {
using type = ck::half_t;
};
template <>
struct DataTypeAdaptor<BFloat16> {
using type = ck::bhalf16_t;
};
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Nop = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
template <typename T, typename ALayout, typename BLayout>
auto GetCKGemmAddFastGeluTypeStringAndOps() {
using CKDataType = typename DataTypeAdaptor<T>::type;
using DeviceGemmAddFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout, BLayout, ck::Tuple<Row>, Row,
CKDataType, CKDataType, ck::Tuple<CKDataType>, CKDataType,
Nop, Nop, AddFastGelu>;
using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceGemmAddFastGelu>;
std::vector<std::pair<std::string, Op<GemmFastGeluParams<T>>>> ret;
for (auto&& impl : InstanceFactory::GetInstances()) {
auto type_string = onnxruntime::MakeString("withbias ", impl->GetTypeString());
auto invoker = impl->MakeInvokerPointer();
auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams<T>* params) -> Status {
auto one = ToHipType<T>::FromFloat(1.0f);
auto zero = ToHipType<T>::FromFloat(0.0f);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->alpha != one || params->beta != zero || params->bias == nullptr,
impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr", params->Signature());
auto nop = Nop{};
auto addfastgelu = AddFastGelu{};
auto arg = impl->MakeArgumentPointer(params->a, params->b, std::array<const void*, 1>{params->bias}, params->c,
params->m, params->n, params->k,
params->lda, params->ldb, std::array<ck::index_t, 1>{0}, params->ldc,
nop, nop, addfastgelu);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
invoker->Run(arg.get(), StreamConfig{params->stream});
return Status::OK();
};
ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op)));
}
return ret;
}
template <typename T, typename ALayout, typename BLayout>
auto GetCKGemmFastGeluTypeStringAndOps() {
using CKDataType = typename DataTypeAdaptor<T>::type;
using DeviceGemmFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout, BLayout, ck::Tuple<>, Row,
CKDataType, CKDataType, ck::Tuple<>, CKDataType,
Nop, Nop, FastGelu>;
using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceGemmFastGelu>;
std::vector<std::pair<std::string, Op<GemmFastGeluParams<T>>>> ret;
for (auto&& impl : InstanceFactory::GetInstances()) {
auto type_string = onnxruntime::MakeString("nobias ", impl->GetTypeString());
auto invoker = impl->MakeInvokerPointer();
auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams<T>* params) -> Status {
auto one = ToHipType<T>::FromFloat(1.0f);
auto zero = ToHipType<T>::FromFloat(0.0f);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->alpha != one || params->beta != zero || params->bias != nullptr,
impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr", params->Signature());
auto nop = Nop{};
auto fastgelu = FastGelu{};
auto arg = impl->MakeArgumentPointer(params->a, params->b,
{},
params->c,
params->m, params->n, params->k,
params->lda, params->ldb,
{},
params->ldc,
nop, nop, fastgelu);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
invoker->Run(arg.get(), StreamConfig{params->stream});
return Status::OK();
};
ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op)));
}
return ret;
}
} // namespace internal
} // namespace blas
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <vector>
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/tunable/gemm_common.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"
using onnxruntime::rocm::tunable::blas::BlasOp;
using onnxruntime::rocm::tunable::blas::BlasOpToString;
namespace onnxruntime {
namespace contrib {
namespace rocm {
namespace blas {
template <typename T>
struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams {
std::string Signature() const override {
bool has_bias = (nullptr != bias) ? 0 : 1;
return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, '_', has_bias);
}
rocblas_handle handle;
BlasOp opa;
BlasOp opb;
int64_t m;
int64_t n;
int64_t k;
T alpha;
const T* a;
int64_t lda;
const T* b;
int64_t ldb;
const T* bias;
T beta;
T* c;
int64_t ldc;
bool tuning{false};
};
} // namespace blas
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -1,79 +1,95 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES
#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h"
#include <hip/hip_fp16.h>
#include <type_traits>
#include <utility>
#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h"
#include "core/providers/rocm/tunable/gemm_common.h"
using onnxruntime::rocm::tunable::blas::BlasOp;
#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
namespace blas {
// See it as row-major
template <typename T>
Status LaunchGemmFastGeluKernel(bool tuning,
hipStream_t stream,
rocblas_handle handle,
bool transa,
bool transb,
int64_t m,
int64_t n,
int64_t k,
const T alpha,
const T* a,
int64_t lda,
const T* b,
int64_t ldb,
const T* bias,
const T beta,
T* c,
int64_t ldc) {
namespace row_major {
template <typename T, typename ScalarT>
inline GEMMFASTGELU(T, ScalarT) {
GemmFastGeluParams<T> params;
params.tuning = tuning;
params.stream = stream;
params.handle = handle;
params.opa = transa ? BlasOp::Trans : BlasOp::NonTrans;
params.opb = transb ? BlasOp::Trans : BlasOp::NonTrans;
params.opa = opa;
params.opb = opb;
params.m = m;
params.n = n;
params.k = k;
params.alpha = alpha;
if constexpr (!std::is_same_v<T, ScalarT> && std::is_same_v<ScalarT, float>) {
params.alpha = ToHipType<T>::FromFloat(std::forward<T>(alpha));
} else {
params.alpha = alpha;
}
params.a = a;
params.lda = lda;
params.b = b;
params.ldb = ldb;
params.bias = bias;
params.beta = beta;
if constexpr (!std::is_same_v<T, ScalarT> && std::is_same_v<ScalarT, float>) {
params.beta = ToHipType<T>::FromFloat(std::forward<T>(beta));
} else {
params.beta = beta;
}
params.c = c;
params.ldc = ldc;
if (tuning) {
static GemmFastGeluTunableOp<T> op;
op.EnableTuning();
return op(&params);
if (tunable) {
params.tuning = true;
if (opa == BlasOp::N && opb == BlasOp::N) {
static internal::GemmFastGeluTunableOp<T, internal::Row, internal::Row> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
} else if (opa == BlasOp::T && opb == BlasOp::N) {
static internal::GemmFastGeluTunableOp<T, internal::Col, internal::Row> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
} else if (opa == BlasOp::N && opb == BlasOp::T) {
static internal::GemmFastGeluTunableOp<T, internal::Row, internal::Col> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
} else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ {
static internal::GemmFastGeluTunableOp<T, internal::Col, internal::Col> gemm_fast_gelu{};
gemm_fast_gelu.EnableTuning();
return gemm_fast_gelu(&params);
}
}
return GemmFastGeluUnfused(&params);
return internal::GemmFastGeluUnfused(&params);
}
#define SPECIALIZED_IMPL(T) \
template Status LaunchGemmFastGeluKernel<T>(bool tuning, \
hipStream_t stream, rocblas_handle handle, \
bool transa, bool transb, \
int64_t m, int64_t n, int64_t k, const T alpha, \
const T* a, int64_t lda, const T* b, int64_t ldb, \
const T* bias, const T beta, T* c, int64_t ldc);
#define CALL_GEMMFASTGELU(T, ScalarT) \
GemmFastGelu<T, ScalarT>(tunable, stream, handle, \
opa, opb, \
m, n, k, \
alpha, a, lda, b, ldb, bias, \
beta, c, ldc)
SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(half)
SPECIALIZED_IMPL(BFloat16)
// clang-format off
GEMMFASTGELU(float, float ) { return CALL_GEMMFASTGELU(float, float ); }
GEMMFASTGELU(half, half ) { return CALL_GEMMFASTGELU(half, half ); }
GEMMFASTGELU(BFloat16, BFloat16) { return CALL_GEMMFASTGELU(BFloat16, BFloat16); }
GEMMFASTGELU(half, float ) { return CALL_GEMMFASTGELU(half, float ); }
GEMMFASTGELU(BFloat16, float ) { return CALL_GEMMFASTGELU(BFloat16, float ); }
// clang-format on
#undef CALL_GEMMFASTGELU
} // namespace row_major
} // namespace blas
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -3,34 +3,38 @@
#pragma once
#include <hip/hip_runtime.h>
#include "core/common/common.h"
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h"
#include "core/common/status.h"
#include "core/framework/float16.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
namespace blas {
template <typename T>
Status LaunchGemmFastGeluKernel(bool tuning,
hipStream_t stream,
rocblas_handle handle,
bool transa,
bool transb,
int64_t m,
int64_t n,
int64_t k,
const T alpha,
const T* a,
int64_t lda,
const T* b,
int64_t ldb,
const T* bias,
const T beta,
T* c,
int64_t ldc);
#define GEMMFASTGELU(T, ScalarT) \
common::Status GemmFastGelu( \
bool tunable, hipStream_t stream, rocblas_handle handle, \
BlasOp opa, BlasOp opb, \
std::int64_t m, std::int64_t n, std::int64_t k, \
ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \
const T* bias, ScalarT beta, T* c, std::int64_t ldc)
namespace row_major {
GEMMFASTGELU(float, float);
GEMMFASTGELU(half, half);
GEMMFASTGELU(BFloat16, BFloat16);
GEMMFASTGELU(half, float);
GEMMFASTGELU(BFloat16, float);
} // namespace row_major
} // namespace blas
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#ifndef _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES
#undef GEMMFASTGELU
#endif

View file

@ -0,0 +1,73 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <hip/hip_runtime.h>
#include <memory>
#include "contrib_ops/rocm/bert/fast_gelu_impl.h"
#include "contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh"
#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h"
#include "core/providers/rocm/tunable/gemm.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"
namespace onnxruntime {
namespace contrib{
namespace rocm {
namespace blas {
namespace internal {
template <typename T>
Status GemmFastGeluUnfused(const GemmFastGeluParams<T>* params) {
namespace column_major = onnxruntime::rocm::tunable::blas::column_major;
ORT_RETURN_IF_ERROR(column_major::Gemm(params->tuning, params->stream, params->handle,
params->opb, params->opa,
params->n, params->m, params->k,
params->alpha, params->b, params->ldb, params->a, params->lda,
params->beta, params->c, params->ldc));
int64_t fast_gelu_input_length = params->m * params->n;
int64_t bias_length = (params->bias != nullptr) ? params->n : 0;
// Because of GemmFastGeluUnfused is a combination of GemmOp and FastGeluOp, FastGeluOp in this combination is
// an inplace computation.
// 1. If we call GemmFastGeluUnfused directly with enabled tuning, it may cause the input buffer of FastGelu been
// updated accumulatedly and result in incorrect result finally. This only happens if the tuning's FindFastest is invoked.
// 2. It's safe to call GemmFastGeluUnfused with disabled tuning, FastGelu only run once and produce correct result.
// 3. It's safe to call GemmFastGeluUnfused as part of GemmFastGeluTunableOp with enable tuning, GemmTunableOp and
// FastGeluTunableOp will do tune in first warmup step separately during GemmFastGeluUnfused profiling process.
// After that, the call to GemmFastGeluUnfused not invoke tuning's FindFastest of FastGelu.
//
// Note: If any change cause directly usage of GemmFastGeluUnfused, add PreTuning() and PostTuning() in FastGeluTunableOp
// to protect original input value.
return onnxruntime::contrib::rocm::LaunchFastGeluKernel<T>(params->tuning,
params->stream,
static_cast<int>(fast_gelu_input_length),
static_cast<int>(bias_length),
params->c,
params->bias,
params->c);
}
template <typename T, typename ALayout, typename BLayout>
class GemmFastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp<GemmFastGeluParams<T>> {
public:
GemmFastGeluTunableOp() {
this->RegisterOp(GemmFastGeluUnfused<T>);
for (auto&& [_, op] : GetCKGemmAddFastGeluTypeStringAndOps<T, ALayout, BLayout>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
for (auto&& [_, op] : GetCKGemmFastGeluTypeStringAndOps<T, ALayout, BLayout>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
}
};
} // namespace internal
} // namespace blas
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -1,80 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <hip/hip_runtime.h>
#include <memory>
#include <string>
#include "contrib_ops/rocm/bert/fast_gelu_impl.h"
#include "core/providers/rocm/tunable/gemm.h"
#include "core/providers/rocm/tunable/gemm_common.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"
using onnxruntime::rocm::tunable::blas::BlasOp;
using onnxruntime::rocm::tunable::blas::BlasOpToString;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
struct GemmFastGeluParams : onnxruntime::rocm::tunable::OpParams {
std::string Signature() const override {
return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k);
}
rocblas_handle handle;
BlasOp opa;
BlasOp opb;
int64_t m;
int64_t n;
int64_t k;
T alpha;
const T* a;
int64_t lda;
const T* b;
int64_t ldb;
const T* bias;
T beta;
T* c;
int64_t ldc;
bool tuning{false};
};
template <typename T>
Status GemmFastGeluUnfused(const GemmFastGeluParams<T>* params) {
namespace column_major = onnxruntime::rocm::tunable::blas::column_major;
if (column_major::Gemm(params->tuning, params->stream, params->handle,
params->opb, params->opa,
params->n, params->m, params->k,
params->alpha, params->b, params->ldb, params->a, params->lda,
params->beta, params->c, params->ldc) != Status::OK()) {
return Status(common::ONNXRUNTIME, common::FAIL, "GemmFastGelu call column_major::Gemm failed");
}
int64_t fast_gelu_input_length = params->m * params->n;
int64_t bias_length = (params->bias != nullptr) ? params->n : 0;
// inplace computation
return LaunchFastGeluKernel<T>(params->tuning,
params->stream,
static_cast<int>(fast_gelu_input_length),
static_cast<int>(bias_length),
params->c,
params->bias,
params->c);
}
template <typename T>
class GemmFastGeluTunableOp : public onnxruntime::rocm::tunable::TunableOp<GemmFastGeluParams<T>> {
public:
GemmFastGeluTunableOp() {
this->RegisterOp(GemmFastGeluUnfused<T>);
this->SetDefaultId(0);
}
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -3,22 +3,14 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import re
import sys
from dataclasses import dataclass
from itertools import product
import kernel_explorer as ke
import numpy as np
import pytest
from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix
def dtype_to_funcs(dtype):
type_map = {
"float16": list(filter(lambda x: re.search("GemmFastGelu.*_half", x), dir(ke))),
"float32": list(filter(lambda x: re.search("GemmFastGelu.*_float", x), dir(ke))),
}
return type_map[dtype]
from utils import dtype_to_suffix, get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix
def fast_gelu(x, bias):
@ -28,7 +20,7 @@ def fast_gelu(x, bias):
# TODO The test method needs update.
def _test_gemmfastgelu(func, dtype: str, m: int, n: int, k: int, transa=False, transb=False):
def _test_gemmfastgelu(my_func, dtype: str, m: int, n: int, k: int, transa=False, transb=False):
assert dtype in ["float16", "float32"]
a_shape = (k, m) if transa else (m, k)
@ -60,17 +52,17 @@ def _test_gemmfastgelu(func, dtype: str, m: int, n: int, k: int, transa=False, t
ldb = b_shape[1]
alpha = 1.0
beta = 0.0
my_func = getattr(ke, func)
my_op = my_func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, dev_bias, beta, dev_c, n)
if my_op.IsSupported():
print(f"dtype={dtype} {transab_to_suffix((transa, transb))} m={m:<5} n={n:<5} k={k:<5} bound: {max(bound, 1e-2)}")
for impl in my_op.ListOps():
if not my_op.SelectOp(impl):
continue
my_op.Run()
dev_c.UpdateHostNumpyArray()
print(
f"{func:<50} : dtype={dtype} {transab_to_suffix((transa, transb))} m={m:<5} n={n:<5} k={k:<5} bound: {bound}"
)
np.testing.assert_allclose(my_c, ref_c, rtol=max(bound, 1e-2))
@ -81,12 +73,45 @@ all_transabs = list(product([True, False], repeat=2))
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False))
@pytest.mark.parametrize("transab", all_transabs)
def test_gemmfastgelu_bert_cases(dtype, size, transab):
for func in dtype_to_funcs(dtype):
_test_gemmfastgelu(func, dtype, *size, *transab)
def test_gemmfastgelu_unfused_bert_cases(dtype, size, transab):
_test_gemmfastgelu(getattr(ke, "GemmFastGeluUnfused_" + dtype_to_suffix(dtype)), dtype, *size, *transab)
def profile_gemmfastgelu_func(func, dtype: str, m: int, n: int, k: int, transa: bool, transb: bool):
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False))
@pytest.mark.parametrize("transab", all_transabs)
def test_gemmfastgelu_tunable_bert_cases(dtype, size, transab):
wrapper_name = "GemmFastGeluTunable_{}_{}".format(dtype_to_suffix(dtype), transab_to_suffix(transab))
_test_gemmfastgelu(getattr(ke, wrapper_name), dtype, *size, *transab)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("size", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False))
@pytest.mark.parametrize("transab", all_transabs)
def test_gemmfastgelu_ck_bert_cases(dtype, size, transab):
wrapper_name = "CKGemmFastGelu_{}_{}".format(dtype_to_suffix(dtype), transab_to_suffix(transab))
_test_gemmfastgelu(getattr(ke, wrapper_name), dtype, *size, *transab)
@dataclass
class GemmFastGeluMetric(ke.ComputeMetric):
transa: bool
transb: bool
m: int
n: int
k: int
def report(self):
prefix = f"{self.name:<50} {self.dtype} {transab_to_suffix((self.transa, self.transb))} "
if self.duration > 0:
return (
prefix
+ f"m={self.m:<4} n={self.n:<4} k={self.k:<4} {self.duration:>8.4f} us {self.tflops:>5.2f} tflops"
)
return prefix + "not supported"
def profile_gemmfastgelu_func(my_func, dtype: str, m: int, n: int, k: int, transa: bool, transb: bool):
a_shape = (k, m) if transa else (m, k)
b_shape = (n, k) if transb else (k, n)
@ -107,34 +132,36 @@ def profile_gemmfastgelu_func(func, dtype: str, m: int, n: int, k: int, transa:
ldb = b_shape[1]
alpha = 1.0
beta = 0.0
my_func = getattr(ke, func)
my_op = my_func(opa, opb, m, n, k, alpha, dev_a, lda, dev_b, ldb, dev_bias, beta, dev_c, n)
if my_op.IsSupported():
my_op.Run()
dev_c.UpdateHostNumpyArray()
time_ms = my_op.Profile()
time_us = time_ms * 1000
for impl in my_op.ListOps():
duration_ms = -1
if my_op.SelectOp(impl):
duration_ms = my_op.Profile()
# only counts gemm tflops because fastgelu is low order term (7 * n).
tflops = (m * k * n * 2) / (time_ms * 1e-3) / 1e12
print(
f"{func:<50} {dtype} {transab_to_suffix((transa, transb))}",
f"m={m:<4} n={n:<4} k={k:<4} {time_us:>8.4f} us {tflops:>5.2f} tflops",
floating_point_operations = m * k * n * 2
ke.report(GemmFastGeluMetric(impl, dtype, duration_ms, floating_point_operations, transa, transb, m, n, k))
def profile_with_args(transa, transb, dtype, m, n, k, sort):
dtype_suffix = "_" + dtype_to_suffix(dtype)
transab_suffix = "_" + transab_to_suffix((transa, transb))
with ke.benchmark(sort):
profile_gemmfastgelu_func(getattr(ke, "GemmFastGeluUnfused" + dtype_suffix), dtype, m, n, k, transa, transb)
profile_gemmfastgelu_func(
getattr(ke, "CKGemmFastGelu" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb
)
profile_gemmfastgelu_func(
getattr(ke, "GemmFastGeluTunable" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb
)
def profile_with_args(transa, transb, dtype, m, n, k):
for func in dtype_to_funcs(dtype):
profile_gemmfastgelu_func(func, dtype, m, n, k, transa, transb)
def profile():
for dtype in dtypes:
for m, n, k in get_gemm_bert_sizes(full=True):
profile_with_args(False, False, dtype, m, n, k)
profile_with_args(False, False, dtype, m, n, k, True)
print()
print()
if __name__ == "__main__":
@ -148,8 +175,9 @@ if __name__ == "__main__":
group.add_argument("m", type=int)
group.add_argument("n", type=int)
group.add_argument("k", type=int)
group.add_argument("--sort", action="store_true")
if len(sys.argv) == 1:
profile()
else:
args = parser.parse_args()
profile_with_args(args.transa == "T", args.transb == "T", args.dtype, args.m, args.n, args.k)
profile_with_args(args.transa == "T", args.transb == "T", args.dtype, args.m, args.n, args.k, args.sort)

View file

@ -10,14 +10,7 @@ from itertools import product
import kernel_explorer as ke
import numpy as np
import pytest
from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix
def dtype_to_suffix(dtype):
return {
"float32": "float",
"float16": "half",
}[dtype]
from utils import dtype_to_suffix, get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, transab_to_suffix
def _test_gemm(func, dtype: str, m: int, n: int, k: int, transa=False, transb=False):

View file

@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h"
#include <pybind11/pybind11.h>
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h"
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h"
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h"
namespace py = pybind11;
namespace onnxruntime {
void InitGemmFastGelu(py::module mod) {
InitGemmFastGeluUnfused(mod);
InitGemmFastGeluTunable(mod);
InitComposableKernelGemmFastGelu(mod);
}
} // namespace onnxruntime

View file

@ -1,146 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu.h"
#include <hip/hip_fp16.h>
#include <pybind11/pybind11.h>
#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable_op.h"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
using onnxruntime::rocm::tunable::blas::BlasOp;
namespace py = pybind11;
namespace onnxruntime {
template <typename T>
class GemmFastGeluUnfused : public IKernelExplorer {
public:
GemmFastGeluUnfused(BlasOp opa, BlasOp opb,
int64_t m, int64_t n, int64_t k,
double alpha,
DeviceArray& a, int64_t lda,
DeviceArray& b, int64_t ldb,
DeviceArray& bias,
double beta,
DeviceArray& c, int64_t ldc) : params_{} {
ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_));
params_.tuning = true;
params_.stream = Stream();
params_.handle = rocblas_handle_;
params_.opa = opa;
params_.opb = opb;
params_.m = m;
params_.n = n;
params_.k = k;
params_.alpha = alpha;
params_.a = static_cast<T*>(a.ptr());
params_.lda = lda;
params_.b = static_cast<T*>(b.ptr());
params_.ldb = ldb;
params_.bias = static_cast<T*>(bias.ptr());
params_.beta = beta;
params_.c = static_cast<T*>(c.ptr());
params_.ldc = ldc;
}
~GemmFastGeluUnfused() {
ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_));
rocblas_handle_ = nullptr;
}
void Run() override {
ORT_THROW_IF_ERROR((contrib::rocm::GemmFastGeluUnfused<T>(&params_)));
}
bool IsSupported() {
Status status = contrib::rocm::GemmFastGeluUnfused<T>(&params_);
return status.IsOK();
}
private:
using ParamsT = contrib::rocm::GemmFastGeluParams<T>;
ParamsT params_{};
rocblas_handle rocblas_handle_;
};
template <typename T>
class GemmFastGeluTunableOp : public IKernelExplorer {
public:
GemmFastGeluTunableOp(BlasOp opa, BlasOp opb,
int64_t m, int64_t n, int64_t k,
double alpha,
DeviceArray& a, int64_t lda,
DeviceArray& b, int64_t ldb,
DeviceArray& bias,
double beta,
DeviceArray& c, int64_t ldc) : params_{} {
ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_));
params_.tuning = true;
params_.stream = Stream();
params_.handle = rocblas_handle_;
params_.opa = opa;
params_.opb = opb;
params_.m = m;
params_.n = n;
params_.k = k;
params_.alpha = alpha;
params_.a = static_cast<T*>(a.ptr());
params_.lda = lda;
params_.b = static_cast<T*>(b.ptr());
params_.ldb = ldb;
params_.bias = static_cast<T*>(bias.ptr());
params_.beta = beta;
params_.c = static_cast<T*>(c.ptr());
params_.ldc = ldc;
op_.EnableTuning();
}
~GemmFastGeluTunableOp() {
ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_));
rocblas_handle_ = nullptr;
}
void Run() override {
ORT_THROW_IF_ERROR((op_(&params_)));
}
bool IsSupported() {
Status status = op_(&params_);
return status.IsOK();
}
private:
using ParamsT = contrib::rocm::GemmFastGeluParams<T>;
ParamsT params_{};
rocblas_handle rocblas_handle_;
contrib::rocm::GemmFastGeluTunableOp<T> op_{};
};
#define REGISTER_OP(name, type) \
py::class_<name<type>>(m, #name "_" #type) \
.def(py::init<BlasOp, BlasOp, int64_t, int64_t, int64_t, \
double, \
DeviceArray&, int64_t, \
DeviceArray&, int64_t, \
DeviceArray&, \
double, \
DeviceArray&, int64_t>()) \
.def("SetRepeats", &name<type>::SetRepeats) \
.def("Run", &name<type>::Run) \
.def("Profile", &name<type>::Profile) \
.def("IsSupported", &name<type>::IsSupported);
void InitGemmFastGelu(py::module m) {
REGISTER_OP(GemmFastGeluUnfused, float)
REGISTER_OP(GemmFastGeluUnfused, half)
REGISTER_OP(GemmFastGeluTunableOp, float)
REGISTER_OP(GemmFastGeluTunableOp, half)
}
} // namespace onnxruntime

View file

@ -0,0 +1,126 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_ck.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h"
#include "contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
using namespace onnxruntime::contrib::rocm::blas;
using namespace onnxruntime::contrib::rocm::blas::internal;
namespace py = pybind11;
namespace onnxruntime {
template <typename T, typename ALayout, typename BLayout>
class CKGemmFastGelu : public IKernelExplorer {
public:
CKGemmFastGelu(BlasOp opa, BlasOp opb,
int64_t m, int64_t n, int64_t k,
double alpha,
DeviceArray& a, int64_t lda,
DeviceArray& b, int64_t ldb,
DeviceArray& bias,
double beta,
DeviceArray& c, int64_t ldc)
: params_{} {
auto supports_a = opa == BlasOp::N ? std::is_same_v<ALayout, Row> : std::is_same_v<ALayout, Col>;
auto supports_b = opb == BlasOp::N ? std::is_same_v<BLayout, Row> : std::is_same_v<BLayout, Col>;
ORT_ENFORCE(supports_a && supports_b);
params_.stream = Stream();
// rocblas handle is not used for ck
params_.handle = nullptr;
params_.opa = opa;
params_.opb = opb;
params_.m = m;
params_.n = n;
params_.k = k;
params_.alpha = alpha;
params_.a = static_cast<T*>(a.ptr());
params_.lda = lda;
params_.b = static_cast<T*>(b.ptr());
params_.ldb = ldb;
params_.bias = static_cast<T*>(bias.ptr());
params_.beta = beta;
params_.c = static_cast<T*>(c.ptr());
params_.ldc = ldc;
for (auto&& [type_string, op] : GetCKGemmAddFastGeluTypeStringAndOps<T, ALayout, BLayout>()) {
type_strings_.emplace_back(std::move(type_string));
ops_.emplace_back(std::move(op));
}
for (auto&& [type_string, op] : GetCKGemmFastGeluTypeStringAndOps<T, ALayout, BLayout>()) {
type_strings_.emplace_back(std::move(type_string));
ops_.emplace_back(std::move(op));
}
}
void Run() override {
ORT_THROW_IF_ERROR(ops_[selected_op_](&params_));
}
std::vector<std::string> ListOps() const {
return type_strings_;
}
bool SelectOp(const std::string& name) {
for (size_t i = 0; i < ops_.size(); i++) {
if (type_strings_[i] == name) {
selected_op_ = i;
Status status = ops_[i](&params_);
return status.IsOK();
}
}
ORT_THROW("Cannot find implementation ", name);
}
private:
using ParamsT = GemmFastGeluParams<T>;
using OpT = rocm::tunable::Op<ParamsT>;
ParamsT params_;
std::vector<OpT> ops_;
std::vector<std::string> type_strings_;
size_t selected_op_{};
};
#define REGISTER_OP(type, alayout, blayout, layout_string) \
py::class_<CKGemmFastGelu<type, alayout, blayout>>(m, "CKGemmFastGelu_" #type "_" layout_string) \
.def(py::init<BlasOp, BlasOp, int64_t, int64_t, int64_t, \
double, \
DeviceArray&, int64_t, \
DeviceArray&, int64_t, \
DeviceArray&, \
double, \
DeviceArray&, int64_t>()) \
.def("SetRepeats", &CKGemmFastGelu<type, alayout, blayout>::SetRepeats) \
.def("Profile", &CKGemmFastGelu<type, alayout, blayout>::Profile) \
.def("Run", &CKGemmFastGelu<type, alayout, blayout>::Run) \
.def("ListOps", &CKGemmFastGelu<type, alayout, blayout>::ListOps) \
.def("SelectOp", &CKGemmFastGelu<type, alayout, blayout>::SelectOp);
#define REGISTER_OP_FOR_ALL_TRANSAB(type) \
REGISTER_OP(type, Row, Row, "NN"); \
REGISTER_OP(type, Row, Col, "NT"); \
REGISTER_OP(type, Col, Row, "TN"); \
REGISTER_OP(type, Col, Col, "TT");
void InitComposableKernelGemmFastGelu(py::module m) {
REGISTER_OP_FOR_ALL_TRANSAB(float);
REGISTER_OP_FOR_ALL_TRANSAB(half);
}
} // namespace onnxruntime

View file

@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitComposableKernelGemmFastGelu(py::module mod);
} // namespace onnxruntime

View file

@ -0,0 +1,107 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.h"
#include <pybind11/stl.h>
#include <string>
#include <vector>
#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h"
#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
using namespace onnxruntime::contrib::rocm::blas;
using namespace onnxruntime::contrib::rocm::blas::internal;
namespace py = pybind11;
namespace onnxruntime {
template <typename T, typename ALayout, typename BLayout>
class GemmFastGeluTunable : public IKernelExplorer {
public:
GemmFastGeluTunable(BlasOp opa, BlasOp opb,
int64_t m, int64_t n, int64_t k,
double alpha,
DeviceArray& a, int64_t lda,
DeviceArray& b, int64_t ldb,
DeviceArray& bias,
double beta,
DeviceArray& c, int64_t ldc) : params_{} {
ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_));
params_.tuning = true;
params_.stream = Stream();
params_.handle = rocblas_handle_;
params_.opa = opa;
params_.opb = opb;
params_.m = m;
params_.n = n;
params_.k = k;
params_.alpha = alpha;
params_.a = static_cast<T*>(a.ptr());
params_.lda = lda;
params_.b = static_cast<T*>(b.ptr());
params_.ldb = ldb;
params_.bias = static_cast<T*>(bias.ptr());
params_.beta = beta;
params_.c = static_cast<T*>(c.ptr());
params_.ldc = ldc;
op_.EnableTuning();
}
~GemmFastGeluTunable() {
ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_));
rocblas_handle_ = nullptr;
}
void Run() override {
ORT_THROW_IF_ERROR((op_(&params_)));
}
std::vector<std::string> ListOps() const {
return {"GemmFastGeluTunable"};
}
bool SelectOp(const std::string& name) {
return name == "GemmFastGeluTunable";
}
private:
using ParamsT = GemmFastGeluParams<T>;
ParamsT params_{};
rocblas_handle rocblas_handle_;
GemmFastGeluTunableOp<T, ALayout, BLayout> op_{};
};
#define REGISTER_OP(type, alayout, blayout, layout_string) \
py::class_<GemmFastGeluTunable<type, alayout, blayout>>(m, "GemmFastGeluTunable_" #type "_" layout_string) \
.def(py::init<BlasOp, BlasOp, int64_t, int64_t, int64_t, \
double, \
DeviceArray&, int64_t, \
DeviceArray&, int64_t, \
DeviceArray&, \
double, \
DeviceArray&, int64_t>()) \
.def("SetRepeats", &GemmFastGeluTunable<type, alayout, blayout>::SetRepeats) \
.def("Profile", &GemmFastGeluTunable<type, alayout, blayout>::Profile) \
.def("Run", &GemmFastGeluTunable<type, alayout, blayout>::Run) \
.def("ListOps", &GemmFastGeluTunable<type, alayout, blayout>::ListOps) \
.def("SelectOp", &GemmFastGeluTunable<type, alayout, blayout>::SelectOp);
#define REGISTER_OP_FOR_ALL_TRANSAB(type) \
REGISTER_OP(type, Row, Row, "NN"); \
REGISTER_OP(type, Row, Col, "NT"); \
REGISTER_OP(type, Col, Row, "TN"); \
REGISTER_OP(type, Col, Col, "TT");
void InitGemmFastGeluTunable(py::module m) {
REGISTER_OP_FOR_ALL_TRANSAB(float);
REGISTER_OP_FOR_ALL_TRANSAB(half);
}
#undef REGISTER_OP
} // namespace onnxruntime

View file

@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitGemmFastGeluTunable(py::module mod);
} // namespace onnxruntime

View file

@ -0,0 +1,99 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_unfused.h"
#include <pybind11/stl.h>
#include <string>
#include <vector>
#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h"
#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
using namespace onnxruntime::contrib::rocm::blas;
using namespace onnxruntime::contrib::rocm::blas::internal;
namespace py = pybind11;
namespace onnxruntime {
template <typename T>
class GemmFastGeluUnfused : public IKernelExplorer {
public:
GemmFastGeluUnfused(BlasOp opa, BlasOp opb,
int64_t m, int64_t n, int64_t k,
double alpha,
DeviceArray& a, int64_t lda,
DeviceArray& b, int64_t ldb,
DeviceArray& bias,
double beta,
DeviceArray& c, int64_t ldc) : params_{} {
ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_));
params_.tuning = true;
params_.stream = Stream();
params_.handle = rocblas_handle_;
params_.opa = opa;
params_.opb = opb;
params_.m = m;
params_.n = n;
params_.k = k;
params_.alpha = alpha;
params_.a = static_cast<T*>(a.ptr());
params_.lda = lda;
params_.b = static_cast<T*>(b.ptr());
params_.ldb = ldb;
params_.bias = static_cast<T*>(bias.ptr());
params_.beta = beta;
params_.c = static_cast<T*>(c.ptr());
params_.ldc = ldc;
}
~GemmFastGeluUnfused() {
ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_));
rocblas_handle_ = nullptr;
}
void Run() override {
ORT_THROW_IF_ERROR((contrib::rocm::blas::internal::GemmFastGeluUnfused<T>(&params_)));
}
std::vector<std::string> ListOps() const {
return {"GemmFastGeluUnfused"};
}
bool SelectOp(const std::string& name) {
Status status = contrib::rocm::blas::internal::GemmFastGeluUnfused<T>(&params_);
return status.IsOK() && name == "GemmFastGeluUnfused";
}
private:
using ParamsT = GemmFastGeluParams<T>;
ParamsT params_{};
rocblas_handle rocblas_handle_;
};
#define REGISTER_OP(type) \
py::class_<GemmFastGeluUnfused<type>>(m, "GemmFastGeluUnfused_" #type) \
.def(py::init<BlasOp, BlasOp, int64_t, int64_t, int64_t, \
double, \
DeviceArray&, int64_t, \
DeviceArray&, int64_t, \
DeviceArray&, \
double, \
DeviceArray&, int64_t>()) \
.def("SetRepeats", &GemmFastGeluUnfused<type>::SetRepeats) \
.def("Run", &GemmFastGeluUnfused<type>::Run) \
.def("Profile", &GemmFastGeluUnfused<type>::Profile) \
.def("ListOps", &GemmFastGeluUnfused<type>::ListOps) \
.def("SelectOp", &GemmFastGeluUnfused<type>::SelectOp);
void InitGemmFastGeluUnfused(py::module m) {
REGISTER_OP(float)
REGISTER_OP(half)
}
#undef REGISTER_OP
} // namespace onnxruntime

View file

@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitGemmFastGeluUnfused(py::module mod);
} // namespace onnxruntime

View file

@ -26,6 +26,13 @@ def transab_to_suffix(transab):
}[tuple(transab)]
def dtype_to_suffix(dtype):
return {
"float32": "float",
"float16": "half",
}[dtype]
def get_gemm_bound(dtype: str, a: np.ndarray, b: np.ndarray, c: np.ndarray, transa: bool, transb: bool):
k = b.shape[1] if transb else b.shape[0]
# The machine epsilon, unit roundoff, the smallest positive floating point number n such that the floating point

View file

@ -8,12 +8,25 @@
#include "test/common/cuda_op_test_utils.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
#include "test/providers/run_options_config_keys.h"
namespace onnxruntime {
namespace test {
namespace gemmfastgelu {
#if defined(USE_ROCM)
namespace {
const onnxruntime::RunOptions run_options = []() {
onnxruntime::RunOptions options{};
ORT_THROW_IF_ERROR(options.config_options.AddConfigEntry(kOpTesterRunOptionsConfigTestTunableOp, "true"));
return options;
}();
const constexpr auto run_with_tunable_op = &run_options;
} // namespace
static void RunGemmFastGeluGpuTest(const std::vector<float>& input_data, const std::vector<float>& weight_data,
const std::vector<float>& bias_data, const std::vector<float>& output_data,
const std::vector<int64_t>& input_dims, const std::vector<int64_t>& weight_dims,
@ -37,11 +50,9 @@ static void RunGemmFastGeluGpuTest(const std::vector<float>& input_data, const s
tester.AddOutput<float>("Y", output_dims, output_data);
}
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultRocmExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
tester.Config(run_with_tunable_op)
.RunWithConfig();
}
#endif
TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat32) {
int batch_size = 1;
@ -71,11 +82,10 @@ TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat32) {
std::vector<int64_t> weight_dims = {hidden_size, dense_size};
std::vector<int64_t> bias_dims = {dense_size};
std::vector<int64_t> output_dims = {batch_size, sequence_length, dense_size};
#if defined(USE_ROCM)
RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data,
input_dims, weight_dims, bias_dims, output_dims,
false);
#endif
}
TEST(GemmFastGeluTest, GemmFastGeluWithBiasFloat32) {
@ -107,15 +117,12 @@ TEST(GemmFastGeluTest, GemmFastGeluWithBiasFloat32) {
std::vector<int64_t> weight_dims = {hidden_size, dense_size};
std::vector<int64_t> bias_dims = {dense_size};
std::vector<int64_t> output_dims = {batch_size, sequence_length, dense_size};
#if defined(USE_ROCM)
RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data,
input_dims, weight_dims, bias_dims, output_dims,
true);
#endif
}
// CUDA and ROCm only for Float16 and BFloat16 type.
#if defined(USE_ROCM)
TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat16) {
int batch_size = 1;
int sequence_length = 2;
@ -144,6 +151,7 @@ TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat16) {
std::vector<int64_t> weight_dims = {hidden_size, dense_size};
std::vector<int64_t> bias_dims = {dense_size};
std::vector<int64_t> output_dims = {batch_size, sequence_length, dense_size};
RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data,
input_dims, weight_dims, bias_dims, output_dims,
false);
@ -178,6 +186,7 @@ TEST(GemmFastGeluTest, GemmFastGeluWithBiasFloat16) {
std::vector<int64_t> weight_dims = {hidden_size, dense_size};
std::vector<int64_t> bias_dims = {dense_size};
std::vector<int64_t> output_dims = {batch_size, sequence_length, dense_size};
RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data,
input_dims, weight_dims, bias_dims, output_dims,
true);
@ -225,9 +234,8 @@ TEST(GemmFastGeluTest, GemmFastGeluWithBias_bfloat16) {
tester.AddInput<BFloat16>("bias", bias_dims, f_B);
tester.AddOutput<BFloat16>("Y", output_dims, f_Y);
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultRocmExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
tester.Config(run_with_tunable_op)
.RunWithConfig();
}
#endif