mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
[MIGraphX EP] Set External Data Path (#21598)
### Description <!-- Describe your changes. --> Changes to add in Set external data path for model weight files. Additional fixes to ensure this compiles off the latest v1.19 Onnxruntime ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Separate weights used for larger models (like stable diffusion) is motivation for this change set --------- Co-authored-by: Jeff Daily <jeff.daily@amd.com> Co-authored-by: Artur Wojcik <artur.wojcik@amd.com> Co-authored-by: Ted Themistokleous <tedthemistokleous@amd.com>
This commit is contained in:
parent
54d6614ad6
commit
45b7c41ef0
2 changed files with 8 additions and 3 deletions
|
|
@ -5,6 +5,7 @@
|
|||
#include <iterator>
|
||||
#include <unordered_map>
|
||||
#include <set>
|
||||
#include <filesystem>
|
||||
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#define ORT_API_MANUAL_INIT
|
||||
|
|
@ -990,6 +991,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v
|
|||
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
|
||||
std::string onnx_string_buffer;
|
||||
model_proto->SerializeToString(onnx_string_buffer);
|
||||
model_path_ = graph_viewer.ModelPath();
|
||||
|
||||
// dump onnx file if environment var is set
|
||||
if (dump_model_ops_) {
|
||||
|
|
@ -1168,7 +1170,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
auto param_shapes = prog.get_parameter_shapes();
|
||||
|
||||
// Add all calibration data read in from int8 table
|
||||
for (auto& [cal_key, cal_val] : dynamic_range_map) {
|
||||
for (auto& [cal_key, cal_val] : dynamic_range_map_) {
|
||||
auto cal_val_shape = migraphx::shape(migraphx_shape_float_type);
|
||||
quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast<void*>(std::move(&cal_val))));
|
||||
}
|
||||
|
|
@ -1217,7 +1219,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
*p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name],
|
||||
map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_,
|
||||
map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_,
|
||||
int8_calibration_cache_available_, dynamic_range_map,
|
||||
int8_calibration_cache_available_, dynamic_range_map_,
|
||||
save_compiled_model_, save_compiled_path_,
|
||||
load_compiled_model_, load_compiled_path_, dump_model_ops_};
|
||||
*state = p.release();
|
||||
|
|
@ -1297,6 +1299,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
if (!input_shape_match) {
|
||||
if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) {
|
||||
LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl;
|
||||
cmp_options.set_external_data_path(model_path_.has_parent_path() ? model_path_.parent_path().string() : std::filesystem::current_path().string());
|
||||
prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options);
|
||||
|
||||
// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <filesystem>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -91,7 +92,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
|
|||
bool int8_calibration_cache_available_ = false;
|
||||
bool int8_use_native_migraphx_calibration_table_ = false;
|
||||
std::string calibration_cache_path_;
|
||||
std::unordered_map<std::string, float> dynamic_range_map;
|
||||
std::unordered_map<std::string, float> dynamic_range_map_;
|
||||
bool save_compiled_model_ = false;
|
||||
std::string save_compiled_path_;
|
||||
bool load_compiled_model_ = false;
|
||||
|
|
@ -100,6 +101,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
|
|||
migraphx::target t_;
|
||||
OrtMutex mgx_mu_;
|
||||
hipStream_t stream_ = nullptr;
|
||||
mutable std::filesystem::path model_path_;
|
||||
|
||||
std::unordered_map<std::string, migraphx::program> map_progs_;
|
||||
std::unordered_map<std::string, std::string> map_onnx_string_;
|
||||
|
|
|
|||
Loading…
Reference in a new issue