From 0fdd356abff9d5ae2bc259de2628cd5badd0b3c2 Mon Sep 17 00:00:00 2001 From: mindest <30493312+mindest@users.noreply.github.com> Date: Fri, 14 Apr 2023 17:56:01 +0800 Subject: [PATCH] [ROCm] Add hipBLASLt GEMM support to Tunable op. (#15351) ### Description Add hipBLASLt to GEMM Tunable op, which supports GEMM and StridedBatchedGEMM. To enable hipBLASLt implementation, add an extra flag to the building command: `--cmake_extra_defines onnxruntime_USE_HIPBLASLT=ON`. --- cmake/onnxruntime_providers.cmake | 6 + onnxruntime/core/providers/rocm/rocm_call.cc | 6 + onnxruntime/core/providers/rocm/rocm_common.h | 4 + onnxruntime/core/providers/rocm/rocm_pch.h | 4 + .../providers/rocm/shared_inc/rocm_call.h | 5 + .../providers/rocm/tunable/gemm_hipblaslt.h | 213 ++++++++++++++++++ .../providers/rocm/tunable/gemm_tunable.cuh | 11 +- .../rocm/tunable/rocm_tuning_context.cc | 6 + 8 files changed, 254 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 5497dc62f2..de46811f22 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1536,6 +1536,12 @@ if (onnxruntime_USE_ROCM) target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_BETA_FEATURES_API) endif() + if (onnxruntime_USE_HIPBLASLT) + find_package(hipblaslt REQUIRED) + target_link_libraries(onnxruntime_providers_rocm PRIVATE roc::hipblaslt) + target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_HIPBLASLT) + endif() + if (onnxruntime_USE_COMPOSABLE_KERNEL) include(composable_kernel) target_link_libraries(onnxruntime_providers_rocm PRIVATE diff --git a/onnxruntime/core/providers/rocm/rocm_call.cc b/onnxruntime/core/providers/rocm/rocm_call.cc index f6dbfbffb1..730f55608c 100644 --- a/onnxruntime/core/providers/rocm/rocm_call.cc +++ b/onnxruntime/core/providers/rocm/rocm_call.cc @@ -148,4 +148,10 @@ template void RocmCall(hipfftResult retCode, const char* exp template Status RocmCall(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line); template void RocmCall(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line); #endif + +#ifdef USE_HIPBLASLT +template Status RocmCall(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line); +template void RocmCall(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line); +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_common.h b/onnxruntime/core/providers/rocm/rocm_common.h index 6a24883757..07b3e252c6 100644 --- a/onnxruntime/core/providers/rocm/rocm_common.h +++ b/onnxruntime/core/providers/rocm/rocm_common.h @@ -23,6 +23,10 @@ namespace rocm { #define MIOPEN2_RETURN_IF_ERROR(expr, m) ORT_RETURN_IF_ERROR(MIOPEN_CALL2(expr, m)) #define HIPFFT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIPFFT_CALL(expr)) +#ifdef USE_HIPBLASLT +#define HIPBLASLT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIPBLASLT_CALL(expr)) +#endif + // Type mapping for MLFloat16 to half template class ToHipType { diff --git a/onnxruntime/core/providers/rocm/rocm_pch.h b/onnxruntime/core/providers/rocm/rocm_pch.h index ecaf7a2c55..d91369dda5 100644 --- a/onnxruntime/core/providers/rocm/rocm_pch.h +++ b/onnxruntime/core/providers/rocm/rocm_pch.h @@ -19,6 +19,10 @@ #include #endif +#ifdef USE_HIPBLASLT +#include +#endif + #if defined(_MSC_VER) #pragma warning(pop) #endif diff --git a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h index 10ce83a614..d6623ef63f 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h +++ b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h @@ -36,4 +36,9 @@ std::conditional_t RocmCall( #define NCCL_CALL_THROW(expr) (RocmCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) #endif +#ifdef USE_HIPBLASLT +#define HIPBLASLT_CALL(expr) (RocmCall((expr), #expr, "hipBLASLt", HIPBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define HIPBLASLT_CALL_THROW(expr) (RocmCall((expr), #expr, "hipBLASLt", HIPBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h new file mode 100644 index 0000000000..aa5630ebea --- /dev/null +++ b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef USE_HIPBLASLT +#include +#endif + +#include "core/common/common.h" +#include "core/providers/rocm/tunable/gemm_common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +#ifdef USE_HIPBLASLT + +// For large K and small M/N, K dim will be splited to multiple workgroups and buffers, +// which will require additional workspace. Here we set the max workspace size to 32MB. +constexpr const size_t kHipBlasLtMaxWorkSpaceSizeInBytes = 32 * 1024 * 1024; +// We only keep one heuristic result here. Note that for tuned input sizes, the first result +// will be the most performant one; but in untuned cases, this is not guaranteed. +constexpr const int kHeuristicResultCount = 1; + +enum ActivationType { + NONE = 0, + RELU = 1, + GELU = 2, +}; + +template +constexpr hipblasDatatype_t HipBlasDataTypeFor(const T*); + +template <> +constexpr hipblasDatatype_t HipBlasDataTypeFor(const float*) { + return HIPBLAS_R_32F; +} + +template <> +constexpr hipblasDatatype_t HipBlasDataTypeFor(const half*) { + return HIPBLAS_R_16F; +} + +template <> +constexpr hipblasDatatype_t HipBlasDataTypeFor(const BFloat16*) { + return HIPBLAS_R_16B; +} + +template <> +constexpr hipblasDatatype_t HipBlasDataTypeFor(const double*) { + return HIPBLAS_R_64F; +} + +template +Status HipBlasLtMatMul(const ParamsT* params, int64_t batch, ActivationType activation_type = ActivationType::NONE, + bool enable_bias = false, const T* d_bias = nullptr, + bool enable_scaleD = false, const T* d_scaleD = nullptr) { + hipblasLtHandle_t handle; + HIPBLASLT_RETURN_IF_ERROR(hipblasLtCreate(&handle)); + + // Note: properties of original matrices A and B are swapped. + int64_t lda = (params->opb == BlasOp::N) ? params->n : params->k; + int64_t ldb = (params->opa == BlasOp::N) ? params->k : params->m; + int64_t ldc = params->n; + int64_t stride_a = (params->opb == BlasOp::N) ? lda * params->k : lda * params->n; + int64_t stride_b = (params->opa == BlasOp::N) ? ldb * params->m : ldb * params->k; + int64_t stride_c = ldc * params->m; + float alpha = static_cast(params->alpha); + float beta = static_cast(params->beta); + int row_a, col_a, row_b, col_b, row_c, col_c; + row_a = lda; + col_a = (params->opb == BlasOp::N) ? params->k : params->n; + row_b = ldb; + col_b = (params->opa == BlasOp::N) ? params->m : params->k; + row_c = ldc; + col_c = params->m; + + hipblasDatatype_t in_out_datatype = HipBlasDataTypeFor(params->a); + hipblasLtMatrixLayout_t mat_a, mat_b, mat_c; + hipblasLtMatmulDesc_t matmul; + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F)); + + if (batch > 1) { + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( + mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( + mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a))); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( + mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( + mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b))); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( + mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( + mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c))); + } + + hipblasOperation_t trans_a = (params->opb == BlasOp::N) ? HIPBLAS_OP_N : HIPBLAS_OP_T; + hipblasOperation_t trans_b = (params->opa == BlasOp::N) ? HIPBLAS_OP_N : HIPBLAS_OP_T; + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(int32_t))); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(int32_t))); + + hipblasLtEpilogue_t epilogue; + switch (activation_type) { + case ActivationType::NONE: + epilogue = enable_bias ? HIPBLASLT_EPILOGUE_BIAS : HIPBLASLT_EPILOGUE_DEFAULT; + break; + case ActivationType::RELU: + epilogue = enable_bias ? HIPBLASLT_EPILOGUE_RELU_BIAS : HIPBLASLT_EPILOGUE_RELU; + break; + case ActivationType::GELU: + epilogue = enable_bias ? HIPBLASLT_EPILOGUE_GELU_BIAS : HIPBLASLT_EPILOGUE_GELU; + break; + default: + throw std::runtime_error("Unsupported activation type for HipBlasLtMatMul"); + } + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + + if (enable_bias) { + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &d_bias, sizeof(void*))); + } + if (enable_scaleD) { + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d_scaleD, sizeof(void*))); + } + + hipblasLtMatmulPreference_t pref; + void* workspace; + size_t max_workspace_size = kHipBlasLtMaxWorkSpaceSizeInBytes; + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceCreate(&pref)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceSetAttribute( + pref, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &max_workspace_size, sizeof(max_workspace_size))); + + hipblasLtMatmulHeuristicResult_t heuristic_result[kHeuristicResultCount] = {0}; + int ret_algo_count = 0; + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, + matmul, + mat_a, + mat_b, + mat_c, + mat_c, + pref, + kHeuristicResultCount, + heuristic_result, + &ret_algo_count)); + + assert(ret_algo_count > 0); + + size_t workspace_size = heuristic_result[0].workspaceSize; + if (workspace_size > 0) { + HIP_RETURN_IF_ERROR(hipMallocAsync(&workspace, workspace_size, params->stream)); + } + + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmul(handle, + matmul, + &alpha, + params->b, + mat_a, + params->a, + mat_b, + &beta, + params->c, + mat_c, + params->c, + mat_c, + &heuristic_result[0].algo, + workspace, + workspace_size, + params->stream)); + + if (workspace_size > 0) { + HIP_RETURN_IF_ERROR(hipFreeAsync(workspace, params->stream)); + } + + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceDestroy(pref)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescDestroy(matmul)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_a)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_b)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_c)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtDestroy(handle)); + return Status::OK(); +} + +template +Status HipBlasLtGemmOp(const GemmParams* params) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((std::is_same_v), "hipBLASLt does not support double inputs"); + return HipBlasLtMatMul>(params, /*batch=*/1); +} + +template +Status HipBlasLtStridedBatchedGemmOp(const StridedBatchedGemmParams* params) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((std::is_same_v), "hipBLASLt does not support double inputs"); + return HipBlasLtMatMul>(params, params->batch); +}; + +#endif // USE_HIPBLASLT + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh index ceda036fb1..5a46654b61 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh +++ b/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh @@ -7,8 +7,9 @@ #include "core/providers/rocm/cu_inc/common.cuh" #include "core/providers/rocm/tunable/gemm_ck.cuh" -#include "core/providers/rocm/tunable/gemm_rocblas.h" #include "core/providers/rocm/tunable/gemm_common.h" +#include "core/providers/rocm/tunable/gemm_hipblaslt.h" +#include "core/providers/rocm/tunable/gemm_rocblas.h" #include "core/providers/rocm/tunable/rocm_tunable.h" namespace onnxruntime { @@ -38,6 +39,10 @@ class GemmTunableOp : public TunableOp> { GemmTunableOp() { this->RegisterOp(RocBlasGemmOp); +#ifdef USE_HIPBLASLT + this->RegisterOp(HipBlasLtGemmOp); +#endif + #ifdef USE_ROCBLAS_EXTENSION_API this->RegisterNestedTunableOp(&rocblas_gemm_tunable_op_); #endif /* #ifdef USE_ROCBLAS_EXTENSION_API */ @@ -139,6 +144,10 @@ class StridedBatchedGemmTunableOp : public TunableOp StridedBatchedGemmTunableOp() { this->RegisterOp(RocBlasStridedBatchedGemmOp); +#ifdef USE_HIPBLASLT + this->RegisterOp(HipBlasLtStridedBatchedGemmOp); +#endif + #ifdef USE_ROCBLAS_EXTENSION_API this->RegisterNestedTunableOp(&rocblas_strided_batched_gemm_tunable_op_); #endif /* #ifdef USE_ROCBLAS_EXTENSION_API */ diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc index 69d68bf8b8..1cc45ce840 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc @@ -74,6 +74,12 @@ std::string RocmTuningResultsValidator::GetOrtBuildConfig() const { #else oss << "USE_ROCBLAS_EXTENSION_API=" << 0 << "|"; #endif + +#ifdef USE_HIPBLASLT + oss << "USE_HIPBLASLT=" << 1 << "|"; +#else + oss << "USE_HIPBLASLT=" << 0 << "|"; +#endif return oss.str(); }