Temporary optimizer support for ort format models in non minimal build (#16485)

This commit is contained in:
Baiju Meswani 2023-06-28 11:35:57 -07:00 committed by GitHub
parent 960e320dff
commit efeb6672d6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 15 deletions

View file

@ -1385,7 +1385,7 @@ struct Value : detail::ValueImpl<OrtValue> {
* \param value - the value to be wrapped.
*/
template <typename T>
static Value CreateOpaque(const char* domain, const char* type_name, const T&); ///< Wraps OrtApi::CreateOpaqueValue
static Value CreateOpaque(const char* domain, const char* type_name, const T& value); ///< Wraps OrtApi::CreateOpaqueValue
#if !defined(DISABLE_SPARSE_TENSORS)
/// <summary>

View file

@ -638,6 +638,7 @@ struct OrtTrainingApi {
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
*
* \param[in] checkpoint_buffer Path to the checkpoint bytes buffer.
* \param[in] num_bytes Number of bytes in the checkpoint buffer.
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "orttraining/training_api/optimizer.h"
#include "core/flatbuffers/flatbuffers_utils.h"
#include "core/framework/execution_provider.h"
#include "core/framework/TensorSeq.h"
#include "core/providers/cpu/cpu_execution_provider.h"
@ -60,29 +61,36 @@ Status GraphInputsAreExpected(gsl::span<std::string> actual_graph_inputs,
} // namespace
std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance(
const std::string& optim_path_or_bytes, int32_t& group_count) {
const std::string& optim_path, int32_t& group_count) {
std::map<std::pair<std::string, std::string>, int32_t> opt_type_to_freq_map;
#if !defined(ORT_MINIMAL_BUILD)
std::shared_ptr<Model> model;
ORT_ENFORCE(Model::Load(ToWideString(optim_path_or_bytes), model, nullptr,
logging::LoggingManager::DefaultLogger())
.IsOK());
Graph& graph = model->MainGraph();
for (auto& node : graph.Nodes()) {
if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) {
auto domain_type_pair = std::make_pair(node.Domain(), node.OpType());
if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) {
opt_type_to_freq_map[domain_type_pair] = 0;
}
if (const auto optim_path_str = ToPathString(optim_path);
fbs::utils::IsOrtFormatModel(optim_path_str)) {
// TODO (baijumeswani): Figure out the best way to extract the optimizer type
// from an ort format model.
opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1;
} else {
std::shared_ptr<Model> model;
ORT_ENFORCE(Model::Load(optim_path_str, model, nullptr,
logging::LoggingManager::DefaultLogger())
.IsOK());
Graph& graph = model->MainGraph();
for (auto& node : graph.Nodes()) {
if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) {
auto domain_type_pair = std::make_pair(node.Domain(), node.OpType());
if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) {
opt_type_to_freq_map[domain_type_pair] = 0;
}
opt_type_to_freq_map[domain_type_pair] += 1;
opt_type_to_freq_map[domain_type_pair] += 1;
}
}
}
#else
// TODO (baijumeswani): Figure out the best way to extract the optimizer type
// from the model (either onnx model or ort format model) or from the checkpoint.
// For now, assume that the optimizer type is AdamWOptimizer in a minimal build.
ORT_UNUSED_PARAMETER(optim_path_or_bytes);
ORT_UNUSED_PARAMETER(optim_path);
opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1;
#endif