From 9bdd42115c786b67dd23a99bab3ddee7c756c6ed Mon Sep 17 00:00:00 2001 From: kailums <109063327+kailums@users.noreply.github.com> Date: Tue, 28 Feb 2023 10:37:07 +0800 Subject: [PATCH] add build flag for rocblas tune and fix bug (#14797) ### Description 1. add a build flag for rocblas tuning feature. 2. fix a build bug when enable rocblas tuning. ### Motivation and Context The rocblas tunning feature has no build flag to control, only using a MACRO flag. So I add an build flag, and fix a code bug when enable rocblas tunning. --- cmake/CMakeLists.txt | 1 + cmake/onnxruntime_providers.cmake | 6 ++++++ onnxruntime/core/framework/tunable.h | 4 ++++ onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h | 8 ++++---- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index de5d39a6bc..5510f656be 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -202,6 +202,7 @@ option(onnxruntime_ENABLE_ATEN "Enable ATen fallback" OFF) # composable kernel is managed automatically, unless user want to explicitly disable it, it should not be manually set option(onnxruntime_USE_COMPOSABLE_KERNEL "Enable composable kernel for ROCm EP" ON) +option(onnxruntime_USE_ROCBLAS_EXTENSION_API "Enable rocblas tuning for ROCm EP" OFF) option(onnxruntime_BUILD_KERNEL_EXPLORER "Build Kernel Explorer for testing and profiling GPU kernels" OFF) option(onnxruntime_BUILD_CACHE "onnxruntime build with cache" OFF) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 0b9faf8849..b870896904 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1449,6 +1449,12 @@ if (onnxruntime_USE_ROCM) #endif() endif() + if (onnxruntime_USE_ROCBLAS_EXTENSION_API) + target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_ROCBLAS_EXTENSION_API) + target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_NO_DEPRECATED_WARNINGS) + target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_BETA_FEATURES_API) + endif() + if (onnxruntime_USE_COMPOSABLE_KERNEL) include(composable_kernel) target_link_libraries(onnxruntime_providers_rocm PRIVATE diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index 1173698c46..1af897151a 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -190,6 +190,10 @@ class TunableOp { this->ops_.emplace_back(std::move(op)); } + int NumberOfOps() { + return this->ops_.size(); + } + void RegisterNestedTunableOp(TunableOp* op_ptr) { nested_tunable_ops_.insert(op_ptr); diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h b/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h index cd1b73b858..a732f9c0c2 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h @@ -126,7 +126,7 @@ class IndexedRocBlasGemmOp { params->c, RocBlasDataTypeFor(params->c), params->ldc, params->c, RocBlasDataTypeFor(params->c), params->ldc, RocBlasComputeTypeFor(params->a), - rocblas_gemm_algo_standard, + rocblas_gemm_algo_solution_index, index_, rocblas_gemm_flags_none)); } @@ -163,7 +163,7 @@ class RocBlasGemmTunableOp : public TunableOp> { auto id = this->FindFastestImpl(params, candidates); // memoize the result this->RegisterOp(std::move(candidates[id])); - return this->ops_.size() - 1; + return this->NumberOfOps() - 1; } private: @@ -182,7 +182,7 @@ class RocBlasGemmTunableOp : public TunableOp> { params->c, RocBlasDataTypeFor(params->c), params->ldc, params->c, RocBlasDataTypeFor(params->c), params->ldc, RocBlasComputeTypeFor(params->a), - rocblas_gemm_algo_standard, + rocblas_gemm_algo_solution_index, rocblas_gemm_flags_none, NULL, &num_solutions)); @@ -201,7 +201,7 @@ class RocBlasGemmTunableOp : public TunableOp> { params->c, RocBlasDataTypeFor(params->c), params->ldc, params->c, RocBlasDataTypeFor(params->c), params->ldc, RocBlasComputeTypeFor(params->a), - rocblas_gemm_algo_standard, + rocblas_gemm_algo_solution_index, rocblas_gemm_flags_none, solutions.data(), &num_solutions));