From efeb6672d6492d53882ab265d4016aafae9e2233 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Wed, 28 Jun 2023 11:35:57 -0700 Subject: [PATCH] Temporary optimizer support for ort format models in non minimal build (#16485) --- .../core/session/onnxruntime_cxx_api.h | 2 +- .../include/onnxruntime_training_c_api.h | 1 + .../orttraining/training_api/optimizer.cc | 36 +++++++++++-------- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 69cec42895..2782893a8e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1385,7 +1385,7 @@ struct Value : detail::ValueImpl { * \param value - the value to be wrapped. */ template - static Value CreateOpaque(const char* domain, const char* type_name, const T&); ///< Wraps OrtApi::CreateOpaqueValue + static Value CreateOpaque(const char* domain, const char* type_name, const T& value); ///< Wraps OrtApi::CreateOpaqueValue #if !defined(DISABLE_SPARSE_TENSORS) /// diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index 71cdeebeb2..b3042c449a 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -638,6 +638,7 @@ struct OrtTrainingApi { * As a result, it is required that the checkpoint state outlive the lifetime of the training session. * * \param[in] checkpoint_buffer Path to the checkpoint bytes buffer. + * \param[in] num_bytes Number of bytes in the checkpoint buffer. * \param[out] checkpoint_state Checkpoint state that contains the states of the training session. * * \snippet{doc} snippets.dox OrtStatus Return Value diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc index 66ae991caa..a6954414a8 100644 --- a/orttraining/orttraining/training_api/optimizer.cc +++ b/orttraining/orttraining/training_api/optimizer.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "orttraining/training_api/optimizer.h" +#include "core/flatbuffers/flatbuffers_utils.h" #include "core/framework/execution_provider.h" #include "core/framework/TensorSeq.h" #include "core/providers/cpu/cpu_execution_provider.h" @@ -60,29 +61,36 @@ Status GraphInputsAreExpected(gsl::span actual_graph_inputs, } // namespace std::unique_ptr OptimizerAlorithmFactory::CreateInstance( - const std::string& optim_path_or_bytes, int32_t& group_count) { + const std::string& optim_path, int32_t& group_count) { std::map, int32_t> opt_type_to_freq_map; #if !defined(ORT_MINIMAL_BUILD) - std::shared_ptr model; - ORT_ENFORCE(Model::Load(ToWideString(optim_path_or_bytes), model, nullptr, - logging::LoggingManager::DefaultLogger()) - .IsOK()); - Graph& graph = model->MainGraph(); - for (auto& node : graph.Nodes()) { - if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) { - auto domain_type_pair = std::make_pair(node.Domain(), node.OpType()); - if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) { - opt_type_to_freq_map[domain_type_pair] = 0; - } + if (const auto optim_path_str = ToPathString(optim_path); + fbs::utils::IsOrtFormatModel(optim_path_str)) { + // TODO (baijumeswani): Figure out the best way to extract the optimizer type + // from an ort format model. + opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; + } else { + std::shared_ptr model; + ORT_ENFORCE(Model::Load(optim_path_str, model, nullptr, + logging::LoggingManager::DefaultLogger()) + .IsOK()); + Graph& graph = model->MainGraph(); + for (auto& node : graph.Nodes()) { + if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) { + auto domain_type_pair = std::make_pair(node.Domain(), node.OpType()); + if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) { + opt_type_to_freq_map[domain_type_pair] = 0; + } - opt_type_to_freq_map[domain_type_pair] += 1; + opt_type_to_freq_map[domain_type_pair] += 1; + } } } #else // TODO (baijumeswani): Figure out the best way to extract the optimizer type // from the model (either onnx model or ort format model) or from the checkpoint. // For now, assume that the optimizer type is AdamWOptimizer in a minimal build. - ORT_UNUSED_PARAMETER(optim_path_or_bytes); + ORT_UNUSED_PARAMETER(optim_path); opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; #endif