Refactor with std::variant (on device training) (#12383)

* use std::variant for synthetic data storage.

* use std::variant to replace TypedCheckpointProperty

* Remvoe shared ptr for checkpoint property

* fix tests

* refine std::variant usage a bit

* remove CheckpointProperty data abstraction

* use InlinedVector and InlinedHashMap if possible

* fix comments

* fix build and test

* fix some comments

* use gsl::span

* fix tests

* refine based on comments

* fix win build

* fix build
This commit is contained in:
pengwa 2022-08-17 08:31:23 +08:00 committed by GitHub
parent caabfcd920
commit 7df2e8c5cc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 237 additions and 256 deletions

View file

@ -6,6 +6,8 @@
#include <algorithm>
#include <memory>
#include <random>
#include <type_traits>
#include <variant>
#include <vector>
#include "synthetic_data_loader.h"
@ -17,108 +19,140 @@ namespace training_api {
namespace {
void RandomFloats(std::vector<float>& rets) {
void RandomFloats(std::vector<float>& rets, size_t num_element) {
const float scale = 1.f;
const float mean = 0.f;
const float seed = 123.f;
static std::default_random_engine generator{static_cast<uint32_t>(seed)};
std::normal_distribution<float> distribution{mean, scale};
std::for_each(rets.begin(), rets.end(),
[&distribution](float& value) { value = distribution(generator); });
std::generate_n(std::back_inserter(rets), num_element,
[&distribution]() -> float { return distribution(generator); });
}
template <typename IntType>
void RandomInts(std::vector<IntType>& rets, IntType low, IntType high) {
void RandomInts(std::vector<IntType>& rets, size_t num_element, IntType low, IntType high) {
static std::random_device rd;
static std::mt19937 generator(rd());
std::uniform_int_distribution<IntType> distribution(low, high);
std::for_each(rets.begin(), rets.end(),
[&distribution](IntType& value) { value = distribution(generator); });
std::generate_n(std::back_inserter(rets), num_element,
[&distribution]() -> IntType { return distribution(generator); });
}
} // namespace
void SyntheticSampleBatch::AddInt64Input(const std::vector<int64_t>& shape, int64_t low, int64_t high) {
data_vector_.emplace_back(std::make_unique<TypedSyntheticInput<int64_t>>(shape));
RandomInts(data_vector_.back()->GetData<int64_t>(), low, high);
template <typename T>
void SyntheticSampleBatch::AddIntInput(gsl::span<const int64_t> shape, T low, T high) {
data_vector_.push_back(SyntheticInput(shape));
std::vector<T> values;
auto num_of_element = data_vector_.back().NumOfElements();
values.reserve(num_of_element);
RandomInts(values, num_of_element, low, high);
SyntheticDataVector& data = data_vector_.back().GetData();
data = values;
}
void SyntheticSampleBatch::AddInt32Input(const std::vector<int64_t>& shape, int32_t low, int32_t high) {
data_vector_.emplace_back(std::make_unique<TypedSyntheticInput<int32_t>>(shape));
RandomInts(data_vector_.back()->GetData<int32_t>(), low, high);
void SyntheticSampleBatch::AddInt64Input(gsl::span<const int64_t> shape, int64_t low, int64_t high) {
AddIntInput(shape, low, high);
}
void SyntheticSampleBatch::AddFloatInput(const std::vector<int64_t>& shape) {
data_vector_.emplace_back(std::make_unique<TypedSyntheticInput<float>>(shape));
RandomFloats(data_vector_.back()->GetData<float>());
void SyntheticSampleBatch::AddInt32Input(gsl::span<const int64_t> shape, int32_t low, int32_t high) {
AddIntInput(shape, low, high);
}
#define ORT_RETURN_ON_ERROR(expr) \
do { \
OrtStatus* onnx_status = (expr); \
if (onnx_status != NULL) { \
void SyntheticSampleBatch::AddBoolInput(gsl::span<const int64_t> shape) {
// Use uint8_t to store the bool value by intention, because vector<bool> is specialized, we can not create a
// Tensor leveraging C APIs to reuse the data buffer.
data_vector_.push_back(SyntheticInput(shape));
std::vector<int32_t> values;
auto num_of_element = data_vector_.back().NumOfElements();
values.reserve(num_of_element);
// Need random with int32_t first because MSVC compiler complains uint8_t usage for uniform_int_distribution.
RandomInts(values, num_of_element, static_cast<int32_t>(0), static_cast<int32_t>(1));
SyntheticDataVector& data = data_vector_.back().GetData();
std::vector<uint8_t> uint8_values;
std::transform(values.begin(), values.end(), std::back_inserter(uint8_values),
[](int32_t x) { return static_cast<uint8_t>(x); });
data = uint8_values;
}
void SyntheticSampleBatch::AddFloatInput(gsl::span<const int64_t> shape) {
data_vector_.push_back(SyntheticInput(shape));
std::vector<float> values;
auto num_of_element = data_vector_.back().NumOfElements();
values.reserve(num_of_element);
RandomFloats(values, num_of_element);
SyntheticDataVector& data = data_vector_.back().GetData();
data = values;
}
#define ORT_RETURN_ON_ERROR(expr) \
do { \
OrtStatus* onnx_status = (expr); \
if (onnx_status != NULL) { \
auto code = ort_api->GetErrorCode(onnx_status); \
const char* msg = ort_api->GetErrorMessage(onnx_status); \
printf("Run failed with error code :%d\n", code); \
printf("Error message :%s\n", msg); \
ort_api->ReleaseStatus(onnx_status); \
printf("Run failed with error code :%d\n", code); \
printf("Error message :%s\n", msg); \
return false; \
} \
return false; \
} \
} while (0);
bool SyntheticSampleBatch::GetBatch(std::vector<OrtValue*>& batches) {
batches.clear();
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
const auto* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
for (size_t i = 0; i < data_vector_.size(); ++i) {
SyntheticInput& input = data_vector_[i];
const bool ret = std::visit([&batches, &input, &ort_api, &memory_info](auto&& arg) -> bool {
ONNXTensorElementDataType elem_data_type;
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<typename T::value_type, uint8_t>) {
elem_data_type = Ort::TypeToTensorType<bool>::type;
} else {
elem_data_type = Ort::TypeToTensorType<typename T::value_type>::type;
}
OrtValue* value = nullptr;
const auto& shape_vector = input.ShapeVector();
// Be noted: the created OrtValue won't clean the raw data after its lifetime ended.
ORT_RETURN_ON_ERROR(ort_api->CreateTensorWithDataAsOrtValue(
memory_info,
arg.data(), (input.NumOfElements() * sizeof(typename T::value_type)),
shape_vector.data(), shape_vector.size(),
elem_data_type,
&value));
batches.emplace_back(value);
return true;
},
input.GetData());
if (!ret) {
return false;
}
}
return true;
}
bool SyntheticDataLoader::GetNextSampleBatch(std::vector<OrtValue*>& batches) {
if (sample_batch_iter_index_ >= NumOfSampleBatches()) {
return false;
}
batches.clear();
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
auto& sample = sample_batch_collections_[sample_batch_iter_index_];
const auto* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
for (size_t i = 0; i < sample->NumOfInput(); ++i) {
auto input_ptr = sample->GetInputAtIndex(i);
auto shape_vector = input_ptr->ShapeVector();
// Be noted: the created OrtValue won't clean the raw data after its lifetime ended.
auto ptr_flt = dynamic_cast<TypedSyntheticInput<float>*>(input_ptr);
if (ptr_flt) {
OrtValue* value = NULL;
ORT_RETURN_ON_ERROR(ort_api->CreateTensorWithDataAsOrtValue(memory_info,
input_ptr->GetData<float>().data(), (input_ptr->NumOfElements() * sizeof(float)),
shape_vector.data(), shape_vector.size(),
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
&value));
batches.emplace_back(value);
continue;
}
auto ptr_int = dynamic_cast<TypedSyntheticInput<int64_t>*>(input_ptr);
if (ptr_int) {
OrtValue* value = NULL;
ORT_RETURN_ON_ERROR(ort_api->CreateTensorWithDataAsOrtValue(memory_info,
input_ptr->GetData<int64_t>().data(), (input_ptr->NumOfElements() * sizeof(int64_t)),
shape_vector.data(), shape_vector.size(),
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
&value));
batches.emplace_back(value);
continue;
}
auto ptr_int32 = dynamic_cast<TypedSyntheticInput<int32_t>*>(input_ptr);
if (ptr_int32) {
OrtValue* value = nullptr;
ORT_RETURN_ON_ERROR(ort_api->CreateTensorWithDataAsOrtValue(memory_info,
input_ptr->GetData<int32_t>().data(), (input_ptr->NumOfElements() * sizeof(int32_t)),
shape_vector.data(), shape_vector.size(),
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
&value));
batches.emplace_back(value);
continue;
}
throw std::runtime_error("unknown data types.");
}
sample.GetBatch(batches);
sample_batch_iter_index_ += 1;
return true;
}

View file

@ -11,10 +11,13 @@
#pragma once
#include "gsl/gsl"
#include <onnxruntime_cxx_api.h>
#include <memory>
#include <utility>
#include <variant>
#include <vector>
namespace onnxruntime {
@ -22,81 +25,61 @@ namespace training {
namespace test {
namespace training_api {
template <typename T>
struct TypedSyntheticInput;
using SyntheticDataVector = std::variant<std::vector<int32_t>, std::vector<int64_t>, std::vector<float>,
std::vector<uint8_t>>;
struct SyntheticInput {
explicit SyntheticInput(const std::vector<int64_t>& shape) : shape_(shape) {
explicit SyntheticInput(gsl::span<const int64_t> shape) : shape_(shape.begin(), shape.end()) {
for (auto d : shape) {
num_of_elements_ *= d;
}
}
virtual ~SyntheticInput() {}
template <typename T>
std::vector<T>& GetData() {
auto ptr = dynamic_cast<TypedSyntheticInput<T>*>(this);
return ptr->Data();
}
size_t NumOfElements() {
size_t NumOfElements() const {
return num_of_elements_;
}
std::vector<int64_t> ShapeVector() const {
gsl::span<const int64_t> ShapeVector() const {
return shape_;
}
protected:
std::vector<int64_t> shape_;
size_t num_of_elements_{1};
};
template <typename T>
struct TypedSyntheticInput : public SyntheticInput {
explicit TypedSyntheticInput(const std::vector<int64_t>& shape)
: SyntheticInput(shape) {
data_.resize(num_of_elements_);
}
std::vector<T>& Data() {
SyntheticDataVector& GetData() {
return data_;
}
private:
std::vector<T> data_;
std::vector<int64_t> shape_;
size_t num_of_elements_{1};
SyntheticDataVector data_;
};
struct SyntheticSampleBatch {
SyntheticSampleBatch() {}
SyntheticSampleBatch() = default;
void AddInt32Input(const std::vector<int64_t>& shape, int32_t low, int32_t high);
void AddInt64Input(const std::vector<int64_t>& shape, int64_t low, int64_t high);
void AddFloatInput(const std::vector<int64_t>& shape);
void AddInt32Input(gsl::span<const int64_t> shape, int32_t low, int32_t high);
void AddInt64Input(gsl::span<const int64_t> shape, int64_t low, int64_t high);
void AddFloatInput(gsl::span<const int64_t> shape);
void AddBoolInput(gsl::span<const int64_t> shape);
size_t NumOfInput() {
return data_vector_.size();
}
SyntheticInput* GetInputAtIndex(size_t index) {
return data_vector_[index].get();
}
bool GetBatch(std::vector<OrtValue*>& batches);
private:
std::vector<std::unique_ptr<SyntheticInput>> data_vector_;
template <typename T>
void AddIntInput(gsl::span<const int64_t> shape, T low, T high);
std::vector<SyntheticInput> data_vector_;
};
struct SyntheticDataLoader {
SyntheticDataLoader() {}
SyntheticDataLoader() = default;
void AddSyntheticSampleBatch(std::unique_ptr<SyntheticSampleBatch> samples) {
sample_batch_collections_.emplace_back(std::move(samples));
void AddSyntheticSampleBatch(SyntheticSampleBatch&& samples) {
sample_batch_collections_.emplace_back(samples);
}
bool GetNextSampleBatch(std::vector<OrtValue*>& batches);
size_t NumOfSampleBatches() {
size_t NumOfSampleBatches() const {
return sample_batch_collections_.size();
}
@ -109,7 +92,7 @@ struct SyntheticDataLoader {
// did not explicitly copy the data in.
// And also, the created OrtValue also won't clean the raw data pointer. The raw data should be removed when
// the life time of this struct ends.
std::vector<std::unique_ptr<SyntheticSampleBatch>> sample_batch_collections_;
std::vector<SyntheticSampleBatch> sample_batch_collections_;
size_t sample_batch_iter_index_{0};
};

View file

@ -213,11 +213,11 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CUDA) {
const std::vector<int64_t> fc2_bias_shape{fc2_weight_dim_in};
onnxruntime::training::test::training_api::SyntheticDataLoader data_loader;
auto sample = std::make_unique<onnxruntime::training::test::training_api::SyntheticSampleBatch>();
sample->AddFloatInput(fc1_weight_shape);
sample->AddFloatInput(fc1_bias_shape);
sample->AddFloatInput(fc2_weight_shape);
sample->AddFloatInput(fc2_bias_shape);
auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch();
sample.AddFloatInput(fc1_weight_shape);
sample.AddFloatInput(fc1_bias_shape);
sample.AddFloatInput(fc2_weight_shape);
sample.AddFloatInput(fc2_bias_shape);
data_loader.AddSyntheticSampleBatch(std::move(sample));
std::vector<OrtValue*> all_weights_values;
@ -343,15 +343,15 @@ TEST(CheckpointApiTest, SaveCustomPropertyAsCheckpoint_ThenLoad_CPU) {
float f_data = 0.5f;
std::string f_property_name("float_number");
property_bag.AddProperty<float>(f_property_name, f_data);
property_bag.AddProperty(f_property_name, f_data);
int64_t i_data = 400;
std::string i_property_name("dataset_epoch_index");
property_bag.AddProperty<int64_t>(i_property_name, i_data);
property_bag.AddProperty(i_property_name, i_data);
std::string s_data("/data/path/train.bin");
std::string s_property_name("train_data_path");
property_bag.AddProperty<std::string>(s_property_name, s_data);
property_bag.AddProperty(s_property_name, s_data);
// Remove the tempoprary directory if it already exists.
auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir");

View file

@ -156,9 +156,9 @@ void InitSyntheticDataLoader(
std::vector<int64_t> input1_shape{params.train_batch_size, 784};
std::vector<int64_t> target_shape{params.train_batch_size};
for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) {
auto sample = std::make_unique<onnxruntime::training::test::training_api::SyntheticSampleBatch>();
sample->AddFloatInput(input1_shape);
sample->AddInt32Input(target_shape, 0, 1);
auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch();
sample.AddFloatInput(input1_shape);
sample.AddInt32Input(target_shape, 0, 1);
data_loader.AddSyntheticSampleBatch(std::move(sample));
}
} else if (params.synthetic_input_type == "S") {
@ -167,10 +167,10 @@ void InitSyntheticDataLoader(
std::vector<int64_t> attention_mask_shape{params.train_batch_size, sequence_length};
std::vector<int64_t> target_shape{params.train_batch_size};
for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) {
auto sample = std::make_unique<onnxruntime::training::test::training_api::SyntheticSampleBatch>();
sample->AddInt64Input(input_ids_shape, 0, 250002 - 1);
sample->AddInt64Input(attention_mask_shape, 0, 1);
sample->AddInt32Input(target_shape, 0, 1);
auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch();
sample.AddInt64Input(input_ids_shape, 0, 250002 - 1);
sample.AddInt64Input(attention_mask_shape, 0, 1);
sample.AddInt32Input(target_shape, 0, 1);
data_loader.AddSyntheticSampleBatch(std::move(sample));
}
} else if (params.synthetic_input_type == "U") {
@ -180,11 +180,11 @@ void InitSyntheticDataLoader(
std::vector<int64_t> target1_shape{params.train_batch_size};
std::vector<int64_t> target2_shape{params.train_batch_size, 81};
for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) {
auto sample = std::make_unique<onnxruntime::training::test::training_api::SyntheticSampleBatch>();
sample->AddInt64Input(input_ids_shape, 0, 250002 - 1);
sample->AddInt64Input(attention_mask_shape, 0, 1);
sample->AddInt32Input(target1_shape, 0, 1);
sample->AddInt32Input(target2_shape, 0, 1);
auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch();
sample.AddInt64Input(input_ids_shape, 0, 250002 - 1);
sample.AddInt64Input(attention_mask_shape, 0, 1);
sample.AddInt32Input(target1_shape, 0, 1);
sample.AddInt32Input(target2_shape, 0, 1);
data_loader.AddSyntheticSampleBatch(std::move(sample));
}
} else if (params.synthetic_input_type == "R") {
@ -193,10 +193,25 @@ void InitSyntheticDataLoader(
std::vector<int64_t> attention_mask_shape{params.train_batch_size, sequence_length};
std::vector<int64_t> labels_shape{params.train_batch_size, 81};
for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) {
auto sample = std::make_unique<onnxruntime::training::test::training_api::SyntheticSampleBatch>();
sample->AddInt64Input(input_ids_shape, 0, 250002 - 1);
sample->AddInt64Input(attention_mask_shape, 0, 1);
sample->AddInt32Input(labels_shape, 0, 1);
auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch();
sample.AddInt64Input(input_ids_shape, 0, 250002 - 1);
sample.AddInt64Input(attention_mask_shape, 0, 1);
sample.AddInt32Input(labels_shape, 0, 1);
data_loader.AddSyntheticSampleBatch(std::move(sample));
}
} else if (params.synthetic_input_type == "C") {
int64_t section = 16;
int64_t sequence_length = 128;
std::vector<int64_t> input_ids_shape{params.train_batch_size, section, sequence_length};
std::vector<int64_t> mask_clss_shape{params.train_batch_size, section};
std::vector<int64_t> attention_mask_shape{params.train_batch_size, section, sequence_length};
std::vector<int64_t> labels_shape{params.train_batch_size};
for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) {
auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch();
sample.AddInt64Input(input_ids_shape, 0, 250002 - 1);
sample.AddBoolInput(mask_clss_shape);
sample.AddInt64Input(attention_mask_shape, 0, 1);
sample.AddInt32Input(labels_shape, 0, 1);
data_loader.AddSyntheticSampleBatch(std::move(sample));
}
} else {

View file

@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/common/inlined_containers.h"
#include "core/common/logging/logging.h"
#include "core/common/logging/sinks/clog_sink.h"
#include "core/common/path.h"
@ -46,14 +47,14 @@ Status CreateTensorProtosFromOrtValues(
const DataTransferManager& data_transfer_manager,
std::vector<ONNX_NAMESPACE::TensorProto>& saved_tensor_protos) {
// Order the tensors by name.
std::vector<std::string> ordered_tensor_names{};
InlinedVector<std::string> ordered_tensor_names{};
ordered_tensor_names.reserve(name_to_ort_value.size());
std::transform(name_to_ort_value.begin(), name_to_ort_value.end(), std::back_inserter(ordered_tensor_names),
[](const NameMLValMap::value_type& v) { return v.first; });
std::sort(ordered_tensor_names.begin(), ordered_tensor_names.end());
// Copy the tensor data and create TensorProto storing the data.
std::vector<char> tensor_data_buffer{};
InlinedVector<char> tensor_data_buffer{};
static const OrtMemoryInfo cpu_alloc_info{onnxruntime::CPU, OrtDeviceAllocator};
saved_tensor_protos.reserve(ordered_tensor_names.size());
@ -184,7 +185,7 @@ Status OrtSaveInternal(
// Make sure name unique across trainable and non-trainable lists.
std::set<std::string> trainable_unique_names;
std::set<std::string> non_trainable_unique_names;
std::vector<std::string> inter_sec;
InlinedVector<std::string> inter_sec;
auto check_unique = [](const std::vector<ONNX_NAMESPACE::TensorProto>& tensor_protos,
std::set<std::string>& unique_names) {
for (auto& tensor_proto : tensor_protos) {
@ -231,7 +232,7 @@ Status OrtSaveModuleStatesInternal(ModuleCheckpointState& module_state,
ORT_ENFORCE(module_state.train_session_data_transfer_mgr,
"module checkpoint state has null train_session_data_transfer_mgr.");
std::unordered_map<PathString, std::unordered_map<std::string, OrtValue>>
InlinedHashMap<PathString, std::unordered_map<std::string, OrtValue>>
parameter_ort_values;
for (auto it = param_states.begin(); it != param_states.end(); ++it) {
if (it->second->RequiresGrad()) {
@ -277,7 +278,7 @@ Status OrtSaveOptimizerStatesInternal(OptimizerCheckpointState& optimizer_state,
// Re-organize optimizer_state_ort_values mapping
// Firstly indexed by momentum names; Secondly indexed by parameter names.
std::unordered_map<std::string, std::unordered_map<std::string, OrtValue>> optimizer_state_ort_values;
InlinedHashMap<std::string, std::unordered_map<std::string, OrtValue>> optimizer_state_ort_values;
for (const std::pair<std::string, ParameterOptimizerState>&
param_named_optimizer_state : group_optimizer_state_ptr->param_named_optimizer_states) {
const std::string& param_name = param_named_optimizer_state.first;
@ -319,8 +320,8 @@ Status OrtSaveOptimizerStatesInternal(OptimizerCheckpointState& optimizer_state,
// Storing group-wise properties.
PropertyBag properties;
properties.AddProperty<float>(builtin_lr_property_name, group_optimizer_state_ptr->initial_lr);
properties.AddProperty<int64_t>(builtin_step_property_name, group_optimizer_state_ptr->step);
properties.AddProperty(builtin_lr_property_name, group_optimizer_state_ptr->initial_lr);
properties.AddProperty(builtin_step_property_name, group_optimizer_state_ptr->step);
std::vector<ONNX_NAMESPACE::TensorProto> group_wise_properties_tensor_protos;
properties.ToTensorProtos(group_wise_properties_tensor_protos);
@ -363,7 +364,7 @@ Status OrtSaveInternal(
Status OrtLoadModuleStatesInternal(
const PathString& parameter_folder_path, ModuleCheckpointState& module_state) {
// Find parameter files.
std::vector<std::pair<PathString, bool>> param_filenames;
InlinedVector<std::pair<PathString, bool>> param_filenames;
FilterFilesFromDirectory(
parameter_folder_path,
[&param_filenames](const PathChar* filename) -> bool {

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "onnx/defs/tensor_proto_util.h"
#include "core/common/inlined_containers.h"
#include "core/platform/path_lib.h"
#include "core/platform/env.h"
#include "core/framework/tensorprotoutils.h"
@ -11,8 +12,12 @@ namespace onnxruntime {
namespace training {
namespace api {
namespace {
template <typename T>
TypedCheckpointProperty<T>::TypedCheckpointProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
void ParsePropertyFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto,
std::string& name,
PropertyDataType& value) {
std::vector<int64_t> tensor_shape_vec = utils::GetTensorShapeFromTensorProto(tensor_proto);
int64_t expected_num_elements = 1;
for (auto& d : tensor_shape_vec) {
@ -20,45 +25,13 @@ TypedCheckpointProperty<T>::TypedCheckpointProperty(const ONNX_NAMESPACE::Tensor
}
ORT_ENFORCE(expected_num_elements == 1, "Only scalar value support for checkpoint property.");
Path model_path;
std::vector<T> data_vector(1);
InlinedVector<T> data_vector(1);
T* p = data_vector.data();
ORT_THROW_IF_ERROR(utils::UnpackTensor<T>(tensor_proto, model_path, p, expected_num_elements));
prop_name_ = tensor_proto.name();
prop_value_ = data_vector[0];
name = tensor_proto.name();
value = data_vector[0];
}
template <typename T>
ONNX_NAMESPACE::TensorProto TypedCheckpointProperty<T>::ToTensorProto() {
auto t_proto = ONNX_NAMESPACE::ToTensor<T>(prop_value_);
t_proto.set_name(prop_name_);
return t_proto;
}
namespace {
std::shared_ptr<CheckpointProperty> CreateCheckpointPropertyFromTensorProto(
const ONNX_NAMESPACE::TensorProto& tensor_proto) {
auto data_type = tensor_proto.data_type();
switch (data_type) {
case ONNX_NAMESPACE::TensorProto::FLOAT: {
return std::static_pointer_cast<CheckpointProperty>(
std::make_shared<TypedCheckpointProperty<float>>(tensor_proto));
break;
}
case ONNX_NAMESPACE::TensorProto::STRING: {
return std::static_pointer_cast<CheckpointProperty>(
std::make_shared<TypedCheckpointProperty<std::string>>(tensor_proto));
break;
}
case ONNX_NAMESPACE::TensorProto::INT64: {
return std::static_pointer_cast<CheckpointProperty>(
std::make_shared<TypedCheckpointProperty<int64_t>>(tensor_proto));
break;
}
default:
ORT_THROW("Unsupported input data type of ", data_type);
}
}
} // namespace
void PropertyBag::AddProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
@ -69,7 +42,27 @@ void PropertyBag::AddProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
ORT_THROW("Failed to add property from tensorproto: float, int64_t and std::string data types supported only.");
}
named_properties_.insert({tensor_proto.name(), CreateCheckpointPropertyFromTensorProto(tensor_proto)});
auto data_type = tensor_proto.data_type();
std::string prop_name;
PropertyDataType prop_value;
switch (data_type) {
case ONNX_NAMESPACE::TensorProto::FLOAT: {
ParsePropertyFromTensorProto<float>(tensor_proto, prop_name, prop_value);
break;
}
case ONNX_NAMESPACE::TensorProto::STRING: {
ParsePropertyFromTensorProto<std::string>(tensor_proto, prop_name, prop_value);
break;
}
case ONNX_NAMESPACE::TensorProto::INT64: {
ParsePropertyFromTensorProto<int64_t>(tensor_proto, prop_name, prop_value);
break;
}
default:
ORT_THROW("Unsupported input data type of ", data_type);
}
named_properties_.insert({prop_name, prop_value});
}
} // namespace api

View file

@ -20,10 +20,10 @@
* ii. optimizer state:
* a instance of data class `OptimizerCheckpointState` managed along with Optimizer class,
* iii. user defined training properties, for example 'epoch', 'best_score':
* a instance of data class `PropertyBag` managed along with CheckpointProperty classes.
* a instance of data class `PropertyBag`.
*
* In terms of class dependencies, Checkpoint implementations are dependent on (and on top of)
* Parameter/Module/Optimizer/CheckpointProperty, NOT vice versa.
* Parameter/Module/Optimizer, NOT vice versa.
*
* 2. A directory of files:
* checkpoint/

View file

@ -3,65 +3,18 @@
#pragma once
#include <string>
#include <type_traits>
#include <variant>
#include "core/common/inlined_containers.h"
#include "onnx/defs/tensor_proto_util.h"
namespace onnxruntime {
namespace training {
namespace api {
template <typename T>
struct TypedCheckpointProperty;
/**
* @brief Base class for user defined checkpoint property.
*/
struct CheckpointProperty {
public:
CheckpointProperty() {}
CheckpointProperty(const std::string& prop_name)
: prop_name_(prop_name) {
}
virtual ~CheckpointProperty() {}
virtual ONNX_NAMESPACE::TensorProto ToTensorProto() = 0;
std::string GetName() const {
return prop_name_;
}
template <typename T>
T GetData() {
auto ptr = dynamic_cast<TypedCheckpointProperty<T>*>(this);
ORT_ENFORCE(ptr);
return ptr->GetData();
}
protected:
std::string prop_name_;
};
/**
* @brief User defined checkpoint property.
*/
template <typename T>
struct TypedCheckpointProperty : public CheckpointProperty {
public:
TypedCheckpointProperty(const std::string& prop_name, const T& prop_value)
: CheckpointProperty(prop_name), prop_value_(prop_value) {
}
TypedCheckpointProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto);
ONNX_NAMESPACE::TensorProto ToTensorProto() override;
T GetData() const {
return prop_value_;
}
private:
T prop_value_;
};
using PropertyDataType = std::variant<int64_t, float, std::string>;
/**
* @brief Collection of user defined properties.
@ -69,33 +22,41 @@ struct TypedCheckpointProperty : public CheckpointProperty {
*/
struct PropertyBag {
public:
PropertyBag() {}
PropertyBag() = default;
template <typename T>
void AddProperty(std::string name, T val) {
static_assert(onnxruntime::training::api::PropertyBag::template IsSupportedDataType<T>(),
"Failed to add property: float, int64_t and std::string data types supported only.");
void AddProperty(const std::string& name, const PropertyDataType& val) {
ORT_ENFORCE(named_properties_.find(name) == named_properties_.end(),
"Duplicated property named ", name);
named_properties_.insert({name, std::make_shared<TypedCheckpointProperty<T>>(name, val)});
named_properties_.insert({name, val});
}
void AddProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto);
template <typename T>
T GetProperty(const std::string& name) const {
static_assert(onnxruntime::training::api::PropertyBag::template IsSupportedDataType<T>(),
"Failed to get property: float, int64_t and std::string data types supported only.");
auto it = named_properties_.find(name);
ORT_ENFORCE(it != named_properties_.end(), "No property named ", name);
return it->second->GetData<T>();
const T* tval = std::get_if<T>(&it->second);
ORT_ENFORCE(tval, "Fail to get the property value using specified type.");
return *tval;
}
void ToTensorProtos(std::vector<ONNX_NAMESPACE::TensorProto>& properties_tensor_protos) const {
for (auto it = named_properties_.begin(); it != named_properties_.end(); ++it) {
properties_tensor_protos.emplace_back((it->second)->ToTensorProto());
onnx::TensorProto t_proto;
if (const float* fval = std::get_if<float>(&it->second); fval != nullptr) {
t_proto = ONNX_NAMESPACE::ToTensor<float>(*fval);
} else if (const int64_t* ival = std::get_if<int64_t>(&it->second); ival != nullptr) {
t_proto = ONNX_NAMESPACE::ToTensor<int64_t>(*ival);
} else if (const std::string* sval = std::get_if<std::string>(&it->second); sval != nullptr) {
t_proto = ONNX_NAMESPACE::ToTensor<std::string>(*sval);
} else {
ORT_THROW("Should not go there, unexpected data_type for prop value.");
}
t_proto.set_name(it->first);
properties_tensor_protos.emplace_back(t_proto);
}
}
@ -104,7 +65,7 @@ struct PropertyBag {
}
private:
const std::vector<int32_t> supported_data_types{
const InlinedVector<int32_t> supported_data_types{
ONNX_NAMESPACE::TensorProto::FLOAT,
ONNX_NAMESPACE::TensorProto::INT64,
ONNX_NAMESPACE::TensorProto::STRING};
@ -113,13 +74,7 @@ struct PropertyBag {
return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) != supported_data_types.end();
}
template <typename T>
static constexpr bool IsSupportedDataType() {
return (std::is_same<T, float>::value || std::is_same<T, int64_t>::value ||
std::is_same<T, std::string>::value);
}
std::unordered_map<std::string, std::shared_ptr<CheckpointProperty>> named_properties_;
InlinedHashMap<std::string, PropertyDataType> named_properties_;
};
} // namespace api