Automatically enable tunable op usage for production models (#15156)

Split `IsTunbaleOpEnable` semantics into **enable tunable op for using**
and **enable tunable op for tuning**.

They remain disabled in general for safety purpose. But
- if session is created with onnx model with tuning results embeded
- the embedded tuning results is set to the EP without error `Status`

then we automatically enable the using, tuning remains disabled.

The planned options will be
- `tunable_op_enable`: The top-level switch of `TunableOp`, indicate if we will run into `TunableOp` related logic. **NOTE:** most of our impls have a bottom impl that is acting as a fallback and is set as the default. In this case, we still call into the `TunableOp`, but no kernel selection, no kernel tuning and caching is involved. This reduced our maintainance burden of a duplicate code path.
- `tunable_op_tuning_enable`: The secondary switch of `TunableOp`, indicate if we will run into the tuning related logic of `TunableOp`

Then for the possible future options:
- `tunable_op_tuning_max_iteration`: blahblah
- `tunable_op_tuning_max_duration_ms`: blahblah
- `tunable_op_flash_attention_enable`: blahblah, for example only, we will not have this.

For developer oriented envvar, it is for developers' convenience to inspect the performance impact of tuning. So there is only `ORT_ROCM_TUNABLE_OP_ENABLE`, `ORT_ROCM_TUNABLE_OP_TUNING_ENABLE` to take the fine-grind control of combinations.
This commit is contained in:
cloudhan 2023-04-06 13:52:47 +08:00 committed by GitHub
parent 2e52de265a
commit 71a4e7eb97
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 224 additions and 74 deletions

View file

@ -27,5 +27,6 @@ struct OrtCUDAProviderOptionsV2 {
int cudnn_conv_use_max_workspace; // flag specifying if maximum workspace can be used in cudnn conv algo search.
int enable_cuda_graph; // flag specifying if the CUDA graph is to be captured for the model.
int cudnn_conv1d_pad_to_nc1d; // flag specifying if pad Conv1D's input [N,C,D] to [N,C,1,D] or [N,C,D,1].
int tunable_op_enabled; // flag specifying if TunableOp is enabled.
int tunable_op_enable; // flag specifying if TunableOp is enabled.
int tunable_op_tuning_enable; // flag specifying if TunableOp is enabled for tuning, this relies on TunableOp is enabled.
};

View file

@ -387,7 +387,8 @@ typedef struct OrtCUDAProviderOptions {
has_user_compute_stream{},
user_compute_stream{},
default_memory_arena_cfg{},
tunable_op_enabled{false} {}
tunable_op_enable{false},
tunable_op_tuning_enable{false} {}
#endif
/** \brief CUDA device Id
@ -438,11 +439,18 @@ typedef struct OrtCUDAProviderOptions {
*/
OrtArenaCfg* default_memory_arena_cfg;
/** \brief Enable TunableOp.
* Set it to 1 to enable TunableOp. Otherwise, it is disabled by default.
* This option can be superseded by environment variable ORT_CUDA_TUNABLE_OP_ENABLED.
/** \brief Enable TunableOp for using.
* Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default.
* This option can be overriden by environment variable ORT_CUDA_TUNABLE_OP_ENABLE.
*/
int tunable_op_enabled;
int tunable_op_enable;
/** \brief Enable TunableOp for tuning.
* Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default.
* This option can be overriden by environment variable ORT_CUDA_TUNABLE_OP_TUNING_ENABLE.
*/
int tunable_op_tuning_enable;
} OrtCUDAProviderOptions;
@ -461,7 +469,8 @@ typedef struct OrtROCMProviderOptions {
has_user_compute_stream{},
user_compute_stream{},
default_memory_arena_cfg{},
tunable_op_enabled{false} {}
tunable_op_enable{false},
tunable_op_tuning_enable{false} {}
#endif
/** \brief ROCM device Id
@ -511,11 +520,17 @@ typedef struct OrtROCMProviderOptions {
*/
OrtArenaCfg* default_memory_arena_cfg;
/** \brief Enable TunableOp.
* Set it to 1 to enable TunableOp. Otherwise, it is disabled by default.
* This option can be superseded by environment variable ORT_ROCM_TUNABLE_OP_ENABLED.
/** \brief Enable TunableOp for using.
* Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default.
* This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE.
*/
int tunable_op_enabled;
int tunable_op_enable;
/** \brief Enable TunableOp for tuning.
* Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default.
* This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_TUNING_ENABLE.
*/
int tunable_op_tuning_enable;
} OrtROCMProviderOptions;
@ -4085,6 +4100,10 @@ struct OrtApi {
* \since Version 1.15.
*/
ORT_API2_STATUS(KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, size_t index, _Out_ int* is_constant, _Outptr_ const OrtValue** out);
#ifdef __cplusplus
OrtApi(const OrtApi&) = delete; // Prevent users from accidentally copying the API structure, it should always be passed as a pointer
#endif
};
/*

View file

@ -133,12 +133,14 @@ class TunableOp {
virtual ~TunableOp() = default;
Status operator()(const ParamsT* params) {
int id = default_id_;
int id = -1;
ITuningContext* ctx = params->TuningContext();
if (ctx->IsTunableOpEnabled()) {
auto& mgr = ctx->GetTuningResultsManager();
auto op_sig = Signature();
auto params_sig = params->Signature();
// Usage is enabled, then we are free to use previous tuning result.
id = mgr.Lookup(op_sig, params_sig);
if (id > static_cast<int>(ops_.size())) {
LOGS_DEFAULT(ERROR) << "Invalid TunableOp kernel id for " << op_sig
@ -146,14 +148,16 @@ class TunableOp {
mgr.Delete(op_sig, params_sig);
id = -1;
}
if (id < 0) {
// If there is not previous tuning result been found, we do the tuning iff tuning is enabled
if (id < 0 && ctx->IsTuningEnabled()) {
auto maybe_proxy_params = PreTuning(params);
id = FindFastest(maybe_proxy_params);
PostTuning(maybe_proxy_params);
mgr.Add(op_sig, params_sig, id);
}
}
ORT_RETURN_IF_ERROR(ops_[id](params));
ORT_RETURN_IF_ERROR(ops_[id < 0 ? default_id_ : id](params));
return Status::OK();
}

View file

@ -24,6 +24,20 @@ class ITuningContext {
virtual void DisableTunableOp() = 0;
virtual bool IsTunableOpEnabled() const = 0;
virtual void EnableTuning() = 0;
virtual void DisableTuning() = 0;
virtual bool IsTuningEnabled() const = 0;
virtual void EnableTunableOpAndTuning() final {
EnableTunableOp();
EnableTuning();
}
virtual void DisableTunableOpAndTuning() final {
DisableTunableOp();
DisableTuning();
}
virtual TuningResultsManager& GetTuningResultsManager() = 0;
virtual const TuningResultsManager& GetTuningResultsManager() const = 0;

View file

@ -211,11 +211,23 @@ void CUDAExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGrap
#endif
void OverrideTunableOpInfoByEnv(CUDAExecutionProviderInfo& info) {
auto env_tunable_op_enabled = onnxruntime::ParseTestOnlyEnvironmentVariable<bool>(
"ORT_CUDA_TUNABLE_OP_ENABLED", {"0", "1"}, "Use provider_options \"tunable_op_enabled\" instead.");
if (env_tunable_op_enabled.has_value() && env_tunable_op_enabled != info.tunable_op.enabled) {
LOGS_DEFAULT(INFO) << "ORT_CUDA_TUNABLE_OP_ENABLED is set to " << *env_tunable_op_enabled;
info.tunable_op.enabled = *env_tunable_op_enabled;
if (auto env_tunable_op_enable = onnxruntime::ParseTestOnlyEnvironmentVariable<bool>(
"ORT_CUDA_TUNABLE_OP_ENABLE", {"0", "1"}, "Use provider_options \"tunable_op_enable\" instead.");
env_tunable_op_enable.has_value() && env_tunable_op_enable != info.tunable_op.enable) {
LOGS_DEFAULT(INFO) << "ORT_CUDA_TUNABLE_OP_ENABLE is set to " << *env_tunable_op_enable;
info.tunable_op.enable = *env_tunable_op_enable;
}
if (auto env_tunable_op_tuning_enable = onnxruntime::ParseTestOnlyEnvironmentVariable<bool>(
"ORT_CUDA_TUNABLE_OP_TUNING_ENABLE", {"0", "1"},
"Use provider_options \"tunable_op_tuning_enable\" instead.");
env_tunable_op_tuning_enable.has_value() && env_tunable_op_tuning_enable != info.tunable_op.tuning_enable) {
LOGS_DEFAULT(INFO) << "ORT_CUDA_TUNABLE_OP_TUNING_ENABLE is set to " << *env_tunable_op_tuning_enable;
info.tunable_op.tuning_enable = *env_tunable_op_tuning_enable;
}
if (info.tunable_op.tuning_enable && !info.tunable_op.enable) {
LOGS_DEFAULT(WARNING) << "TunableOp is enabled for tuning but is not enabled for using. This will have no effect.";
}
}

View file

@ -24,7 +24,8 @@ constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache";
constexpr const char* kCudnnConvUseMaxWorkspace = "cudnn_conv_use_max_workspace";
constexpr const char* kEnableCudaGraph = "enable_cuda_graph";
constexpr const char* kCudnnConv1dPadToNc1d = "cudnn_conv1d_pad_to_nc1d";
constexpr const char* kTunableOpEnabled = "tunable_op_enabled";
constexpr const char* kTunableOpEnable = "tunable_op_enable";
constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable";
} // namespace provider_option_names
} // namespace cuda
@ -94,9 +95,15 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
.AddAssignmentToReference(cuda::provider_option_names::kEnableCudaGraph, info.enable_cuda_graph)
.AddAssignmentToReference(cuda::provider_option_names::kCudnnConv1dPadToNc1d, info.cudnn_conv1d_pad_to_nc1d)
.AddValueParser(
cuda::provider_option_names::kTunableOpEnabled,
cuda::provider_option_names::kTunableOpEnable,
[&info](const std::string& value_str) -> Status {
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.enabled));
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.enable));
return Status::OK();
})
.AddValueParser(
cuda::provider_option_names::kTunableOpTuningEnable,
[&info](const std::string& value_str) -> Status {
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.tuning_enable));
return Status::OK();
})
.Parse(options));
@ -121,7 +128,8 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution
{cuda::provider_option_names::kCudnnConvUseMaxWorkspace, MakeStringWithClassicLocale(info.cudnn_conv_use_max_workspace)},
{cuda::provider_option_names::kEnableCudaGraph, MakeStringWithClassicLocale(info.enable_cuda_graph)},
{cuda::provider_option_names::kCudnnConv1dPadToNc1d, MakeStringWithClassicLocale(info.cudnn_conv1d_pad_to_nc1d)},
{cuda::provider_option_names::kTunableOpEnabled, MakeStringWithClassicLocale(info.tunable_op.enabled)},
{cuda::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op.enable)},
{cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)},
};
return options;
@ -135,7 +143,9 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid
{cuda::provider_option_names::kCudnnConvAlgoSearch, EnumToName(ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)},
{cuda::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)},
{cuda::provider_option_names::kCudnnConvUseMaxWorkspace, MakeStringWithClassicLocale(info.cudnn_conv_use_max_workspace)},
{cuda::provider_option_names::kCudnnConv1dPadToNc1d, MakeStringWithClassicLocale(info.cudnn_conv1d_pad_to_nc1d)}
{cuda::provider_option_names::kCudnnConv1dPadToNc1d, MakeStringWithClassicLocale(info.cudnn_conv1d_pad_to_nc1d)},
{cuda::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op_enable)},
{cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op_tuning_enable)},
};
return options;

View file

@ -38,7 +38,8 @@ struct CUDAExecutionProviderExternalAllocatorInfo {
namespace cuda {
struct TunableOpInfo {
bool enabled{false};
bool enable{false};
bool tuning_enable{false};
};
} // namespace cuda
@ -78,7 +79,8 @@ template<>
struct std::hash<::onnxruntime::cuda::TunableOpInfo> {
size_t operator()(const ::onnxruntime::cuda::TunableOpInfo& info) const {
size_t seed_and_value{0xbc9f1d34};
onnxruntime::HashCombine(info.enabled, seed_and_value);
onnxruntime::HashCombine(info.enable, seed_and_value);
onnxruntime::HashCombine(info.tuning_enable, seed_and_value);
return seed_and_value;
}
};

View file

@ -251,7 +251,8 @@ struct CUDA_Provider : Provider {
info.cudnn_conv_use_max_workspace = params->cudnn_conv_use_max_workspace != 0;
info.enable_cuda_graph = params->enable_cuda_graph != 0;
info.cudnn_conv1d_pad_to_nc1d = params->cudnn_conv1d_pad_to_nc1d != 0;
info.tunable_op.enabled = params->tunable_op_enabled;
info.tunable_op.enable = params->tunable_op_enable;
info.tunable_op.tuning_enable = params->tunable_op_tuning_enable;
return std::make_shared<CUDAProviderFactory>(info);
}

View file

@ -51,16 +51,30 @@ CudaTuningContext::CudaTuningContext(CUDAExecutionProvider* ep, TunableOpInfo* i
void CudaTuningContext::EnableTunableOp() {
LOGS_DEFAULT(INFO) << "Enable TunableOp for CUDA Execution Provider";
info_->enabled = true;
info_->enable = true;
}
void CudaTuningContext::DisableTunableOp() {
LOGS_DEFAULT(INFO) << "Disable TunableOp for CUDA Execution Provider";
info_->enabled = false;
info_->enable = false;
}
bool CudaTuningContext::IsTunableOpEnabled() const {
return info_->enabled;
return info_->enable;
}
void CudaTuningContext::EnableTuning() {
LOGS_DEFAULT(INFO) << "Enable TunableOp tuning for CUDA Execution Provider";
info_->tuning_enable = true;
}
void CudaTuningContext::DisableTuning() {
LOGS_DEFAULT(INFO) << "Disable TunableOp tuning for CUDA Execution Provider";
info_->tuning_enable = false;
}
bool CudaTuningContext::IsTuningEnabled() const {
return info_->tuning_enable;
}
TuningResultsManager& CudaTuningContext::GetTuningResultsManager() {

View file

@ -35,6 +35,10 @@ class CudaTuningContext : public ITuningContext {
void DisableTunableOp() override;
bool IsTunableOpEnabled() const override;
void EnableTuning() override;
void DisableTuning() override;
bool IsTuningEnabled() const override;
TuningResultsManager& GetTuningResultsManager() override;
const TuningResultsManager& GetTuningResultsManager() const override;

View file

@ -157,11 +157,23 @@ ROCMExecutionProvider::PerThreadContext::~PerThreadContext() {
}
void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) {
auto env_tunable_op_enabled = onnxruntime::ParseTestOnlyEnvironmentVariable<bool>(
"ORT_ROCM_TUNABLE_OP_ENABLED", {"0", "1"}, "Use provider_options \"tunable_op_enabled\" instead.");
if (env_tunable_op_enabled.has_value() && env_tunable_op_enabled != info.tunable_op.enabled) {
LOGS_DEFAULT(INFO) << "ORT_ROCM_TUNABLE_OP_ENABLED is set to " << *env_tunable_op_enabled;
info.tunable_op.enabled = *env_tunable_op_enabled;
if (auto env_tunable_op_enable = onnxruntime::ParseTestOnlyEnvironmentVariable<bool>(
"ORT_ROCM_TUNABLE_OP_ENABLE", {"0", "1"}, "Use provider_options \"tunable_op_enable\" instead.");
env_tunable_op_enable.has_value() && env_tunable_op_enable != info.tunable_op.enable) {
LOGS_DEFAULT(INFO) << "ORT_ROCM_TUNABLE_OP_ENABLE is set to " << *env_tunable_op_enable;
info.tunable_op.enable = *env_tunable_op_enable;
}
if (auto env_tunable_op_tuning_enable = onnxruntime::ParseTestOnlyEnvironmentVariable<bool>(
"ORT_ROCM_TUNABLE_OP_TUNING_ENABLE", {"0", "1"},
"Use provider_options \"tunable_op_tuning_enable\" instead.");
env_tunable_op_tuning_enable.has_value() && env_tunable_op_tuning_enable != info.tunable_op.tuning_enable) {
LOGS_DEFAULT(INFO) << "ORT_ROCM_TUNABLE_OP_TUNING_ENABLE is set to " << *env_tunable_op_tuning_enable;
info.tunable_op.tuning_enable = *env_tunable_op_tuning_enable;
}
if (info.tunable_op.tuning_enable && !info.tunable_op.enable) {
LOGS_DEFAULT(WARNING) << "TunableOp is enabled for tuning but is not enabled for using. This will have no effect.";
}
}

View file

@ -21,7 +21,8 @@ constexpr const char* kGpuExternalAlloc = "gpu_external_alloc";
constexpr const char* kGpuExternalFree = "gpu_external_free";
constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache";
constexpr const char* kMiopenConvUseMaxWorkspace = "miopen_conv_use_max_workspace";
constexpr const char* kTunableOpEnabled = "tunable_op_enabled";
constexpr const char* kTunableOpEnable = "tunable_op_enable";
constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable";
} // namespace provider_option_names
} // namespace rocm
@ -83,9 +84,15 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
.AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream)
.AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace)
.AddValueParser(
rocm::provider_option_names::kTunableOpEnabled,
rocm::provider_option_names::kTunableOpEnable,
[&info](const std::string& value_str) -> Status {
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.enabled));
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.enable));
return Status::OK();
})
.AddValueParser(
rocm::provider_option_names::kTunableOpTuningEnable,
[&info](const std::string& value_str) -> Status {
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.tunable_op.tuning_enable));
return Status::OK();
})
.Parse(options));
@ -107,7 +114,8 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution
{rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)},
{rocm::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)},
{rocm::provider_option_names::kMiopenConvUseMaxWorkspace, MakeStringWithClassicLocale(info.miopen_conv_use_max_workspace)},
{rocm::provider_option_names::kTunableOpEnabled, MakeStringWithClassicLocale(info.tunable_op.enabled)},
{rocm::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op.enable)},
{rocm::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)},
};
return options;

View file

@ -38,7 +38,8 @@ struct ROCMExecutionProviderExternalAllocatorInfo {
namespace rocm {
struct TunableOpInfo {
bool enabled{false};
bool enable{false};
bool tuning_enable{false};
};
} // namespace rocm
@ -72,7 +73,8 @@ template<>
struct std::hash<::onnxruntime::rocm::TunableOpInfo> {
size_t operator()(const ::onnxruntime::rocm::TunableOpInfo& info) const {
size_t seed_and_value{0xbc9f1d34};
onnxruntime::HashCombine(info.enabled, seed_and_value);
onnxruntime::HashCombine(info.enable, seed_and_value);
onnxruntime::HashCombine(info.tuning_enable, seed_and_value);
return seed_and_value;
}
};

View file

@ -173,7 +173,8 @@ struct ROCM_Provider : Provider {
info.has_user_compute_stream = params->has_user_compute_stream;
info.user_compute_stream = params->user_compute_stream;
info.default_memory_arena_cfg = params->default_memory_arena_cfg;
info.tunable_op.enabled = params->tunable_op_enabled;
info.tunable_op.enable = params->tunable_op_enable;
info.tunable_op.tuning_enable = params->tunable_op_tuning_enable;
return std::make_shared<ROCMProviderFactory>(info);
}

View file

@ -82,16 +82,30 @@ RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* i
void RocmTuningContext::EnableTunableOp() {
LOGS_DEFAULT(INFO) << "Enable TunableOp for ROCm Execution Provider";
info_->enabled = true;
info_->enable = true;
}
void RocmTuningContext::DisableTunableOp() {
LOGS_DEFAULT(INFO) << "Disable TunableOp for ROCm Execution Provider";
info_->enabled = false;
info_->enable = false;
}
bool RocmTuningContext::IsTunableOpEnabled() const {
return info_->enabled;
return info_->enable;
}
void RocmTuningContext::EnableTuning() {
LOGS_DEFAULT(INFO) << "Enable TunableOp tuning for ROCm Execution Provider";
info_->tuning_enable = true;
}
void RocmTuningContext::DisableTuning() {
LOGS_DEFAULT(INFO) << "Disable TunableOp tuning for ROCm Execution Provider";
info_->tuning_enable = false;
}
bool RocmTuningContext::IsTuningEnabled() const {
return info_->tuning_enable;
}
TuningResultsManager& RocmTuningContext::GetTuningResultsManager() {

View file

@ -37,6 +37,10 @@ class RocmTuningContext : public ITuningContext {
void DisableTunableOp() override;
bool IsTunableOpEnabled() const override;
void EnableTuning() override;
void DisableTuning() override;
bool IsTuningEnabled() const override;
TuningResultsManager& GetTuningResultsManager() override;
const TuningResultsManager& GetTuningResultsManager() const override;

View file

@ -1568,7 +1568,7 @@ common::Status InferenceSession::Initialize() {
ORT_RETURN_IF_ERROR_SESSIONID_(inference_session_utils::ParseTuningResultsFromModelMetadata(
model_metadata_, tuning_results, found_tuning_results));
if (found_tuning_results) {
ORT_RETURN_IF_ERROR_SESSIONID_(SetTuningResults(tuning_results));
ORT_RETURN_IF_ERROR_SESSIONID_(SetTuningResults(tuning_results, /*error_on_invalid*/false, /*auto_enable*/true));
}
#endif // !defined(ORT_MINIMAL_BUILD)
@ -2268,7 +2268,10 @@ std::vector<TuningResults> InferenceSession::GetTuningResults() const {
return ret;
}
Status InferenceSession::SetTuningResults(const std::vector<TuningResults>& trs, bool error_on_invalid) {
Status InferenceSession::SetTuningResults(
const std::vector<TuningResults>& trs,
bool error_on_invalid,
bool auto_enable) {
std::string msg;
for (size_t i = 0; i < trs.size(); i++) {
@ -2294,6 +2297,12 @@ Status InferenceSession::SetTuningResults(const std::vector<TuningResults>& trs,
msg = MakeString("Failed to load TuningResults (index=", i, "). Reason: ", status.ErrorMessage());
ORT_RETURN_IF(error_on_invalid, msg);
LOGS(*session_logger_, WARNING) << msg;
continue;
}
if (auto_enable) {
LOGS(*session_logger_, INFO) << "Correctly set TuningResults for " << tr.ep << ", enable TunableOp for using";
tuning_ctx->EnableTunableOp();
}
}
return Status::OK();

View file

@ -458,9 +458,12 @@ class InferenceSession {
* Set the TuningResults back to each execution provider. Mainly for offline tuning.
* @param trs is the list of TuningResults to be loaded.
* @param error_on_invalid otherwise, validation faliure is not an error, only a warning log will be produced.
* @param auto_enable if true, automatically enable tunable op usage (but not tuning) if the TuningResults is
correctly loaded
* @return OK if success.
*/
Status SetTuningResults(const std::vector<TuningResults>& trs, bool error_on_invalid = false);
Status SetTuningResults(const std::vector<TuningResults>& trs, bool error_on_invalid = false,
bool auto_enable = false);
#endif
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)

View file

@ -245,7 +245,7 @@ Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& met
}
key_found = true;
LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while running the model";
LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while loading the model";
Status status;
ORT_TRY {

View file

@ -60,7 +60,7 @@ class FastGeluTunable : public IKernelExplorer {
FastGeluTunable(DeviceArray& input, DeviceArray& bias, DeviceArray& output, int input_length, int bias_length)
: params_(TuningContext(), Stream(), static_cast<T*>(input.ptr()), static_cast<T*>(bias.ptr()),
static_cast<T*>(output.ptr()), input_length, bias_length) {
params_.TuningContext()->EnableTunableOp();
params_.TuningContext()->EnableTunableOpAndTuning();
}
void Run() override {

View file

@ -47,7 +47,7 @@ class GemmFastGeluTunable : public IKernelExplorer {
params_.c = static_cast<T*>(c.ptr());
params_.ldc = ldc;
params_.TuningContext()->EnableTunableOp();
params_.TuningContext()->EnableTunableOpAndTuning();
}
~GemmFastGeluTunable() {

View file

@ -242,7 +242,7 @@ class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplor
GemmSoftmaxGemmPermuteGenericPipeline<T>::GetWorkspaceNumBytes(&this->attn_),
GemmSoftmaxGemmPermuteTunableOp<T>::GetWorkspaceNumBytes(&this->attn_)));
this->params_.TuningContext()->EnableTunableOp();
this->params_.TuningContext()->EnableTunableOpAndTuning();
}
std::vector<std::string> ListOps() const {

View file

@ -47,7 +47,7 @@ class GemmTunable : public IKernelExplorer {
params_.c = static_cast<T*>(c.ptr());
params_.ldc = ldc;
params_.TuningContext()->EnableTunableOp();
params_.TuningContext()->EnableTunableOpAndTuning();
}
~GemmTunable() {
@ -108,7 +108,7 @@ class BatchedGemmTunable : public IBatchedGemmKernelExplorer<T> {
params_.ldc = ldc;
params_.batch = batch;
params_.TuningContext()->EnableTunableOp();
params_.TuningContext()->EnableTunableOpAndTuning();
}
~BatchedGemmTunable() {
@ -170,7 +170,7 @@ class StridedBatchedGemmTunable : public IKernelExplorer {
params_.stride_c = stride_c;
params_.batch = batch;
params_.TuningContext()->EnableTunableOp();
params_.TuningContext()->EnableTunableOpAndTuning();
}
~StridedBatchedGemmTunable() {

View file

@ -93,7 +93,7 @@ class SkipLayerNormTunable : public IKernelExplorer {
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {
params_.TuningContext()->EnableTunableOp();
params_.TuningContext()->EnableTunableOpAndTuning();
}
void Run() override {

View file

@ -106,7 +106,7 @@ class SoftmaxTunable : public IKernelExplorer {
int input_stride, int output_stride, int batch_count, bool is_log_softmax)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) {
params_.TuningContext()->EnableTunableOp();
params_.TuningContext()->EnableTunableOpAndTuning();
}
void Run() override {

View file

@ -128,7 +128,7 @@ class VectorAddTunable : public IKernelExplorer {
params_.z = static_cast<T*>(z.ptr());
params_.n = n;
params_.TuningContext()->EnableTunableOp();
params_.TuningContext()->EnableTunableOpAndTuning();
}
void Run() override {

View file

@ -49,9 +49,13 @@ class TestTuningContext : public ITuningContext {
public:
using ITuningContext::ITuningContext;
void EnableTunableOp() override { tuning_enabled_ = true; }
void DisableTunableOp() override { tuning_enabled_ = false; }
bool IsTunableOpEnabled() const override { return tuning_enabled_; }
void EnableTunableOp() override { op_enabled_ = true; }
void DisableTunableOp() override { op_enabled_ = false; }
bool IsTunableOpEnabled() const override { return op_enabled_; }
void EnableTuning() override { tuning_enabled_ = true; }
void DisableTuning() override { tuning_enabled_ = false; }
bool IsTuningEnabled() const override { return tuning_enabled_; }
TuningResultsManager& GetTuningResultsManager() override { return manager_; }
const TuningResultsManager& GetTuningResultsManager() const override { return manager_; }
@ -61,6 +65,7 @@ class TestTuningContext : public ITuningContext {
void ClearCache() { manager_.Clear(); }
private:
bool op_enabled_{false};
bool tuning_enabled_{false};
TuningResultsManager manager_{};
TestTuningResultsValidator validator_{};
@ -374,7 +379,7 @@ class TunableVecAddSelectFast : public TunableOp<VecAddParamsRecordLastRun> {
constexpr static int kFastFullId = 1;
};
TEST(TunableOp, SelectFast) {
TEST(TunableOp, SelectFastIfTuning) {
#ifdef ORT_NO_RTTI
GTEST_SKIP() << "TunableOp needs RTTI to work correctly";
#else
@ -386,10 +391,16 @@ TEST(TunableOp, SelectFast) {
params.last_run = &last_run;
TunableVecAddSelectFast op{};
// Only enable op usage, slow (default) should be selected
params.TuningContext()->EnableTunableOp();
auto status = op(&params);
ASSERT_TRUE(status.IsOK());
ASSERT_EQ(last_run, "SlowFull");
// Also enable tuning, fast should be selected
params.TuningContext()->EnableTuning();
status = op(&params);
ASSERT_TRUE(status.IsOK());
ASSERT_EQ(last_run, "FastFull");
#endif
}
@ -414,7 +425,7 @@ TEST(TunableOp, SelectSupported) {
params.last_run = &last_run;
TunableVecAddSelectSupported op{};
params.TuningContext()->EnableTunableOp();
params.TuningContext()->EnableTunableOpAndTuning();
auto status = op(&params);
ASSERT_TRUE(status.IsOK());
@ -445,7 +456,7 @@ TEST(TunableOp, SelectFastestIfSupported) {
params.last_run = &last_run;
TunableVecAddSelectFastestIfSupported op{};
params.TuningContext()->EnableTunableOp();
params.TuningContext()->EnableTunableOpAndTuning();
auto status = op(&params);
ASSERT_TRUE(status.IsOK());
@ -530,7 +541,7 @@ TEST(TunableOp, HandleInplaceUpdate) {
c = 4200;
VecAddParamsRecordLastRun params(&a, &b, &c, 1, /*beta=*/0);
TunableVecAddNotHandleInplaceUpdate op_not_handle_inplace_update{};
params.TuningContext()->EnableTunableOp();
params.TuningContext()->EnableTunableOpAndTuning();
auto status = op_not_handle_inplace_update(&params);
ASSERT_TRUE(status.IsOK());
ASSERT_EQ(c, 7500042);
@ -541,7 +552,7 @@ TEST(TunableOp, HandleInplaceUpdate) {
c = 4200;
VecAddParamsRecordLastRun params(&a, &b, &c, 1, /*beta=*/1);
TunableVecAddNotHandleInplaceUpdate op_not_handle_inplace_update{};
params.TuningContext()->EnableTunableOp();
params.TuningContext()->EnableTunableOpAndTuning();
auto status = op_not_handle_inplace_update(&params);
ASSERT_TRUE(status.IsOK());
ASSERT_NE(c, 4200); // value should be changed
@ -553,7 +564,7 @@ TEST(TunableOp, HandleInplaceUpdate) {
c = 4200;
VecAddParamsRecordLastRun params(&a, &b, &c, 1, /*beta=*/0);
TunableVecAddHandleInplaceUpdate op{};
params.TuningContext()->EnableTunableOp();
params.TuningContext()->EnableTunableOpAndTuning();
auto status = op(&params);
ASSERT_TRUE(status.IsOK());
ASSERT_EQ(c, 7500042);
@ -565,7 +576,7 @@ TEST(TunableOp, HandleInplaceUpdate) {
c = 4200;
VecAddParamsRecordLastRun params(&a, &b, &c, 1, /*beta=*/1);
TunableVecAddHandleInplaceUpdate op{};
params.TuningContext()->EnableTunableOp();
params.TuningContext()->EnableTunableOpAndTuning();
auto status = op(&params);
ASSERT_TRUE(status.IsOK());
ASSERT_EQ(c, 7504242);
@ -629,7 +640,7 @@ TEST(TuningContext, TunableOpRespectTuningContext) {
tuning::TunableVecAddSelectFast op{};
auto* ctx = params.TuningContext();
auto& mgr = ctx->GetTuningResultsManager();
ctx->EnableTunableOp();
ctx->EnableTunableOpAndTuning();
{
// Before TunableOp(...), there is no entry in it.
@ -683,7 +694,7 @@ TEST(TuningContext, GetAndLoadTuningResults) {
tuning::TunableVecAddSelectFast op{};
auto* ctx = params.TuningContext();
ctx->EnableTunableOp();
ctx->EnableTunableOpAndTuning();
auto status = op(&params);
ASSERT_TRUE(status.IsOK());

View file

@ -245,7 +245,9 @@ class TestInferenceSession(unittest.TestCase):
test_get_and_set_option_with_values("do_copy_in_default_stream", [0, 1])
test_get_and_set_option_with_values("tunable_op_enabled", ["1", "0"])
test_get_and_set_option_with_values("tunable_op_enable", ["1", "0"])
test_get_and_set_option_with_values("tunable_op_tuning_enable", ["1", "0"])
option["gpu_external_alloc"] = "0"
option["gpu_external_free"] = "0"
@ -379,7 +381,9 @@ class TestInferenceSession(unittest.TestCase):
str(option_value),
)
test_get_and_set_option_with_values("tunable_op_enabled", ["1", "0"])
test_get_and_set_option_with_values("tunable_op_enable", ["1", "0"])
test_get_and_set_option_with_values("tunable_op_tuning_enable", ["1", "0"])
runRocmOptionsTest()

View file

@ -179,7 +179,8 @@ std::unique_ptr<IExecutionProvider> DefaultRocmExecutionProvider(bool test_tunab
#ifdef USE_ROCM
OrtROCMProviderOptions provider_options{};
provider_options.do_copy_in_default_stream = true;
provider_options.tunable_op_enabled = test_tunable_op ? 1 : 0;
provider_options.tunable_op_enable = test_tunable_op ? 1 : 0;
provider_options.tunable_op_tuning_enable = test_tunable_op ? 1 : 0;
if (auto factory = RocmProviderFactoryCreator::Create(&provider_options))
return factory->CreateProvider();
#endif