[ROCm] Add hipBLASLt GEMM support to Tunable op. (#15351)

### Description
Add hipBLASLt to GEMM Tunable op, which supports GEMM and
StridedBatchedGEMM.

To enable hipBLASLt implementation, add an extra flag to the building
command: `--cmake_extra_defines onnxruntime_USE_HIPBLASLT=ON`.
This commit is contained in:
mindest 2023-04-14 17:56:01 +08:00 committed by GitHub
parent fda0aa14c8
commit 0fdd356abf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 254 additions and 1 deletions

View file

@ -1536,6 +1536,12 @@ if (onnxruntime_USE_ROCM)
target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_BETA_FEATURES_API)
endif()
if (onnxruntime_USE_HIPBLASLT)
find_package(hipblaslt REQUIRED)
target_link_libraries(onnxruntime_providers_rocm PRIVATE roc::hipblaslt)
target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_HIPBLASLT)
endif()
if (onnxruntime_USE_COMPOSABLE_KERNEL)
include(composable_kernel)
target_link_libraries(onnxruntime_providers_rocm PRIVATE

View file

@ -148,4 +148,10 @@ template void RocmCall<hipfftResult, true>(hipfftResult retCode, const char* exp
template Status RocmCall<ncclResult_t, false>(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line);
template void RocmCall<ncclResult_t, true>(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line);
#endif
#ifdef USE_HIPBLASLT
template Status RocmCall<hipblasStatus_t, false>(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line);
template void RocmCall<hipblasStatus_t, true>(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line);
#endif
} // namespace onnxruntime

View file

@ -23,6 +23,10 @@ namespace rocm {
#define MIOPEN2_RETURN_IF_ERROR(expr, m) ORT_RETURN_IF_ERROR(MIOPEN_CALL2(expr, m))
#define HIPFFT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIPFFT_CALL(expr))
#ifdef USE_HIPBLASLT
#define HIPBLASLT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIPBLASLT_CALL(expr))
#endif
// Type mapping for MLFloat16 to half
template <typename T>
class ToHipType {

View file

@ -19,6 +19,10 @@
#include <rccl/rccl.h>
#endif
#ifdef USE_HIPBLASLT
#include <hipblas/hipblas.h>
#endif
#if defined(_MSC_VER)
#pragma warning(pop)
#endif

View file

@ -36,4 +36,9 @@ std::conditional_t<THRW, void, Status> RocmCall(
#define NCCL_CALL_THROW(expr) (RocmCall<ncclResult_t, true>((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__))
#endif
#ifdef USE_HIPBLASLT
#define HIPBLASLT_CALL(expr) (RocmCall<hipblasStatus_t, false>((expr), #expr, "hipBLASLt", HIPBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__))
#define HIPBLASLT_CALL_THROW(expr) (RocmCall<hipblasStatus_t, true>((expr), #expr, "hipBLASLt", HIPBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__))
#endif
} // namespace onnxruntime

View file

@ -0,0 +1,213 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#ifdef USE_HIPBLASLT
#include <hipblaslt/hipblaslt.h>
#endif
#include "core/common/common.h"
#include "core/providers/rocm/tunable/gemm_common.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"
namespace onnxruntime {
namespace rocm {
namespace tunable {
namespace blas {
namespace internal {
#ifdef USE_HIPBLASLT
// For large K and small M/N, K dim will be splited to multiple workgroups and buffers,
// which will require additional workspace. Here we set the max workspace size to 32MB.
constexpr const size_t kHipBlasLtMaxWorkSpaceSizeInBytes = 32 * 1024 * 1024;
// We only keep one heuristic result here. Note that for tuned input sizes, the first result
// will be the most performant one; but in untuned cases, this is not guaranteed.
constexpr const int kHeuristicResultCount = 1;
enum ActivationType {
NONE = 0,
RELU = 1,
GELU = 2,
};
template <typename T>
constexpr hipblasDatatype_t HipBlasDataTypeFor(const T*);
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor(const float*) {
return HIPBLAS_R_32F;
}
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor(const half*) {
return HIPBLAS_R_16F;
}
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor(const BFloat16*) {
return HIPBLAS_R_16B;
}
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor(const double*) {
return HIPBLAS_R_64F;
}
template <typename T, typename ParamsT>
Status HipBlasLtMatMul(const ParamsT* params, int64_t batch, ActivationType activation_type = ActivationType::NONE,
bool enable_bias = false, const T* d_bias = nullptr,
bool enable_scaleD = false, const T* d_scaleD = nullptr) {
hipblasLtHandle_t handle;
HIPBLASLT_RETURN_IF_ERROR(hipblasLtCreate(&handle));
// Note: properties of original matrices A and B are swapped.
int64_t lda = (params->opb == BlasOp::N) ? params->n : params->k;
int64_t ldb = (params->opa == BlasOp::N) ? params->k : params->m;
int64_t ldc = params->n;
int64_t stride_a = (params->opb == BlasOp::N) ? lda * params->k : lda * params->n;
int64_t stride_b = (params->opa == BlasOp::N) ? ldb * params->m : ldb * params->k;
int64_t stride_c = ldc * params->m;
float alpha = static_cast<float>(params->alpha);
float beta = static_cast<float>(params->beta);
int row_a, col_a, row_b, col_b, row_c, col_c;
row_a = lda;
col_a = (params->opb == BlasOp::N) ? params->k : params->n;
row_b = ldb;
col_b = (params->opa == BlasOp::N) ? params->m : params->k;
row_c = ldc;
col_c = params->m;
hipblasDatatype_t in_out_datatype = HipBlasDataTypeFor(params->a);
hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
hipblasLtMatmulDesc_t matmul;
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F));
if (batch > 1) {
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
}
hipblasOperation_t trans_a = (params->opb == BlasOp::N) ? HIPBLAS_OP_N : HIPBLAS_OP_T;
hipblasOperation_t trans_b = (params->opa == BlasOp::N) ? HIPBLAS_OP_N : HIPBLAS_OP_T;
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(int32_t)));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(int32_t)));
hipblasLtEpilogue_t epilogue;
switch (activation_type) {
case ActivationType::NONE:
epilogue = enable_bias ? HIPBLASLT_EPILOGUE_BIAS : HIPBLASLT_EPILOGUE_DEFAULT;
break;
case ActivationType::RELU:
epilogue = enable_bias ? HIPBLASLT_EPILOGUE_RELU_BIAS : HIPBLASLT_EPILOGUE_RELU;
break;
case ActivationType::GELU:
epilogue = enable_bias ? HIPBLASLT_EPILOGUE_GELU_BIAS : HIPBLASLT_EPILOGUE_GELU;
break;
default:
throw std::runtime_error("Unsupported activation type for HipBlasLtMatMul");
}
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
if (enable_bias) {
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &d_bias, sizeof(void*)));
}
if (enable_scaleD) {
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d_scaleD, sizeof(void*)));
}
hipblasLtMatmulPreference_t pref;
void* workspace;
size_t max_workspace_size = kHipBlasLtMaxWorkSpaceSizeInBytes;
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceCreate(&pref));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceSetAttribute(
pref, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &max_workspace_size, sizeof(max_workspace_size)));
hipblasLtMatmulHeuristicResult_t heuristic_result[kHeuristicResultCount] = {0};
int ret_algo_count = 0;
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle,
matmul,
mat_a,
mat_b,
mat_c,
mat_c,
pref,
kHeuristicResultCount,
heuristic_result,
&ret_algo_count));
assert(ret_algo_count > 0);
size_t workspace_size = heuristic_result[0].workspaceSize;
if (workspace_size > 0) {
HIP_RETURN_IF_ERROR(hipMallocAsync(&workspace, workspace_size, params->stream));
}
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmul(handle,
matmul,
&alpha,
params->b,
mat_a,
params->a,
mat_b,
&beta,
params->c,
mat_c,
params->c,
mat_c,
&heuristic_result[0].algo,
workspace,
workspace_size,
params->stream));
if (workspace_size > 0) {
HIP_RETURN_IF_ERROR(hipFreeAsync(workspace, params->stream));
}
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceDestroy(pref));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescDestroy(matmul));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_a));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_b));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_c));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtDestroy(handle));
return Status::OK();
}
template <typename T>
Status HipBlasLtGemmOp(const GemmParams<T>* params) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((std::is_same_v<T, double>), "hipBLASLt does not support double inputs");
return HipBlasLtMatMul<T, GemmParams<T>>(params, /*batch=*/1);
}
template <typename T>
Status HipBlasLtStridedBatchedGemmOp(const StridedBatchedGemmParams<T>* params) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((std::is_same_v<T, double>), "hipBLASLt does not support double inputs");
return HipBlasLtMatMul<T, StridedBatchedGemmParams<T>>(params, params->batch);
};
#endif // USE_HIPBLASLT
} // namespace internal
} // namespace blas
} // namespace tunable
} // namespace rocm
} // namespace onnxruntime

View file

@ -7,8 +7,9 @@
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/tunable/gemm_ck.cuh"
#include "core/providers/rocm/tunable/gemm_rocblas.h"
#include "core/providers/rocm/tunable/gemm_common.h"
#include "core/providers/rocm/tunable/gemm_hipblaslt.h"
#include "core/providers/rocm/tunable/gemm_rocblas.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"
namespace onnxruntime {
@ -38,6 +39,10 @@ class GemmTunableOp : public TunableOp<GemmParams<T>> {
GemmTunableOp() {
this->RegisterOp(RocBlasGemmOp<T>);
#ifdef USE_HIPBLASLT
this->RegisterOp(HipBlasLtGemmOp<T>);
#endif
#ifdef USE_ROCBLAS_EXTENSION_API
this->RegisterNestedTunableOp(&rocblas_gemm_tunable_op_);
#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */
@ -139,6 +144,10 @@ class StridedBatchedGemmTunableOp : public TunableOp<StridedBatchedGemmParams<T>
StridedBatchedGemmTunableOp() {
this->RegisterOp(RocBlasStridedBatchedGemmOp<T>);
#ifdef USE_HIPBLASLT
this->RegisterOp(HipBlasLtStridedBatchedGemmOp<T>);
#endif
#ifdef USE_ROCBLAS_EXTENSION_API
this->RegisterNestedTunableOp(&rocblas_strided_batched_gemm_tunable_op_);
#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */

View file

@ -74,6 +74,12 @@ std::string RocmTuningResultsValidator::GetOrtBuildConfig() const {
#else
oss << "USE_ROCBLAS_EXTENSION_API=" << 0 << "|";
#endif
#ifdef USE_HIPBLASLT
oss << "USE_HIPBLASLT=" << 1 << "|";
#else
oss << "USE_HIPBLASLT=" << 0 << "|";
#endif
return oss.str();
}