mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
fix win build errors (on device training) (#11844)
* fix win build errors * fix linux build * fix typo * minor fix * fix win in c api * fix linux build complaining bool * fix ORT_RETURN_ON_ERROR
This commit is contained in:
parent
fac8dae9df
commit
17a8ecee6f
7 changed files with 78 additions and 58 deletions
|
|
@ -53,9 +53,9 @@ void EnforceCheck(bool run_ret, std::string err_msg) {
|
|||
if (onnx_status != NULL) { \
|
||||
auto code = g_ort_api->GetErrorCode(onnx_status); \
|
||||
const char* msg = g_ort_api->GetErrorMessage(onnx_status); \
|
||||
g_ort_api->ReleaseStatus(onnx_status); \
|
||||
printf("Run failed with error code :%d\n", code); \
|
||||
printf("Error message :%s\n", msg); \
|
||||
g_ort_api->ReleaseStatus(onnx_status); \
|
||||
return -1; \
|
||||
} \
|
||||
} while (0);
|
||||
|
|
@ -221,9 +221,9 @@ int RunTraining(const TestRunnerParameters& params) {
|
|||
InitSyntheticDataLoader(data_loader, params, num_of_batches_per_epoch);
|
||||
|
||||
// TODO(baiju): Add C API for LRScheduler
|
||||
//int64_t total_step_count = params.num_train_epochs * num_of_batches_per_epoch;
|
||||
//int64_t warmup_step_count = total_step_count / 3;
|
||||
//Ort::OrtLinearLRScheduler scheduler = Ort::OrtLinearLRScheduler(optimizer, warmup_step_count, total_step_count);
|
||||
// int64_t total_step_count = params.num_train_epochs * num_of_batches_per_epoch;
|
||||
// int64_t warmup_step_count = total_step_count / 3;
|
||||
// Ort::OrtLinearLRScheduler scheduler = Ort::OrtLinearLRScheduler(optimizer, warmup_step_count, total_step_count);
|
||||
|
||||
std::cout << "Initialization completed. Now starting training loop." << std::endl;
|
||||
const int64_t stabilized_perf_start_step = 0;
|
||||
|
|
@ -273,7 +273,7 @@ int RunTraining(const TestRunnerParameters& params) {
|
|||
#endif
|
||||
|
||||
// Update learning rate.
|
||||
//EnforceCheck(scheduler.Step(), "Failed during shceduler.Step()");
|
||||
// EnforceCheck(scheduler.Step(), "Failed during shceduler.Step()");
|
||||
|
||||
#if defined(USE_CUDA) && defined(ENABLE_NVTX_PROFILE)
|
||||
onnxruntime::profile::NvtxRangeCreator resetgrad_range(
|
||||
|
|
|
|||
|
|
@ -22,13 +22,13 @@ namespace api {
|
|||
|
||||
namespace {
|
||||
|
||||
constexpr const char* k_tensor_proto_file_name = "tensors.pbseq";
|
||||
constexpr const char* k_tensor_proto_properties_file_name = "properties.pbseq";
|
||||
constexpr const char* k_trainable_param_root_prefix = "paramtrain";
|
||||
constexpr const char* k_non_trainable_param_root_prefix = "paramfrozen";
|
||||
constexpr const char* k_optimizer_root_prefix = "optim";
|
||||
constexpr const char* k_property_root_prefix = "custom";
|
||||
constexpr const char* k_name_seperator = "_";
|
||||
const PathString k_tensor_proto_file_name = ORT_TSTR("tensors.pbseq");
|
||||
const PathString k_tensor_proto_properties_file_name = ORT_TSTR("properties.pbseq");
|
||||
const PathString k_trainable_param_root_prefix = ORT_TSTR("paramtrain");
|
||||
const PathString k_non_trainable_param_root_prefix = ORT_TSTR("paramfrozen");
|
||||
const PathString k_optimizer_root_prefix = ORT_TSTR("optim");
|
||||
const PathString k_property_root_prefix = ORT_TSTR("custom");
|
||||
const PathString k_name_separator = ORT_TSTR("_");
|
||||
|
||||
const std::string builtin_lr_property_name("builtin.initial_learning_rate");
|
||||
const std::string builtin_step_property_name("builtin.step");
|
||||
|
|
@ -58,8 +58,8 @@ Status CreateTensorProtosFromOrtValues(
|
|||
|
||||
saved_tensor_protos.reserve(ordered_tensor_names.size());
|
||||
|
||||
unsigned long total_bytes = 0;
|
||||
constexpr unsigned long PROTOBUF_UPPER_LIMIT = 2 * 1000 * 1000 * 1000;
|
||||
uint64_t total_bytes = 0;
|
||||
constexpr uint64_t PROTOBUF_UPPER_LIMIT = 2 * 1000 * 1000 * 1000;
|
||||
for (const auto& tensor_name : ordered_tensor_names) {
|
||||
const OrtValue& ort_value = name_to_ort_value.at(tensor_name);
|
||||
ORT_RETURN_IF_NOT(ort_value.IsTensor(), "ort_value.IsTensor() was false");
|
||||
|
|
@ -68,7 +68,7 @@ Status CreateTensorProtosFromOrtValues(
|
|||
|
||||
// Currently large model size not considered, so exception thrown here
|
||||
// when protobuf upper limit hit.
|
||||
total_bytes += src_tensor.SizeInBytes();
|
||||
total_bytes += static_cast<uint64_t>(src_tensor.SizeInBytes());
|
||||
if (total_bytes >= PROTOBUF_UPPER_LIMIT) {
|
||||
ORT_THROW("checkpoint file size hit upper limit.");
|
||||
}
|
||||
|
|
@ -95,23 +95,33 @@ Status CreateTensorProtosFromOrtValues(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
PathString GetTensorProtoFilePath(const PathString& checkpoint_directory, const std::string& filename_prefix) {
|
||||
return ConcatPathComponent<PathChar>(checkpoint_directory, ORT_TSTR(filename_prefix + k_name_seperator) + k_tensor_proto_file_name);
|
||||
PathString GetTensorProtoFilePath(const PathString& checkpoint_directory, const PathString& filename_prefix) {
|
||||
std::basic_ostringstream<PathChar> oss;
|
||||
oss << filename_prefix << k_name_separator << k_tensor_proto_file_name;
|
||||
return ConcatPathComponent<PathChar>(checkpoint_directory, oss.str());
|
||||
}
|
||||
|
||||
PathString GetTensorProtoPropertiesFilePath(const PathString& checkpoint_directory, const std::string& filename_prefix) {
|
||||
return ConcatPathComponent<PathChar>(checkpoint_directory, ORT_TSTR(filename_prefix + k_name_seperator) + k_tensor_proto_properties_file_name);
|
||||
PathString GetTensorProtoPropertiesFilePath(
|
||||
const PathString& checkpoint_directory, const PathString& filename_prefix) {
|
||||
std::basic_ostringstream<PathChar> oss;
|
||||
oss << filename_prefix << k_name_separator << k_tensor_proto_properties_file_name;
|
||||
return ConcatPathComponent<PathChar>(checkpoint_directory, oss.str());
|
||||
}
|
||||
|
||||
std::string StringConcat(const std::string& s_a, const std::string& s_b, const std::string& del = k_name_seperator) {
|
||||
return s_a + del + s_b;
|
||||
PathString StringConcat(
|
||||
const PathString& s_a, const PathString& s_b,
|
||||
const PathString& del = k_name_separator) {
|
||||
std::basic_ostringstream<PathChar> oss;
|
||||
oss << s_a << del << s_b;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
void StringSplit(const std::string& s, std::vector<std::string>& results, const std::string& del = k_name_seperator) {
|
||||
void StringSplit(const PathString& s, std::vector<PathString>& results,
|
||||
const PathString& del = k_name_separator) {
|
||||
ORT_ENFORCE(!s.empty(), "String to split is empty");
|
||||
int start = 0;
|
||||
int end = s.find(del);
|
||||
while (end != -1) {
|
||||
size_t start = 0;
|
||||
size_t end = s.find(del);
|
||||
while (end != std::string::npos) {
|
||||
results.push_back(s.substr(start, end - start));
|
||||
start = end + del.size();
|
||||
end = s.find(del, start);
|
||||
|
|
@ -119,11 +129,11 @@ void StringSplit(const std::string& s, std::vector<std::string>& results, const
|
|||
results.push_back(s.substr(start, end - start));
|
||||
}
|
||||
|
||||
bool StringStartsWith(std::string const& s, std::string const& p) {
|
||||
bool StringStartsWith(PathString const& s, PathString const& p) {
|
||||
return s.rfind(p, 0) == 0;
|
||||
}
|
||||
|
||||
bool StringEndsWith(std::string const& s, std::string const& p) {
|
||||
bool StringEndsWith(PathString const& s, PathString const& p) {
|
||||
if (p.size() > s.size()) return false;
|
||||
return std::equal(p.rbegin(), p.rend(), s.rbegin());
|
||||
}
|
||||
|
|
@ -159,12 +169,11 @@ void LoadTensorProtoFromFile(const PathString& file_path,
|
|||
template <typename Func>
|
||||
void FilterFilesFromDirectory(const PathString& folder_path, Func func) {
|
||||
LoopDir(folder_path, [&func](const PathChar* filename, OrtFileType file_type) -> bool {
|
||||
std::string filename_str = filename;
|
||||
if (filename_str[0] == '.' || file_type == OrtFileType::TYPE_DIR) {
|
||||
if (filename[0] == '.' || file_type == OrtFileType::TYPE_DIR) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return func(filename_str);
|
||||
return func(filename);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -222,7 +231,8 @@ Status OrtSaveModuleStatesInternal(ModuleCheckpointState& module_state,
|
|||
ORT_ENFORCE(module_state.train_session_data_transfer_mgr,
|
||||
"module checkpoint state has null train_session_data_transfer_mgr.");
|
||||
|
||||
std::unordered_map<std::string, std::unordered_map<std::string, OrtValue>> parameter_ort_values;
|
||||
std::unordered_map<PathString, std::unordered_map<std::string, OrtValue>>
|
||||
parameter_ort_values;
|
||||
parameter_ort_values[k_trainable_param_root_prefix] = {};
|
||||
parameter_ort_values[k_non_trainable_param_root_prefix] = {};
|
||||
for (auto it = param_states.begin(); it != param_states.end(); ++it) {
|
||||
|
|
@ -262,9 +272,10 @@ Status OrtSaveOptimizerStatesInternal(OptimizerCheckpointState& optimizer_state,
|
|||
|
||||
// Write optimizer state tensors files.
|
||||
for (auto& group_named_optimizer_state : optimizer_state.group_named_optimizer_states) {
|
||||
const std::string& group_name = group_named_optimizer_state.first;
|
||||
const PathString group_name = ToWideString(group_named_optimizer_state.first);
|
||||
const std::shared_ptr<GroupOptimizerState>& group_optimizer_state_ptr = group_named_optimizer_state.second;
|
||||
const std::string& cur_group_filename_prefix = StringConcat(k_optimizer_root_prefix, group_name);
|
||||
const PathString& cur_group_filename_prefix =
|
||||
StringConcat(k_optimizer_root_prefix, group_name);
|
||||
|
||||
// Re-organize optimizer_state_ort_values mapping
|
||||
// Firstly indexed by momentum names; Secondly indexed by parameter names.
|
||||
|
|
@ -291,9 +302,10 @@ Status OrtSaveOptimizerStatesInternal(OptimizerCheckpointState& optimizer_state,
|
|||
// Save each optimizer state (of all parameters) into single file.
|
||||
// For example: save "momentum_1" of all parameters into one file.
|
||||
for (auto& pair : optimizer_state_ort_values) {
|
||||
const auto& momentum_name = pair.first;
|
||||
const PathString momentum_name = ToWideString(pair.first);
|
||||
const std::unordered_map<std::string, OrtValue>& param_name_to_ortvalue = pair.second;
|
||||
const std::string& cur_state_filename_prefix = StringConcat(cur_group_filename_prefix, momentum_name);
|
||||
const PathString& cur_state_filename_prefix =
|
||||
StringConcat(cur_group_filename_prefix, momentum_name);
|
||||
|
||||
std::vector<ONNX_NAMESPACE::TensorProto> saved_tensor_protos;
|
||||
ORT_RETURN_IF_ERROR(CreateTensorProtosFromOrtValues(
|
||||
|
|
@ -353,10 +365,11 @@ Status OrtSaveInternal(
|
|||
Status OrtLoadModuleStatesInternal(
|
||||
const PathString& parameter_folder_path, ModuleCheckpointState& module_state) {
|
||||
// Find parameter files.
|
||||
std::vector<std::pair<std::string, bool>> param_filenames;
|
||||
std::vector<std::pair<PathString, bool>> param_filenames;
|
||||
FilterFilesFromDirectory(
|
||||
parameter_folder_path,
|
||||
[¶m_filenames](const std::string& filename_str) -> bool {
|
||||
[¶m_filenames](const PathChar* filename) -> bool {
|
||||
PathString filename_str = filename;
|
||||
if (StringStartsWith(filename_str, k_trainable_param_root_prefix)) {
|
||||
param_filenames.push_back(std::make_pair(filename_str, true));
|
||||
} else if (StringStartsWith(filename_str, k_non_trainable_param_root_prefix)) {
|
||||
|
|
@ -397,11 +410,12 @@ Status OrtLoadModuleStatesInternal(
|
|||
Status OrtLoadOptimizerStatesInternal(const PathString& optimizer_folder_path,
|
||||
OptimizerCheckpointState& optimizer_state) {
|
||||
// Optimizer states parsing.
|
||||
std::vector<std::string> optim_state_filenames;
|
||||
std::vector<std::string> optim_property_filenames;
|
||||
std::vector<PathString> optim_state_filenames;
|
||||
std::vector<PathString> optim_property_filenames;
|
||||
FilterFilesFromDirectory(
|
||||
optimizer_folder_path,
|
||||
[&optim_state_filenames, &optim_property_filenames](const std::string& filename_str) -> bool {
|
||||
[&optim_state_filenames, &optim_property_filenames](const PathChar* filename) -> bool {
|
||||
PathString filename_str = filename;
|
||||
if (StringStartsWith(filename_str, k_optimizer_root_prefix)) {
|
||||
if (StringEndsWith(filename_str, k_tensor_proto_file_name)) {
|
||||
optim_state_filenames.push_back(filename_str);
|
||||
|
|
@ -417,13 +431,15 @@ Status OrtLoadOptimizerStatesInternal(const PathString& optimizer_folder_path,
|
|||
auto& grouped_optimizer_states = optimizer_state.group_named_optimizer_states;
|
||||
// For each optimizer state files, parse the data and feed into grouped_optimizer_states.
|
||||
for (auto& filename : optim_state_filenames) {
|
||||
std::vector<std::string> results;
|
||||
std::vector<PathString> results;
|
||||
StringSplit(filename, results);
|
||||
const std::string& group_name = results[1];
|
||||
const std::string& momentum_name = results[2];
|
||||
const std::string& cur_group_filename_prefix = StringConcat(k_optimizer_root_prefix, group_name);
|
||||
std::string cur_momentum_state_filename_prefix = StringConcat(cur_group_filename_prefix, momentum_name);
|
||||
const std::string& group_name = ToUTF8String(results[1]);
|
||||
const std::string& momentum_name = ToUTF8String(results[2]);
|
||||
|
||||
const PathString cur_group_filename_prefix =
|
||||
StringConcat(k_optimizer_root_prefix, results[1]);
|
||||
PathString cur_momentum_state_filename_prefix =
|
||||
StringConcat(cur_group_filename_prefix, results[2]);
|
||||
ORT_ENFORCE(filename.compare(StringConcat(cur_momentum_state_filename_prefix, k_tensor_proto_file_name)) == 0);
|
||||
|
||||
if (grouped_optimizer_states.find(group_name) == grouped_optimizer_states.end()) {
|
||||
|
|
@ -452,9 +468,9 @@ Status OrtLoadOptimizerStatesInternal(const PathString& optimizer_folder_path,
|
|||
|
||||
// For each optimizer properties files, parse the data and feed into grouped_optimizer_states.
|
||||
for (auto& filename : optim_property_filenames) {
|
||||
std::vector<std::string> results;
|
||||
std::vector<PathString> results;
|
||||
StringSplit(filename, results);
|
||||
const std::string& group_name = results[1];
|
||||
const std::string& group_name = ToUTF8String(results[1]);
|
||||
|
||||
if (grouped_optimizer_states.find(group_name) == grouped_optimizer_states.end()) {
|
||||
grouped_optimizer_states.insert({group_name, std::make_shared<GroupOptimizerState>()});
|
||||
|
|
@ -463,7 +479,7 @@ Status OrtLoadOptimizerStatesInternal(const PathString& optimizer_folder_path,
|
|||
auto& group_optimizer_state = grouped_optimizer_states[group_name];
|
||||
|
||||
// Parse group-wise properties.
|
||||
const std::string& cur_group_filename_prefix = StringConcat(k_optimizer_root_prefix, group_name);
|
||||
const PathString cur_group_filename_prefix = StringConcat(k_optimizer_root_prefix, results[1]);
|
||||
const PathString& tensor_file_path = GetTensorProtoPropertiesFilePath(optimizer_folder_path, cur_group_filename_prefix);
|
||||
std::vector<ONNX_NAMESPACE::TensorProto> group_wise_property_protos{};
|
||||
LoadTensorProtoFromFile(tensor_file_path, group_wise_property_protos, "[optimizer_groupwise_property]");
|
||||
|
|
@ -484,10 +500,11 @@ Status OrtLoadOptimizerStatesInternal(const PathString& optimizer_folder_path,
|
|||
Status OrtLoadCustomPropertyInternal(const PathString& property_folder_path,
|
||||
PropertyBag& property_bag) {
|
||||
// Find custom property files.
|
||||
std::vector<std::string> custom_property_filenames;
|
||||
std::vector<PathString> custom_property_filenames;
|
||||
FilterFilesFromDirectory(
|
||||
property_folder_path,
|
||||
[&custom_property_filenames](const std::string& filename_str) -> bool {
|
||||
[&custom_property_filenames](const PathChar* filename) -> bool {
|
||||
PathString filename_str = filename;
|
||||
if (StringStartsWith(filename_str, k_property_root_prefix)) {
|
||||
custom_property_filenames.push_back(filename_str);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ struct PropertyBag {
|
|||
}
|
||||
}
|
||||
|
||||
int Size() const {
|
||||
size_t Size() const {
|
||||
return named_properties.size();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ struct Parameter {
|
|||
std::string gradient_name_;
|
||||
|
||||
bool requires_grad_{true};
|
||||
friend class Module;
|
||||
friend struct Module;
|
||||
};
|
||||
|
||||
struct ModuleCheckpointState {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
// DO NOT UNCOMMENT
|
||||
//#include "core/session/onnxruntime_c_api.h"
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
|
||||
ORT_API2_STATUS(SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* session,
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ struct ParameterOptimizerState {
|
|||
*/
|
||||
struct GroupOptimizerState {
|
||||
int64_t step = 0;
|
||||
float initial_lr = 0.001; // Default value used in torch AdamW
|
||||
float initial_lr = 0.001f; // Default value used in torch AdamW
|
||||
float learning_rate{initial_lr}; // Adaptive learning rate as training proceeds.
|
||||
std::unordered_map<std::string, ParameterOptimizerState> param_named_optimizer_states;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/error_code_helper.h"
|
||||
#include "core/framework/ort_value.h"
|
||||
#include "core/session/ort_apis.h"
|
||||
|
|
@ -40,12 +41,12 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_
|
|||
options == nullptr ? onnxruntime::SessionOptions() : options->value,
|
||||
options == nullptr ? ProvidersType() : CreateProviders(options->provider_factories),
|
||||
chkpt_state->module_checkpoint_state.named_parameters,
|
||||
onnxruntime::training::api::ModelIdentifiers{
|
||||
train_model_path,
|
||||
eval_model_path ? std::optional<std::string>{eval_model_path}
|
||||
onnxruntime::training::api::ModelIdentifiers(
|
||||
onnxruntime::ToUTF8String(train_model_path),
|
||||
eval_model_path ? std::optional<std::string>(onnxruntime::ToUTF8String(eval_model_path))
|
||||
: std::nullopt,
|
||||
optimizer_model_path ? std::optional<std::string>{optimizer_model_path}
|
||||
: std::nullopt});
|
||||
optimizer_model_path ? std::optional<std::string>(onnxruntime::ToUTF8String(optimizer_model_path))
|
||||
: std::nullopt));
|
||||
|
||||
*out = reinterpret_cast<OrtTrainingSession*>(train_sess.release());
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue