From 17a8ecee6fb4bc8fafc7b45bee363cebb28eb5df Mon Sep 17 00:00:00 2001 From: pengwa Date: Thu, 23 Jun 2022 18:16:50 +0800 Subject: [PATCH] fix win build errors (on device training) (#11844) * fix win build errors * fix linux build * fix typo * minor fix * fix win in c api * fix linux build complaining bool * fix ORT_RETURN_ON_ERROR --- .../test/training_api/trainer/trainer.cc | 10 +- .../orttraining/training_api/checkpoint.cc | 107 ++++++++++-------- .../include/checkpoint_property.h | 2 +- .../orttraining/training_api/include/module.h | 2 +- .../include/onnxruntime_training_c_api.h | 2 + .../training_api/include/optimizer.h | 2 +- .../onnxruntime_training_c_api.cc | 11 +- 7 files changed, 78 insertions(+), 58 deletions(-) diff --git a/orttraining/orttraining/test/training_api/trainer/trainer.cc b/orttraining/orttraining/test/training_api/trainer/trainer.cc index 6d6d8ade32..d6ebad30f8 100644 --- a/orttraining/orttraining/test/training_api/trainer/trainer.cc +++ b/orttraining/orttraining/test/training_api/trainer/trainer.cc @@ -53,9 +53,9 @@ void EnforceCheck(bool run_ret, std::string err_msg) { if (onnx_status != NULL) { \ auto code = g_ort_api->GetErrorCode(onnx_status); \ const char* msg = g_ort_api->GetErrorMessage(onnx_status); \ - g_ort_api->ReleaseStatus(onnx_status); \ printf("Run failed with error code :%d\n", code); \ printf("Error message :%s\n", msg); \ + g_ort_api->ReleaseStatus(onnx_status); \ return -1; \ } \ } while (0); @@ -221,9 +221,9 @@ int RunTraining(const TestRunnerParameters& params) { InitSyntheticDataLoader(data_loader, params, num_of_batches_per_epoch); // TODO(baiju): Add C API for LRScheduler - //int64_t total_step_count = params.num_train_epochs * num_of_batches_per_epoch; - //int64_t warmup_step_count = total_step_count / 3; - //Ort::OrtLinearLRScheduler scheduler = Ort::OrtLinearLRScheduler(optimizer, warmup_step_count, total_step_count); + // int64_t total_step_count = params.num_train_epochs * num_of_batches_per_epoch; + // int64_t warmup_step_count = total_step_count / 3; + // Ort::OrtLinearLRScheduler scheduler = Ort::OrtLinearLRScheduler(optimizer, warmup_step_count, total_step_count); std::cout << "Initialization completed. Now starting training loop." << std::endl; const int64_t stabilized_perf_start_step = 0; @@ -273,7 +273,7 @@ int RunTraining(const TestRunnerParameters& params) { #endif // Update learning rate. - //EnforceCheck(scheduler.Step(), "Failed during shceduler.Step()"); + // EnforceCheck(scheduler.Step(), "Failed during shceduler.Step()"); #if defined(USE_CUDA) && defined(ENABLE_NVTX_PROFILE) onnxruntime::profile::NvtxRangeCreator resetgrad_range( diff --git a/orttraining/orttraining/training_api/checkpoint.cc b/orttraining/orttraining/training_api/checkpoint.cc index f041bc2c53..a02e04c79a 100644 --- a/orttraining/orttraining/training_api/checkpoint.cc +++ b/orttraining/orttraining/training_api/checkpoint.cc @@ -22,13 +22,13 @@ namespace api { namespace { -constexpr const char* k_tensor_proto_file_name = "tensors.pbseq"; -constexpr const char* k_tensor_proto_properties_file_name = "properties.pbseq"; -constexpr const char* k_trainable_param_root_prefix = "paramtrain"; -constexpr const char* k_non_trainable_param_root_prefix = "paramfrozen"; -constexpr const char* k_optimizer_root_prefix = "optim"; -constexpr const char* k_property_root_prefix = "custom"; -constexpr const char* k_name_seperator = "_"; +const PathString k_tensor_proto_file_name = ORT_TSTR("tensors.pbseq"); +const PathString k_tensor_proto_properties_file_name = ORT_TSTR("properties.pbseq"); +const PathString k_trainable_param_root_prefix = ORT_TSTR("paramtrain"); +const PathString k_non_trainable_param_root_prefix = ORT_TSTR("paramfrozen"); +const PathString k_optimizer_root_prefix = ORT_TSTR("optim"); +const PathString k_property_root_prefix = ORT_TSTR("custom"); +const PathString k_name_separator = ORT_TSTR("_"); const std::string builtin_lr_property_name("builtin.initial_learning_rate"); const std::string builtin_step_property_name("builtin.step"); @@ -58,8 +58,8 @@ Status CreateTensorProtosFromOrtValues( saved_tensor_protos.reserve(ordered_tensor_names.size()); - unsigned long total_bytes = 0; - constexpr unsigned long PROTOBUF_UPPER_LIMIT = 2 * 1000 * 1000 * 1000; + uint64_t total_bytes = 0; + constexpr uint64_t PROTOBUF_UPPER_LIMIT = 2 * 1000 * 1000 * 1000; for (const auto& tensor_name : ordered_tensor_names) { const OrtValue& ort_value = name_to_ort_value.at(tensor_name); ORT_RETURN_IF_NOT(ort_value.IsTensor(), "ort_value.IsTensor() was false"); @@ -68,7 +68,7 @@ Status CreateTensorProtosFromOrtValues( // Currently large model size not considered, so exception thrown here // when protobuf upper limit hit. - total_bytes += src_tensor.SizeInBytes(); + total_bytes += static_cast(src_tensor.SizeInBytes()); if (total_bytes >= PROTOBUF_UPPER_LIMIT) { ORT_THROW("checkpoint file size hit upper limit."); } @@ -95,23 +95,33 @@ Status CreateTensorProtosFromOrtValues( return Status::OK(); } -PathString GetTensorProtoFilePath(const PathString& checkpoint_directory, const std::string& filename_prefix) { - return ConcatPathComponent(checkpoint_directory, ORT_TSTR(filename_prefix + k_name_seperator) + k_tensor_proto_file_name); +PathString GetTensorProtoFilePath(const PathString& checkpoint_directory, const PathString& filename_prefix) { + std::basic_ostringstream oss; + oss << filename_prefix << k_name_separator << k_tensor_proto_file_name; + return ConcatPathComponent(checkpoint_directory, oss.str()); } -PathString GetTensorProtoPropertiesFilePath(const PathString& checkpoint_directory, const std::string& filename_prefix) { - return ConcatPathComponent(checkpoint_directory, ORT_TSTR(filename_prefix + k_name_seperator) + k_tensor_proto_properties_file_name); +PathString GetTensorProtoPropertiesFilePath( + const PathString& checkpoint_directory, const PathString& filename_prefix) { + std::basic_ostringstream oss; + oss << filename_prefix << k_name_separator << k_tensor_proto_properties_file_name; + return ConcatPathComponent(checkpoint_directory, oss.str()); } -std::string StringConcat(const std::string& s_a, const std::string& s_b, const std::string& del = k_name_seperator) { - return s_a + del + s_b; +PathString StringConcat( + const PathString& s_a, const PathString& s_b, + const PathString& del = k_name_separator) { + std::basic_ostringstream oss; + oss << s_a << del << s_b; + return oss.str(); } -void StringSplit(const std::string& s, std::vector& results, const std::string& del = k_name_seperator) { +void StringSplit(const PathString& s, std::vector& results, + const PathString& del = k_name_separator) { ORT_ENFORCE(!s.empty(), "String to split is empty"); - int start = 0; - int end = s.find(del); - while (end != -1) { + size_t start = 0; + size_t end = s.find(del); + while (end != std::string::npos) { results.push_back(s.substr(start, end - start)); start = end + del.size(); end = s.find(del, start); @@ -119,11 +129,11 @@ void StringSplit(const std::string& s, std::vector& results, const results.push_back(s.substr(start, end - start)); } -bool StringStartsWith(std::string const& s, std::string const& p) { +bool StringStartsWith(PathString const& s, PathString const& p) { return s.rfind(p, 0) == 0; } -bool StringEndsWith(std::string const& s, std::string const& p) { +bool StringEndsWith(PathString const& s, PathString const& p) { if (p.size() > s.size()) return false; return std::equal(p.rbegin(), p.rend(), s.rbegin()); } @@ -159,12 +169,11 @@ void LoadTensorProtoFromFile(const PathString& file_path, template void FilterFilesFromDirectory(const PathString& folder_path, Func func) { LoopDir(folder_path, [&func](const PathChar* filename, OrtFileType file_type) -> bool { - std::string filename_str = filename; - if (filename_str[0] == '.' || file_type == OrtFileType::TYPE_DIR) { + if (filename[0] == '.' || file_type == OrtFileType::TYPE_DIR) { return true; } - return func(filename_str); + return func(filename); }); } @@ -222,7 +231,8 @@ Status OrtSaveModuleStatesInternal(ModuleCheckpointState& module_state, ORT_ENFORCE(module_state.train_session_data_transfer_mgr, "module checkpoint state has null train_session_data_transfer_mgr."); - std::unordered_map> parameter_ort_values; + std::unordered_map> + parameter_ort_values; parameter_ort_values[k_trainable_param_root_prefix] = {}; parameter_ort_values[k_non_trainable_param_root_prefix] = {}; for (auto it = param_states.begin(); it != param_states.end(); ++it) { @@ -262,9 +272,10 @@ Status OrtSaveOptimizerStatesInternal(OptimizerCheckpointState& optimizer_state, // Write optimizer state tensors files. for (auto& group_named_optimizer_state : optimizer_state.group_named_optimizer_states) { - const std::string& group_name = group_named_optimizer_state.first; + const PathString group_name = ToWideString(group_named_optimizer_state.first); const std::shared_ptr& group_optimizer_state_ptr = group_named_optimizer_state.second; - const std::string& cur_group_filename_prefix = StringConcat(k_optimizer_root_prefix, group_name); + const PathString& cur_group_filename_prefix = + StringConcat(k_optimizer_root_prefix, group_name); // Re-organize optimizer_state_ort_values mapping // Firstly indexed by momentum names; Secondly indexed by parameter names. @@ -291,9 +302,10 @@ Status OrtSaveOptimizerStatesInternal(OptimizerCheckpointState& optimizer_state, // Save each optimizer state (of all parameters) into single file. // For example: save "momentum_1" of all parameters into one file. for (auto& pair : optimizer_state_ort_values) { - const auto& momentum_name = pair.first; + const PathString momentum_name = ToWideString(pair.first); const std::unordered_map& param_name_to_ortvalue = pair.second; - const std::string& cur_state_filename_prefix = StringConcat(cur_group_filename_prefix, momentum_name); + const PathString& cur_state_filename_prefix = + StringConcat(cur_group_filename_prefix, momentum_name); std::vector saved_tensor_protos; ORT_RETURN_IF_ERROR(CreateTensorProtosFromOrtValues( @@ -353,10 +365,11 @@ Status OrtSaveInternal( Status OrtLoadModuleStatesInternal( const PathString& parameter_folder_path, ModuleCheckpointState& module_state) { // Find parameter files. - std::vector> param_filenames; + std::vector> param_filenames; FilterFilesFromDirectory( parameter_folder_path, - [¶m_filenames](const std::string& filename_str) -> bool { + [¶m_filenames](const PathChar* filename) -> bool { + PathString filename_str = filename; if (StringStartsWith(filename_str, k_trainable_param_root_prefix)) { param_filenames.push_back(std::make_pair(filename_str, true)); } else if (StringStartsWith(filename_str, k_non_trainable_param_root_prefix)) { @@ -397,11 +410,12 @@ Status OrtLoadModuleStatesInternal( Status OrtLoadOptimizerStatesInternal(const PathString& optimizer_folder_path, OptimizerCheckpointState& optimizer_state) { // Optimizer states parsing. - std::vector optim_state_filenames; - std::vector optim_property_filenames; + std::vector optim_state_filenames; + std::vector optim_property_filenames; FilterFilesFromDirectory( optimizer_folder_path, - [&optim_state_filenames, &optim_property_filenames](const std::string& filename_str) -> bool { + [&optim_state_filenames, &optim_property_filenames](const PathChar* filename) -> bool { + PathString filename_str = filename; if (StringStartsWith(filename_str, k_optimizer_root_prefix)) { if (StringEndsWith(filename_str, k_tensor_proto_file_name)) { optim_state_filenames.push_back(filename_str); @@ -417,13 +431,15 @@ Status OrtLoadOptimizerStatesInternal(const PathString& optimizer_folder_path, auto& grouped_optimizer_states = optimizer_state.group_named_optimizer_states; // For each optimizer state files, parse the data and feed into grouped_optimizer_states. for (auto& filename : optim_state_filenames) { - std::vector results; + std::vector results; StringSplit(filename, results); - const std::string& group_name = results[1]; - const std::string& momentum_name = results[2]; - const std::string& cur_group_filename_prefix = StringConcat(k_optimizer_root_prefix, group_name); - std::string cur_momentum_state_filename_prefix = StringConcat(cur_group_filename_prefix, momentum_name); + const std::string& group_name = ToUTF8String(results[1]); + const std::string& momentum_name = ToUTF8String(results[2]); + const PathString cur_group_filename_prefix = + StringConcat(k_optimizer_root_prefix, results[1]); + PathString cur_momentum_state_filename_prefix = + StringConcat(cur_group_filename_prefix, results[2]); ORT_ENFORCE(filename.compare(StringConcat(cur_momentum_state_filename_prefix, k_tensor_proto_file_name)) == 0); if (grouped_optimizer_states.find(group_name) == grouped_optimizer_states.end()) { @@ -452,9 +468,9 @@ Status OrtLoadOptimizerStatesInternal(const PathString& optimizer_folder_path, // For each optimizer properties files, parse the data and feed into grouped_optimizer_states. for (auto& filename : optim_property_filenames) { - std::vector results; + std::vector results; StringSplit(filename, results); - const std::string& group_name = results[1]; + const std::string& group_name = ToUTF8String(results[1]); if (grouped_optimizer_states.find(group_name) == grouped_optimizer_states.end()) { grouped_optimizer_states.insert({group_name, std::make_shared()}); @@ -463,7 +479,7 @@ Status OrtLoadOptimizerStatesInternal(const PathString& optimizer_folder_path, auto& group_optimizer_state = grouped_optimizer_states[group_name]; // Parse group-wise properties. - const std::string& cur_group_filename_prefix = StringConcat(k_optimizer_root_prefix, group_name); + const PathString cur_group_filename_prefix = StringConcat(k_optimizer_root_prefix, results[1]); const PathString& tensor_file_path = GetTensorProtoPropertiesFilePath(optimizer_folder_path, cur_group_filename_prefix); std::vector group_wise_property_protos{}; LoadTensorProtoFromFile(tensor_file_path, group_wise_property_protos, "[optimizer_groupwise_property]"); @@ -484,10 +500,11 @@ Status OrtLoadOptimizerStatesInternal(const PathString& optimizer_folder_path, Status OrtLoadCustomPropertyInternal(const PathString& property_folder_path, PropertyBag& property_bag) { // Find custom property files. - std::vector custom_property_filenames; + std::vector custom_property_filenames; FilterFilesFromDirectory( property_folder_path, - [&custom_property_filenames](const std::string& filename_str) -> bool { + [&custom_property_filenames](const PathChar* filename) -> bool { + PathString filename_str = filename; if (StringStartsWith(filename_str, k_property_root_prefix)) { custom_property_filenames.push_back(filename_str); } diff --git a/orttraining/orttraining/training_api/include/checkpoint_property.h b/orttraining/orttraining/training_api/include/checkpoint_property.h index 03100bac61..7a7a5b77e3 100644 --- a/orttraining/orttraining/training_api/include/checkpoint_property.h +++ b/orttraining/orttraining/training_api/include/checkpoint_property.h @@ -102,7 +102,7 @@ struct PropertyBag { } } - int Size() const { + size_t Size() const { return named_properties.size(); } diff --git a/orttraining/orttraining/training_api/include/module.h b/orttraining/orttraining/training_api/include/module.h index dabc372034..c209804869 100644 --- a/orttraining/orttraining/training_api/include/module.h +++ b/orttraining/orttraining/training_api/include/module.h @@ -44,7 +44,7 @@ struct Parameter { std::string gradient_name_; bool requires_grad_{true}; - friend class Module; + friend struct Module; }; struct ModuleCheckpointState { diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index 610e03ee58..c143ede565 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -7,6 +7,8 @@ // DO NOT UNCOMMENT //#include "core/session/onnxruntime_c_api.h" +#include + ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state); ORT_API2_STATUS(SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* session, diff --git a/orttraining/orttraining/training_api/include/optimizer.h b/orttraining/orttraining/training_api/include/optimizer.h index cb8a20f602..7c27f23abf 100644 --- a/orttraining/orttraining/training_api/include/optimizer.h +++ b/orttraining/orttraining/training_api/include/optimizer.h @@ -27,7 +27,7 @@ struct ParameterOptimizerState { */ struct GroupOptimizerState { int64_t step = 0; - float initial_lr = 0.001; // Default value used in torch AdamW + float initial_lr = 0.001f; // Default value used in torch AdamW float learning_rate{initial_lr}; // Adaptive learning rate as training proceeds. std::unordered_map param_named_optimizer_states; }; diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 730ad208c4..5a6237018c 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/common.h" #include "core/framework/error_code_helper.h" #include "core/framework/ort_value.h" #include "core/session/ort_apis.h" @@ -40,12 +41,12 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ options == nullptr ? onnxruntime::SessionOptions() : options->value, options == nullptr ? ProvidersType() : CreateProviders(options->provider_factories), chkpt_state->module_checkpoint_state.named_parameters, - onnxruntime::training::api::ModelIdentifiers{ - train_model_path, - eval_model_path ? std::optional{eval_model_path} + onnxruntime::training::api::ModelIdentifiers( + onnxruntime::ToUTF8String(train_model_path), + eval_model_path ? std::optional(onnxruntime::ToUTF8String(eval_model_path)) : std::nullopt, - optimizer_model_path ? std::optional{optimizer_model_path} - : std::nullopt}); + optimizer_model_path ? std::optional(onnxruntime::ToUTF8String(optimizer_model_path)) + : std::nullopt)); *out = reinterpret_cast(train_sess.release()); }