From 0e827c27fb143dcbe64fd286cd8cf0a1df900551 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Wed, 21 Aug 2024 10:32:12 -0400 Subject: [PATCH] [MIGraphX EP] Add support for MIGraphX Exhaustive tune flag (#46) (#21599) ### Description Set the exhaustive tune flag through the MIGraphX API and make this a Session option in Onnxruntime ### Motivation and Context Allow users to use MIGraphX Exhaustive tuning with Onnxruntime inferences This goers hand in hand with save/load after a model and been compiled and tuning has found. --------- Co-authored-by: Ted Themistokleous Co-authored-by: Tianlei Wu --- .../onnxruntime/core/session/onnxruntime_c_api.h | 1 + .../migraphx/migraphx_execution_provider.cc | 9 +++++++++ .../migraphx/migraphx_execution_provider.h | 3 +++ .../migraphx/migraphx_execution_provider_info.cc | 4 ++++ .../migraphx/migraphx_execution_provider_info.h | 1 + .../providers/migraphx/migraphx_provider_factory.cc | 2 ++ onnxruntime/python/onnxruntime_pybind_state.cc | 13 ++++++++++++- onnxruntime/test/util/default_providers.cc | 3 ++- 8 files changed, 34 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 234574503c..4674db42fb 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -621,6 +621,7 @@ typedef struct OrtMIGraphXProviderOptions { const char* migraphx_save_model_path; // migraphx model path name int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true const char* migraphx_load_model_path; // migraphx model path name + bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index be9f1bd681..90dfa49c73 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -182,6 +182,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); } + // Allow for exhaustive tune during compile + const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune); + if (!exhaustive_tune_env.empty()) { + exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true); + } + metadef_id_generator_ = ModelMetadefIdGenerator::Create(); LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " @@ -190,6 +196,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv << ", migraphx_int8_enable: " << int8_enable_ << ", migraphx_int8_enable: " << int8_enable_ << ", dump_model_ops: " << dump_model_ops_ + << ", exhaustive_tune: " << exhaustive_tune_ << ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ << ", int8_calibration_cache_available: " << int8_calibration_cache_available_ << ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ @@ -1181,6 +1188,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::compile_options co; co.set_fast_math(false); + co.set_exhaustive_tune_flag(exhaustive_tune_); LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; prog.compile(t_, co); LOGS_DEFAULT(INFO) << "Model Compile: Complete" << std::endl; @@ -1345,6 +1353,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; migraphx::compile_options co; co.set_fast_math(false); + co.set_exhaustive_tune_flag(exhaustive_tune_); prog.compile(t, co); save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 21b582de8f..21679d1f6f 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -26,6 +26,7 @@ static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"; static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH"; static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"; static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH"; +static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"; }; // namespace migraphx_env_vars @@ -50,6 +51,7 @@ struct MIGraphXFuncState { bool load_compiled_mode = false; std::string load_compiled_path; bool dump_model_ops = false; + bool exhaustive_tune = false; }; // Logical device representation. @@ -101,6 +103,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { migraphx::target t_; OrtMutex mgx_mu_; hipStream_t stream_ = nullptr; + bool exhaustive_tune_ = false; mutable std::filesystem::path model_path_; std::unordered_map map_progs_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index 2a135b7324..1f9a47d3ad 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -21,6 +21,7 @@ constexpr const char* kSaveCompiledModel = "migx_save_compiled_model"; constexpr const char* kSaveModelPath = "migx_save_model_name"; constexpr const char* kLoadCompiledModel = "migx_load_compiled_model"; constexpr const char* kLoadModelPath = "migx_load_model_name"; +constexpr const char* kExhaustiveTune = "migx_exhaustive_tune"; } // namespace provider_option_names } // namespace migraphx @@ -45,6 +46,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) + .AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune) .Parse(options)); return info; @@ -57,6 +59,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, + {migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)}, }; return options; } @@ -68,6 +71,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, + {migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)}, }; return options; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 68d5d9af98..b8bf86580f 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -23,6 +23,7 @@ struct MIGraphXExecutionProviderInfo { std::string save_model_file{"./compiled_model.mxr"}; bool load_compiled_model{true}; std::string load_model_file{"./compiled_model.mxr"}; + bool exhaustive_tune{false}; static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 6d19993011..7b192b657b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -60,6 +60,7 @@ struct MIGraphX_Provider : Provider { info.device_id = static_cast(options.device_id); info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; + info.exhaustive_tune = options.migraphx_exhaustive_tune; info.int8_enable = options.migraphx_int8_enable; info.int8_calibration_table_name = ""; if (options.migraphx_int8_calibration_table_name != nullptr) { @@ -85,6 +86,7 @@ struct MIGraphX_Provider : Provider { migx_options.device_id = internal_options.device_id; migx_options.migraphx_fp16_enable = internal_options.fp16_enable; migx_options.migraphx_int8_enable = internal_options.int8_enable; + migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; char* dest = nullptr; auto str_size = internal_options.int8_calibration_table_name.size(); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index ffcd339c0c..47b8d75f22 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -844,7 +844,8 @@ std::unique_ptr CreateExecutionProviderInstance( 1, "./compiled_model.mxr", 1, - "./compiled_model.mxr"}; + "./compiled_model.mxr", + 1}; for (auto option : it->second) { if (option.first == "device_id") { if (!option.second.empty()) { @@ -929,6 +930,16 @@ std::unique_ptr CreateExecutionProviderInstance( "[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a " "file name i.e. 'compiled_model.mxr'.\n"); } + } else if (option.first == "migraphx_exhaustive_tune") { + if (option.second == "True" || option.second == "true") { + params.migraphx_exhaustive_tune = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_exhaustive_tune = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migraphx_exhaustive_tune' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } } else { ORT_THROW("Invalid MIGraphX EP option: ", option.first); } diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 312aa86277..1feba20e32 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -80,7 +80,8 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { 1, "./compiled_model.mxr", 1, - "./compiled_model.mxr"}; + "./compiled_model.mxr", + 1}; return MIGraphXProviderFactoryCreator::Create(¶ms)->CreateProvider(); #else return nullptr;