diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index 3c5090e9a9..52d522b312 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -27,5 +27,6 @@ struct OrtCUDAProviderOptionsV2 { int cudnn_conv_use_max_workspace; // flag specifying if maximum workspace can be used in cudnn conv algo search. int enable_cuda_graph; // flag specifying if the CUDA graph is to be captured for the model. int cudnn_conv1d_pad_to_nc1d; // flag specifying if pad Conv1D's input [N,C,D] to [N,C,1,D] or [N,C,D,1]. - int tunable_op_enabled; // flag specifying if TunableOp is enabled. + int tunable_op_enable; // flag specifying if TunableOp is enabled. + int tunable_op_tuning_enable; // flag specifying if TunableOp is enabled for tuning, this relies on TunableOp is enabled. }; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b56497ea32..0b41ee9ffa 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -387,7 +387,8 @@ typedef struct OrtCUDAProviderOptions { has_user_compute_stream{}, user_compute_stream{}, default_memory_arena_cfg{}, - tunable_op_enabled{false} {} + tunable_op_enable{false}, + tunable_op_tuning_enable{false} {} #endif /** \brief CUDA device Id @@ -438,11 +439,18 @@ typedef struct OrtCUDAProviderOptions { */ OrtArenaCfg* default_memory_arena_cfg; - /** \brief Enable TunableOp. - * Set it to 1 to enable TunableOp. Otherwise, it is disabled by default. - * This option can be superseded by environment variable ORT_CUDA_TUNABLE_OP_ENABLED. + /** \brief Enable TunableOp for using. + * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. + * This option can be overriden by environment variable ORT_CUDA_TUNABLE_OP_ENABLE. */ - int tunable_op_enabled; + int tunable_op_enable; + + /** \brief Enable TunableOp for tuning. + * Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default. + * This option can be overriden by environment variable ORT_CUDA_TUNABLE_OP_TUNING_ENABLE. + */ + int tunable_op_tuning_enable; + } OrtCUDAProviderOptions; @@ -461,7 +469,8 @@ typedef struct OrtROCMProviderOptions { has_user_compute_stream{}, user_compute_stream{}, default_memory_arena_cfg{}, - tunable_op_enabled{false} {} + tunable_op_enable{false}, + tunable_op_tuning_enable{false} {} #endif /** \brief ROCM device Id @@ -511,11 +520,17 @@ typedef struct OrtROCMProviderOptions { */ OrtArenaCfg* default_memory_arena_cfg; - /** \brief Enable TunableOp. - * Set it to 1 to enable TunableOp. Otherwise, it is disabled by default. - * This option can be superseded by environment variable ORT_ROCM_TUNABLE_OP_ENABLED. + /** \brief Enable TunableOp for using. + * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. + * This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE. */ - int tunable_op_enabled; + int tunable_op_enable; + + /** \brief Enable TunableOp for tuning. + * Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default. + * This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_TUNING_ENABLE. + */ + int tunable_op_tuning_enable; } OrtROCMProviderOptions; @@ -4085,6 +4100,10 @@ struct OrtApi { * \since Version 1.15. */ ORT_API2_STATUS(KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, size_t index, _Out_ int* is_constant, _Outptr_ const OrtValue** out); + +#ifdef __cplusplus + OrtApi(const OrtApi&) = delete; // Prevent users from accidentally copying the API structure, it should always be passed as a pointer +#endif }; /* diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index fbf1f1edc7..7c4467d348 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -133,12 +133,14 @@ class TunableOp { virtual ~TunableOp() = default; Status operator()(const ParamsT* params) { - int id = default_id_; + int id = -1; ITuningContext* ctx = params->TuningContext(); if (ctx->IsTunableOpEnabled()) { auto& mgr = ctx->GetTuningResultsManager(); auto op_sig = Signature(); auto params_sig = params->Signature(); + + // Usage is enabled, then we are free to use previous tuning result. id = mgr.Lookup(op_sig, params_sig); if (id > static_cast(ops_.size())) { LOGS_DEFAULT(ERROR) << "Invalid TunableOp kernel id for " << op_sig @@ -146,14 +148,16 @@ class TunableOp { mgr.Delete(op_sig, params_sig); id = -1; } - if (id < 0) { + + // If there is not previous tuning result been found, we do the tuning iff tuning is enabled + if (id < 0 && ctx->IsTuningEnabled()) { auto maybe_proxy_params = PreTuning(params); id = FindFastest(maybe_proxy_params); PostTuning(maybe_proxy_params); mgr.Add(op_sig, params_sig, id); } } - ORT_RETURN_IF_ERROR(ops_[id](params)); + ORT_RETURN_IF_ERROR(ops_[id < 0 ? default_id_ : id](params)); return Status::OK(); } diff --git a/onnxruntime/core/framework/tuning_context.h b/onnxruntime/core/framework/tuning_context.h index 6cd61931b8..3fea4cb85f 100644 --- a/onnxruntime/core/framework/tuning_context.h +++ b/onnxruntime/core/framework/tuning_context.h @@ -24,6 +24,20 @@ class ITuningContext { virtual void DisableTunableOp() = 0; virtual bool IsTunableOpEnabled() const = 0; + virtual void EnableTuning() = 0; + virtual void DisableTuning() = 0; + virtual bool IsTuningEnabled() const = 0; + + virtual void EnableTunableOpAndTuning() final { + EnableTunableOp(); + EnableTuning(); + } + + virtual void DisableTunableOpAndTuning() final { + DisableTunableOp(); + DisableTuning(); + } + virtual TuningResultsManager& GetTuningResultsManager() = 0; virtual const TuningResultsManager& GetTuningResultsManager() const = 0; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 5089e0b8ab..412d078da6 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -211,11 +211,23 @@ void CUDAExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGrap #endif void OverrideTunableOpInfoByEnv(CUDAExecutionProviderInfo& info) { - auto env_tunable_op_enabled = onnxruntime::ParseTestOnlyEnvironmentVariable( - "ORT_CUDA_TUNABLE_OP_ENABLED", {"0", "1"}, "Use provider_options \"tunable_op_enabled\" instead."); - if (env_tunable_op_enabled.has_value() && env_tunable_op_enabled != info.tunable_op.enabled) { - LOGS_DEFAULT(INFO) << "ORT_CUDA_TUNABLE_OP_ENABLED is set to " << *env_tunable_op_enabled; - info.tunable_op.enabled = *env_tunable_op_enabled; + if (auto env_tunable_op_enable = onnxruntime::ParseTestOnlyEnvironmentVariable( + "ORT_CUDA_TUNABLE_OP_ENABLE", {"0", "1"}, "Use provider_options \"tunable_op_enable\" instead."); + env_tunable_op_enable.has_value() && env_tunable_op_enable != info.tunable_op.enable) { + LOGS_DEFAULT(INFO) << "ORT_CUDA_TUNABLE_OP_ENABLE is set to " << *env_tunable_op_enable; + info.tunable_op.enable = *env_tunable_op_enable; + } + + if (auto env_tunable_op_tuning_enable = onnxruntime::ParseTestOnlyEnvironmentVariable( + "ORT_CUDA_TUNABLE_OP_TUNING_ENABLE", {"0", "1"}, + "Use provider_options \"tunable_op_tuning_enable\" instead."); + env_tunable_op_tuning_enable.has_value() && env_tunable_op_tuning_enable != info.tunable_op.tuning_enable) { + LOGS_DEFAULT(INFO) << "ORT_CUDA_TUNABLE_OP_TUNING_ENABLE is set to " << *env_tunable_op_tuning_enable; + info.tunable_op.tuning_enable = *env_tunable_op_tuning_enable; + } + + if (info.tunable_op.tuning_enable && !info.tunable_op.enable) { + LOGS_DEFAULT(WARNING) << "TunableOp is enabled for tuning but is not enabled for using. This will have no effect."; } } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index 9c1664324a..eb257a652f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -24,7 +24,8 @@ constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache"; constexpr const char* kCudnnConvUseMaxWorkspace = "cudnn_conv_use_max_workspace"; constexpr const char* kEnableCudaGraph = "enable_cuda_graph"; constexpr const char* kCudnnConv1dPadToNc1d = "cudnn_conv1d_pad_to_nc1d"; -constexpr const char* kTunableOpEnabled = "tunable_op_enabled"; +constexpr const char* kTunableOpEnable = "tunable_op_enable"; +constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; } // namespace provider_option_names } // namespace cuda @@ -94,9 +95,15 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kEnableCudaGraph, info.enable_cuda_graph) .AddAssignmentToReference(cuda::provider_option_names::kCudnnConv1dPadToNc1d, info.cudnn_conv1d_pad_to_nc1d) .AddValueParser( - cuda::provider_option_names::kTunableOpEnabled, + cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.enabled)); + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.enable)); + return Status::OK(); + }) + .AddValueParser( + cuda::provider_option_names::kTunableOpTuningEnable, + [&info](const std::string& value_str) -> Status { + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.tuning_enable)); return Status::OK(); }) .Parse(options)); @@ -121,7 +128,8 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kCudnnConvUseMaxWorkspace, MakeStringWithClassicLocale(info.cudnn_conv_use_max_workspace)}, {cuda::provider_option_names::kEnableCudaGraph, MakeStringWithClassicLocale(info.enable_cuda_graph)}, {cuda::provider_option_names::kCudnnConv1dPadToNc1d, MakeStringWithClassicLocale(info.cudnn_conv1d_pad_to_nc1d)}, - {cuda::provider_option_names::kTunableOpEnabled, MakeStringWithClassicLocale(info.tunable_op.enabled)}, + {cuda::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op.enable)}, + {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, }; return options; @@ -135,7 +143,9 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kCudnnConvAlgoSearch, EnumToName(ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)}, {cuda::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)}, {cuda::provider_option_names::kCudnnConvUseMaxWorkspace, MakeStringWithClassicLocale(info.cudnn_conv_use_max_workspace)}, - {cuda::provider_option_names::kCudnnConv1dPadToNc1d, MakeStringWithClassicLocale(info.cudnn_conv1d_pad_to_nc1d)} + {cuda::provider_option_names::kCudnnConv1dPadToNc1d, MakeStringWithClassicLocale(info.cudnn_conv1d_pad_to_nc1d)}, + {cuda::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op_enable)}, + {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op_tuning_enable)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index 534d3354ec..f066c203df 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -38,7 +38,8 @@ struct CUDAExecutionProviderExternalAllocatorInfo { namespace cuda { struct TunableOpInfo { - bool enabled{false}; + bool enable{false}; + bool tuning_enable{false}; }; } // namespace cuda @@ -78,7 +79,8 @@ template<> struct std::hash<::onnxruntime::cuda::TunableOpInfo> { size_t operator()(const ::onnxruntime::cuda::TunableOpInfo& info) const { size_t seed_and_value{0xbc9f1d34}; - onnxruntime::HashCombine(info.enabled, seed_and_value); + onnxruntime::HashCombine(info.enable, seed_and_value); + onnxruntime::HashCombine(info.tuning_enable, seed_and_value); return seed_and_value; } }; diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 003dc57d61..6253648cfa 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -251,7 +251,8 @@ struct CUDA_Provider : Provider { info.cudnn_conv_use_max_workspace = params->cudnn_conv_use_max_workspace != 0; info.enable_cuda_graph = params->enable_cuda_graph != 0; info.cudnn_conv1d_pad_to_nc1d = params->cudnn_conv1d_pad_to_nc1d != 0; - info.tunable_op.enabled = params->tunable_op_enabled; + info.tunable_op.enable = params->tunable_op_enable; + info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; return std::make_shared(info); } diff --git a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc index bea2889ca8..4e302a2c66 100644 --- a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc +++ b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc @@ -51,16 +51,30 @@ CudaTuningContext::CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* i void CudaTuningContext::EnableTunableOp() { LOGS_DEFAULT(INFO) << "Enable TunableOp for CUDA Execution Provider"; - info_->enabled = true; + info_->enable = true; } void CudaTuningContext::DisableTunableOp() { LOGS_DEFAULT(INFO) << "Disable TunableOp for CUDA Execution Provider"; - info_->enabled = false; + info_->enable = false; } bool CudaTuningContext::IsTunableOpEnabled() const { - return info_->enabled; + return info_->enable; +} + +void CudaTuningContext::EnableTuning() { + LOGS_DEFAULT(INFO) << "Enable TunableOp tuning for CUDA Execution Provider"; + info_->tuning_enable = true; +} + +void CudaTuningContext::DisableTuning() { + LOGS_DEFAULT(INFO) << "Disable TunableOp tuning for CUDA Execution Provider"; + info_->tuning_enable = false; +} + +bool CudaTuningContext::IsTuningEnabled() const { + return info_->tuning_enable; } TuningResultsManager& CudaTuningContext::GetTuningResultsManager() { diff --git a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h index 10d0782f5b..ec961890e6 100644 --- a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h +++ b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.h @@ -35,6 +35,10 @@ class CudaTuningContext : public ITuningContext { void DisableTunableOp() override; bool IsTunableOpEnabled() const override; + void EnableTuning() override; + void DisableTuning() override; + bool IsTuningEnabled() const override; + TuningResultsManager& GetTuningResultsManager() override; const TuningResultsManager& GetTuningResultsManager() const override; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 6db27dff94..1ec595d799 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -157,11 +157,23 @@ ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { } void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { - auto env_tunable_op_enabled = onnxruntime::ParseTestOnlyEnvironmentVariable( - "ORT_ROCM_TUNABLE_OP_ENABLED", {"0", "1"}, "Use provider_options \"tunable_op_enabled\" instead."); - if (env_tunable_op_enabled.has_value() && env_tunable_op_enabled != info.tunable_op.enabled) { - LOGS_DEFAULT(INFO) << "ORT_ROCM_TUNABLE_OP_ENABLED is set to " << *env_tunable_op_enabled; - info.tunable_op.enabled = *env_tunable_op_enabled; + if (auto env_tunable_op_enable = onnxruntime::ParseTestOnlyEnvironmentVariable( + "ORT_ROCM_TUNABLE_OP_ENABLE", {"0", "1"}, "Use provider_options \"tunable_op_enable\" instead."); + env_tunable_op_enable.has_value() && env_tunable_op_enable != info.tunable_op.enable) { + LOGS_DEFAULT(INFO) << "ORT_ROCM_TUNABLE_OP_ENABLE is set to " << *env_tunable_op_enable; + info.tunable_op.enable = *env_tunable_op_enable; + } + + if (auto env_tunable_op_tuning_enable = onnxruntime::ParseTestOnlyEnvironmentVariable( + "ORT_ROCM_TUNABLE_OP_TUNING_ENABLE", {"0", "1"}, + "Use provider_options \"tunable_op_tuning_enable\" instead."); + env_tunable_op_tuning_enable.has_value() && env_tunable_op_tuning_enable != info.tunable_op.tuning_enable) { + LOGS_DEFAULT(INFO) << "ORT_ROCM_TUNABLE_OP_TUNING_ENABLE is set to " << *env_tunable_op_tuning_enable; + info.tunable_op.tuning_enable = *env_tunable_op_tuning_enable; + } + + if (info.tunable_op.tuning_enable && !info.tunable_op.enable) { + LOGS_DEFAULT(WARNING) << "TunableOp is enabled for tuning but is not enabled for using. This will have no effect."; } } diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index e29402f93d..5e32ae3067 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -21,7 +21,8 @@ constexpr const char* kGpuExternalAlloc = "gpu_external_alloc"; constexpr const char* kGpuExternalFree = "gpu_external_free"; constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache"; constexpr const char* kMiopenConvUseMaxWorkspace = "miopen_conv_use_max_workspace"; -constexpr const char* kTunableOpEnabled = "tunable_op_enabled"; +constexpr const char* kTunableOpEnable = "tunable_op_enable"; +constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; } // namespace provider_option_names } // namespace rocm @@ -83,9 +84,15 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream) .AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace) .AddValueParser( - rocm::provider_option_names::kTunableOpEnabled, + rocm::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.enabled)); + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.enable)); + return Status::OK(); + }) + .AddValueParser( + rocm::provider_option_names::kTunableOpTuningEnable, + [&info](const std::string& value_str) -> Status { + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.tuning_enable)); return Status::OK(); }) .Parse(options)); @@ -107,7 +114,8 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution {rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, {rocm::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)}, {rocm::provider_option_names::kMiopenConvUseMaxWorkspace, MakeStringWithClassicLocale(info.miopen_conv_use_max_workspace)}, - {rocm::provider_option_names::kTunableOpEnabled, MakeStringWithClassicLocale(info.tunable_op.enabled)}, + {rocm::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op.enable)}, + {rocm::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, }; return options; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h index b7c308d13d..5ce54a387e 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h @@ -38,7 +38,8 @@ struct ROCMExecutionProviderExternalAllocatorInfo { namespace rocm { struct TunableOpInfo { - bool enabled{false}; + bool enable{false}; + bool tuning_enable{false}; }; } // namespace rocm @@ -72,7 +73,8 @@ template<> struct std::hash<::onnxruntime::rocm::TunableOpInfo> { size_t operator()(const ::onnxruntime::rocm::TunableOpInfo& info) const { size_t seed_and_value{0xbc9f1d34}; - onnxruntime::HashCombine(info.enabled, seed_and_value); + onnxruntime::HashCombine(info.enable, seed_and_value); + onnxruntime::HashCombine(info.tuning_enable, seed_and_value); return seed_and_value; } }; diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index 366e3124c7..0eaf214887 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -173,7 +173,8 @@ struct ROCM_Provider : Provider { info.has_user_compute_stream = params->has_user_compute_stream; info.user_compute_stream = params->user_compute_stream; info.default_memory_arena_cfg = params->default_memory_arena_cfg; - info.tunable_op.enabled = params->tunable_op_enabled; + info.tunable_op.enable = params->tunable_op_enable; + info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; return std::make_shared(info); } diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc index bdaa7c4042..69d68bf8b8 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc @@ -82,16 +82,30 @@ RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* i void RocmTuningContext::EnableTunableOp() { LOGS_DEFAULT(INFO) << "Enable TunableOp for ROCm Execution Provider"; - info_->enabled = true; + info_->enable = true; } void RocmTuningContext::DisableTunableOp() { LOGS_DEFAULT(INFO) << "Disable TunableOp for ROCm Execution Provider"; - info_->enabled = false; + info_->enable = false; } bool RocmTuningContext::IsTunableOpEnabled() const { - return info_->enabled; + return info_->enable; +} + +void RocmTuningContext::EnableTuning() { + LOGS_DEFAULT(INFO) << "Enable TunableOp tuning for ROCm Execution Provider"; + info_->tuning_enable = true; +} + +void RocmTuningContext::DisableTuning() { + LOGS_DEFAULT(INFO) << "Disable TunableOp tuning for ROCm Execution Provider"; + info_->tuning_enable = false; +} + +bool RocmTuningContext::IsTuningEnabled() const { + return info_->tuning_enable; } TuningResultsManager& RocmTuningContext::GetTuningResultsManager() { diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h index cffc0e5614..d2ddb37dd0 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.h @@ -37,6 +37,10 @@ class RocmTuningContext : public ITuningContext { void DisableTunableOp() override; bool IsTunableOpEnabled() const override; + void EnableTuning() override; + void DisableTuning() override; + bool IsTuningEnabled() const override; + TuningResultsManager& GetTuningResultsManager() override; const TuningResultsManager& GetTuningResultsManager() const override; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e9899e38ad..27ce6dab65 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1568,7 +1568,7 @@ common::Status InferenceSession::Initialize() { ORT_RETURN_IF_ERROR_SESSIONID_(inference_session_utils::ParseTuningResultsFromModelMetadata( model_metadata_, tuning_results, found_tuning_results)); if (found_tuning_results) { - ORT_RETURN_IF_ERROR_SESSIONID_(SetTuningResults(tuning_results)); + ORT_RETURN_IF_ERROR_SESSIONID_(SetTuningResults(tuning_results, /*error_on_invalid*/false, /*auto_enable*/true)); } #endif // !defined(ORT_MINIMAL_BUILD) @@ -2268,7 +2268,10 @@ std::vector InferenceSession::GetTuningResults() const { return ret; } -Status InferenceSession::SetTuningResults(const std::vector& trs, bool error_on_invalid) { +Status InferenceSession::SetTuningResults( + const std::vector& trs, + bool error_on_invalid, + bool auto_enable) { std::string msg; for (size_t i = 0; i < trs.size(); i++) { @@ -2294,6 +2297,12 @@ Status InferenceSession::SetTuningResults(const std::vector& trs, msg = MakeString("Failed to load TuningResults (index=", i, "). Reason: ", status.ErrorMessage()); ORT_RETURN_IF(error_on_invalid, msg); LOGS(*session_logger_, WARNING) << msg; + continue; + } + + if (auto_enable) { + LOGS(*session_logger_, INFO) << "Correctly set TuningResults for " << tr.ep << ", enable TunableOp for using"; + tuning_ctx->EnableTunableOp(); } } return Status::OK(); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 473c38c1af..3cf36a1d80 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -458,9 +458,12 @@ class InferenceSession { * Set the TuningResults back to each execution provider. Mainly for offline tuning. * @param trs is the list of TuningResults to be loaded. * @param error_on_invalid otherwise, validation faliure is not an error, only a warning log will be produced. + * @param auto_enable if true, automatically enable tunable op usage (but not tuning) if the TuningResults is + correctly loaded * @return OK if success. */ - Status SetTuningResults(const std::vector& trs, bool error_on_invalid = false); + Status SetTuningResults(const std::vector& trs, bool error_on_invalid = false, + bool auto_enable = false); #endif #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) diff --git a/onnxruntime/core/session/inference_session_utils.cc b/onnxruntime/core/session/inference_session_utils.cc index 3e2d03a930..3436eebda3 100644 --- a/onnxruntime/core/session/inference_session_utils.cc +++ b/onnxruntime/core/session/inference_session_utils.cc @@ -245,7 +245,7 @@ Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& met } key_found = true; - LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while running the model"; + LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while loading the model"; Status status; ORT_TRY { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.cu index 8f57806e13..31d9a68bbe 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/fast_gelu.cu @@ -60,7 +60,7 @@ class FastGeluTunable : public IKernelExplorer { FastGeluTunable(DeviceArray& input, DeviceArray& bias, DeviceArray& output, int input_length, int bias_length) : params_(TuningContext(), Stream(), static_cast(input.ptr()), static_cast(bias.ptr()), static_cast(output.ptr()), input_length, bias_length) { - params_.TuningContext()->EnableTunableOp(); + params_.TuningContext()->EnableTunableOpAndTuning(); } void Run() override { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu index 9c66e43170..0dfb3f63e2 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_fast_gelu_tunable.cu @@ -47,7 +47,7 @@ class GemmFastGeluTunable : public IKernelExplorer { params_.c = static_cast(c.ptr()); params_.ldc = ldc; - params_.TuningContext()->EnableTunableOp(); + params_.TuningContext()->EnableTunableOpAndTuning(); } ~GemmFastGeluTunable() { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu index a18dfcb021..60d41b320f 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu @@ -242,7 +242,7 @@ class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplor GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(&this->attn_), GemmSoftmaxGemmPermuteTunableOp::GetWorkspaceNumBytes(&this->attn_))); - this->params_.TuningContext()->EnableTunableOp(); + this->params_.TuningContext()->EnableTunableOpAndTuning(); } std::vector ListOps() const { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu index ee39cf32a2..4366bd6b55 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_tunable.cu @@ -47,7 +47,7 @@ class GemmTunable : public IKernelExplorer { params_.c = static_cast(c.ptr()); params_.ldc = ldc; - params_.TuningContext()->EnableTunableOp(); + params_.TuningContext()->EnableTunableOpAndTuning(); } ~GemmTunable() { @@ -108,7 +108,7 @@ class BatchedGemmTunable : public IBatchedGemmKernelExplorer { params_.ldc = ldc; params_.batch = batch; - params_.TuningContext()->EnableTunableOp(); + params_.TuningContext()->EnableTunableOpAndTuning(); } ~BatchedGemmTunable() { @@ -170,7 +170,7 @@ class StridedBatchedGemmTunable : public IKernelExplorer { params_.stride_c = stride_c; params_.batch = batch; - params_.TuningContext()->EnableTunableOp(); + params_.TuningContext()->EnableTunableOpAndTuning(); } ~StridedBatchedGemmTunable() { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu index 37a9f14769..55b79141f1 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/skip_layer_norm.cu @@ -93,7 +93,7 @@ class SkipLayerNormTunable : public IKernelExplorer { : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(skip_input_bias_add_output.ptr()), static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(bias.ptr()), epsilon, hidden_size, element_count) { - params_.TuningContext()->EnableTunableOp(); + params_.TuningContext()->EnableTunableOpAndTuning(); } void Run() override { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu index 2c9980ef30..cf364ef9c3 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu @@ -106,7 +106,7 @@ class SoftmaxTunable : public IKernelExplorer { int input_stride, int output_stride, int batch_count, bool is_log_softmax) : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(input.ptr()), softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) { - params_.TuningContext()->EnableTunableOp(); + params_.TuningContext()->EnableTunableOpAndTuning(); } void Run() override { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cu b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cu index ed10fc0a2c..cffbe41f63 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add.cu @@ -128,7 +128,7 @@ class VectorAddTunable : public IKernelExplorer { params_.z = static_cast(z.ptr()); params_.n = n; - params_.TuningContext()->EnableTunableOp(); + params_.TuningContext()->EnableTunableOpAndTuning(); } void Run() override { diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index 3ed1bbf567..6793b1c49c 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -49,9 +49,13 @@ class TestTuningContext : public ITuningContext { public: using ITuningContext::ITuningContext; - void EnableTunableOp() override { tuning_enabled_ = true; } - void DisableTunableOp() override { tuning_enabled_ = false; } - bool IsTunableOpEnabled() const override { return tuning_enabled_; } + void EnableTunableOp() override { op_enabled_ = true; } + void DisableTunableOp() override { op_enabled_ = false; } + bool IsTunableOpEnabled() const override { return op_enabled_; } + + void EnableTuning() override { tuning_enabled_ = true; } + void DisableTuning() override { tuning_enabled_ = false; } + bool IsTuningEnabled() const override { return tuning_enabled_; } TuningResultsManager& GetTuningResultsManager() override { return manager_; } const TuningResultsManager& GetTuningResultsManager() const override { return manager_; } @@ -61,6 +65,7 @@ class TestTuningContext : public ITuningContext { void ClearCache() { manager_.Clear(); } private: + bool op_enabled_{false}; bool tuning_enabled_{false}; TuningResultsManager manager_{}; TestTuningResultsValidator validator_{}; @@ -374,7 +379,7 @@ class TunableVecAddSelectFast : public TunableOp { constexpr static int kFastFullId = 1; }; -TEST(TunableOp, SelectFast) { +TEST(TunableOp, SelectFastIfTuning) { #ifdef ORT_NO_RTTI GTEST_SKIP() << "TunableOp needs RTTI to work correctly"; #else @@ -386,10 +391,16 @@ TEST(TunableOp, SelectFast) { params.last_run = &last_run; TunableVecAddSelectFast op{}; + // Only enable op usage, slow (default) should be selected params.TuningContext()->EnableTunableOp(); - auto status = op(¶ms); ASSERT_TRUE(status.IsOK()); + ASSERT_EQ(last_run, "SlowFull"); + + // Also enable tuning, fast should be selected + params.TuningContext()->EnableTuning(); + status = op(¶ms); + ASSERT_TRUE(status.IsOK()); ASSERT_EQ(last_run, "FastFull"); #endif } @@ -414,7 +425,7 @@ TEST(TunableOp, SelectSupported) { params.last_run = &last_run; TunableVecAddSelectSupported op{}; - params.TuningContext()->EnableTunableOp(); + params.TuningContext()->EnableTunableOpAndTuning(); auto status = op(¶ms); ASSERT_TRUE(status.IsOK()); @@ -445,7 +456,7 @@ TEST(TunableOp, SelectFastestIfSupported) { params.last_run = &last_run; TunableVecAddSelectFastestIfSupported op{}; - params.TuningContext()->EnableTunableOp(); + params.TuningContext()->EnableTunableOpAndTuning(); auto status = op(¶ms); ASSERT_TRUE(status.IsOK()); @@ -530,7 +541,7 @@ TEST(TunableOp, HandleInplaceUpdate) { c = 4200; VecAddParamsRecordLastRun params(&a, &b, &c, 1, /*beta=*/0); TunableVecAddNotHandleInplaceUpdate op_not_handle_inplace_update{}; - params.TuningContext()->EnableTunableOp(); + params.TuningContext()->EnableTunableOpAndTuning(); auto status = op_not_handle_inplace_update(¶ms); ASSERT_TRUE(status.IsOK()); ASSERT_EQ(c, 7500042); @@ -541,7 +552,7 @@ TEST(TunableOp, HandleInplaceUpdate) { c = 4200; VecAddParamsRecordLastRun params(&a, &b, &c, 1, /*beta=*/1); TunableVecAddNotHandleInplaceUpdate op_not_handle_inplace_update{}; - params.TuningContext()->EnableTunableOp(); + params.TuningContext()->EnableTunableOpAndTuning(); auto status = op_not_handle_inplace_update(¶ms); ASSERT_TRUE(status.IsOK()); ASSERT_NE(c, 4200); // value should be changed @@ -553,7 +564,7 @@ TEST(TunableOp, HandleInplaceUpdate) { c = 4200; VecAddParamsRecordLastRun params(&a, &b, &c, 1, /*beta=*/0); TunableVecAddHandleInplaceUpdate op{}; - params.TuningContext()->EnableTunableOp(); + params.TuningContext()->EnableTunableOpAndTuning(); auto status = op(¶ms); ASSERT_TRUE(status.IsOK()); ASSERT_EQ(c, 7500042); @@ -565,7 +576,7 @@ TEST(TunableOp, HandleInplaceUpdate) { c = 4200; VecAddParamsRecordLastRun params(&a, &b, &c, 1, /*beta=*/1); TunableVecAddHandleInplaceUpdate op{}; - params.TuningContext()->EnableTunableOp(); + params.TuningContext()->EnableTunableOpAndTuning(); auto status = op(¶ms); ASSERT_TRUE(status.IsOK()); ASSERT_EQ(c, 7504242); @@ -629,7 +640,7 @@ TEST(TuningContext, TunableOpRespectTuningContext) { tuning::TunableVecAddSelectFast op{}; auto* ctx = params.TuningContext(); auto& mgr = ctx->GetTuningResultsManager(); - ctx->EnableTunableOp(); + ctx->EnableTunableOpAndTuning(); { // Before TunableOp(...), there is no entry in it. @@ -683,7 +694,7 @@ TEST(TuningContext, GetAndLoadTuningResults) { tuning::TunableVecAddSelectFast op{}; auto* ctx = params.TuningContext(); - ctx->EnableTunableOp(); + ctx->EnableTunableOpAndTuning(); auto status = op(¶ms); ASSERT_TRUE(status.IsOK()); diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index f6639954e4..b2490e0149 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -245,7 +245,9 @@ class TestInferenceSession(unittest.TestCase): test_get_and_set_option_with_values("do_copy_in_default_stream", [0, 1]) - test_get_and_set_option_with_values("tunable_op_enabled", ["1", "0"]) + test_get_and_set_option_with_values("tunable_op_enable", ["1", "0"]) + + test_get_and_set_option_with_values("tunable_op_tuning_enable", ["1", "0"]) option["gpu_external_alloc"] = "0" option["gpu_external_free"] = "0" @@ -379,7 +381,9 @@ class TestInferenceSession(unittest.TestCase): str(option_value), ) - test_get_and_set_option_with_values("tunable_op_enabled", ["1", "0"]) + test_get_and_set_option_with_values("tunable_op_enable", ["1", "0"]) + + test_get_and_set_option_with_values("tunable_op_tuning_enable", ["1", "0"]) runRocmOptionsTest() diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index b7ee133133..5cb80f9edb 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -179,7 +179,8 @@ std::unique_ptr DefaultRocmExecutionProvider(bool test_tunab #ifdef USE_ROCM OrtROCMProviderOptions provider_options{}; provider_options.do_copy_in_default_stream = true; - provider_options.tunable_op_enabled = test_tunable_op ? 1 : 0; + provider_options.tunable_op_enable = test_tunable_op ? 1 : 0; + provider_options.tunable_op_tuning_enable = test_tunable_op ? 1 : 0; if (auto factory = RocmProviderFactoryCreator::Create(&provider_options)) return factory->CreateProvider(); #endif