mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
### Description <!-- Describe your changes. --> Set the exhaustive tune flag through the MIGraphX API and make this a Session option in Onnxruntime ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Allow users to use MIGraphX Exhaustive tuning with Onnxruntime inferences This goers hand in hand with save/load after a model and been compiled and tuning has found. --------- Co-authored-by: Ted Themistokleous <tedthemistokleous@amd.com> Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
This commit is contained in:
parent
26a499323f
commit
0e827c27fb
8 changed files with 34 additions and 2 deletions
|
|
@ -621,6 +621,7 @@ typedef struct OrtMIGraphXProviderOptions {
|
|||
const char* migraphx_save_model_path; // migraphx model path name
|
||||
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
|
||||
const char* migraphx_load_model_path; // migraphx model path name
|
||||
bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false
|
||||
} OrtMIGraphXProviderOptions;
|
||||
|
||||
/** \brief OpenVINO Provider Options
|
||||
|
|
|
|||
|
|
@ -182,6 +182,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
|
|||
dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true);
|
||||
}
|
||||
|
||||
// Allow for exhaustive tune during compile
|
||||
const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune);
|
||||
if (!exhaustive_tune_env.empty()) {
|
||||
exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true);
|
||||
}
|
||||
|
||||
metadef_id_generator_ = ModelMetadefIdGenerator::Create();
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: "
|
||||
|
|
@ -190,6 +196,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
|
|||
<< ", migraphx_int8_enable: " << int8_enable_
|
||||
<< ", migraphx_int8_enable: " << int8_enable_
|
||||
<< ", dump_model_ops: " << dump_model_ops_
|
||||
<< ", exhaustive_tune: " << exhaustive_tune_
|
||||
<< ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_
|
||||
<< ", int8_calibration_cache_available: " << int8_calibration_cache_available_
|
||||
<< ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_
|
||||
|
|
@ -1181,6 +1188,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
|
||||
migraphx::compile_options co;
|
||||
co.set_fast_math(false);
|
||||
co.set_exhaustive_tune_flag(exhaustive_tune_);
|
||||
LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl;
|
||||
prog.compile(t_, co);
|
||||
LOGS_DEFAULT(INFO) << "Model Compile: Complete" << std::endl;
|
||||
|
|
@ -1345,6 +1353,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl;
|
||||
migraphx::compile_options co;
|
||||
co.set_fast_math(false);
|
||||
co.set_exhaustive_tune_flag(exhaustive_tune_);
|
||||
prog.compile(t, co);
|
||||
|
||||
save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path);
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL";
|
|||
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH";
|
||||
static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL";
|
||||
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH";
|
||||
static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE";
|
||||
|
||||
}; // namespace migraphx_env_vars
|
||||
|
||||
|
|
@ -50,6 +51,7 @@ struct MIGraphXFuncState {
|
|||
bool load_compiled_mode = false;
|
||||
std::string load_compiled_path;
|
||||
bool dump_model_ops = false;
|
||||
bool exhaustive_tune = false;
|
||||
};
|
||||
|
||||
// Logical device representation.
|
||||
|
|
@ -101,6 +103,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
|
|||
migraphx::target t_;
|
||||
OrtMutex mgx_mu_;
|
||||
hipStream_t stream_ = nullptr;
|
||||
bool exhaustive_tune_ = false;
|
||||
mutable std::filesystem::path model_path_;
|
||||
|
||||
std::unordered_map<std::string, migraphx::program> map_progs_;
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ constexpr const char* kSaveCompiledModel = "migx_save_compiled_model";
|
|||
constexpr const char* kSaveModelPath = "migx_save_model_name";
|
||||
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
|
||||
constexpr const char* kLoadModelPath = "migx_load_model_name";
|
||||
constexpr const char* kExhaustiveTune = "migx_exhaustive_tune";
|
||||
|
||||
} // namespace provider_option_names
|
||||
} // namespace migraphx
|
||||
|
|
@ -45,6 +46,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
|
|||
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
|
||||
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
|
||||
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
|
||||
.AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune)
|
||||
.Parse(options));
|
||||
|
||||
return info;
|
||||
|
|
@ -57,6 +59,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
|
|||
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
|
||||
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
|
||||
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
|
||||
{migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)},
|
||||
};
|
||||
return options;
|
||||
}
|
||||
|
|
@ -68,6 +71,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
|
|||
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
|
||||
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
|
||||
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
|
||||
{migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)},
|
||||
};
|
||||
return options;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ struct MIGraphXExecutionProviderInfo {
|
|||
std::string save_model_file{"./compiled_model.mxr"};
|
||||
bool load_compiled_model{true};
|
||||
std::string load_model_file{"./compiled_model.mxr"};
|
||||
bool exhaustive_tune{false};
|
||||
|
||||
static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
|
||||
static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info);
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ struct MIGraphX_Provider : Provider {
|
|||
info.device_id = static_cast<OrtDevice::DeviceId>(options.device_id);
|
||||
info.target_device = "gpu";
|
||||
info.fp16_enable = options.migraphx_fp16_enable;
|
||||
info.exhaustive_tune = options.migraphx_exhaustive_tune;
|
||||
info.int8_enable = options.migraphx_int8_enable;
|
||||
info.int8_calibration_table_name = "";
|
||||
if (options.migraphx_int8_calibration_table_name != nullptr) {
|
||||
|
|
@ -85,6 +86,7 @@ struct MIGraphX_Provider : Provider {
|
|||
migx_options.device_id = internal_options.device_id;
|
||||
migx_options.migraphx_fp16_enable = internal_options.fp16_enable;
|
||||
migx_options.migraphx_int8_enable = internal_options.int8_enable;
|
||||
migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune;
|
||||
|
||||
char* dest = nullptr;
|
||||
auto str_size = internal_options.int8_calibration_table_name.size();
|
||||
|
|
|
|||
|
|
@ -844,7 +844,8 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|||
1,
|
||||
"./compiled_model.mxr",
|
||||
1,
|
||||
"./compiled_model.mxr"};
|
||||
"./compiled_model.mxr",
|
||||
1};
|
||||
for (auto option : it->second) {
|
||||
if (option.first == "device_id") {
|
||||
if (!option.second.empty()) {
|
||||
|
|
@ -929,6 +930,16 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|||
"[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a "
|
||||
"file name i.e. 'compiled_model.mxr'.\n");
|
||||
}
|
||||
} else if (option.first == "migraphx_exhaustive_tune") {
|
||||
if (option.second == "True" || option.second == "true") {
|
||||
params.migraphx_exhaustive_tune = true;
|
||||
} else if (option.second == "False" || option.second == "false") {
|
||||
params.migraphx_exhaustive_tune = false;
|
||||
} else {
|
||||
ORT_THROW(
|
||||
"[ERROR] [MIGraphX] The value for the key 'migraphx_exhaustive_tune' should be"
|
||||
" 'True' or 'False'. Default value is 'False'.\n");
|
||||
}
|
||||
} else {
|
||||
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,7 +80,8 @@ std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
|
|||
1,
|
||||
"./compiled_model.mxr",
|
||||
1,
|
||||
"./compiled_model.mxr"};
|
||||
"./compiled_model.mxr",
|
||||
1};
|
||||
return MIGraphXProviderFactoryCreator::Create(¶ms)->CreateProvider();
|
||||
#else
|
||||
return nullptr;
|
||||
|
|
|
|||
Loading…
Reference in a new issue