mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[ROCm] Add hipBLASLt GEMM support to Tunable op. (#15351)
### Description Add hipBLASLt to GEMM Tunable op, which supports GEMM and StridedBatchedGEMM. To enable hipBLASLt implementation, add an extra flag to the building command: `--cmake_extra_defines onnxruntime_USE_HIPBLASLT=ON`.
This commit is contained in:
parent
fda0aa14c8
commit
0fdd356abf
8 changed files with 254 additions and 1 deletions
|
|
@ -1536,6 +1536,12 @@ if (onnxruntime_USE_ROCM)
|
|||
target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_BETA_FEATURES_API)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_USE_HIPBLASLT)
|
||||
find_package(hipblaslt REQUIRED)
|
||||
target_link_libraries(onnxruntime_providers_rocm PRIVATE roc::hipblaslt)
|
||||
target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_HIPBLASLT)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_USE_COMPOSABLE_KERNEL)
|
||||
include(composable_kernel)
|
||||
target_link_libraries(onnxruntime_providers_rocm PRIVATE
|
||||
|
|
|
|||
|
|
@ -148,4 +148,10 @@ template void RocmCall<hipfftResult, true>(hipfftResult retCode, const char* exp
|
|||
template Status RocmCall<ncclResult_t, false>(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line);
|
||||
template void RocmCall<ncclResult_t, true>(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line);
|
||||
#endif
|
||||
|
||||
#ifdef USE_HIPBLASLT
|
||||
template Status RocmCall<hipblasStatus_t, false>(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line);
|
||||
template void RocmCall<hipblasStatus_t, true>(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line);
|
||||
#endif
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -23,6 +23,10 @@ namespace rocm {
|
|||
#define MIOPEN2_RETURN_IF_ERROR(expr, m) ORT_RETURN_IF_ERROR(MIOPEN_CALL2(expr, m))
|
||||
#define HIPFFT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIPFFT_CALL(expr))
|
||||
|
||||
#ifdef USE_HIPBLASLT
|
||||
#define HIPBLASLT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIPBLASLT_CALL(expr))
|
||||
#endif
|
||||
|
||||
// Type mapping for MLFloat16 to half
|
||||
template <typename T>
|
||||
class ToHipType {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,10 @@
|
|||
#include <rccl/rccl.h>
|
||||
#endif
|
||||
|
||||
#ifdef USE_HIPBLASLT
|
||||
#include <hipblas/hipblas.h>
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -36,4 +36,9 @@ std::conditional_t<THRW, void, Status> RocmCall(
|
|||
#define NCCL_CALL_THROW(expr) (RocmCall<ncclResult_t, true>((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__))
|
||||
#endif
|
||||
|
||||
#ifdef USE_HIPBLASLT
|
||||
#define HIPBLASLT_CALL(expr) (RocmCall<hipblasStatus_t, false>((expr), #expr, "hipBLASLt", HIPBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__))
|
||||
#define HIPBLASLT_CALL_THROW(expr) (RocmCall<hipblasStatus_t, true>((expr), #expr, "hipBLASLt", HIPBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__))
|
||||
#endif
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
213
onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h
Normal file
213
onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef USE_HIPBLASLT
|
||||
#include <hipblaslt/hipblaslt.h>
|
||||
#endif
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/rocm/tunable/gemm_common.h"
|
||||
#include "core/providers/rocm/tunable/rocm_tunable.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
namespace tunable {
|
||||
namespace blas {
|
||||
namespace internal {
|
||||
|
||||
#ifdef USE_HIPBLASLT
|
||||
|
||||
// For large K and small M/N, K dim will be splited to multiple workgroups and buffers,
|
||||
// which will require additional workspace. Here we set the max workspace size to 32MB.
|
||||
constexpr const size_t kHipBlasLtMaxWorkSpaceSizeInBytes = 32 * 1024 * 1024;
|
||||
// We only keep one heuristic result here. Note that for tuned input sizes, the first result
|
||||
// will be the most performant one; but in untuned cases, this is not guaranteed.
|
||||
constexpr const int kHeuristicResultCount = 1;
|
||||
|
||||
enum ActivationType {
|
||||
NONE = 0,
|
||||
RELU = 1,
|
||||
GELU = 2,
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor(const T*);
|
||||
|
||||
template <>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor(const float*) {
|
||||
return HIPBLAS_R_32F;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor(const half*) {
|
||||
return HIPBLAS_R_16F;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor(const BFloat16*) {
|
||||
return HIPBLAS_R_16B;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor(const double*) {
|
||||
return HIPBLAS_R_64F;
|
||||
}
|
||||
|
||||
template <typename T, typename ParamsT>
|
||||
Status HipBlasLtMatMul(const ParamsT* params, int64_t batch, ActivationType activation_type = ActivationType::NONE,
|
||||
bool enable_bias = false, const T* d_bias = nullptr,
|
||||
bool enable_scaleD = false, const T* d_scaleD = nullptr) {
|
||||
hipblasLtHandle_t handle;
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtCreate(&handle));
|
||||
|
||||
// Note: properties of original matrices A and B are swapped.
|
||||
int64_t lda = (params->opb == BlasOp::N) ? params->n : params->k;
|
||||
int64_t ldb = (params->opa == BlasOp::N) ? params->k : params->m;
|
||||
int64_t ldc = params->n;
|
||||
int64_t stride_a = (params->opb == BlasOp::N) ? lda * params->k : lda * params->n;
|
||||
int64_t stride_b = (params->opa == BlasOp::N) ? ldb * params->m : ldb * params->k;
|
||||
int64_t stride_c = ldc * params->m;
|
||||
float alpha = static_cast<float>(params->alpha);
|
||||
float beta = static_cast<float>(params->beta);
|
||||
int row_a, col_a, row_b, col_b, row_c, col_c;
|
||||
row_a = lda;
|
||||
col_a = (params->opb == BlasOp::N) ? params->k : params->n;
|
||||
row_b = ldb;
|
||||
col_b = (params->opa == BlasOp::N) ? params->m : params->k;
|
||||
row_c = ldc;
|
||||
col_c = params->m;
|
||||
|
||||
hipblasDatatype_t in_out_datatype = HipBlasDataTypeFor(params->a);
|
||||
hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
|
||||
hipblasLtMatmulDesc_t matmul;
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F));
|
||||
|
||||
if (batch > 1) {
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
|
||||
mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
|
||||
mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
|
||||
mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
|
||||
mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
|
||||
mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute(
|
||||
mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
|
||||
}
|
||||
|
||||
hipblasOperation_t trans_a = (params->opb == BlasOp::N) ? HIPBLAS_OP_N : HIPBLAS_OP_T;
|
||||
hipblasOperation_t trans_b = (params->opa == BlasOp::N) ? HIPBLAS_OP_N : HIPBLAS_OP_T;
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
|
||||
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(int32_t)));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
|
||||
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(int32_t)));
|
||||
|
||||
hipblasLtEpilogue_t epilogue;
|
||||
switch (activation_type) {
|
||||
case ActivationType::NONE:
|
||||
epilogue = enable_bias ? HIPBLASLT_EPILOGUE_BIAS : HIPBLASLT_EPILOGUE_DEFAULT;
|
||||
break;
|
||||
case ActivationType::RELU:
|
||||
epilogue = enable_bias ? HIPBLASLT_EPILOGUE_RELU_BIAS : HIPBLASLT_EPILOGUE_RELU;
|
||||
break;
|
||||
case ActivationType::GELU:
|
||||
epilogue = enable_bias ? HIPBLASLT_EPILOGUE_GELU_BIAS : HIPBLASLT_EPILOGUE_GELU;
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported activation type for HipBlasLtMatMul");
|
||||
}
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
|
||||
matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
|
||||
|
||||
if (enable_bias) {
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
|
||||
matmul, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &d_bias, sizeof(void*)));
|
||||
}
|
||||
if (enable_scaleD) {
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute(
|
||||
matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d_scaleD, sizeof(void*)));
|
||||
}
|
||||
|
||||
hipblasLtMatmulPreference_t pref;
|
||||
void* workspace;
|
||||
size_t max_workspace_size = kHipBlasLtMaxWorkSpaceSizeInBytes;
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceCreate(&pref));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceSetAttribute(
|
||||
pref, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &max_workspace_size, sizeof(max_workspace_size)));
|
||||
|
||||
hipblasLtMatmulHeuristicResult_t heuristic_result[kHeuristicResultCount] = {0};
|
||||
int ret_algo_count = 0;
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle,
|
||||
matmul,
|
||||
mat_a,
|
||||
mat_b,
|
||||
mat_c,
|
||||
mat_c,
|
||||
pref,
|
||||
kHeuristicResultCount,
|
||||
heuristic_result,
|
||||
&ret_algo_count));
|
||||
|
||||
assert(ret_algo_count > 0);
|
||||
|
||||
size_t workspace_size = heuristic_result[0].workspaceSize;
|
||||
if (workspace_size > 0) {
|
||||
HIP_RETURN_IF_ERROR(hipMallocAsync(&workspace, workspace_size, params->stream));
|
||||
}
|
||||
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmul(handle,
|
||||
matmul,
|
||||
&alpha,
|
||||
params->b,
|
||||
mat_a,
|
||||
params->a,
|
||||
mat_b,
|
||||
&beta,
|
||||
params->c,
|
||||
mat_c,
|
||||
params->c,
|
||||
mat_c,
|
||||
&heuristic_result[0].algo,
|
||||
workspace,
|
||||
workspace_size,
|
||||
params->stream));
|
||||
|
||||
if (workspace_size > 0) {
|
||||
HIP_RETURN_IF_ERROR(hipFreeAsync(workspace, params->stream));
|
||||
}
|
||||
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceDestroy(pref));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescDestroy(matmul));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_a));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_b));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_c));
|
||||
HIPBLASLT_RETURN_IF_ERROR(hipblasLtDestroy(handle));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status HipBlasLtGemmOp(const GemmParams<T>* params) {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((std::is_same_v<T, double>), "hipBLASLt does not support double inputs");
|
||||
return HipBlasLtMatMul<T, GemmParams<T>>(params, /*batch=*/1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status HipBlasLtStridedBatchedGemmOp(const StridedBatchedGemmParams<T>* params) {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((std::is_same_v<T, double>), "hipBLASLt does not support double inputs");
|
||||
return HipBlasLtMatMul<T, StridedBatchedGemmParams<T>>(params, params->batch);
|
||||
};
|
||||
|
||||
#endif // USE_HIPBLASLT
|
||||
|
||||
} // namespace internal
|
||||
} // namespace blas
|
||||
} // namespace tunable
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -7,8 +7,9 @@
|
|||
|
||||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "core/providers/rocm/tunable/gemm_ck.cuh"
|
||||
#include "core/providers/rocm/tunable/gemm_rocblas.h"
|
||||
#include "core/providers/rocm/tunable/gemm_common.h"
|
||||
#include "core/providers/rocm/tunable/gemm_hipblaslt.h"
|
||||
#include "core/providers/rocm/tunable/gemm_rocblas.h"
|
||||
#include "core/providers/rocm/tunable/rocm_tunable.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -38,6 +39,10 @@ class GemmTunableOp : public TunableOp<GemmParams<T>> {
|
|||
GemmTunableOp() {
|
||||
this->RegisterOp(RocBlasGemmOp<T>);
|
||||
|
||||
#ifdef USE_HIPBLASLT
|
||||
this->RegisterOp(HipBlasLtGemmOp<T>);
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCBLAS_EXTENSION_API
|
||||
this->RegisterNestedTunableOp(&rocblas_gemm_tunable_op_);
|
||||
#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */
|
||||
|
|
@ -139,6 +144,10 @@ class StridedBatchedGemmTunableOp : public TunableOp<StridedBatchedGemmParams<T>
|
|||
StridedBatchedGemmTunableOp() {
|
||||
this->RegisterOp(RocBlasStridedBatchedGemmOp<T>);
|
||||
|
||||
#ifdef USE_HIPBLASLT
|
||||
this->RegisterOp(HipBlasLtStridedBatchedGemmOp<T>);
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCBLAS_EXTENSION_API
|
||||
this->RegisterNestedTunableOp(&rocblas_strided_batched_gemm_tunable_op_);
|
||||
#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */
|
||||
|
|
|
|||
|
|
@ -74,6 +74,12 @@ std::string RocmTuningResultsValidator::GetOrtBuildConfig() const {
|
|||
#else
|
||||
oss << "USE_ROCBLAS_EXTENSION_API=" << 0 << "|";
|
||||
#endif
|
||||
|
||||
#ifdef USE_HIPBLASLT
|
||||
oss << "USE_HIPBLASLT=" << 1 << "|";
|
||||
#else
|
||||
oss << "USE_HIPBLASLT=" << 0 << "|";
|
||||
#endif
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue