mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
add build flag for rocblas tune and fix bug (#14797)
### Description <!-- Describe your changes. --> 1. add a build flag for rocblas tuning feature. 2. fix a build bug when enable rocblas tuning. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> The rocblas tunning feature has no build flag to control, only using a MACRO flag. So I add an build flag, and fix a code bug when enable rocblas tunning.
This commit is contained in:
parent
2d079c6333
commit
9bdd42115c
4 changed files with 15 additions and 4 deletions
|
|
@ -202,6 +202,7 @@ option(onnxruntime_ENABLE_ATEN "Enable ATen fallback" OFF)
|
|||
|
||||
# composable kernel is managed automatically, unless user want to explicitly disable it, it should not be manually set
|
||||
option(onnxruntime_USE_COMPOSABLE_KERNEL "Enable composable kernel for ROCm EP" ON)
|
||||
option(onnxruntime_USE_ROCBLAS_EXTENSION_API "Enable rocblas tuning for ROCm EP" OFF)
|
||||
option(onnxruntime_BUILD_KERNEL_EXPLORER "Build Kernel Explorer for testing and profiling GPU kernels" OFF)
|
||||
|
||||
option(onnxruntime_BUILD_CACHE "onnxruntime build with cache" OFF)
|
||||
|
|
|
|||
|
|
@ -1449,6 +1449,12 @@ if (onnxruntime_USE_ROCM)
|
|||
#endif()
|
||||
endif()
|
||||
|
||||
if (onnxruntime_USE_ROCBLAS_EXTENSION_API)
|
||||
target_compile_definitions(onnxruntime_providers_rocm PRIVATE USE_ROCBLAS_EXTENSION_API)
|
||||
target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_NO_DEPRECATED_WARNINGS)
|
||||
target_compile_definitions(onnxruntime_providers_rocm PRIVATE ROCBLAS_BETA_FEATURES_API)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_USE_COMPOSABLE_KERNEL)
|
||||
include(composable_kernel)
|
||||
target_link_libraries(onnxruntime_providers_rocm PRIVATE
|
||||
|
|
|
|||
|
|
@ -190,6 +190,10 @@ class TunableOp {
|
|||
this->ops_.emplace_back(std::move(op));
|
||||
}
|
||||
|
||||
int NumberOfOps() {
|
||||
return this->ops_.size();
|
||||
}
|
||||
|
||||
void RegisterNestedTunableOp(TunableOp<ParamsT, TimerT>* op_ptr) {
|
||||
nested_tunable_ops_.insert(op_ptr);
|
||||
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ class IndexedRocBlasGemmOp {
|
|||
params->c, RocBlasDataTypeFor(params->c), params->ldc,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc,
|
||||
RocBlasComputeTypeFor(params->a),
|
||||
rocblas_gemm_algo_standard,
|
||||
rocblas_gemm_algo_solution_index,
|
||||
index_,
|
||||
rocblas_gemm_flags_none));
|
||||
}
|
||||
|
|
@ -163,7 +163,7 @@ class RocBlasGemmTunableOp : public TunableOp<GemmParams<T>> {
|
|||
auto id = this->FindFastestImpl(params, candidates);
|
||||
// memoize the result
|
||||
this->RegisterOp(std::move(candidates[id]));
|
||||
return this->ops_.size() - 1;
|
||||
return this->NumberOfOps() - 1;
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -182,7 +182,7 @@ class RocBlasGemmTunableOp : public TunableOp<GemmParams<T>> {
|
|||
params->c, RocBlasDataTypeFor(params->c), params->ldc,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc,
|
||||
RocBlasComputeTypeFor(params->a),
|
||||
rocblas_gemm_algo_standard,
|
||||
rocblas_gemm_algo_solution_index,
|
||||
rocblas_gemm_flags_none,
|
||||
NULL,
|
||||
&num_solutions));
|
||||
|
|
@ -201,7 +201,7 @@ class RocBlasGemmTunableOp : public TunableOp<GemmParams<T>> {
|
|||
params->c, RocBlasDataTypeFor(params->c), params->ldc,
|
||||
params->c, RocBlasDataTypeFor(params->c), params->ldc,
|
||||
RocBlasComputeTypeFor(params->a),
|
||||
rocblas_gemm_algo_standard,
|
||||
rocblas_gemm_algo_solution_index,
|
||||
rocblas_gemm_flags_none,
|
||||
solutions.data(),
|
||||
&num_solutions));
|
||||
|
|
|
|||
Loading…
Reference in a new issue