mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
[MIGraphX EP] Add migraphx ep save load compiles (#20643)
### Description Adds the ability for MIGraphX EP to save off or load compiled models to save time between inferences. Via Command line User should be able to set the save ability with ORT_MIGRAPHX_SAVE_COMPILED_MODEL ORT_MIGRAPHX_SAVE_COMPILE_PATH User should be able to set the load ability with ORT_MIGRAPHX_LOAD_COMPILED_MODEL ORT_MIGRAPHX_LOAD_COMPILE_PATH via Onnxruntime API migx_save_compiled_model migx_save_model_name migx_load_compiled_model migx_load_model_name ### Motivation and Context The motivation for this is to leverage MIGraphX's existing API to save/load models after our compile step of graph optimization. For larger models or models which were compiled with additional tuning steps, this saves time after first compile and inference run, and thus speeds up the user experience in order to encourage development. --------- Co-authored-by: Ted Themistokleous <tedthemistokleous@amd.com>
This commit is contained in:
parent
d4470fe653
commit
11e7a1b8f2
8 changed files with 254 additions and 47 deletions
|
|
@ -617,6 +617,10 @@ typedef struct OrtMIGraphXProviderOptions {
|
|||
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
|
||||
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
|
||||
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
|
||||
int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true
|
||||
const char* migraphx_save_model_path; // migraphx model path name
|
||||
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
|
||||
const char* migraphx_load_model_path; // migraphx model path name
|
||||
} OrtMIGraphXProviderOptions;
|
||||
|
||||
/** \brief OpenVINO Provider Options
|
||||
|
|
|
|||
|
|
@ -153,6 +153,28 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
|
|||
}
|
||||
}
|
||||
|
||||
// Save/load migraphx compiled models
|
||||
const std::string save_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSaveCompiledModel);
|
||||
if (!save_comp_model_env.empty()) {
|
||||
save_compiled_model_ = (std::stoi(save_comp_model_env) == 0 ? false : true);
|
||||
}
|
||||
|
||||
const std::string save_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSavedModelPath);
|
||||
|
||||
if (save_compiled_model_ && !save_model_path_env.empty()) {
|
||||
save_compiled_path_ = save_model_path_env;
|
||||
}
|
||||
|
||||
const std::string load_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadCompiledModel);
|
||||
if (!load_comp_model_env.empty()) {
|
||||
load_compiled_model_ = (std::stoi(load_comp_model_env) == 0 ? false : true);
|
||||
}
|
||||
|
||||
const std::string load_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadModelPath);
|
||||
if (load_compiled_model_ && !load_model_path_env.empty()) {
|
||||
load_compiled_path_ = load_model_path_env;
|
||||
}
|
||||
|
||||
// dump unsupported ops
|
||||
const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps);
|
||||
if (!dump_model_ops_env.empty()) {
|
||||
|
|
@ -171,10 +193,15 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
|
|||
<< "device_id: " << device_id_
|
||||
<< ", migraphx_fp16_enable: " << fp16_enable_
|
||||
<< ", migraphx_int8_enable: " << int8_enable_
|
||||
<< ", migraphx_int8_enable: " << int8_enable_
|
||||
<< ", dump_model_ops: " << dump_model_ops_
|
||||
<< ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_
|
||||
<< ", int8_calibration_cache_available: " << int8_calibration_cache_available_
|
||||
<< ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_;
|
||||
<< ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_
|
||||
<< ", migraphx_save_compiled_model: " << save_compiled_model_
|
||||
<< ", migraphx_save_compiled_model_path: " << save_compiled_path_
|
||||
<< ", migraphx_load_compiled_model: " << load_compiled_model_
|
||||
<< ", migraphx_load_compiled_model_path: " << load_compiled_path_;
|
||||
}
|
||||
|
||||
MIGraphXExecutionProvider::~MIGraphXExecutionProvider() {
|
||||
|
|
@ -265,7 +292,7 @@ static bool getMIGraphXType(ONNXTensorElementDataType type,
|
|||
break;
|
||||
default:
|
||||
LOGS_DEFAULT(WARNING) << "MiGraphx: unsupported data type " << type << ", fallback to CPU";
|
||||
LOGS_DEFAULT(WARNING) << "implementation" << std::endl;
|
||||
LOGS_DEFAULT(WARNING) << "implementation";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -1008,11 +1035,11 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v
|
|||
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
|
||||
} else { // unsupported_nodes_idx.empty()
|
||||
if (dump_model_ops_) {
|
||||
LOGS_DEFAULT(INFO) << "============= Unsupported nodes ====================" << std::endl;
|
||||
LOGS_DEFAULT(INFO) << "============= Unsupported nodes ====================";
|
||||
for (auto idx : unsupported_nodes) {
|
||||
LOGS_DEFAULT(INFO) << graph_viewer.GetNode(idx)->OpType() << std::endl;
|
||||
}
|
||||
LOGS_DEFAULT(INFO) << "************* Unsupported nodes ********************" << std::endl;
|
||||
LOGS_DEFAULT(INFO) << "************* Unsupported nodes ********************";
|
||||
}
|
||||
|
||||
if (unsupported_nodes.size() > 10) {
|
||||
|
|
@ -1087,6 +1114,34 @@ bool get_input_output_names(const GraphViewer& graph,
|
|||
return no_input_shape;
|
||||
}
|
||||
|
||||
// Attempt to load a model and catch any exceptions on load fail.
|
||||
// Useful to default to EP to trigger the compile if file doesn't exist or loading fails.
|
||||
bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::string path) {
|
||||
try {
|
||||
if (load_enable) {
|
||||
LOGS_DEFAULT(INFO) << "Attempting to load model at:" << path;
|
||||
prog = migraphx::load(path.c_str());
|
||||
LOGS_DEFAULT(INFO) << "load model : Success";
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} catch (...) {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void save_compiled_model(migraphx::program& prog, bool save_enable, std::string out_path) {
|
||||
if (save_enable) {
|
||||
LOGS_DEFAULT(INFO) << "Model Save at " << out_path << ": Begin" << std::endl;
|
||||
migraphx::file_options fo;
|
||||
fo.set_file_format("msgpack");
|
||||
migraphx::save(prog, out_path.c_str(), fo);
|
||||
LOGS_DEFAULT(INFO) << "Model Save: Complete" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
migraphx::onnx_options options;
|
||||
|
|
@ -1117,39 +1172,56 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
}
|
||||
|
||||
std::vector<std::string> input_names, output_names;
|
||||
no_input_shape = get_input_output_names(graph_body_viewer, input_names, output_names);
|
||||
no_input_shape = no_input_shape or get_input_output_names(graph_body_viewer, input_names, output_names);
|
||||
|
||||
// by parsing the model_proto, create a program corresponding to
|
||||
// the input fused_node
|
||||
migraphx::program prog;
|
||||
|
||||
if (!no_input_shape) {
|
||||
prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options);
|
||||
if (fp16_enable_) {
|
||||
migraphx::quantize_fp16(prog);
|
||||
}
|
||||
if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) {
|
||||
LOGS_DEFAULT(INFO) << "No Input shapes detected quantizing model";
|
||||
prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options);
|
||||
|
||||
// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
|
||||
if (int8_enable_ && int8_calibration_cache_available_) {
|
||||
migraphx::quantize_int8_options quant_opts;
|
||||
migraphx::program_parameters quant_params;
|
||||
// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
|
||||
if (int8_enable_ && int8_calibration_cache_available_) {
|
||||
LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl;
|
||||
migraphx::quantize_int8_options quant_opts;
|
||||
migraphx::program_parameters quant_params;
|
||||
|
||||
auto param_shapes = prog.get_parameter_shapes();
|
||||
auto param_shapes = prog.get_parameter_shapes();
|
||||
|
||||
for (auto&& name : param_shapes.names()) {
|
||||
auto dynamic_range_i = dynamic_range_map.find(name);
|
||||
if (dynamic_range_i != dynamic_range_map.end()) {
|
||||
quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second)));
|
||||
// Add all calibration data read in from int8 table
|
||||
for (auto& [cal_key, cal_val] : dynamic_range_map) {
|
||||
auto cal_val_shape = migraphx::shape(migraphx_shape_float_type);
|
||||
quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast<void*>(std::move(&cal_val))));
|
||||
}
|
||||
quant_opts.add_calibration_data(quant_params);
|
||||
|
||||
// specify thing we want to int8 quantize
|
||||
quant_opts.add_op_name("convolution");
|
||||
quant_opts.add_op_name("dot");
|
||||
|
||||
// perform static quantization on the programs
|
||||
migraphx::quantize_int8(prog, t_, quant_opts);
|
||||
LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete" << std::endl;
|
||||
}
|
||||
|
||||
quant_opts.add_calibration_data(quant_params);
|
||||
// perform static quantization on the programs
|
||||
migraphx::quantize_int8(prog, t_, quant_opts);
|
||||
if (fp16_enable_) {
|
||||
LOGS_DEFAULT(INFO) << "Quantizing input program to fp16" << std::endl;
|
||||
migraphx::quantize_fp16(prog);
|
||||
LOGS_DEFAULT(INFO) << "Quantizing input program to fp16: Complete" << std::endl;
|
||||
}
|
||||
|
||||
migraphx::compile_options co;
|
||||
co.set_fast_math(false);
|
||||
LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl;
|
||||
prog.compile(t_, co);
|
||||
LOGS_DEFAULT(INFO) << "Model Compile: Complete" << std::endl;
|
||||
|
||||
save_compiled_model(prog, save_compiled_model_, save_compiled_path_);
|
||||
}
|
||||
migraphx::compile_options co;
|
||||
co.set_fast_math(false);
|
||||
prog.compile(t_, co);
|
||||
|
||||
auto prog_output_shapes = prog.get_output_shapes();
|
||||
for (std::size_t i = 0; i < output_names.size(); ++i) {
|
||||
auto out_len = prog_output_shapes[i].lengths();
|
||||
|
|
@ -1169,7 +1241,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
*p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name],
|
||||
map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_,
|
||||
map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_,
|
||||
int8_calibration_cache_available_, dynamic_range_map, dump_model_ops_};
|
||||
int8_calibration_cache_available_, dynamic_range_map,
|
||||
save_compiled_model_, save_compiled_path_,
|
||||
load_compiled_model_, load_compiled_path_, dump_model_ops_};
|
||||
*state = p.release();
|
||||
return 0;
|
||||
};
|
||||
|
|
@ -1199,6 +1273,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
bool input_shape_match = true;
|
||||
migraphx::program_parameter_shapes param_shapes;
|
||||
if (no_input_shape) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again" << std::endl;
|
||||
for (auto& it : map_input_name_index) {
|
||||
auto& name = it.first;
|
||||
auto& index = it.second;
|
||||
|
|
@ -1210,6 +1285,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
input_shape_match = false;
|
||||
}
|
||||
} else {
|
||||
LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model" << std::endl;
|
||||
param_shapes = prog.get_parameter_shapes();
|
||||
auto prog_output_shapes = prog.get_output_shapes();
|
||||
|
||||
|
|
@ -1243,33 +1319,67 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
// input shapes are different, needs to re-parse onnx and
|
||||
// re-compile the program
|
||||
if (!input_shape_match) {
|
||||
prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options);
|
||||
if (fp16_enable) {
|
||||
migraphx::quantize_fp16(prog);
|
||||
}
|
||||
if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) {
|
||||
LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl;
|
||||
prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options);
|
||||
|
||||
// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
|
||||
if (int8_enable && int8_calibration_cache_available) {
|
||||
migraphx::quantize_int8_options quant_opts;
|
||||
migraphx::program_parameters quant_params;
|
||||
// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
|
||||
if (int8_enable && int8_calibration_cache_available) {
|
||||
LOGS_DEFAULT(INFO) << "Quantize Int8: Begin" << std::endl;
|
||||
migraphx::quantize_int8_options quant_opts;
|
||||
migraphx::program_parameters quant_params;
|
||||
|
||||
auto param_shapes = prog.get_parameter_shapes();
|
||||
auto param_shapes = prog.get_parameter_shapes();
|
||||
|
||||
for (auto&& name : param_shapes.names()) {
|
||||
auto dynamic_range_i = map_dynamic_range.find(name);
|
||||
if (dynamic_range_i != map_dynamic_range.end()) {
|
||||
quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second)));
|
||||
// Add input parameter data and the values they're set to
|
||||
for (auto&& name : param_shapes.names()) {
|
||||
if (map_input_name_index.count(name) > 0) {
|
||||
auto input_tensor = ctx.GetInput(map_input_name_index[name]);
|
||||
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
|
||||
const auto tensor_shape = tensor_info.GetShape();
|
||||
const auto tensor_type = tensor_info.GetElementType();
|
||||
|
||||
migraphx_shape_datatype_t mgx_type;
|
||||
getMIGraphXType(tensor_type, mgx_type);
|
||||
auto mgx_s = param_shapes[name];
|
||||
|
||||
if (mgx_type != mgx_s.type()) {
|
||||
LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch";
|
||||
}
|
||||
quant_params.add(name, migraphx::argument(param_shapes[name], const_cast<void*>(input_tensor.GetTensorRawData())));
|
||||
}
|
||||
}
|
||||
|
||||
// Add all calibration data read in from int8 table
|
||||
for (auto& [cal_key, cal_val] : map_dynamic_range) {
|
||||
auto cal_val_shape = migraphx::shape(migraphx_shape_float_type);
|
||||
quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast<void*>(std::move(&cal_val))));
|
||||
}
|
||||
quant_opts.add_calibration_data(quant_params);
|
||||
|
||||
// specify thing we want to int8 quantize
|
||||
quant_opts.add_op_name("convolution");
|
||||
quant_opts.add_op_name("dot");
|
||||
|
||||
// perform static quantization on the programs
|
||||
migraphx::quantize_int8(prog, t, quant_opts);
|
||||
LOGS_DEFAULT(INFO) << "Quantize Int8: Completed" << std::endl;
|
||||
}
|
||||
|
||||
quant_opts.add_calibration_data(quant_params);
|
||||
// perform static quantization on the programs
|
||||
migraphx::quantize_int8(prog, t, quant_opts);
|
||||
if (fp16_enable) {
|
||||
LOGS_DEFAULT(INFO) << "Quantize fp16: Begin" << std::endl;
|
||||
migraphx::quantize_fp16(prog);
|
||||
LOGS_DEFAULT(INFO) << "Quantize fp16: Completed" << std::endl;
|
||||
}
|
||||
|
||||
LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl;
|
||||
migraphx::compile_options co;
|
||||
co.set_fast_math(false);
|
||||
prog.compile(t, co);
|
||||
|
||||
save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path);
|
||||
}
|
||||
|
||||
migraphx::compile_options co;
|
||||
co.set_fast_math(false);
|
||||
prog.compile(t, co);
|
||||
mgx_state->prog = prog;
|
||||
param_shapes = prog.get_parameter_shapes();
|
||||
no_input_shape = false;
|
||||
|
|
@ -1281,6 +1391,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
if (param_shapes.size() > 0) {
|
||||
for (auto&& name : param_shapes.names()) {
|
||||
if (map_input_name_index.count(name) > 0) {
|
||||
LOGS_DEFAULT(INFO) << "Setting parameters for:" << name << std::endl;
|
||||
auto input_tensor = ctx.GetInput(map_input_name_index[name]);
|
||||
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
|
||||
const auto tensor_shape = tensor_info.GetShape();
|
||||
|
|
@ -1293,6 +1404,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
if (mgx_type != mgx_s.type()) {
|
||||
LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch";
|
||||
}
|
||||
|
||||
LOGS_DEFAULT(INFO) << "Writing Raw tensor data " << std::endl;
|
||||
m.add(name, migraphx::argument(param_shapes[name],
|
||||
const_cast<void*>(input_tensor.GetTensorRawData())));
|
||||
}
|
||||
|
|
@ -1353,7 +1466,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
HIP_CALL_THROW(hipMemcpy(output_data, gpu_res.data(), res_shape.bytes(), hipMemcpyDeviceToDevice));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
|
|
|
|||
|
|
@ -26,6 +26,11 @@ static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
|
|||
static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
|
||||
static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH";
|
||||
static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
|
||||
static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL";
|
||||
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH";
|
||||
static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL";
|
||||
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH";
|
||||
|
||||
}; // namespace migraphx_env_vars
|
||||
|
||||
// Information to construct kernel function state.
|
||||
|
|
@ -44,6 +49,10 @@ struct MIGraphXFuncState {
|
|||
bool int8_enable = false;
|
||||
bool int8_calibration_cache_available = false;
|
||||
std::unordered_map<std::string, float> dynamic_range_map;
|
||||
bool save_compiled_mode = false;
|
||||
std::string save_compiled_path;
|
||||
bool load_compiled_mode = false;
|
||||
std::string load_compiled_path;
|
||||
bool dump_model_ops = false;
|
||||
};
|
||||
|
||||
|
|
@ -84,6 +93,10 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
|
|||
bool int8_use_native_migraphx_calibration_table_ = false;
|
||||
std::string calibration_cache_path_;
|
||||
std::unordered_map<std::string, float> dynamic_range_map;
|
||||
bool save_compiled_model_ = false;
|
||||
std::string save_compiled_path_;
|
||||
bool load_compiled_model_ = false;
|
||||
std::string load_compiled_path_;
|
||||
bool dump_model_ops_ = false;
|
||||
int device_id_;
|
||||
migraphx::target t_;
|
||||
|
|
|
|||
|
|
@ -17,6 +17,10 @@ constexpr const char* kFp16Enable = "trt_fp16_enable";
|
|||
constexpr const char* kInt8Enable = "migx_int8_enable";
|
||||
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
|
||||
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";
|
||||
constexpr const char* kSaveCompiledModel = "migx_save_compiled_model";
|
||||
constexpr const char* kSaveModelPath = "migx_save_model_name";
|
||||
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
|
||||
constexpr const char* kLoadModelPath = "migx_load_model_name";
|
||||
|
||||
} // namespace provider_option_names
|
||||
} // namespace migraphx
|
||||
|
|
@ -39,6 +43,8 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
|
|||
})
|
||||
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
|
||||
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
|
||||
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
|
||||
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
|
||||
.Parse(options));
|
||||
|
||||
return info;
|
||||
|
|
@ -49,6 +55,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
|
|||
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
|
||||
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
|
||||
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
|
||||
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
|
||||
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
|
||||
};
|
||||
return options;
|
||||
}
|
||||
|
|
@ -58,6 +66,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
|
|||
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
|
||||
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
|
||||
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
|
||||
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
|
||||
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
|
||||
};
|
||||
return options;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,10 @@ struct MIGraphXExecutionProviderInfo {
|
|||
bool int8_enable{false};
|
||||
std::string int8_calibration_table_name{""};
|
||||
bool int8_use_native_calibration_table{false};
|
||||
bool save_compiled_model{true};
|
||||
std::string save_model_file{"./compiled_model.mxr"};
|
||||
bool load_compiled_model{true};
|
||||
std::string load_model_file{"./compiled_model.mxr"};
|
||||
|
||||
static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
|
||||
static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info);
|
||||
|
|
|
|||
|
|
@ -53,6 +53,16 @@ struct MIGraphX_Provider : Provider {
|
|||
info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name;
|
||||
}
|
||||
info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0;
|
||||
info.save_compiled_model = options.migraphx_save_compiled_model;
|
||||
info.save_model_file = "";
|
||||
if (options.migraphx_save_model_path != nullptr) {
|
||||
info.save_model_file = options.migraphx_save_model_path;
|
||||
}
|
||||
info.load_compiled_model = options.migraphx_load_compiled_model;
|
||||
info.load_model_file = "";
|
||||
if (options.migraphx_load_model_path != nullptr) {
|
||||
info.load_model_file = options.migraphx_load_model_path;
|
||||
}
|
||||
return std::make_shared<MIGraphXProviderFactory>(info);
|
||||
}
|
||||
|
||||
|
|
@ -79,6 +89,11 @@ struct MIGraphX_Provider : Provider {
|
|||
}
|
||||
|
||||
migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table;
|
||||
|
||||
migx_options.migraphx_save_compiled_model = internal_options.save_compiled_model;
|
||||
migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str();
|
||||
migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model;
|
||||
migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str();
|
||||
}
|
||||
|
||||
ProviderOptions GetProviderOptions(const void* provider_options) override {
|
||||
|
|
|
|||
|
|
@ -823,6 +823,8 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|||
} else if (type == kMIGraphXExecutionProvider) {
|
||||
#ifdef USE_MIGRAPHX
|
||||
std::string calibration_table;
|
||||
std::string save_model_path;
|
||||
std::string load_model_path;
|
||||
auto it = provider_options_map.find(type);
|
||||
if (it != provider_options_map.end()) {
|
||||
OrtMIGraphXProviderOptions params{
|
||||
|
|
@ -830,7 +832,11 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|||
0,
|
||||
0,
|
||||
0,
|
||||
nullptr};
|
||||
nullptr,
|
||||
1,
|
||||
"./compiled_model.mxr",
|
||||
1,
|
||||
"./compiled_model.mxr"};
|
||||
for (auto option : it->second) {
|
||||
if (option.first == "device_id") {
|
||||
if (!option.second.empty()) {
|
||||
|
|
@ -877,6 +883,44 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|||
"[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be"
|
||||
" 'True' or 'False'. Default value is 'False'.\n");
|
||||
}
|
||||
} else if (option.first == "migraphx_save_compiled_model") {
|
||||
if (option.second == "True" || option.second == "true") {
|
||||
params.migraphx_fp16_enable = true;
|
||||
} else if (option.second == "False" || option.second == "false") {
|
||||
params.migraphx_fp16_enable = false;
|
||||
} else {
|
||||
ORT_THROW(
|
||||
"[ERROR] [MIGraphX] The value for the key 'migx_save_compiled_model' should be"
|
||||
" 'True' or 'False'. Default value is 'False'.\n");
|
||||
}
|
||||
} else if (option.first == "migraphx_save_model_path") {
|
||||
if (!option.second.empty()) {
|
||||
save_model_path = option.second;
|
||||
params.migraphx_save_model_path = save_model_path.c_str();
|
||||
} else {
|
||||
ORT_THROW(
|
||||
"[ERROR] [MIGraphX] The value for the key 'migx_save_model_name' should be a "
|
||||
"file name i.e. 'compiled_model.mxr'.\n");
|
||||
}
|
||||
} else if (option.first == "migraphx_load_compiled_model") {
|
||||
if (option.second == "True" || option.second == "true") {
|
||||
params.migraphx_fp16_enable = true;
|
||||
} else if (option.second == "False" || option.second == "false") {
|
||||
params.migraphx_fp16_enable = false;
|
||||
} else {
|
||||
ORT_THROW(
|
||||
"[ERROR] [MIGraphX] The value for the key 'migx_load_compiled_model' should be"
|
||||
" 'True' or 'False'. Default value is 'False'.\n");
|
||||
}
|
||||
} else if (option.first == "migraphx_load_model_path") {
|
||||
if (!option.second.empty()) {
|
||||
load_model_path = option.second;
|
||||
params.migraphx_load_model_path = load_model_path.c_str();
|
||||
} else {
|
||||
ORT_THROW(
|
||||
"[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a "
|
||||
"file name i.e. 'compiled_model.mxr'.\n");
|
||||
}
|
||||
} else {
|
||||
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -76,7 +76,11 @@ std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
|
|||
0,
|
||||
0,
|
||||
0,
|
||||
nullptr};
|
||||
nullptr,
|
||||
1,
|
||||
"./compiled_model.mxr",
|
||||
1,
|
||||
"./compiled_model.mxr"};
|
||||
return MIGraphXProviderFactoryCreator::Create(¶ms)->CreateProvider();
|
||||
#else
|
||||
return nullptr;
|
||||
|
|
|
|||
Loading…
Reference in a new issue