diff --git a/orttraining/orttraining/test/training_api/common/synthetic_data_loader.cc b/orttraining/orttraining/test/training_api/common/synthetic_data_loader.cc index 9f5ba7296b..0a9989d9c6 100644 --- a/orttraining/orttraining/test/training_api/common/synthetic_data_loader.cc +++ b/orttraining/orttraining/test/training_api/common/synthetic_data_loader.cc @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include #include "synthetic_data_loader.h" @@ -17,108 +19,140 @@ namespace training_api { namespace { -void RandomFloats(std::vector& rets) { +void RandomFloats(std::vector& rets, size_t num_element) { const float scale = 1.f; const float mean = 0.f; const float seed = 123.f; static std::default_random_engine generator{static_cast(seed)}; std::normal_distribution distribution{mean, scale}; - std::for_each(rets.begin(), rets.end(), - [&distribution](float& value) { value = distribution(generator); }); + + std::generate_n(std::back_inserter(rets), num_element, + [&distribution]() -> float { return distribution(generator); }); } template -void RandomInts(std::vector& rets, IntType low, IntType high) { +void RandomInts(std::vector& rets, size_t num_element, IntType low, IntType high) { static std::random_device rd; static std::mt19937 generator(rd()); std::uniform_int_distribution distribution(low, high); - std::for_each(rets.begin(), rets.end(), - [&distribution](IntType& value) { value = distribution(generator); }); + + std::generate_n(std::back_inserter(rets), num_element, + [&distribution]() -> IntType { return distribution(generator); }); } } // namespace -void SyntheticSampleBatch::AddInt64Input(const std::vector& shape, int64_t low, int64_t high) { - data_vector_.emplace_back(std::make_unique>(shape)); - RandomInts(data_vector_.back()->GetData(), low, high); +template +void SyntheticSampleBatch::AddIntInput(gsl::span shape, T low, T high) { + data_vector_.push_back(SyntheticInput(shape)); + + std::vector values; + auto num_of_element = data_vector_.back().NumOfElements(); + values.reserve(num_of_element); + RandomInts(values, num_of_element, low, high); + + SyntheticDataVector& data = data_vector_.back().GetData(); + data = values; } -void SyntheticSampleBatch::AddInt32Input(const std::vector& shape, int32_t low, int32_t high) { - data_vector_.emplace_back(std::make_unique>(shape)); - RandomInts(data_vector_.back()->GetData(), low, high); +void SyntheticSampleBatch::AddInt64Input(gsl::span shape, int64_t low, int64_t high) { + AddIntInput(shape, low, high); } -void SyntheticSampleBatch::AddFloatInput(const std::vector& shape) { - data_vector_.emplace_back(std::make_unique>(shape)); - RandomFloats(data_vector_.back()->GetData()); +void SyntheticSampleBatch::AddInt32Input(gsl::span shape, int32_t low, int32_t high) { + AddIntInput(shape, low, high); } -#define ORT_RETURN_ON_ERROR(expr) \ - do { \ - OrtStatus* onnx_status = (expr); \ - if (onnx_status != NULL) { \ +void SyntheticSampleBatch::AddBoolInput(gsl::span shape) { + // Use uint8_t to store the bool value by intention, because vector is specialized, we can not create a + // Tensor leveraging C APIs to reuse the data buffer. + data_vector_.push_back(SyntheticInput(shape)); + + std::vector values; + auto num_of_element = data_vector_.back().NumOfElements(); + values.reserve(num_of_element); + + // Need random with int32_t first because MSVC compiler complains uint8_t usage for uniform_int_distribution. + RandomInts(values, num_of_element, static_cast(0), static_cast(1)); + + SyntheticDataVector& data = data_vector_.back().GetData(); + std::vector uint8_values; + std::transform(values.begin(), values.end(), std::back_inserter(uint8_values), + [](int32_t x) { return static_cast(x); }); + data = uint8_values; +} + +void SyntheticSampleBatch::AddFloatInput(gsl::span shape) { + data_vector_.push_back(SyntheticInput(shape)); + + std::vector values; + auto num_of_element = data_vector_.back().NumOfElements(); + values.reserve(num_of_element); + RandomFloats(values, num_of_element); + + SyntheticDataVector& data = data_vector_.back().GetData(); + data = values; +} + +#define ORT_RETURN_ON_ERROR(expr) \ + do { \ + OrtStatus* onnx_status = (expr); \ + if (onnx_status != NULL) { \ auto code = ort_api->GetErrorCode(onnx_status); \ const char* msg = ort_api->GetErrorMessage(onnx_status); \ + printf("Run failed with error code :%d\n", code); \ + printf("Error message :%s\n", msg); \ ort_api->ReleaseStatus(onnx_status); \ - printf("Run failed with error code :%d\n", code); \ - printf("Error message :%s\n", msg); \ - return false; \ - } \ + return false; \ + } \ } while (0); +bool SyntheticSampleBatch::GetBatch(std::vector& batches) { + batches.clear(); + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + const auto* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + for (size_t i = 0; i < data_vector_.size(); ++i) { + SyntheticInput& input = data_vector_[i]; + + const bool ret = std::visit([&batches, &input, &ort_api, &memory_info](auto&& arg) -> bool { + ONNXTensorElementDataType elem_data_type; + using T = std::decay_t; + if constexpr (std::is_same_v) { + elem_data_type = Ort::TypeToTensorType::type; + } else { + elem_data_type = Ort::TypeToTensorType::type; + } + + OrtValue* value = nullptr; + const auto& shape_vector = input.ShapeVector(); + // Be noted: the created OrtValue won't clean the raw data after its lifetime ended. + ORT_RETURN_ON_ERROR(ort_api->CreateTensorWithDataAsOrtValue( + memory_info, + arg.data(), (input.NumOfElements() * sizeof(typename T::value_type)), + shape_vector.data(), shape_vector.size(), + elem_data_type, + &value)); + + batches.emplace_back(value); + return true; + }, + input.GetData()); + + if (!ret) { + return false; + } + } + + return true; +} + bool SyntheticDataLoader::GetNextSampleBatch(std::vector& batches) { if (sample_batch_iter_index_ >= NumOfSampleBatches()) { return false; } - batches.clear(); - - auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); auto& sample = sample_batch_collections_[sample_batch_iter_index_]; - const auto* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - for (size_t i = 0; i < sample->NumOfInput(); ++i) { - auto input_ptr = sample->GetInputAtIndex(i); - auto shape_vector = input_ptr->ShapeVector(); - // Be noted: the created OrtValue won't clean the raw data after its lifetime ended. - auto ptr_flt = dynamic_cast*>(input_ptr); - if (ptr_flt) { - OrtValue* value = NULL; - ORT_RETURN_ON_ERROR(ort_api->CreateTensorWithDataAsOrtValue(memory_info, - input_ptr->GetData().data(), (input_ptr->NumOfElements() * sizeof(float)), - shape_vector.data(), shape_vector.size(), - ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, - &value)); - batches.emplace_back(value); - continue; - } - - auto ptr_int = dynamic_cast*>(input_ptr); - if (ptr_int) { - OrtValue* value = NULL; - ORT_RETURN_ON_ERROR(ort_api->CreateTensorWithDataAsOrtValue(memory_info, - input_ptr->GetData().data(), (input_ptr->NumOfElements() * sizeof(int64_t)), - shape_vector.data(), shape_vector.size(), - ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, - &value)); - batches.emplace_back(value); - continue; - } - - auto ptr_int32 = dynamic_cast*>(input_ptr); - if (ptr_int32) { - OrtValue* value = nullptr; - ORT_RETURN_ON_ERROR(ort_api->CreateTensorWithDataAsOrtValue(memory_info, - input_ptr->GetData().data(), (input_ptr->NumOfElements() * sizeof(int32_t)), - shape_vector.data(), shape_vector.size(), - ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, - &value)); - batches.emplace_back(value); - continue; - } - - throw std::runtime_error("unknown data types."); - } - + sample.GetBatch(batches); sample_batch_iter_index_ += 1; return true; } diff --git a/orttraining/orttraining/test/training_api/common/synthetic_data_loader.h b/orttraining/orttraining/test/training_api/common/synthetic_data_loader.h index a0ce51ab7f..11ced3f563 100644 --- a/orttraining/orttraining/test/training_api/common/synthetic_data_loader.h +++ b/orttraining/orttraining/test/training_api/common/synthetic_data_loader.h @@ -11,10 +11,13 @@ #pragma once +#include "gsl/gsl" + #include #include #include +#include #include namespace onnxruntime { @@ -22,81 +25,61 @@ namespace training { namespace test { namespace training_api { -template -struct TypedSyntheticInput; +using SyntheticDataVector = std::variant, std::vector, std::vector, + std::vector>; struct SyntheticInput { - explicit SyntheticInput(const std::vector& shape) : shape_(shape) { + explicit SyntheticInput(gsl::span shape) : shape_(shape.begin(), shape.end()) { for (auto d : shape) { num_of_elements_ *= d; } } - virtual ~SyntheticInput() {} - - template - std::vector& GetData() { - auto ptr = dynamic_cast*>(this); - return ptr->Data(); - } - - size_t NumOfElements() { + size_t NumOfElements() const { return num_of_elements_; } - std::vector ShapeVector() const { + gsl::span ShapeVector() const { return shape_; } - protected: - std::vector shape_; - size_t num_of_elements_{1}; -}; - -template -struct TypedSyntheticInput : public SyntheticInput { - explicit TypedSyntheticInput(const std::vector& shape) - : SyntheticInput(shape) { - data_.resize(num_of_elements_); - } - - std::vector& Data() { + SyntheticDataVector& GetData() { return data_; } private: - std::vector data_; + std::vector shape_; + size_t num_of_elements_{1}; + SyntheticDataVector data_; }; struct SyntheticSampleBatch { - SyntheticSampleBatch() {} + SyntheticSampleBatch() = default; - void AddInt32Input(const std::vector& shape, int32_t low, int32_t high); - void AddInt64Input(const std::vector& shape, int64_t low, int64_t high); - void AddFloatInput(const std::vector& shape); + void AddInt32Input(gsl::span shape, int32_t low, int32_t high); + void AddInt64Input(gsl::span shape, int64_t low, int64_t high); + void AddFloatInput(gsl::span shape); + void AddBoolInput(gsl::span shape); - size_t NumOfInput() { - return data_vector_.size(); - } - - SyntheticInput* GetInputAtIndex(size_t index) { - return data_vector_[index].get(); - } + bool GetBatch(std::vector& batches); private: - std::vector> data_vector_; + template + void AddIntInput(gsl::span shape, T low, T high); + + std::vector data_vector_; }; struct SyntheticDataLoader { - SyntheticDataLoader() {} + SyntheticDataLoader() = default; - void AddSyntheticSampleBatch(std::unique_ptr samples) { - sample_batch_collections_.emplace_back(std::move(samples)); + void AddSyntheticSampleBatch(SyntheticSampleBatch&& samples) { + sample_batch_collections_.emplace_back(samples); } bool GetNextSampleBatch(std::vector& batches); - size_t NumOfSampleBatches() { + size_t NumOfSampleBatches() const { return sample_batch_collections_.size(); } @@ -109,7 +92,7 @@ struct SyntheticDataLoader { // did not explicitly copy the data in. // And also, the created OrtValue also won't clean the raw data pointer. The raw data should be removed when // the life time of this struct ends. - std::vector> sample_batch_collections_; + std::vector sample_batch_collections_; size_t sample_batch_iter_index_{0}; }; diff --git a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc index eee6dee6a5..862c920edb 100644 --- a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc +++ b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc @@ -213,11 +213,11 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CUDA) { const std::vector fc2_bias_shape{fc2_weight_dim_in}; onnxruntime::training::test::training_api::SyntheticDataLoader data_loader; - auto sample = std::make_unique(); - sample->AddFloatInput(fc1_weight_shape); - sample->AddFloatInput(fc1_bias_shape); - sample->AddFloatInput(fc2_weight_shape); - sample->AddFloatInput(fc2_bias_shape); + auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch(); + sample.AddFloatInput(fc1_weight_shape); + sample.AddFloatInput(fc1_bias_shape); + sample.AddFloatInput(fc2_weight_shape); + sample.AddFloatInput(fc2_bias_shape); data_loader.AddSyntheticSampleBatch(std::move(sample)); std::vector all_weights_values; @@ -343,15 +343,15 @@ TEST(CheckpointApiTest, SaveCustomPropertyAsCheckpoint_ThenLoad_CPU) { float f_data = 0.5f; std::string f_property_name("float_number"); - property_bag.AddProperty(f_property_name, f_data); + property_bag.AddProperty(f_property_name, f_data); int64_t i_data = 400; std::string i_property_name("dataset_epoch_index"); - property_bag.AddProperty(i_property_name, i_data); + property_bag.AddProperty(i_property_name, i_data); std::string s_data("/data/path/train.bin"); std::string s_property_name("train_data_path"); - property_bag.AddProperty(s_property_name, s_data); + property_bag.AddProperty(s_property_name, s_data); // Remove the tempoprary directory if it already exists. auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir"); diff --git a/orttraining/orttraining/test/training_api/trainer/trainer.cc b/orttraining/orttraining/test/training_api/trainer/trainer.cc index ce5f572726..61b4084f4d 100644 --- a/orttraining/orttraining/test/training_api/trainer/trainer.cc +++ b/orttraining/orttraining/test/training_api/trainer/trainer.cc @@ -156,9 +156,9 @@ void InitSyntheticDataLoader( std::vector input1_shape{params.train_batch_size, 784}; std::vector target_shape{params.train_batch_size}; for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) { - auto sample = std::make_unique(); - sample->AddFloatInput(input1_shape); - sample->AddInt32Input(target_shape, 0, 1); + auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch(); + sample.AddFloatInput(input1_shape); + sample.AddInt32Input(target_shape, 0, 1); data_loader.AddSyntheticSampleBatch(std::move(sample)); } } else if (params.synthetic_input_type == "S") { @@ -167,10 +167,10 @@ void InitSyntheticDataLoader( std::vector attention_mask_shape{params.train_batch_size, sequence_length}; std::vector target_shape{params.train_batch_size}; for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) { - auto sample = std::make_unique(); - sample->AddInt64Input(input_ids_shape, 0, 250002 - 1); - sample->AddInt64Input(attention_mask_shape, 0, 1); - sample->AddInt32Input(target_shape, 0, 1); + auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch(); + sample.AddInt64Input(input_ids_shape, 0, 250002 - 1); + sample.AddInt64Input(attention_mask_shape, 0, 1); + sample.AddInt32Input(target_shape, 0, 1); data_loader.AddSyntheticSampleBatch(std::move(sample)); } } else if (params.synthetic_input_type == "U") { @@ -180,11 +180,11 @@ void InitSyntheticDataLoader( std::vector target1_shape{params.train_batch_size}; std::vector target2_shape{params.train_batch_size, 81}; for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) { - auto sample = std::make_unique(); - sample->AddInt64Input(input_ids_shape, 0, 250002 - 1); - sample->AddInt64Input(attention_mask_shape, 0, 1); - sample->AddInt32Input(target1_shape, 0, 1); - sample->AddInt32Input(target2_shape, 0, 1); + auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch(); + sample.AddInt64Input(input_ids_shape, 0, 250002 - 1); + sample.AddInt64Input(attention_mask_shape, 0, 1); + sample.AddInt32Input(target1_shape, 0, 1); + sample.AddInt32Input(target2_shape, 0, 1); data_loader.AddSyntheticSampleBatch(std::move(sample)); } } else if (params.synthetic_input_type == "R") { @@ -193,10 +193,25 @@ void InitSyntheticDataLoader( std::vector attention_mask_shape{params.train_batch_size, sequence_length}; std::vector labels_shape{params.train_batch_size, 81}; for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) { - auto sample = std::make_unique(); - sample->AddInt64Input(input_ids_shape, 0, 250002 - 1); - sample->AddInt64Input(attention_mask_shape, 0, 1); - sample->AddInt32Input(labels_shape, 0, 1); + auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch(); + sample.AddInt64Input(input_ids_shape, 0, 250002 - 1); + sample.AddInt64Input(attention_mask_shape, 0, 1); + sample.AddInt32Input(labels_shape, 0, 1); + data_loader.AddSyntheticSampleBatch(std::move(sample)); + } + } else if (params.synthetic_input_type == "C") { + int64_t section = 16; + int64_t sequence_length = 128; + std::vector input_ids_shape{params.train_batch_size, section, sequence_length}; + std::vector mask_clss_shape{params.train_batch_size, section}; + std::vector attention_mask_shape{params.train_batch_size, section, sequence_length}; + std::vector labels_shape{params.train_batch_size}; + for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) { + auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch(); + sample.AddInt64Input(input_ids_shape, 0, 250002 - 1); + sample.AddBoolInput(mask_clss_shape); + sample.AddInt64Input(attention_mask_shape, 0, 1); + sample.AddInt32Input(labels_shape, 0, 1); data_loader.AddSyntheticSampleBatch(std::move(sample)); } } else { diff --git a/orttraining/orttraining/training_api/checkpoint.cc b/orttraining/orttraining/training_api/checkpoint.cc index e7e27aac57..7908a89459 100644 --- a/orttraining/orttraining/training_api/checkpoint.cc +++ b/orttraining/orttraining/training_api/checkpoint.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/inlined_containers.h" #include "core/common/logging/logging.h" #include "core/common/logging/sinks/clog_sink.h" #include "core/common/path.h" @@ -46,14 +47,14 @@ Status CreateTensorProtosFromOrtValues( const DataTransferManager& data_transfer_manager, std::vector& saved_tensor_protos) { // Order the tensors by name. - std::vector ordered_tensor_names{}; + InlinedVector ordered_tensor_names{}; ordered_tensor_names.reserve(name_to_ort_value.size()); std::transform(name_to_ort_value.begin(), name_to_ort_value.end(), std::back_inserter(ordered_tensor_names), [](const NameMLValMap::value_type& v) { return v.first; }); std::sort(ordered_tensor_names.begin(), ordered_tensor_names.end()); // Copy the tensor data and create TensorProto storing the data. - std::vector tensor_data_buffer{}; + InlinedVector tensor_data_buffer{}; static const OrtMemoryInfo cpu_alloc_info{onnxruntime::CPU, OrtDeviceAllocator}; saved_tensor_protos.reserve(ordered_tensor_names.size()); @@ -184,7 +185,7 @@ Status OrtSaveInternal( // Make sure name unique across trainable and non-trainable lists. std::set trainable_unique_names; std::set non_trainable_unique_names; - std::vector inter_sec; + InlinedVector inter_sec; auto check_unique = [](const std::vector& tensor_protos, std::set& unique_names) { for (auto& tensor_proto : tensor_protos) { @@ -231,7 +232,7 @@ Status OrtSaveModuleStatesInternal(ModuleCheckpointState& module_state, ORT_ENFORCE(module_state.train_session_data_transfer_mgr, "module checkpoint state has null train_session_data_transfer_mgr."); - std::unordered_map> + InlinedHashMap> parameter_ort_values; for (auto it = param_states.begin(); it != param_states.end(); ++it) { if (it->second->RequiresGrad()) { @@ -277,7 +278,7 @@ Status OrtSaveOptimizerStatesInternal(OptimizerCheckpointState& optimizer_state, // Re-organize optimizer_state_ort_values mapping // Firstly indexed by momentum names; Secondly indexed by parameter names. - std::unordered_map> optimizer_state_ort_values; + InlinedHashMap> optimizer_state_ort_values; for (const std::pair& param_named_optimizer_state : group_optimizer_state_ptr->param_named_optimizer_states) { const std::string& param_name = param_named_optimizer_state.first; @@ -319,8 +320,8 @@ Status OrtSaveOptimizerStatesInternal(OptimizerCheckpointState& optimizer_state, // Storing group-wise properties. PropertyBag properties; - properties.AddProperty(builtin_lr_property_name, group_optimizer_state_ptr->initial_lr); - properties.AddProperty(builtin_step_property_name, group_optimizer_state_ptr->step); + properties.AddProperty(builtin_lr_property_name, group_optimizer_state_ptr->initial_lr); + properties.AddProperty(builtin_step_property_name, group_optimizer_state_ptr->step); std::vector group_wise_properties_tensor_protos; properties.ToTensorProtos(group_wise_properties_tensor_protos); @@ -363,7 +364,7 @@ Status OrtSaveInternal( Status OrtLoadModuleStatesInternal( const PathString& parameter_folder_path, ModuleCheckpointState& module_state) { // Find parameter files. - std::vector> param_filenames; + InlinedVector> param_filenames; FilterFilesFromDirectory( parameter_folder_path, [¶m_filenames](const PathChar* filename) -> bool { diff --git a/orttraining/orttraining/training_api/checkpoint_property.cc b/orttraining/orttraining/training_api/checkpoint_property.cc index 01f21d73cf..59556923b8 100644 --- a/orttraining/orttraining/training_api/checkpoint_property.cc +++ b/orttraining/orttraining/training_api/checkpoint_property.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "onnx/defs/tensor_proto_util.h" +#include "core/common/inlined_containers.h" #include "core/platform/path_lib.h" #include "core/platform/env.h" #include "core/framework/tensorprotoutils.h" @@ -11,8 +12,12 @@ namespace onnxruntime { namespace training { namespace api { +namespace { + template -TypedCheckpointProperty::TypedCheckpointProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto) { +void ParsePropertyFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, + std::string& name, + PropertyDataType& value) { std::vector tensor_shape_vec = utils::GetTensorShapeFromTensorProto(tensor_proto); int64_t expected_num_elements = 1; for (auto& d : tensor_shape_vec) { @@ -20,45 +25,13 @@ TypedCheckpointProperty::TypedCheckpointProperty(const ONNX_NAMESPACE::Tensor } ORT_ENFORCE(expected_num_elements == 1, "Only scalar value support for checkpoint property."); Path model_path; - std::vector data_vector(1); + InlinedVector data_vector(1); T* p = data_vector.data(); ORT_THROW_IF_ERROR(utils::UnpackTensor(tensor_proto, model_path, p, expected_num_elements)); - prop_name_ = tensor_proto.name(); - prop_value_ = data_vector[0]; + name = tensor_proto.name(); + value = data_vector[0]; } -template -ONNX_NAMESPACE::TensorProto TypedCheckpointProperty::ToTensorProto() { - auto t_proto = ONNX_NAMESPACE::ToTensor(prop_value_); - t_proto.set_name(prop_name_); - return t_proto; -} - -namespace { - -std::shared_ptr CreateCheckpointPropertyFromTensorProto( - const ONNX_NAMESPACE::TensorProto& tensor_proto) { - auto data_type = tensor_proto.data_type(); - switch (data_type) { - case ONNX_NAMESPACE::TensorProto::FLOAT: { - return std::static_pointer_cast( - std::make_shared>(tensor_proto)); - break; - } - case ONNX_NAMESPACE::TensorProto::STRING: { - return std::static_pointer_cast( - std::make_shared>(tensor_proto)); - break; - } - case ONNX_NAMESPACE::TensorProto::INT64: { - return std::static_pointer_cast( - std::make_shared>(tensor_proto)); - break; - } - default: - ORT_THROW("Unsupported input data type of ", data_type); - } -} } // namespace void PropertyBag::AddProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto) { @@ -69,7 +42,27 @@ void PropertyBag::AddProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto) { ORT_THROW("Failed to add property from tensorproto: float, int64_t and std::string data types supported only."); } - named_properties_.insert({tensor_proto.name(), CreateCheckpointPropertyFromTensorProto(tensor_proto)}); + auto data_type = tensor_proto.data_type(); + std::string prop_name; + PropertyDataType prop_value; + switch (data_type) { + case ONNX_NAMESPACE::TensorProto::FLOAT: { + ParsePropertyFromTensorProto(tensor_proto, prop_name, prop_value); + break; + } + case ONNX_NAMESPACE::TensorProto::STRING: { + ParsePropertyFromTensorProto(tensor_proto, prop_name, prop_value); + break; + } + case ONNX_NAMESPACE::TensorProto::INT64: { + ParsePropertyFromTensorProto(tensor_proto, prop_name, prop_value); + break; + } + default: + ORT_THROW("Unsupported input data type of ", data_type); + } + + named_properties_.insert({prop_name, prop_value}); } } // namespace api diff --git a/orttraining/orttraining/training_api/include/checkpoint.h b/orttraining/orttraining/training_api/include/checkpoint.h index 2bab6a95d8..ce6785702b 100644 --- a/orttraining/orttraining/training_api/include/checkpoint.h +++ b/orttraining/orttraining/training_api/include/checkpoint.h @@ -20,10 +20,10 @@ * ii. optimizer state: * a instance of data class `OptimizerCheckpointState` managed along with Optimizer class, * iii. user defined training properties, for example 'epoch', 'best_score': - * a instance of data class `PropertyBag` managed along with CheckpointProperty classes. + * a instance of data class `PropertyBag`. * * In terms of class dependencies, Checkpoint implementations are dependent on (and on top of) - * Parameter/Module/Optimizer/CheckpointProperty, NOT vice versa. + * Parameter/Module/Optimizer, NOT vice versa. * * 2. A directory of files: * checkpoint/ diff --git a/orttraining/orttraining/training_api/include/checkpoint_property.h b/orttraining/orttraining/training_api/include/checkpoint_property.h index e27da6bb08..6e0ce3babb 100644 --- a/orttraining/orttraining/training_api/include/checkpoint_property.h +++ b/orttraining/orttraining/training_api/include/checkpoint_property.h @@ -3,65 +3,18 @@ #pragma once +#include #include +#include + +#include "core/common/inlined_containers.h" #include "onnx/defs/tensor_proto_util.h" namespace onnxruntime { namespace training { namespace api { -template -struct TypedCheckpointProperty; - -/** - * @brief Base class for user defined checkpoint property. - */ -struct CheckpointProperty { - public: - CheckpointProperty() {} - - CheckpointProperty(const std::string& prop_name) - : prop_name_(prop_name) { - } - - virtual ~CheckpointProperty() {} - virtual ONNX_NAMESPACE::TensorProto ToTensorProto() = 0; - - std::string GetName() const { - return prop_name_; - } - - template - T GetData() { - auto ptr = dynamic_cast*>(this); - ORT_ENFORCE(ptr); - return ptr->GetData(); - } - - protected: - std::string prop_name_; -}; - -/** - * @brief User defined checkpoint property. - */ -template -struct TypedCheckpointProperty : public CheckpointProperty { - public: - TypedCheckpointProperty(const std::string& prop_name, const T& prop_value) - : CheckpointProperty(prop_name), prop_value_(prop_value) { - } - TypedCheckpointProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto); - - ONNX_NAMESPACE::TensorProto ToTensorProto() override; - - T GetData() const { - return prop_value_; - } - - private: - T prop_value_; -}; +using PropertyDataType = std::variant; /** * @brief Collection of user defined properties. @@ -69,33 +22,41 @@ struct TypedCheckpointProperty : public CheckpointProperty { */ struct PropertyBag { public: - PropertyBag() {} + PropertyBag() = default; - template - void AddProperty(std::string name, T val) { - static_assert(onnxruntime::training::api::PropertyBag::template IsSupportedDataType(), - "Failed to add property: float, int64_t and std::string data types supported only."); + void AddProperty(const std::string& name, const PropertyDataType& val) { ORT_ENFORCE(named_properties_.find(name) == named_properties_.end(), "Duplicated property named ", name); - named_properties_.insert({name, std::make_shared>(name, val)}); + named_properties_.insert({name, val}); } void AddProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto); template T GetProperty(const std::string& name) const { - static_assert(onnxruntime::training::api::PropertyBag::template IsSupportedDataType(), - "Failed to get property: float, int64_t and std::string data types supported only."); - auto it = named_properties_.find(name); ORT_ENFORCE(it != named_properties_.end(), "No property named ", name); - return it->second->GetData(); + + const T* tval = std::get_if(&it->second); + ORT_ENFORCE(tval, "Fail to get the property value using specified type."); + return *tval; } void ToTensorProtos(std::vector& properties_tensor_protos) const { for (auto it = named_properties_.begin(); it != named_properties_.end(); ++it) { - properties_tensor_protos.emplace_back((it->second)->ToTensorProto()); + onnx::TensorProto t_proto; + if (const float* fval = std::get_if(&it->second); fval != nullptr) { + t_proto = ONNX_NAMESPACE::ToTensor(*fval); + } else if (const int64_t* ival = std::get_if(&it->second); ival != nullptr) { + t_proto = ONNX_NAMESPACE::ToTensor(*ival); + } else if (const std::string* sval = std::get_if(&it->second); sval != nullptr) { + t_proto = ONNX_NAMESPACE::ToTensor(*sval); + } else { + ORT_THROW("Should not go there, unexpected data_type for prop value."); + } + t_proto.set_name(it->first); + properties_tensor_protos.emplace_back(t_proto); } } @@ -104,7 +65,7 @@ struct PropertyBag { } private: - const std::vector supported_data_types{ + const InlinedVector supported_data_types{ ONNX_NAMESPACE::TensorProto::FLOAT, ONNX_NAMESPACE::TensorProto::INT64, ONNX_NAMESPACE::TensorProto::STRING}; @@ -113,13 +74,7 @@ struct PropertyBag { return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) != supported_data_types.end(); } - template - static constexpr bool IsSupportedDataType() { - return (std::is_same::value || std::is_same::value || - std::is_same::value); - } - - std::unordered_map> named_properties_; + InlinedHashMap named_properties_; }; } // namespace api