From 11e7a1b8f2f3acbde33b922bdf8013d992380d1d Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Sun, 16 Jun 2024 23:24:31 -0400 Subject: [PATCH] [MIGraphX EP] Add migraphx ep save load compiles (#20643) ### Description Adds the ability for MIGraphX EP to save off or load compiled models to save time between inferences. Via Command line User should be able to set the save ability with ORT_MIGRAPHX_SAVE_COMPILED_MODEL ORT_MIGRAPHX_SAVE_COMPILE_PATH User should be able to set the load ability with ORT_MIGRAPHX_LOAD_COMPILED_MODEL ORT_MIGRAPHX_LOAD_COMPILE_PATH via Onnxruntime API migx_save_compiled_model migx_save_model_name migx_load_compiled_model migx_load_model_name ### Motivation and Context The motivation for this is to leverage MIGraphX's existing API to save/load models after our compile step of graph optimization. For larger models or models which were compiled with additional tuning steps, this saves time after first compile and inference run, and thus speeds up the user experience in order to encourage development. --------- Co-authored-by: Ted Themistokleous --- .../core/session/onnxruntime_c_api.h | 4 + .../migraphx/migraphx_execution_provider.cc | 203 ++++++++++++++---- .../migraphx/migraphx_execution_provider.h | 13 ++ .../migraphx_execution_provider_info.cc | 10 + .../migraphx_execution_provider_info.h | 4 + .../migraphx/migraphx_provider_factory.cc | 15 ++ .../python/onnxruntime_pybind_state.cc | 46 +++- onnxruntime/test/util/default_providers.cc | 6 +- 8 files changed, 254 insertions(+), 47 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 16701f2e0d..5c61963a2f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -617,6 +617,10 @@ typedef struct OrtMIGraphXProviderOptions { int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name + int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true + const char* migraphx_save_model_path; // migraphx model path name + int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true + const char* migraphx_load_model_path; // migraphx model path name } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 9acbb9c17e..581376623f 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -153,6 +153,28 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv } } + // Save/load migraphx compiled models + const std::string save_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSaveCompiledModel); + if (!save_comp_model_env.empty()) { + save_compiled_model_ = (std::stoi(save_comp_model_env) == 0 ? false : true); + } + + const std::string save_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSavedModelPath); + + if (save_compiled_model_ && !save_model_path_env.empty()) { + save_compiled_path_ = save_model_path_env; + } + + const std::string load_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadCompiledModel); + if (!load_comp_model_env.empty()) { + load_compiled_model_ = (std::stoi(load_comp_model_env) == 0 ? false : true); + } + + const std::string load_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadModelPath); + if (load_compiled_model_ && !load_model_path_env.empty()) { + load_compiled_path_ = load_model_path_env; + } + // dump unsupported ops const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps); if (!dump_model_ops_env.empty()) { @@ -171,10 +193,15 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv << "device_id: " << device_id_ << ", migraphx_fp16_enable: " << fp16_enable_ << ", migraphx_int8_enable: " << int8_enable_ + << ", migraphx_int8_enable: " << int8_enable_ << ", dump_model_ops: " << dump_model_ops_ << ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ << ", int8_calibration_cache_available: " << int8_calibration_cache_available_ - << ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_; + << ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ + << ", migraphx_save_compiled_model: " << save_compiled_model_ + << ", migraphx_save_compiled_model_path: " << save_compiled_path_ + << ", migraphx_load_compiled_model: " << load_compiled_model_ + << ", migraphx_load_compiled_model_path: " << load_compiled_path_; } MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { @@ -265,7 +292,7 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, break; default: LOGS_DEFAULT(WARNING) << "MiGraphx: unsupported data type " << type << ", fallback to CPU"; - LOGS_DEFAULT(WARNING) << "implementation" << std::endl; + LOGS_DEFAULT(WARNING) << "implementation"; return false; } @@ -1008,11 +1035,11 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v result.push_back(ComputeCapability::Create(std::move(sub_graph))); } else { // unsupported_nodes_idx.empty() if (dump_model_ops_) { - LOGS_DEFAULT(INFO) << "============= Unsupported nodes ====================" << std::endl; + LOGS_DEFAULT(INFO) << "============= Unsupported nodes ===================="; for (auto idx : unsupported_nodes) { LOGS_DEFAULT(INFO) << graph_viewer.GetNode(idx)->OpType() << std::endl; } - LOGS_DEFAULT(INFO) << "************* Unsupported nodes ********************" << std::endl; + LOGS_DEFAULT(INFO) << "************* Unsupported nodes ********************"; } if (unsupported_nodes.size() > 10) { @@ -1087,6 +1114,34 @@ bool get_input_output_names(const GraphViewer& graph, return no_input_shape; } +// Attempt to load a model and catch any exceptions on load fail. +// Useful to default to EP to trigger the compile if file doesn't exist or loading fails. +bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::string path) { + try { + if (load_enable) { + LOGS_DEFAULT(INFO) << "Attempting to load model at:" << path; + prog = migraphx::load(path.c_str()); + LOGS_DEFAULT(INFO) << "load model : Success"; + return true; + } else { + return false; + } + } catch (...) { + return false; + } + return false; +} + +void save_compiled_model(migraphx::program& prog, bool save_enable, std::string out_path) { + if (save_enable) { + LOGS_DEFAULT(INFO) << "Model Save at " << out_path << ": Begin" << std::endl; + migraphx::file_options fo; + fo.set_file_format("msgpack"); + migraphx::save(prog, out_path.c_str(), fo); + LOGS_DEFAULT(INFO) << "Model Save: Complete" << std::endl; + } +} + Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { migraphx::onnx_options options; @@ -1117,39 +1172,56 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } std::vector input_names, output_names; - no_input_shape = get_input_output_names(graph_body_viewer, input_names, output_names); + no_input_shape = no_input_shape or get_input_output_names(graph_body_viewer, input_names, output_names); // by parsing the model_proto, create a program corresponding to // the input fused_node migraphx::program prog; if (!no_input_shape) { - prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); - if (fp16_enable_) { - migraphx::quantize_fp16(prog); - } + if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { + LOGS_DEFAULT(INFO) << "No Input shapes detected quantizing model"; + prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); - // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable_ && int8_calibration_cache_available_) { - migraphx::quantize_int8_options quant_opts; - migraphx::program_parameters quant_params; + // Read in the calibration data and map it to an migraphx paramater map for the calibration ops + if (int8_enable_ && int8_calibration_cache_available_) { + LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl; + migraphx::quantize_int8_options quant_opts; + migraphx::program_parameters quant_params; - auto param_shapes = prog.get_parameter_shapes(); + auto param_shapes = prog.get_parameter_shapes(); - for (auto&& name : param_shapes.names()) { - auto dynamic_range_i = dynamic_range_map.find(name); - if (dynamic_range_i != dynamic_range_map.end()) { - quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second))); + // Add all calibration data read in from int8 table + for (auto& [cal_key, cal_val] : dynamic_range_map) { + auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); + quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); } + quant_opts.add_calibration_data(quant_params); + + // specify thing we want to int8 quantize + quant_opts.add_op_name("convolution"); + quant_opts.add_op_name("dot"); + + // perform static quantization on the programs + migraphx::quantize_int8(prog, t_, quant_opts); + LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete" << std::endl; } - quant_opts.add_calibration_data(quant_params); - // perform static quantization on the programs - migraphx::quantize_int8(prog, t_, quant_opts); + if (fp16_enable_) { + LOGS_DEFAULT(INFO) << "Quantizing input program to fp16" << std::endl; + migraphx::quantize_fp16(prog); + LOGS_DEFAULT(INFO) << "Quantizing input program to fp16: Complete" << std::endl; + } + + migraphx::compile_options co; + co.set_fast_math(false); + LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; + prog.compile(t_, co); + LOGS_DEFAULT(INFO) << "Model Compile: Complete" << std::endl; + + save_compiled_model(prog, save_compiled_model_, save_compiled_path_); } - migraphx::compile_options co; - co.set_fast_math(false); - prog.compile(t_, co); + auto prog_output_shapes = prog.get_output_shapes(); for (std::size_t i = 0; i < output_names.size(); ++i) { auto out_len = prog_output_shapes[i].lengths(); @@ -1169,7 +1241,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, - int8_calibration_cache_available_, dynamic_range_map, dump_model_ops_}; + int8_calibration_cache_available_, dynamic_range_map, + save_compiled_model_, save_compiled_path_, + load_compiled_model_, load_compiled_path_, dump_model_ops_}; *state = p.release(); return 0; }; @@ -1199,6 +1273,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& bool input_shape_match = true; migraphx::program_parameter_shapes param_shapes; if (no_input_shape) { + LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again" << std::endl; for (auto& it : map_input_name_index) { auto& name = it.first; auto& index = it.second; @@ -1210,6 +1285,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& input_shape_match = false; } } else { + LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model" << std::endl; param_shapes = prog.get_parameter_shapes(); auto prog_output_shapes = prog.get_output_shapes(); @@ -1243,33 +1319,67 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // input shapes are different, needs to re-parse onnx and // re-compile the program if (!input_shape_match) { - prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); - if (fp16_enable) { - migraphx::quantize_fp16(prog); - } + if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { + LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl; + prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); - // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable && int8_calibration_cache_available) { - migraphx::quantize_int8_options quant_opts; - migraphx::program_parameters quant_params; + // Read in the calibration data and map it to an migraphx paramater map for the calibration ops + if (int8_enable && int8_calibration_cache_available) { + LOGS_DEFAULT(INFO) << "Quantize Int8: Begin" << std::endl; + migraphx::quantize_int8_options quant_opts; + migraphx::program_parameters quant_params; - auto param_shapes = prog.get_parameter_shapes(); + auto param_shapes = prog.get_parameter_shapes(); - for (auto&& name : param_shapes.names()) { - auto dynamic_range_i = map_dynamic_range.find(name); - if (dynamic_range_i != map_dynamic_range.end()) { - quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second))); + // Add input parameter data and the values they're set to + for (auto&& name : param_shapes.names()) { + if (map_input_name_index.count(name) > 0) { + auto input_tensor = ctx.GetInput(map_input_name_index[name]); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shape = tensor_info.GetShape(); + const auto tensor_type = tensor_info.GetElementType(); + + migraphx_shape_datatype_t mgx_type; + getMIGraphXType(tensor_type, mgx_type); + auto mgx_s = param_shapes[name]; + + if (mgx_type != mgx_s.type()) { + LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; + } + quant_params.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); + } } + + // Add all calibration data read in from int8 table + for (auto& [cal_key, cal_val] : map_dynamic_range) { + auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); + quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); + } + quant_opts.add_calibration_data(quant_params); + + // specify thing we want to int8 quantize + quant_opts.add_op_name("convolution"); + quant_opts.add_op_name("dot"); + + // perform static quantization on the programs + migraphx::quantize_int8(prog, t, quant_opts); + LOGS_DEFAULT(INFO) << "Quantize Int8: Completed" << std::endl; } - quant_opts.add_calibration_data(quant_params); - // perform static quantization on the programs - migraphx::quantize_int8(prog, t, quant_opts); + if (fp16_enable) { + LOGS_DEFAULT(INFO) << "Quantize fp16: Begin" << std::endl; + migraphx::quantize_fp16(prog); + LOGS_DEFAULT(INFO) << "Quantize fp16: Completed" << std::endl; + } + + LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; + migraphx::compile_options co; + co.set_fast_math(false); + prog.compile(t, co); + + save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); } - migraphx::compile_options co; - co.set_fast_math(false); - prog.compile(t, co); mgx_state->prog = prog; param_shapes = prog.get_parameter_shapes(); no_input_shape = false; @@ -1281,6 +1391,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (param_shapes.size() > 0) { for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { + LOGS_DEFAULT(INFO) << "Setting parameters for:" << name << std::endl; auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); @@ -1293,6 +1404,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (mgx_type != mgx_s.type()) { LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } + + LOGS_DEFAULT(INFO) << "Writing Raw tensor data " << std::endl; m.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } @@ -1353,7 +1466,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& HIP_CALL_THROW(hipMemcpy(output_data, gpu_res.data(), res_shape.bytes(), hipMemcpyDeviceToDevice)); } } - } + }; return Status::OK(); }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index c3617f409e..1977f71b8b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -26,6 +26,11 @@ static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; +static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"; +static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH"; +static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"; +static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH"; + }; // namespace migraphx_env_vars // Information to construct kernel function state. @@ -44,6 +49,10 @@ struct MIGraphXFuncState { bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; + bool save_compiled_mode = false; + std::string save_compiled_path; + bool load_compiled_mode = false; + std::string load_compiled_path; bool dump_model_ops = false; }; @@ -84,6 +93,10 @@ class MIGraphXExecutionProvider : public IExecutionProvider { bool int8_use_native_migraphx_calibration_table_ = false; std::string calibration_cache_path_; std::unordered_map dynamic_range_map; + bool save_compiled_model_ = false; + std::string save_compiled_path_; + bool load_compiled_model_ = false; + std::string load_compiled_path_; bool dump_model_ops_ = false; int device_id_; migraphx::target t_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index b7d7a77853..2a135b7324 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -17,6 +17,10 @@ constexpr const char* kFp16Enable = "trt_fp16_enable"; constexpr const char* kInt8Enable = "migx_int8_enable"; constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; +constexpr const char* kSaveCompiledModel = "migx_save_compiled_model"; +constexpr const char* kSaveModelPath = "migx_save_model_name"; +constexpr const char* kLoadCompiledModel = "migx_load_compiled_model"; +constexpr const char* kLoadModelPath = "migx_load_model_name"; } // namespace provider_option_names } // namespace migraphx @@ -39,6 +43,8 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions }) .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) + .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) + .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) .Parse(options)); return info; @@ -49,6 +55,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, + {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, + {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, }; return options; } @@ -58,6 +66,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, + {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, + {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, }; return options; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 18ac30fdc1..8411e3eef0 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -19,6 +19,10 @@ struct MIGraphXExecutionProviderInfo { bool int8_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; + bool save_compiled_model{true}; + std::string save_model_file{"./compiled_model.mxr"}; + bool load_compiled_model{true}; + std::string load_model_file{"./compiled_model.mxr"}; static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index f985682ddc..dd24dbdc76 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -53,6 +53,16 @@ struct MIGraphX_Provider : Provider { info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name; } info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0; + info.save_compiled_model = options.migraphx_save_compiled_model; + info.save_model_file = ""; + if (options.migraphx_save_model_path != nullptr) { + info.save_model_file = options.migraphx_save_model_path; + } + info.load_compiled_model = options.migraphx_load_compiled_model; + info.load_model_file = ""; + if (options.migraphx_load_model_path != nullptr) { + info.load_model_file = options.migraphx_load_model_path; + } return std::make_shared(info); } @@ -79,6 +89,11 @@ struct MIGraphX_Provider : Provider { } migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; + + migx_options.migraphx_save_compiled_model = internal_options.save_compiled_model; + migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str(); + migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model; + migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str(); } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index b1784f700d..857498b7e6 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -823,6 +823,8 @@ std::unique_ptr CreateExecutionProviderInstance( } else if (type == kMIGraphXExecutionProvider) { #ifdef USE_MIGRAPHX std::string calibration_table; + std::string save_model_path; + std::string load_model_path; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { OrtMIGraphXProviderOptions params{ @@ -830,7 +832,11 @@ std::unique_ptr CreateExecutionProviderInstance( 0, 0, 0, - nullptr}; + nullptr, + 1, + "./compiled_model.mxr", + 1, + "./compiled_model.mxr"}; for (auto option : it->second) { if (option.first == "device_id") { if (!option.second.empty()) { @@ -877,6 +883,44 @@ std::unique_ptr CreateExecutionProviderInstance( "[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be" " 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "migraphx_save_compiled_model") { + if (option.second == "True" || option.second == "true") { + params.migraphx_fp16_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_fp16_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_save_compiled_model' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } + } else if (option.first == "migraphx_save_model_path") { + if (!option.second.empty()) { + save_model_path = option.second; + params.migraphx_save_model_path = save_model_path.c_str(); + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_save_model_name' should be a " + "file name i.e. 'compiled_model.mxr'.\n"); + } + } else if (option.first == "migraphx_load_compiled_model") { + if (option.second == "True" || option.second == "true") { + params.migraphx_fp16_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_fp16_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_load_compiled_model' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } + } else if (option.first == "migraphx_load_model_path") { + if (!option.second.empty()) { + load_model_path = option.second; + params.migraphx_load_model_path = load_model_path.c_str(); + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a " + "file name i.e. 'compiled_model.mxr'.\n"); + } } else { ORT_THROW("Invalid MIGraphX EP option: ", option.first); } diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index e353cc73b2..6f07385729 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -76,7 +76,11 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { 0, 0, 0, - nullptr}; + nullptr, + 1, + "./compiled_model.mxr", + 1, + "./compiled_model.mxr"}; return MIGraphXProviderFactoryCreator::Create(¶ms)->CreateProvider(); #else return nullptr;