mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Temporary optimizer support for ort format models in non minimal build (#16485)
This commit is contained in:
parent
960e320dff
commit
efeb6672d6
3 changed files with 24 additions and 15 deletions
|
|
@ -1385,7 +1385,7 @@ struct Value : detail::ValueImpl<OrtValue> {
|
|||
* \param value - the value to be wrapped.
|
||||
*/
|
||||
template <typename T>
|
||||
static Value CreateOpaque(const char* domain, const char* type_name, const T&); ///< Wraps OrtApi::CreateOpaqueValue
|
||||
static Value CreateOpaque(const char* domain, const char* type_name, const T& value); ///< Wraps OrtApi::CreateOpaqueValue
|
||||
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
/// <summary>
|
||||
|
|
|
|||
|
|
@ -638,6 +638,7 @@ struct OrtTrainingApi {
|
|||
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
|
||||
*
|
||||
* \param[in] checkpoint_buffer Path to the checkpoint bytes buffer.
|
||||
* \param[in] num_bytes Number of bytes in the checkpoint buffer.
|
||||
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_api/optimizer.h"
|
||||
#include "core/flatbuffers/flatbuffers_utils.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "core/framework/TensorSeq.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
|
|
@ -60,29 +61,36 @@ Status GraphInputsAreExpected(gsl::span<std::string> actual_graph_inputs,
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance(
|
||||
const std::string& optim_path_or_bytes, int32_t& group_count) {
|
||||
const std::string& optim_path, int32_t& group_count) {
|
||||
std::map<std::pair<std::string, std::string>, int32_t> opt_type_to_freq_map;
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
std::shared_ptr<Model> model;
|
||||
ORT_ENFORCE(Model::Load(ToWideString(optim_path_or_bytes), model, nullptr,
|
||||
logging::LoggingManager::DefaultLogger())
|
||||
.IsOK());
|
||||
Graph& graph = model->MainGraph();
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) {
|
||||
auto domain_type_pair = std::make_pair(node.Domain(), node.OpType());
|
||||
if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) {
|
||||
opt_type_to_freq_map[domain_type_pair] = 0;
|
||||
}
|
||||
if (const auto optim_path_str = ToPathString(optim_path);
|
||||
fbs::utils::IsOrtFormatModel(optim_path_str)) {
|
||||
// TODO (baijumeswani): Figure out the best way to extract the optimizer type
|
||||
// from an ort format model.
|
||||
opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1;
|
||||
} else {
|
||||
std::shared_ptr<Model> model;
|
||||
ORT_ENFORCE(Model::Load(optim_path_str, model, nullptr,
|
||||
logging::LoggingManager::DefaultLogger())
|
||||
.IsOK());
|
||||
Graph& graph = model->MainGraph();
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) {
|
||||
auto domain_type_pair = std::make_pair(node.Domain(), node.OpType());
|
||||
if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) {
|
||||
opt_type_to_freq_map[domain_type_pair] = 0;
|
||||
}
|
||||
|
||||
opt_type_to_freq_map[domain_type_pair] += 1;
|
||||
opt_type_to_freq_map[domain_type_pair] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
// TODO (baijumeswani): Figure out the best way to extract the optimizer type
|
||||
// from the model (either onnx model or ort format model) or from the checkpoint.
|
||||
// For now, assume that the optimizer type is AdamWOptimizer in a minimal build.
|
||||
ORT_UNUSED_PARAMETER(optim_path_or_bytes);
|
||||
ORT_UNUSED_PARAMETER(optim_path);
|
||||
|
||||
opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1;
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue