From b7fd5af48bcad76466b67dcd95dfc78a0bb469b2 Mon Sep 17 00:00:00 2001 From: mindest <30493312+mindest@users.noreply.github.com> Date: Thu, 13 Jul 2023 11:20:26 +0800 Subject: [PATCH] [ROCm] TunableOp: Update rocBLAS get_solutions API (since ROCm5.6) (#16657) ### Description - Update existing rocBLAS get_solutions API using `*_get_solutions_by_type` (supported from ROCm5.6); remove the original nested TunableOp logic. - Update kernel_explorer. --- cmake/onnxruntime_kernel_explorer.cmake | 5 + .../rocm/diffusion/group_norm_triton.cuh | 2 +- .../providers/rocm/tunable/gemm_rocblas.h | 562 +++++++----------- .../providers/rocm/tunable/gemm_tunable.cuh | 36 +- .../tools/kernel_explorer/device_array.h | 4 + .../tools/kernel_explorer/kernel_explorer.cc | 3 +- .../kernels/_kernel_explorer.pyi | 1 + .../kernels/batched_gemm_test.py | 5 + .../kernel_explorer/kernels/gemm_test.py | 4 +- .../kernels/rocm/gemm_rocblas.cc | 93 ++- .../kernels/strided_batched_gemm_test.py | 3 + 11 files changed, 322 insertions(+), 396 deletions(-) diff --git a/cmake/onnxruntime_kernel_explorer.cmake b/cmake/onnxruntime_kernel_explorer.cmake index 09b153e998..856fed40ab 100644 --- a/cmake/onnxruntime_kernel_explorer.cmake +++ b/cmake/onnxruntime_kernel_explorer.cmake @@ -75,6 +75,11 @@ elseif (onnxruntime_USE_ROCM) if (onnxruntime_USE_HIPBLASLT) target_compile_definitions(kernel_explorer PRIVATE USE_HIPBLASLT) endif() + if (onnxruntime_USE_ROCBLAS_EXTENSION_API) + target_compile_definitions(kernel_explorer PRIVATE USE_ROCBLAS_EXTENSION_API) + target_compile_definitions(kernel_explorer PRIVATE ROCBLAS_NO_DEPRECATED_WARNINGS) + target_compile_definitions(kernel_explorer PRIVATE ROCBLAS_BETA_FEATURES_API) + endif() endif() add_dependencies(kernel_explorer onnxruntime_pybind11_state) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index 7ba0eefb9c..526d220d4b 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -50,7 +50,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { params->cPerGroup > block_size || params->cPerGroup * 2 <= block_size, "Arg block_size (", block_size, ") is not the next power of 2 of cPerGroup (", params->cPerGroup, ")."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->hw % hw_size != 0, "Arg hw_size (", hw_size ") is not a divisor of hw (", params->hw, ")."); + params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ")."); if constexpr (WithSwish) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->withSwish, "Swish version does not support GN w/o swish."); } else { diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h b/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h index d36bc41527..068db18332 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h @@ -36,50 +36,38 @@ class RocblasHandleStreamGuard { #ifdef USE_ROCBLAS_EXTENSION_API template -constexpr rocblas_datatype RocBlasDataTypeFor(const T*) { - static_assert(sizeof(T) == -1, "Unsupported type for rocBLAS operation."); - // The code below should be unreachable due to the static_assert above. - // But the compiler doesn't like not having a return statement, so we - // return something sensible. +constexpr rocblas_datatype RocBlasDataTypeFor(); + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { return rocblas_datatype_f32_r; } template <> -constexpr rocblas_datatype RocBlasDataTypeFor(const float*) { - return rocblas_datatype_f32_r; -} - -template <> -constexpr rocblas_datatype RocBlasDataTypeFor(const half*) { +constexpr rocblas_datatype RocBlasDataTypeFor() { return rocblas_datatype_f16_r; } template <> -constexpr rocblas_datatype RocBlasDataTypeFor(const double*) { +constexpr rocblas_datatype RocBlasDataTypeFor() { return rocblas_datatype_f64_r; } template <> -constexpr rocblas_datatype RocBlasDataTypeFor(const BFloat16*) { +constexpr rocblas_datatype RocBlasDataTypeFor() { return rocblas_datatype_bf16_r; } template -constexpr rocblas_datatype RocBlasComputeTypeFor(const T*) { - static_assert(sizeof(T) == -1, "Unsupported type for rocBLAS operation."); - // The code below should be unreachable due to the static_assert above. - // But the compiler doesn't like not having a return statement, so we - // return something sensible. +constexpr rocblas_datatype RocBlasComputeTypeFor(); + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { return rocblas_datatype_f32_r; } template <> -constexpr rocblas_datatype RocBlasComputeTypeFor(const float*) { - return rocblas_datatype_f32_r; -} - -template <> -constexpr rocblas_datatype RocBlasComputeTypeFor(const half*) { +constexpr rocblas_datatype RocBlasComputeTypeFor() { // Note that we're returning the _compute_ type for a given datatype. // As of 12/2022, using compute type FP16 for 16-bit floats was much // slower than using compute type FP32. So we use FP32 compute even for @@ -89,12 +77,12 @@ constexpr rocblas_datatype RocBlasComputeTypeFor(const half*) { } template <> -constexpr rocblas_datatype RocBlasComputeTypeFor(const double*) { +constexpr rocblas_datatype RocBlasComputeTypeFor() { return rocblas_datatype_f64_r; } template <> -constexpr rocblas_datatype RocBlasComputeTypeFor(const BFloat16*) { +constexpr rocblas_datatype RocBlasComputeTypeFor() { // Note that we're returning the _compute_ type for a given datatype. // As of 12/2022, using compute type FP16 for 16-bit floats was much // slower than using compute type FP32. So we use FP32 compute even for @@ -109,359 +97,219 @@ auto DoCastForHalfOrBfloat16(const T fp) { } template <> -auto DoCastForHalfOrBfloat16(const half fp) { +inline auto DoCastForHalfOrBfloat16(const half fp) { // alpha and beta should be the same as compute_type, in half case it is float. float h = onnxruntime::math::halfToFloat(*reinterpret_cast(&fp)); return h; } template <> -auto DoCastForHalfOrBfloat16(const BFloat16 fp) { +inline auto DoCastForHalfOrBfloat16(const BFloat16 fp) { // alpha and beta should be the same as compute_type, in bfloat16 case it is float. float h = fp.ToFloat(); return h; } template -class IndexedRocBlasGemmOp { - public: - IndexedRocBlasGemmOp() - : index_(0) {} - IndexedRocBlasGemmOp(int index) - : index_(index) {} +auto GetRocBlasGemmTypeStringAndOps() { + rocblas_handle handle; + ROCBLAS_CALL_THROW(rocblas_create_handle(&handle)); - Status operator()(const GemmParams* params) { - RocblasHandleStreamGuard guard(params->handle, params->stream); - auto h_a = DoCastForHalfOrBfloat16(params->alpha); - auto h_b = DoCastForHalfOrBfloat16(params->beta); - return ROCBLAS_CALL( - rocblas_gemm_ex( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &h_a, - params->b, RocBlasDataTypeFor(params->b), params->ldb, - params->a, RocBlasDataTypeFor(params->a), params->lda, - &h_b, - params->c, RocBlasDataTypeFor(params->c), params->ldc, - params->c, RocBlasDataTypeFor(params->c), params->ldc, - RocBlasComputeTypeFor(params->a), - rocblas_gemm_algo_solution_index, - index_, - rocblas_gemm_flags_none)); + int solution_size; + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + + // Get the number of available solutions + ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + nullptr, + &solution_size)); + + std::vector solutions(solution_size); + + // Get the list of available solutions + ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + solutions.data(), + &solution_size)); + + ROCBLAS_CALL_THROW(rocblas_destroy_handle(handle)); + + std::vector>>> ret; + for (auto solution : solutions) { + auto rocblas_gemm_op = [=](const GemmParams* params) -> Status { + auto h_a = DoCastForHalfOrBfloat16(params->alpha); + auto h_b = DoCastForHalfOrBfloat16(params->beta); + auto status = rocblas_gemm_ex( + params->handle, + params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, + params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, + params->n, params->m, params->k, + &h_a, + params->b, input_output_type, params->ldb, + params->a, input_output_type, params->lda, + &h_b, + params->c, input_output_type, params->ldc, + params->c, input_output_type, params->ldc, + compute_type, + rocblas_gemm_algo_solution_index, + solution, + rocblas_gemm_flags_none); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE."); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + status != rocblas_status_success, "Solution ", solution, " failed."); + + return Status::OK(); + }; + ret.emplace_back(std::make_pair(onnxruntime::MakeString("RocBlasGemm_", solution), std::move(rocblas_gemm_op))); } - - Status IsSupported(const GemmParams*) { - return Status::OK(); - } - - private: - int index_; -}; + return ret; +} template -class RocBlasGemmTunableOp : public TunableOp> { - public: - RocBlasGemmTunableOp() { - // Ensure that the default implementation is always present - this->RegisterOp(IndexedRocBlasGemmOp{0}); +auto GetRocBlasBatchedGemmTypeStringAndOps() { + rocblas_handle handle; + ROCBLAS_CALL_THROW(rocblas_create_handle(&handle)); + + int solution_size; + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + + // Get the number of available solutions + ROCBLAS_CALL_THROW(rocblas_gemm_batched_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + nullptr, + &solution_size)); + + std::vector solutions(solution_size); + + // Get the list of available solutions + ROCBLAS_CALL_THROW(rocblas_gemm_batched_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + solutions.data(), + &solution_size)); + + ROCBLAS_CALL_THROW(rocblas_destroy_handle(handle)); + + std::vector>>> ret; + for (auto solution : solutions) { + auto rocblas_gemm_op = [=](const BatchedGemmParams* params) -> Status { + auto h_a = DoCastForHalfOrBfloat16(params->alpha); + auto h_b = DoCastForHalfOrBfloat16(params->beta); + auto status = rocblas_gemm_batched_ex( + params->handle, + params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, + params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, + params->n, params->m, params->k, + &h_a, + params->bs, input_output_type, params->ldb, + params->as, input_output_type, params->lda, + &h_b, + params->cs, input_output_type, params->ldc, + params->cs, input_output_type, params->ldc, + params->batch, + compute_type, + rocblas_gemm_algo_solution_index, + solution, + rocblas_gemm_flags_none); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE."); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + status != rocblas_status_success, "Solution ", solution, " failed."); + + return Status::OK(); + }; + ret.emplace_back(std::make_pair( + onnxruntime::MakeString("RocBlasBatchedGemm_", solution), std::move(rocblas_gemm_op))); } - - Status IsSupported(const GemmParams* params) { - ORT_UNUSED_PARAMETER(params); - return Status::OK(); - } - - protected: - virtual int FindFastest(const GemmParams* params) override { - auto solution_indices = this->GetSolutions(params); - std::vector>> candidates; - for (int solution_idx : solution_indices) { - candidates.emplace_back(IndexedRocBlasGemmOp{solution_idx}); - } - - auto id = this->FindFastestImpl(params, candidates); - // memoize the result - this->RegisterOp(std::move(candidates[id])); - return this->NumberOfOps() - 1; - } - - private: - std::vector GetSolutions(const GemmParams* params) { - int num_solutions = 0; - auto h_a = DoCastForHalfOrBfloat16(params->alpha); - auto h_b = DoCastForHalfOrBfloat16(params->beta); - // Get the number of candidate solutions - ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &h_a, - params->b, RocBlasDataTypeFor(params->b), params->ldb, - params->a, RocBlasDataTypeFor(params->a), params->lda, - &h_b, - params->c, RocBlasDataTypeFor(params->c), params->ldc, - params->c, RocBlasDataTypeFor(params->c), params->ldc, - RocBlasComputeTypeFor(params->a), - rocblas_gemm_algo_solution_index, - rocblas_gemm_flags_none, - NULL, - &num_solutions)); - - // Get the actual candidate solutions - std::vector solutions(num_solutions); - ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &h_a, - params->b, RocBlasDataTypeFor(params->b), params->ldb, - params->a, RocBlasDataTypeFor(params->a), params->lda, - &h_b, - params->c, RocBlasDataTypeFor(params->c), params->ldc, - params->c, RocBlasDataTypeFor(params->c), params->ldc, - RocBlasComputeTypeFor(params->a), - rocblas_gemm_algo_solution_index, - rocblas_gemm_flags_none, - solutions.data(), - &num_solutions)); - - return solutions; - } -}; + return ret; +} template -class IndexedRocBlasBatchedGemmOp { - public: - IndexedRocBlasBatchedGemmOp() - : index_(0) {} - IndexedRocBlasBatchedGemmOp(int index) - : index_(index) {} +auto GetRocBlasStridedBatchedGemmTypeStringAndOps() { + rocblas_handle handle; + ROCBLAS_CALL_THROW(rocblas_create_handle(&handle)); - Status operator()(const BatchedGemmParams* params) { - RocblasHandleStreamGuard guard(params->handle, params->stream); - auto h_a = DoCastForHalfOrBfloat16(params->alpha); - auto h_b = DoCastForHalfOrBfloat16(params->beta); - return ROCBLAS_CALL( - rocblas_gemm_batched_ex( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &h_a, - params->bs, RocBlasDataTypeFor(*(params->bs)), params->ldb, - params->as, RocBlasDataTypeFor(*(params->as)), params->lda, - &h_b, - params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc, - params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc, - params->batch, - RocBlasComputeTypeFor(*(params->as)), - rocblas_gemm_algo_solution_index, - index_, - rocblas_gemm_flags_none)); + int solution_size; + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + + // Get the number of available solutions + ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + nullptr, + &solution_size)); + + std::vector solutions(solution_size); + + // Get the list of available solutions + ROCBLAS_CALL_THROW(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + solutions.data(), + &solution_size)); + + ROCBLAS_CALL_THROW(rocblas_destroy_handle(handle)); + + std::vector>>> ret; + for (auto solution : solutions) { + auto rocblas_gemm_op = [=](const StridedBatchedGemmParams* params) -> Status { + auto h_a = DoCastForHalfOrBfloat16(params->alpha); + auto h_b = DoCastForHalfOrBfloat16(params->beta); + auto status = rocblas_gemm_strided_batched_ex( + params->handle, + params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, + params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, + params->n, params->m, params->k, + &h_a, + params->b, input_output_type, params->ldb, params->stride_b, + params->a, input_output_type, params->lda, params->stride_a, + &h_b, + params->c, input_output_type, params->ldc, params->stride_c, + params->c, input_output_type, params->ldc, params->stride_c, + params->batch, + compute_type, + rocblas_gemm_algo_solution_index, + solution, + rocblas_gemm_flags_none); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE."); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + status != rocblas_status_success, "Solution ", solution, " failed."); + + return Status::OK(); + }; + ret.emplace_back(std::make_pair( + onnxruntime::MakeString("RocBlasStridedBatchedGemm_", solution), std::move(rocblas_gemm_op))); } + return ret; +} - Status IsSupported(const BatchedGemmParams*) { - return Status::OK(); - } - - private: - int index_; -}; - -template -class RocBlasBatchedGemmTunableOp : public TunableOp> { - public: - RocBlasBatchedGemmTunableOp() { - // Ensure that the default implementation is always present - this->RegisterOp(IndexedRocBlasBatchedGemmOp{0}); - } - - Status IsSupported(const BatchedGemmParams* params) { - ORT_UNUSED_PARAMETER(params); - return Status::OK(); - } - - protected: - virtual int FindFastest(const BatchedGemmParams* params) override { - auto solution_indices = this->GetSolutions(params); - std::vector>> candidates; - for (int solution_idx : solution_indices) { - candidates.emplace_back(IndexedRocBlasBatchedGemmOp{solution_idx}); - } - - auto id = this->FindFastestImpl(params, candidates); - // memoize the result - this->RegisterOp(std::move(candidates[id])); - return this->NumberOfOps() - 1; - } - - private: - std::vector GetSolutions(const BatchedGemmParams* params) { - int num_solutions = 0; - auto h_a = DoCastForHalfOrBfloat16(params->alpha); - auto h_b = DoCastForHalfOrBfloat16(params->beta); - // Get the number of candidate solutions - ROCBLAS_CALL_THROW(rocblas_gemm_batched_ex_get_solutions( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &h_a, - params->bs, RocBlasDataTypeFor(*(params->bs)), params->ldb, - params->as, RocBlasDataTypeFor(*(params->as)), params->lda, - &h_b, - params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc, - params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc, - params->batch, - RocBlasComputeTypeFor(*(params->as)), - rocblas_gemm_algo_solution_index, - rocblas_gemm_flags_none, - NULL, - &num_solutions)); - - // Get the actual candidate solutions - std::vector solutions(num_solutions); - ROCBLAS_CALL_THROW(rocblas_gemm_batched_ex_get_solutions( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &h_a, - params->bs, RocBlasDataTypeFor(*(params->bs)), params->ldb, - params->as, RocBlasDataTypeFor(*(params->as)), params->lda, - &h_b, - params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc, - params->cs, RocBlasDataTypeFor(*(params->cs)), params->ldc, - params->batch, - RocBlasComputeTypeFor(*(params->as)), - rocblas_gemm_algo_solution_index, - rocblas_gemm_flags_none, - solutions.data(), - &num_solutions)); - - return solutions; - } -}; - -template -class IndexedRocBlasStridedBatchedGemmOp { - public: - IndexedRocBlasStridedBatchedGemmOp() - : index_(0) {} - IndexedRocBlasStridedBatchedGemmOp(int index) - : index_(index) {} - - Status operator()(const StridedBatchedGemmParams* params) { - RocblasHandleStreamGuard guard(params->handle, params->stream); - auto h_a = DoCastForHalfOrBfloat16(params->alpha); - auto h_b = DoCastForHalfOrBfloat16(params->beta); - return ROCBLAS_CALL( - rocblas_gemm_strided_batched_ex( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &h_a, - params->b, RocBlasDataTypeFor(params->b), params->ldb, params->stride_b, - params->a, RocBlasDataTypeFor(params->a), params->lda, params->stride_a, - &h_b, - params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c, - params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c, - params->batch, - RocBlasComputeTypeFor(params->a), - rocblas_gemm_algo_solution_index, - index_, - rocblas_gemm_flags_none)); - } - - Status IsSupported(const StridedBatchedGemmParams*) { - return Status::OK(); - } - - private: - int index_; -}; - -template -class RocBlasStridedBatchedGemmTunableOp : public TunableOp> { - public: - RocBlasStridedBatchedGemmTunableOp() { - // Ensure that the default implementation is always present - this->RegisterOp(IndexedRocBlasStridedBatchedGemmOp{0}); - } - - Status IsSupported(const StridedBatchedGemmParams* params) { - ORT_UNUSED_PARAMETER(params); - return Status::OK(); - } - - protected: - virtual int FindFastest(const StridedBatchedGemmParams* params) override { - auto solution_indices = this->GetSolutions(params); - std::vector>> candidates; - for (int solution_idx : solution_indices) { - candidates.emplace_back(IndexedRocBlasStridedBatchedGemmOp{solution_idx}); - } - - auto id = this->FindFastestImpl(params, candidates); - // memoize the result - this->RegisterOp(std::move(candidates[id])); - return this->NumberOfOps() - 1; - } - - private: - std::vector GetSolutions(const StridedBatchedGemmParams* params) { - int num_solutions = 0; - auto h_a = DoCastForHalfOrBfloat16(params->alpha); - auto h_b = DoCastForHalfOrBfloat16(params->beta); - // Get the number of candidate solutions - ROCBLAS_CALL_THROW(rocblas_gemm_strided_batched_ex_get_solutions( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &h_a, - params->b, RocBlasDataTypeFor(params->b), params->ldb, params->stride_b, - params->a, RocBlasDataTypeFor(params->a), params->lda, params->stride_a, - &h_b, - params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c, - params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c, - params->batch, - RocBlasComputeTypeFor(params->a), - rocblas_gemm_algo_solution_index, - rocblas_gemm_flags_none, - NULL, - &num_solutions)); - - // Get the actual candidate solutions - std::vector solutions(num_solutions); - ROCBLAS_CALL_THROW(rocblas_gemm_strided_batched_ex_get_solutions( - params->handle, - params->opb == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->opa == BlasOp::N ? rocblas_operation_none : rocblas_operation_transpose, - params->n, params->m, params->k, - &h_a, - params->b, RocBlasDataTypeFor(params->b), params->ldb, params->stride_b, - params->a, RocBlasDataTypeFor(params->a), params->lda, params->stride_a, - &h_b, - params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c, - params->c, RocBlasDataTypeFor(params->c), params->ldc, params->stride_c, - params->batch, - RocBlasComputeTypeFor(params->a), - rocblas_gemm_algo_solution_index, - rocblas_gemm_flags_none, - solutions.data(), - &num_solutions)); - - return solutions; - } -}; - -#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */ +#endif // USE_ROCBLAS_EXTENSION_API template Status RocBlasGemmOp(const GemmParams* params) { diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh index 5a46654b61..2a0642bbee 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh +++ b/onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh @@ -44,8 +44,11 @@ class GemmTunableOp : public TunableOp> { #endif #ifdef USE_ROCBLAS_EXTENSION_API - this->RegisterNestedTunableOp(&rocblas_gemm_tunable_op_); -#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */ + for (auto&& [_, op] : GetRocBlasGemmTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#endif #ifdef USE_COMPOSABLE_KERNEL for (auto&& [_, op] : GetCKGemmTypeStringAndOps()) { @@ -79,11 +82,6 @@ class GemmTunableOp : public TunableOp> { delete params; } } - - private: -#ifdef USE_ROCBLAS_EXTENSION_API - RocBlasGemmTunableOp rocblas_gemm_tunable_op_; -#endif }; template @@ -93,8 +91,11 @@ class BatchedGemmTunableOp : public TunableOp> { this->RegisterOp(RocBlasBatchedGemmOp); #ifdef USE_ROCBLAS_EXTENSION_API - this->RegisterNestedTunableOp(&rocblas_batched_gemm_tunable_op_); -#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */ + for (auto&& [_, op] : GetRocBlasBatchedGemmTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#endif } const BatchedGemmParams* PreTuning(const BatchedGemmParams* params) override { @@ -131,11 +132,6 @@ class BatchedGemmTunableOp : public TunableOp> { delete params; } } - - private: -#ifdef USE_ROCBLAS_EXTENSION_API - RocBlasBatchedGemmTunableOp rocblas_batched_gemm_tunable_op_; -#endif }; template @@ -149,8 +145,11 @@ class StridedBatchedGemmTunableOp : public TunableOp #endif #ifdef USE_ROCBLAS_EXTENSION_API - this->RegisterNestedTunableOp(&rocblas_strided_batched_gemm_tunable_op_); -#endif /* #ifdef USE_ROCBLAS_EXTENSION_API */ + for (auto&& [_, op] : GetRocBlasStridedBatchedGemmTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#endif #ifdef USE_COMPOSABLE_KERNEL for (auto&& [_, op] : GetCKStridedBatchedGemmTypeStringAndOps()) { @@ -178,11 +177,6 @@ class StridedBatchedGemmTunableOp : public TunableOp delete params; } } - - private: -#ifdef USE_ROCBLAS_EXTENSION_API - RocBlasStridedBatchedGemmTunableOp rocblas_strided_batched_gemm_tunable_op_; -#endif }; } // namespace internal diff --git a/onnxruntime/python/tools/kernel_explorer/device_array.h b/onnxruntime/python/tools/kernel_explorer/device_array.h index 89784543e1..bb868c2b7a 100644 --- a/onnxruntime/python/tools/kernel_explorer/device_array.h +++ b/onnxruntime/python/tools/kernel_explorer/device_array.h @@ -51,6 +51,10 @@ class DeviceArray { CALL_THROW(MEMCPY(host_, device_.get(), size_ * itemsize_, MEMCPY_DEVICE_TO_HOST)); } + void UpdateDeviceArray() { + CALL_THROW(MEMCPY(device_.get(), host_, size_ * itemsize_, MEMCPY_HOST_TO_DEVICE)); + } + void* ptr() const { return device_.get(); } diff --git a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc index 553b237b5a..34152995c3 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc @@ -32,7 +32,8 @@ PYBIND11_PLUGIN_IMPL(_kernel_explorer) { KE_REGISTER(m) { py::class_(m, "DeviceArray") .def(py::init()) - .def("UpdateHostNumpyArray", &DeviceArray::UpdateHostNumpyArray); + .def("UpdateHostNumpyArray", &DeviceArray::UpdateHostNumpyArray) + .def("UpdateDeviceArray", &DeviceArray::UpdateDeviceArray); m.def("is_composable_kernel_available", []() { #ifdef USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/_kernel_explorer.pyi b/onnxruntime/python/tools/kernel_explorer/kernels/_kernel_explorer.pyi index 22122a605f..94213aceed 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/_kernel_explorer.pyi +++ b/onnxruntime/python/tools/kernel_explorer/kernels/_kernel_explorer.pyi @@ -1,6 +1,7 @@ class DeviceArray: def __init__(self, ndarray) -> None: ... def UpdateHostNumpyArray(self) -> None: ... # noqa: N802 + def UpdateDeviceArray(self) -> None: ... # noqa: N802 class blas_op: # noqa: N801 T: int diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py index d183825da0..73323d767a 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py @@ -66,6 +66,11 @@ def _test_batched_gemm( if not my_gemm.SelectOp(impl): continue + # Restore C Arrays + for my_c in my_cs: + my_c.fill(1.0) + for dev_c in dev_cs: + dev_c.UpdateDeviceArray() my_gemm.Run() for dev_c in dev_cs: dev_c.UpdateHostNumpyArray() diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py index 445b823a70..9b1a3b9760 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py @@ -50,7 +50,9 @@ def _test_gemm(func, dtype: str, transa: bool, transb: bool, m: int, n: int, k: for impl in my_gemm.ListOps(): if not my_gemm.SelectOp(impl): continue - + # Restore C Array + my_c.fill(1.0) + dev_c.UpdateDeviceArray() my_gemm.Run() dev_c.UpdateHostNumpyArray() diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.cc b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.cc index 673e04621d..8c3aceb3f7 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_rocblas.cc @@ -30,7 +30,8 @@ class RocBlasGemm : public IKernelExplorer { DeviceArray& a, int64_t lda, DeviceArray& b, int64_t ldb, double beta, - DeviceArray& c, int64_t ldc) { + DeviceArray& c, int64_t ldc) + : params_{} { ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); params_.tuning_ctx = TuningContext(); params_.stream = Stream(); @@ -48,6 +49,16 @@ class RocBlasGemm : public IKernelExplorer { params_.beta = beta; params_.c = static_cast(c.ptr()); params_.ldc = ldc; + + type_strings_.emplace_back("RocBlasGemmDefault"); + ops_.emplace_back([](auto* params) { return RocBlasGemmOp(params); }); + +#ifdef USE_ROCBLAS_EXTENSION_API + for (auto&& [type_string, op] : GetRocBlasGemmTypeStringAndOps()) { + type_strings_.emplace_back(std::move(type_string)); + ops_.emplace_back(std::move(op)); + } +#endif } ~RocBlasGemm() { @@ -56,15 +67,23 @@ class RocBlasGemm : public IKernelExplorer { } void Run() override { - ORT_THROW_IF_ERROR(op_(¶ms_)); + ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); } std::vector ListOps() const { - return {"Rocblas"}; + return type_strings_; } bool SelectOp(const std::string& name) { - return name == "Rocblas"; + for (size_t i = 0; i < ops_.size(); i++) { + if (type_strings_[i] == name) { + selected_op_ = i; + Status status = ops_[i](¶ms_); + return status.IsOK(); + } + } + + ORT_THROW("Cannot find implementation ", name); } private: @@ -74,7 +93,9 @@ class RocBlasGemm : public IKernelExplorer { using OpT = Op; ParamsT params_{}; - OpT op_{RocBlasGemmOp}; + std::vector ops_; + std::vector type_strings_; + size_t selected_op_{}; }; template @@ -87,7 +108,8 @@ class RocBlasBatchedGemm : public IBatchedGemmKernelExplorer { std::vector& bs, int64_t ldb, double beta, std::vector& cs, int64_t ldc, - int64_t batch) { + int64_t batch) + : params_{} { this->CopyAsBsCsPointersToDevice(as, bs, cs, batch); ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); params_.tuning_ctx = this->TuningContext(); @@ -107,6 +129,16 @@ class RocBlasBatchedGemm : public IBatchedGemmKernelExplorer { params_.cs = this->dev_cs_.get(); params_.ldc = ldc; params_.batch = batch; + + type_strings_.emplace_back("RocBlasBatchedGemmDefault"); + ops_.emplace_back([](auto* params) { return RocBlasBatchedGemmOp(params); }); + +#ifdef USE_ROCBLAS_EXTENSION_API + for (auto&& [type_string, op] : GetRocBlasBatchedGemmTypeStringAndOps()) { + type_strings_.emplace_back(std::move(type_string)); + ops_.emplace_back(std::move(op)); + } +#endif } ~RocBlasBatchedGemm() { @@ -115,15 +147,23 @@ class RocBlasBatchedGemm : public IBatchedGemmKernelExplorer { } void Run() override { - ORT_THROW_IF_ERROR(op_(¶ms_)); + ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); } std::vector ListOps() const { - return {"Rocblas"}; + return type_strings_; } bool SelectOp(const std::string& name) { - return name == "Rocblas"; + for (size_t i = 0; i < ops_.size(); i++) { + if (type_strings_[i] == name) { + selected_op_ = i; + Status status = ops_[i](¶ms_); + return status.IsOK(); + } + } + + ORT_THROW("Cannot find implementation ", name); } private: @@ -133,7 +173,9 @@ class RocBlasBatchedGemm : public IBatchedGemmKernelExplorer { using OpT = Op; ParamsT params_{}; - OpT op_{RocBlasBatchedGemmOp}; + std::vector ops_; + std::vector type_strings_; + size_t selected_op_{}; }; template @@ -146,7 +188,8 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer { DeviceArray& b, int64_t ldb, int64_t stride_b, double beta, DeviceArray& c, int64_t ldc, int64_t stride_c, - int64_t batch) { + int64_t batch) + : params_{} { ROCBLAS_CALL_THROW(rocblas_create_handle(&rocblas_handle_)); params_.tuning_ctx = TuningContext(); params_.stream = Stream(); @@ -168,6 +211,16 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer { params_.ldc = ldc; params_.stride_c = stride_c; params_.batch = batch; + + type_strings_.emplace_back("RocBlasStridedBatchedGemmDefault"); + ops_.emplace_back([](auto* params) { return RocBlasStridedBatchedGemmOp(params); }); + +#ifdef USE_ROCBLAS_EXTENSION_API + for (auto&& [type_string, op] : GetRocBlasStridedBatchedGemmTypeStringAndOps()) { + type_strings_.emplace_back(std::move(type_string)); + ops_.emplace_back(std::move(op)); + } +#endif } ~RocBlasStridedBatchedGemm() { @@ -176,15 +229,23 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer { } void Run() override { - ORT_THROW_IF_ERROR(op_(¶ms_)); + ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); } std::vector ListOps() const { - return {"Rocblas"}; + return type_strings_; } bool SelectOp(const std::string& name) { - return name == "Rocblas"; + for (size_t i = 0; i < ops_.size(); i++) { + if (type_strings_[i] == name) { + selected_op_ = i; + Status status = ops_[i](¶ms_); + return status.IsOK(); + } + } + + ORT_THROW("Cannot find implementation ", name); } private: @@ -194,7 +255,9 @@ class RocBlasStridedBatchedGemm : public IKernelExplorer { using OpT = Op; ParamsT params_{}; - OpT op_{RocBlasStridedBatchedGemmOp}; + std::vector ops_; + std::vector type_strings_; + size_t selected_op_{}; }; #define REGISTER_OP_COMMON(type, dtype) \ diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py index 312592e2f4..7114f3b6d9 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py @@ -73,6 +73,9 @@ def _test_strided_batched_gemm( if not my_gemm.SelectOp(impl): continue + # Restore C Array + my_c.fill(1.0) + dev_c.UpdateDeviceArray() my_gemm.Run() dev_c.UpdateHostNumpyArray()