From 810512c6583fa4acf0a114627dfe7bc481f819ff Mon Sep 17 00:00:00 2001 From: mindest <30493312+mindest@users.noreply.github.com> Date: Fri, 14 Jul 2023 08:20:58 +0800 Subject: [PATCH] [ROCm] TunableOp: add hipBLASLt tuning logic (#16338) ### Description - Add hipBLASLt tuning logic in place of default hipBLASLt implementation; - add kernel explorer for hipBLASLt. related operators: Gemm, StridedBatchedGemm, and GemmFastGelu. Temporarily mark algos that require extra workspace as unsupported. Will add workspace support in later PR, which will change Gemm Params def and affect multiple files. --- .../rocm/bert/gemm_fast_gelu_tunable.cuh | 5 +- .../providers/rocm/tunable/gemm_hipblaslt.h | 361 ++++++++++-------- .../providers/rocm/tunable/gemm_tunable.cuh | 10 +- .../kernels/gemm_fast_gelu_test.py | 5 +- .../kernel_explorer/kernels/gemm_test.py | 4 + .../kernels/rocm/gemm_fast_gelu_hipblaslt.cu | 72 ++-- .../kernels/rocm/gemm_hipblaslt.cu | 212 ++++++++++ .../kernels/strided_batched_gemm_test.py | 4 + 8 files changed, 490 insertions(+), 183 deletions(-) create mode 100644 onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_hipblaslt.cu diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh index dcd2d45df5..ba0485508c 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh @@ -70,7 +70,10 @@ class GemmFastGeluTunableOp : public TunableOp> { #endif #ifdef USE_HIPBLASLT - this->RegisterOp(HipBlasLtGemmFastGeluOp); + for (auto&& [_, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } #endif } }; diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h index c3bd17ddf5..f93040ab9e 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h @@ -5,6 +5,10 @@ #ifdef USE_HIPBLASLT #include +#include +#include "core/providers/rocm/tunable/gemm_ck.cuh" +#include "core/providers/rocm/rocm_execution_provider.h" +#include "core/providers/rocm/rocm_stream_handle.h" #endif #include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" @@ -22,13 +26,6 @@ using onnxruntime::contrib::rocm::blas::GemmFastGeluParams; #ifdef USE_HIPBLASLT -// For large K and small M/N, K dim will be split to multiple workgroups and buffers, -// which will require additional workspace. Here we set the max workspace size to 32MB. -constexpr const size_t kHipBlasLtMaxWorkSpaceSizeInBytes = 32 * 1024 * 1024; -// We only keep one heuristic result here. Note that for tuned input sizes, the first result -// will be the most performant one; but in untuned cases, this is not guaranteed. -constexpr const int kHeuristicResultCount = 1; - enum ActivationType { NONE = 0, RELU = 1, @@ -36,184 +33,234 @@ enum ActivationType { }; template -constexpr hipblasDatatype_t HipBlasDataTypeFor(const T*); +constexpr hipblasDatatype_t HipBlasDataTypeFor(); template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor(const float*) { +constexpr hipblasDatatype_t HipBlasDataTypeFor() { return HIPBLAS_R_32F; } template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor(const half*) { +constexpr hipblasDatatype_t HipBlasDataTypeFor() { return HIPBLAS_R_16F; } template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor(const BFloat16*) { +constexpr hipblasDatatype_t HipBlasDataTypeFor() { return HIPBLAS_R_16B; } template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor(const double*) { +constexpr hipblasDatatype_t HipBlasDataTypeFor() { return HIPBLAS_R_64F; } +template +constexpr hipblasOperation_t MapCKLayoutToHipBlasLt() { + if constexpr (std::is_same_v) { + return HIPBLAS_OP_N; + } + return HIPBLAS_OP_T; +} + template -Status HipBlasLtMatMul(const ParamsT* params, int64_t batch, ActivationType activation_type = ActivationType::NONE, - bool enable_bias = false, const T* d_bias = nullptr, - bool enable_scaleD = false, const T* d_scaleD = nullptr) { +int GetBatchCountFromParams(const ParamsT* params) { + ORT_UNUSED_PARAMETER(params); + return 1; +} + +template +int GetBatchCountFromParams(const StridedBatchedGemmParams* params) { + return params->batch; +} + +template +const T* GetBiasFromParams(const ParamsT* params) { + ORT_UNUSED_PARAMETER(params); + return nullptr; +} + +template +const T* GetBiasFromParams(const GemmFastGeluParams* params) { + return params->bias; +} + +template +std::string TypeStringFor() { + if constexpr (std::is_same_v>) { + return "Gemm"; + } else if constexpr (std::is_same_v>) { + return "StridedBatchedGemm"; + } else if constexpr (std::is_same_v>) { + return "GemmFastGelu"; + } + return "UnknownType"; +} + +template +auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationType::NONE) { hipblasLtHandle_t handle; - HIPBLASLT_RETURN_IF_ERROR(hipblasLtCreate(&handle)); + HIPBLASLT_CALL_THROW(hipblasLtCreate(&handle)); - // Note: properties of original matrices A and B are swapped. - int64_t lda = (params->opb == BlasOp::N) ? params->n : params->k; - int64_t ldb = (params->opa == BlasOp::N) ? params->k : params->m; - int64_t ldc = params->n; - int64_t stride_a = (params->opb == BlasOp::N) ? lda * params->k : lda * params->n; - int64_t stride_b = (params->opa == BlasOp::N) ? ldb * params->m : ldb * params->k; - int64_t stride_c = ldc * params->m; - float alpha = static_cast(params->alpha); - float beta = static_cast(params->beta); - int row_a, col_a, row_b, col_b, row_c, col_c; - row_a = lda; - col_a = (params->opb == BlasOp::N) ? params->k : params->n; - row_b = ldb; - col_b = (params->opa == BlasOp::N) ? params->m : params->k; - row_c = ldc; - col_c = params->m; + hipblasOperation_t trans_a = MapCKLayoutToHipBlasLt(); + hipblasOperation_t trans_b = MapCKLayoutToHipBlasLt(); + hipblasDatatype_t in_out_datatype = HipBlasDataTypeFor(); + std::vector heuristic_result; - hipblasDatatype_t in_out_datatype = HipBlasDataTypeFor(params->a); - hipblasLtMatrixLayout_t mat_a, mat_b, mat_c; - hipblasLtMatmulDesc_t matmul; - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F)); + HIPBLASLT_CALL_THROW(hipblaslt_ext::getAllAlgos(handle, + hipblaslt_ext::GemmType::HIPBLASLT_GEMM, + trans_a, + trans_b, + in_out_datatype, + in_out_datatype, + in_out_datatype, + in_out_datatype, + HIPBLASLT_COMPUTE_F32, + heuristic_result)); + HIPBLASLT_CALL_THROW(hipblasLtDestroy(handle)); - if (batch > 1) { - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( - mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( - mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a))); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( - mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( - mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b))); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( - mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( - mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c))); + int returned_algo_count = heuristic_result.size(); + std::vector>> ret; + for (int i = 0; i < returned_algo_count; i++) { + hipblasLtMatmulAlgo_t algo = heuristic_result[i].algo; + auto hipblaslt_gemm_op = [=](const ParamsT* params) -> Status { + hipblasLtHandle_t op_handle; + HIPBLASLT_RETURN_IF_ERROR(hipblasLtCreate(&op_handle)); + + // Note: properties of original matrices A and B are swapped. + int64_t lda = (params->opb == BlasOp::N) ? params->n : params->k; + int64_t ldb = (params->opa == BlasOp::N) ? params->k : params->m; + int64_t ldc = params->n; + int64_t stride_a = (params->opb == BlasOp::N) ? lda * params->k : lda * params->n; + int64_t stride_b = (params->opa == BlasOp::N) ? ldb * params->m : ldb * params->k; + int64_t stride_c = ldc * params->m; + float alpha = static_cast(params->alpha); + float beta = static_cast(params->beta); + int row_a, col_a, row_b, col_b, row_c, col_c; + row_a = lda; + col_a = (params->opb == BlasOp::N) ? params->k : params->n; + row_b = ldb; + col_b = (params->opa == BlasOp::N) ? params->m : params->k; + row_c = ldc; + col_c = params->m; + + hipblasLtMatrixLayout_t mat_a, mat_b, mat_c; + hipblasLtMatmulDesc_t matmul; + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F)); + + int batch = GetBatchCountFromParams(params); + if (batch > 1) { + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( + mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( + mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a))); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( + mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( + mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b))); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( + mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutSetAttribute( + mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c))); + } + + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(int32_t))); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(int32_t))); + + // Deduce enable_bias from params + auto d_bias = GetBiasFromParams(params); + bool enable_bias = d_bias != nullptr; + + hipblasLtEpilogue_t epilogue; + switch (activation_type) { + case ActivationType::NONE: + epilogue = enable_bias ? HIPBLASLT_EPILOGUE_BIAS : HIPBLASLT_EPILOGUE_DEFAULT; + break; + case ActivationType::RELU: + epilogue = enable_bias ? HIPBLASLT_EPILOGUE_RELU_BIAS : HIPBLASLT_EPILOGUE_RELU; + break; + case ActivationType::GELU: + epilogue = enable_bias ? HIPBLASLT_EPILOGUE_GELU_BIAS : HIPBLASLT_EPILOGUE_GELU; + break; + default: + throw std::runtime_error("Unsupported activation type for HipBlasLtMatMul"); + } + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + + if (enable_bias) { + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &d_bias, sizeof(void*))); + } + + size_t workspace_size = 0; + hipblasLtMatmulAlgo_t algo_i = algo; + auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle, + matmul, + &alpha, + mat_a, + mat_b, + &beta, + mat_c, + mat_c, + algo_i, + workspace_size); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + status != HIPBLAS_STATUS_SUCCESS, "hipBLASLt find_all: algo not supported, index ", std::to_string(i)); + // TODO: support workspace in next PR + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + workspace_size > 0, "hipBLASLt find_all: extra workspace not supported for now."); + + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmul(op_handle, + matmul, + &alpha, + params->b, + mat_a, + params->a, + mat_b, + &beta, + params->c, + mat_c, + params->c, + mat_c, + &algo_i, + nullptr, + 0, + params->stream)); + + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescDestroy(matmul)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_a)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_b)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_c)); + HIPBLASLT_RETURN_IF_ERROR(hipblasLtDestroy(op_handle)); + return Status::OK(); + }; + std::string type_string = onnxruntime::MakeString(TypeStringFor(), "HipBlasLt_", i); + ret.emplace_back(type_string, std::move(hipblaslt_gemm_op)); } - - hipblasOperation_t trans_a = (params->opb == BlasOp::N) ? HIPBLAS_OP_N : HIPBLAS_OP_T; - hipblasOperation_t trans_b = (params->opa == BlasOp::N) ? HIPBLAS_OP_N : HIPBLAS_OP_T; - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(int32_t))); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(int32_t))); - - hipblasLtEpilogue_t epilogue; - switch (activation_type) { - case ActivationType::NONE: - epilogue = enable_bias ? HIPBLASLT_EPILOGUE_BIAS : HIPBLASLT_EPILOGUE_DEFAULT; - break; - case ActivationType::RELU: - epilogue = enable_bias ? HIPBLASLT_EPILOGUE_RELU_BIAS : HIPBLASLT_EPILOGUE_RELU; - break; - case ActivationType::GELU: - epilogue = enable_bias ? HIPBLASLT_EPILOGUE_GELU_BIAS : HIPBLASLT_EPILOGUE_GELU; - break; - default: - throw std::runtime_error("Unsupported activation type for HipBlasLtMatMul"); - } - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); - - if (enable_bias) { - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &d_bias, sizeof(void*))); - } - if (enable_scaleD) { - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d_scaleD, sizeof(void*))); - } - - hipblasLtMatmulPreference_t pref; - void* workspace; - size_t max_workspace_size = kHipBlasLtMaxWorkSpaceSizeInBytes; - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceCreate(&pref)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceSetAttribute( - pref, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &max_workspace_size, sizeof(max_workspace_size))); - - hipblasLtMatmulHeuristicResult_t heuristic_result[kHeuristicResultCount] = {}; - int ret_algo_count = 0; - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, - matmul, - mat_a, - mat_b, - mat_c, - mat_c, - pref, - kHeuristicResultCount, - heuristic_result, - &ret_algo_count)); - - assert(ret_algo_count > 0); - - size_t workspace_size = heuristic_result[0].workspaceSize; - if (workspace_size > 0) { - HIP_RETURN_IF_ERROR(hipMallocAsync(&workspace, workspace_size, params->stream)); - } - - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmul(handle, - matmul, - &alpha, - params->b, - mat_a, - params->a, - mat_b, - &beta, - params->c, - mat_c, - params->c, - mat_c, - &heuristic_result[0].algo, - workspace, - workspace_size, - params->stream)); - - if (workspace_size > 0) { - HIP_RETURN_IF_ERROR(hipFreeAsync(workspace, params->stream)); - } - - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulPreferenceDestroy(pref)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescDestroy(matmul)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_a)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_b)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_c)); - HIPBLASLT_RETURN_IF_ERROR(hipblasLtDestroy(handle)); - return Status::OK(); + return ret; } -template -Status HipBlasLtGemmOp(const GemmParams* params) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((std::is_same_v), "hipBLASLt does not support double inputs"); - return HipBlasLtMatMul>(params, /*batch=*/1); +template +auto GetHipBlasLtGemmTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); } -template -Status HipBlasLtStridedBatchedGemmOp(const StridedBatchedGemmParams* params) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((std::is_same_v), "hipBLASLt does not support double inputs"); - return HipBlasLtMatMul>(params, params->batch); -}; +template +auto GetHipBlasLtStridedBatchedGemmTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} -template -Status HipBlasLtGemmFastGeluOp(const GemmFastGeluParams* params) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((std::is_same_v), "hipBLASLt does not support double inputs"); - bool enable_bias = nullptr != params->bias; - return HipBlasLtMatMul>(params, /*batch=*/1, ActivationType::GELU, - enable_bias, params->bias); -}; +template +auto GetHipBlasLtGemmFastGeluTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(ActivationType::GELU); +} #endif // USE_HIPBLASLT diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh index 2a0642bbee..d39fa3e662 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh +++ b/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh @@ -40,7 +40,10 @@ class GemmTunableOp : public TunableOp> { this->RegisterOp(RocBlasGemmOp); #ifdef USE_HIPBLASLT - this->RegisterOp(HipBlasLtGemmOp); + for (auto&& [_, op] : GetHipBlasLtGemmTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } #endif #ifdef USE_ROCBLAS_EXTENSION_API @@ -141,7 +144,10 @@ class StridedBatchedGemmTunableOp : public TunableOp this->RegisterOp(RocBlasStridedBatchedGemmOp); #ifdef USE_HIPBLASLT - this->RegisterOp(HipBlasLtStridedBatchedGemmOp); + for (auto&& [_, op] : GetHipBlasLtStridedBatchedGemmTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } #endif #ifdef USE_ROCBLAS_EXTENSION_API diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py index 361789a597..2c1e5e54b8 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py @@ -166,7 +166,10 @@ def profile_with_args(transa, transb, dtype, m, n, k, sort): profile_gemmfastgelu_func( getattr(ke, "GemmFastGeluTunable" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb ) - profile_gemmfastgelu_func(getattr(ke, "GemmFastGeluHipBlasLt" + dtype_suffix), dtype, m, n, k, transa, transb) + if ke.is_hipblaslt_available(): + profile_gemmfastgelu_func( + getattr(ke, "GemmFastGeluHipBlasLt" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb + ) def profile(): diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py index 9b1a3b9760..6cb984935c 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py @@ -179,6 +179,10 @@ def profile_with_args(dtype, transa, transb, m, n, k, sort): profile_gemm_func(getattr(ke, "RocBlasGemm" + dtype_suffix), dtype, transa, transb, m, n, k) profile_gemm_func(getattr(ke, "CKGemm" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k) profile_gemm_func(getattr(ke, "GemmTunable" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k) + if ke.is_hipblaslt_available(): + profile_gemm_func( + getattr(ke, "GemmHipBlasLt" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k + ) print() diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_hipblaslt.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_hipblaslt.cu index e927d78a8b..4638de5010 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_hipblaslt.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_hipblaslt.cu @@ -20,7 +20,10 @@ namespace py = pybind11; namespace onnxruntime { #ifdef USE_HIPBLASLT -template + +using namespace rocm::tunable::blas::internal; + +template class GemmFastGeluHipBlasLt : public IKernelExplorer { public: GemmFastGeluHipBlasLt(BlasOp opa, BlasOp opb, @@ -33,6 +36,8 @@ class GemmFastGeluHipBlasLt : public IKernelExplorer { DeviceArray& c, int64_t ldc) : params_{} { params_.tuning_ctx = TuningContext(); params_.stream = Stream(); + // rocblas handle is not used for hipBLASLt + params_.handle = nullptr; params_.opa = opa; params_.opb = opb; params_.m = m; @@ -47,44 +52,67 @@ class GemmFastGeluHipBlasLt : public IKernelExplorer { params_.beta = beta; params_.c = static_cast(c.ptr()); params_.ldc = ldc; + + for (auto&& [type_string, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { + type_strings_.emplace_back(std::move(type_string)); + ops_.emplace_back(std::move(op)); + } + ORT_ENFORCE(!ops_.empty()); } void Run() override { - ORT_THROW_IF_ERROR((rocm::tunable::blas::internal::HipBlasLtGemmFastGeluOp(¶ms_))); + ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); } std::vector ListOps() const { - return {"GemmFastGeluHipBlasLt"}; + return type_strings_; } bool SelectOp(const std::string& name) { - Status status = rocm::tunable::blas::internal::HipBlasLtGemmFastGeluOp(¶ms_); - return status.IsOK() && name == "GemmFastGeluHipBlasLt"; + for (size_t i = 0; i < ops_.size(); i++) { + if (type_strings_[i] == name) { + selected_op_ = i; + Status status = ops_[i](¶ms_); + return status.IsOK(); + } + } + + ORT_THROW("Cannot find implementation ", name); } private: using ParamsT = contrib::rocm::blas::GemmFastGeluParams; - ParamsT params_{}; + using OpT = Op; + ParamsT params_; + std::vector ops_; + std::vector type_strings_; + size_t selected_op_{}; }; -#define REGISTER_OP(type) \ - py::class_>(m, "GemmFastGeluHipBlasLt_" #type) \ - .def(py::init()) \ - .def("SetRepeats", &GemmFastGeluHipBlasLt::SetRepeats) \ - .def("Run", &GemmFastGeluHipBlasLt::Run) \ - .def("Profile", &GemmFastGeluHipBlasLt::Profile) \ - .def("ListOps", &GemmFastGeluHipBlasLt::ListOps) \ - .def("SelectOp", &GemmFastGeluHipBlasLt::SelectOp); +#define REGISTER_OP(type, alayout, blayout, layout_string) \ + py::class_>(m, "GemmFastGeluHipBlasLt_" #type "_" layout_string) \ + .def(py::init()) \ + .def("SetRepeats", &GemmFastGeluHipBlasLt::SetRepeats) \ + .def("Profile", &GemmFastGeluHipBlasLt::Profile) \ + .def("Run", &GemmFastGeluHipBlasLt::Run) \ + .def("ListOps", &GemmFastGeluHipBlasLt::ListOps) \ + .def("SelectOp", &GemmFastGeluHipBlasLt::SelectOp); + +#define REGISTER_OP_FOR_ALL_TRANSAB(type) \ + REGISTER_OP(type, Row, Row, "NN"); \ + REGISTER_OP(type, Row, Col, "NT"); \ + REGISTER_OP(type, Col, Row, "TN"); \ + REGISTER_OP(type, Col, Col, "TT"); KE_REGISTER(m) { - REGISTER_OP(float) - REGISTER_OP(half) + REGISTER_OP_FOR_ALL_TRANSAB(float); + REGISTER_OP_FOR_ALL_TRANSAB(half); } #endif // USE_HIPBLASLT diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_hipblaslt.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_hipblaslt.cu new file mode 100644 index 0000000000..1d9a5fac22 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_hipblaslt.cu @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include +#include + +#ifdef USE_HIPBLASLT +#include "core/providers/rocm/tunable/gemm_hipblaslt.h" +#endif + +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/tunable/gemm_common.h" +#include "python/tools/kernel_explorer/device_array.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" + +using namespace onnxruntime::rocm::tunable::blas; + +namespace py = pybind11; + +namespace onnxruntime { + +#ifdef USE_HIPBLASLT + +using namespace rocm::tunable::blas::internal; + +template +class GemmHipBlasLt : public IKernelExplorer { + public: + GemmHipBlasLt(BlasOp opa, BlasOp opb, + int64_t m, int64_t n, int64_t k, + double alpha, + DeviceArray& a, int64_t lda, + DeviceArray& b, int64_t ldb, + double beta, + DeviceArray& c, int64_t ldc) + : params_{} { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + // rocblas handle is not used for hipBLASLt + params_.handle = nullptr; + params_.opa = opa; + params_.opb = opb; + params_.m = m; + params_.n = n; + params_.k = k; + params_.alpha = alpha; + params_.a = static_cast(a.ptr()); + params_.lda = lda; + params_.b = static_cast(b.ptr()); + params_.ldb = ldb; + params_.beta = beta; + params_.c = static_cast(c.ptr()); + params_.ldc = ldc; + + for (auto&& [type_string, op] : GetHipBlasLtGemmTypeStringAndOps()) { + type_strings_.emplace_back(std::move(type_string)); + ops_.emplace_back(std::move(op)); + } + ORT_ENFORCE(!ops_.empty()); + } + + void Run() override { + ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); + } + + std::vector ListOps() const { + return type_strings_; + } + + bool SelectOp(const std::string& name) { + for (size_t i = 0; i < ops_.size(); i++) { + if (type_strings_[i] == name) { + selected_op_ = i; + Status status = ops_[i](¶ms_); + return status.IsOK(); + } + } + + ORT_THROW("Cannot find implementation ", name); + } + + private: + using ParamsT = GemmParams; + using OpT = Op; + ParamsT params_; + std::vector ops_; + std::vector type_strings_; + size_t selected_op_{}; +}; + +template +class StridedBatchedGemmHipBlasLt : public IKernelExplorer { + public: + StridedBatchedGemmHipBlasLt( + BlasOp opa, BlasOp opb, + int64_t m, int64_t n, int64_t k, + double alpha, + DeviceArray& a, int64_t lda, int64_t stride_a, + DeviceArray& b, int64_t ldb, int64_t stride_b, + double beta, + DeviceArray& c, int64_t ldc, int64_t stride_c, + int64_t batch) + : params_{} { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + // rocblas handle is not used for hipBLASLt + params_.handle = nullptr; + params_.opa = opa; + params_.opb = opb; + params_.m = m; + params_.n = n; + params_.k = k; + params_.alpha = alpha; + params_.a = static_cast(a.ptr()); + params_.lda = lda; + params_.stride_a = stride_a; + params_.b = static_cast(b.ptr()); + params_.ldb = ldb; + params_.stride_b = stride_b; + params_.beta = beta; + params_.c = static_cast(c.ptr()); + params_.ldc = ldc; + params_.stride_c = stride_c; + params_.batch = batch; + + for (auto&& [type_string, op] : GetHipBlasLtStridedBatchedGemmTypeStringAndOps()) { + type_strings_.emplace_back(std::move(type_string)); + ops_.emplace_back(std::move(op)); + } + ORT_ENFORCE(!ops_.empty()); + } + + void Run() override { + ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); + } + + std::vector ListOps() const { + return type_strings_; + } + + bool SelectOp(const std::string& name) { + for (size_t i = 0; i < ops_.size(); i++) { + if (type_strings_[i] == name) { + selected_op_ = i; + Status status = ops_[i](¶ms_); + return status.IsOK(); + } + } + + ORT_THROW("Cannot find implementation ", name); + } + + private: + using ParamsT = StridedBatchedGemmParams; + using OpT = Op; + ParamsT params_; + std::vector ops_; + std::vector type_strings_; + size_t selected_op_{}; +}; + +#define REGISTER_OP_COMMON(type, dtype, alayout, blayout, layout_string) \ + py::class_>(m, #type "_" #dtype "_" layout_string) \ + .def("SetRepeats", &type::SetRepeats) \ + .def("Profile", &type::Profile) \ + .def("Run", &type::Run) \ + .def("ListOps", &type::ListOps) \ + .def("SelectOp", &type::SelectOp) + +#define REGISTER_GEMM_HIPBLASLT(dtype, alayout, blayout, layout_string) \ + REGISTER_OP_COMMON(GemmHipBlasLt, dtype, alayout, blayout, layout_string) \ + .def(py::init()); + +#define REGISTER_GEMM_HIPBLASLT_FOR_ALL_TRANSAB(dtype) \ + REGISTER_GEMM_HIPBLASLT(dtype, Row, Row, "NN"); \ + REGISTER_GEMM_HIPBLASLT(dtype, Row, Col, "NT"); \ + REGISTER_GEMM_HIPBLASLT(dtype, Col, Row, "TN"); \ + REGISTER_GEMM_HIPBLASLT(dtype, Col, Col, "TT"); + +#define REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, alayout, blayout, layout_string) \ + REGISTER_OP_COMMON(StridedBatchedGemmHipBlasLt, dtype, alayout, blayout, layout_string) \ + .def(py::init()); + +#define REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT_FOR_ALL_TRANSAB(dtype) \ + REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, Row, Row, "NN"); \ + REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, Row, Col, "NT"); \ + REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, Col, Row, "TN"); \ + REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT(dtype, Col, Col, "TT"); + +KE_REGISTER(m) { + REGISTER_GEMM_HIPBLASLT_FOR_ALL_TRANSAB(float); + REGISTER_GEMM_HIPBLASLT_FOR_ALL_TRANSAB(half); + + REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT_FOR_ALL_TRANSAB(float); + REGISTER_STRIDEDBATCHEDGEMM_HIPBLASLT_FOR_ALL_TRANSAB(half); +} +#endif // USE_HIPBLASLT + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py index 7114f3b6d9..9b2b0b0871 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py @@ -215,10 +215,14 @@ def profile_with_args(dtype, transa, transb, m, n, k, batch, sort): fn_rocblas = getattr(ke, "RocBlasStridedBatchedGemm" + dtype_suffix) fn_ck = getattr(ke, "CKStridedBatchedGemm" + dtype_suffix + transab_suffix) fn_tunable = getattr(ke, "StridedBatchedGemmTunable" + dtype_suffix + transab_suffix) + if ke.is_hipblaslt_available(): + fn_hipblaslt = getattr(ke, "StridedBatchedGemmHipBlasLt" + dtype_suffix + transab_suffix) with ke.benchmark(sort): profile_gemm_func(fn_rocblas, dtype, transa, transb, m, n, k, batch) profile_gemm_func(fn_ck, dtype, transa, transb, m, n, k, batch) profile_gemm_func(fn_tunable, dtype, transa, transb, m, n, k, batch) + if ke.is_hipblaslt_available(): + profile_gemm_func(fn_hipblaslt, dtype, transa, transb, m, n, k, batch) print()