add build flag for rocblas tune and fix bug (#14797)

### Description
<!-- Describe your changes. -->
1. add a build flag for rocblas tuning feature.

2. fix a build bug when enable rocblas tuning.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
The rocblas tunning feature has no build flag to control, only using a
MACRO flag.

So I add an build flag, and fix a code bug when enable rocblas tunning.
This commit is contained in:
kailums 2023-02-28 10:37:07 +08:00 committed by GitHub
parent 2d079c6333
commit 9bdd42115c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 4 deletions

View file

@ -202,6 +202,7 @@ option(onnxruntime_ENABLE_ATEN "Enable ATen fallback" OFF)
# composable kernel is managed automatically, unless user want to explicitly disable it, it should not be manually set
option(onnxruntime_USE_COMPOSABLE_KERNEL "Enable composable kernel for ROCm EP" ON)
option(onnxruntime_USE_ROCBLAS_EXTENSION_API "Enable rocblas tuning for ROCm EP" OFF)
option(onnxruntime_BUILD_KERNEL_EXPLORER "Build Kernel Explorer for testing and profiling GPU kernels" OFF)
option(onnxruntime_BUILD_CACHE "onnxruntime build with cache" OFF)

View file

@ -1449,6 +1449,12 @@ if (onnxruntime_USE_ROCM)
#endif()
endif()
if (onnxruntime_USE_ROCBLAS_EXTENSION_API)
target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_ROCBLAS_EXTENSION_API)
target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_NO_DEPRECATED_WARNINGS)
target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_BETA_FEATURES_API)
endif()
if (onnxruntime_USE_COMPOSABLE_KERNEL)
include(composable_kernel)
target_link_libraries(onnxruntime_providers_rocm PRIVATE

View file

@ -190,6 +190,10 @@ class TunableOp {
this->ops_.emplace_back(std::move(op));
}
int NumberOfOps() {
return this->ops_.size();
}
void RegisterNestedTunableOp(TunableOp<ParamsT, TimerT>* op_ptr) {
nested_tunable_ops_.insert(op_ptr);

View file

@ -126,7 +126,7 @@ class IndexedRocBlasGemmOp {
params->c, RocBlasDataTypeFor(params->c), params->ldc,
params->c, RocBlasDataTypeFor(params->c), params->ldc,
RocBlasComputeTypeFor(params->a),
rocblas_gemm_algo_standard,
rocblas_gemm_algo_solution_index,
index_,
rocblas_gemm_flags_none));
}
@ -163,7 +163,7 @@ class RocBlasGemmTunableOp : public TunableOp<GemmParams<T>> {
auto id = this->FindFastestImpl(params, candidates);
// memoize the result
this->RegisterOp(std::move(candidates[id]));
return this->ops_.size() - 1;
return this->NumberOfOps() - 1;
}
private:
@ -182,7 +182,7 @@ class RocBlasGemmTunableOp : public TunableOp<GemmParams<T>> {
params->c, RocBlasDataTypeFor(params->c), params->ldc,
params->c, RocBlasDataTypeFor(params->c), params->ldc,
RocBlasComputeTypeFor(params->a),
rocblas_gemm_algo_standard,
rocblas_gemm_algo_solution_index,
rocblas_gemm_flags_none,
NULL,
&num_solutions));
@ -201,7 +201,7 @@ class RocBlasGemmTunableOp : public TunableOp<GemmParams<T>> {
params->c, RocBlasDataTypeFor(params->c), params->ldc,
params->c, RocBlasDataTypeFor(params->c), params->ldc,
RocBlasComputeTypeFor(params->a),
rocblas_gemm_algo_standard,
rocblas_gemm_algo_solution_index,
rocblas_gemm_flags_none,
solutions.data(),
&num_solutions));