mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
Refactor with std::variant (on device training) (#12383)
* use std::variant for synthetic data storage. * use std::variant to replace TypedCheckpointProperty * Remvoe shared ptr for checkpoint property * fix tests * refine std::variant usage a bit * remove CheckpointProperty data abstraction * use InlinedVector and InlinedHashMap if possible * fix comments * fix build and test * fix some comments * use gsl::span * fix tests * refine based on comments * fix win build * fix build
This commit is contained in:
parent
caabfcd920
commit
7df2e8c5cc
8 changed files with 237 additions and 256 deletions
|
|
@ -6,6 +6,8 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <type_traits>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "synthetic_data_loader.h"
|
||||
|
|
@ -17,108 +19,140 @@ namespace training_api {
|
|||
|
||||
namespace {
|
||||
|
||||
void RandomFloats(std::vector<float>& rets) {
|
||||
void RandomFloats(std::vector<float>& rets, size_t num_element) {
|
||||
const float scale = 1.f;
|
||||
const float mean = 0.f;
|
||||
const float seed = 123.f;
|
||||
static std::default_random_engine generator{static_cast<uint32_t>(seed)};
|
||||
std::normal_distribution<float> distribution{mean, scale};
|
||||
std::for_each(rets.begin(), rets.end(),
|
||||
[&distribution](float& value) { value = distribution(generator); });
|
||||
|
||||
std::generate_n(std::back_inserter(rets), num_element,
|
||||
[&distribution]() -> float { return distribution(generator); });
|
||||
}
|
||||
|
||||
template <typename IntType>
|
||||
void RandomInts(std::vector<IntType>& rets, IntType low, IntType high) {
|
||||
void RandomInts(std::vector<IntType>& rets, size_t num_element, IntType low, IntType high) {
|
||||
static std::random_device rd;
|
||||
static std::mt19937 generator(rd());
|
||||
std::uniform_int_distribution<IntType> distribution(low, high);
|
||||
std::for_each(rets.begin(), rets.end(),
|
||||
[&distribution](IntType& value) { value = distribution(generator); });
|
||||
|
||||
std::generate_n(std::back_inserter(rets), num_element,
|
||||
[&distribution]() -> IntType { return distribution(generator); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SyntheticSampleBatch::AddInt64Input(const std::vector<int64_t>& shape, int64_t low, int64_t high) {
|
||||
data_vector_.emplace_back(std::make_unique<TypedSyntheticInput<int64_t>>(shape));
|
||||
RandomInts(data_vector_.back()->GetData<int64_t>(), low, high);
|
||||
template <typename T>
|
||||
void SyntheticSampleBatch::AddIntInput(gsl::span<const int64_t> shape, T low, T high) {
|
||||
data_vector_.push_back(SyntheticInput(shape));
|
||||
|
||||
std::vector<T> values;
|
||||
auto num_of_element = data_vector_.back().NumOfElements();
|
||||
values.reserve(num_of_element);
|
||||
RandomInts(values, num_of_element, low, high);
|
||||
|
||||
SyntheticDataVector& data = data_vector_.back().GetData();
|
||||
data = values;
|
||||
}
|
||||
|
||||
void SyntheticSampleBatch::AddInt32Input(const std::vector<int64_t>& shape, int32_t low, int32_t high) {
|
||||
data_vector_.emplace_back(std::make_unique<TypedSyntheticInput<int32_t>>(shape));
|
||||
RandomInts(data_vector_.back()->GetData<int32_t>(), low, high);
|
||||
void SyntheticSampleBatch::AddInt64Input(gsl::span<const int64_t> shape, int64_t low, int64_t high) {
|
||||
AddIntInput(shape, low, high);
|
||||
}
|
||||
|
||||
void SyntheticSampleBatch::AddFloatInput(const std::vector<int64_t>& shape) {
|
||||
data_vector_.emplace_back(std::make_unique<TypedSyntheticInput<float>>(shape));
|
||||
RandomFloats(data_vector_.back()->GetData<float>());
|
||||
void SyntheticSampleBatch::AddInt32Input(gsl::span<const int64_t> shape, int32_t low, int32_t high) {
|
||||
AddIntInput(shape, low, high);
|
||||
}
|
||||
|
||||
#define ORT_RETURN_ON_ERROR(expr) \
|
||||
do { \
|
||||
OrtStatus* onnx_status = (expr); \
|
||||
if (onnx_status != NULL) { \
|
||||
void SyntheticSampleBatch::AddBoolInput(gsl::span<const int64_t> shape) {
|
||||
// Use uint8_t to store the bool value by intention, because vector<bool> is specialized, we can not create a
|
||||
// Tensor leveraging C APIs to reuse the data buffer.
|
||||
data_vector_.push_back(SyntheticInput(shape));
|
||||
|
||||
std::vector<int32_t> values;
|
||||
auto num_of_element = data_vector_.back().NumOfElements();
|
||||
values.reserve(num_of_element);
|
||||
|
||||
// Need random with int32_t first because MSVC compiler complains uint8_t usage for uniform_int_distribution.
|
||||
RandomInts(values, num_of_element, static_cast<int32_t>(0), static_cast<int32_t>(1));
|
||||
|
||||
SyntheticDataVector& data = data_vector_.back().GetData();
|
||||
std::vector<uint8_t> uint8_values;
|
||||
std::transform(values.begin(), values.end(), std::back_inserter(uint8_values),
|
||||
[](int32_t x) { return static_cast<uint8_t>(x); });
|
||||
data = uint8_values;
|
||||
}
|
||||
|
||||
void SyntheticSampleBatch::AddFloatInput(gsl::span<const int64_t> shape) {
|
||||
data_vector_.push_back(SyntheticInput(shape));
|
||||
|
||||
std::vector<float> values;
|
||||
auto num_of_element = data_vector_.back().NumOfElements();
|
||||
values.reserve(num_of_element);
|
||||
RandomFloats(values, num_of_element);
|
||||
|
||||
SyntheticDataVector& data = data_vector_.back().GetData();
|
||||
data = values;
|
||||
}
|
||||
|
||||
#define ORT_RETURN_ON_ERROR(expr) \
|
||||
do { \
|
||||
OrtStatus* onnx_status = (expr); \
|
||||
if (onnx_status != NULL) { \
|
||||
auto code = ort_api->GetErrorCode(onnx_status); \
|
||||
const char* msg = ort_api->GetErrorMessage(onnx_status); \
|
||||
printf("Run failed with error code :%d\n", code); \
|
||||
printf("Error message :%s\n", msg); \
|
||||
ort_api->ReleaseStatus(onnx_status); \
|
||||
printf("Run failed with error code :%d\n", code); \
|
||||
printf("Error message :%s\n", msg); \
|
||||
return false; \
|
||||
} \
|
||||
return false; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
bool SyntheticSampleBatch::GetBatch(std::vector<OrtValue*>& batches) {
|
||||
batches.clear();
|
||||
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
const auto* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||
for (size_t i = 0; i < data_vector_.size(); ++i) {
|
||||
SyntheticInput& input = data_vector_[i];
|
||||
|
||||
const bool ret = std::visit([&batches, &input, &ort_api, &memory_info](auto&& arg) -> bool {
|
||||
ONNXTensorElementDataType elem_data_type;
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
if constexpr (std::is_same_v<typename T::value_type, uint8_t>) {
|
||||
elem_data_type = Ort::TypeToTensorType<bool>::type;
|
||||
} else {
|
||||
elem_data_type = Ort::TypeToTensorType<typename T::value_type>::type;
|
||||
}
|
||||
|
||||
OrtValue* value = nullptr;
|
||||
const auto& shape_vector = input.ShapeVector();
|
||||
// Be noted: the created OrtValue won't clean the raw data after its lifetime ended.
|
||||
ORT_RETURN_ON_ERROR(ort_api->CreateTensorWithDataAsOrtValue(
|
||||
memory_info,
|
||||
arg.data(), (input.NumOfElements() * sizeof(typename T::value_type)),
|
||||
shape_vector.data(), shape_vector.size(),
|
||||
elem_data_type,
|
||||
&value));
|
||||
|
||||
batches.emplace_back(value);
|
||||
return true;
|
||||
},
|
||||
input.GetData());
|
||||
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SyntheticDataLoader::GetNextSampleBatch(std::vector<OrtValue*>& batches) {
|
||||
if (sample_batch_iter_index_ >= NumOfSampleBatches()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
batches.clear();
|
||||
|
||||
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
auto& sample = sample_batch_collections_[sample_batch_iter_index_];
|
||||
const auto* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||
for (size_t i = 0; i < sample->NumOfInput(); ++i) {
|
||||
auto input_ptr = sample->GetInputAtIndex(i);
|
||||
auto shape_vector = input_ptr->ShapeVector();
|
||||
// Be noted: the created OrtValue won't clean the raw data after its lifetime ended.
|
||||
auto ptr_flt = dynamic_cast<TypedSyntheticInput<float>*>(input_ptr);
|
||||
if (ptr_flt) {
|
||||
OrtValue* value = NULL;
|
||||
ORT_RETURN_ON_ERROR(ort_api->CreateTensorWithDataAsOrtValue(memory_info,
|
||||
input_ptr->GetData<float>().data(), (input_ptr->NumOfElements() * sizeof(float)),
|
||||
shape_vector.data(), shape_vector.size(),
|
||||
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
|
||||
&value));
|
||||
batches.emplace_back(value);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto ptr_int = dynamic_cast<TypedSyntheticInput<int64_t>*>(input_ptr);
|
||||
if (ptr_int) {
|
||||
OrtValue* value = NULL;
|
||||
ORT_RETURN_ON_ERROR(ort_api->CreateTensorWithDataAsOrtValue(memory_info,
|
||||
input_ptr->GetData<int64_t>().data(), (input_ptr->NumOfElements() * sizeof(int64_t)),
|
||||
shape_vector.data(), shape_vector.size(),
|
||||
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
|
||||
&value));
|
||||
batches.emplace_back(value);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto ptr_int32 = dynamic_cast<TypedSyntheticInput<int32_t>*>(input_ptr);
|
||||
if (ptr_int32) {
|
||||
OrtValue* value = nullptr;
|
||||
ORT_RETURN_ON_ERROR(ort_api->CreateTensorWithDataAsOrtValue(memory_info,
|
||||
input_ptr->GetData<int32_t>().data(), (input_ptr->NumOfElements() * sizeof(int32_t)),
|
||||
shape_vector.data(), shape_vector.size(),
|
||||
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
|
||||
&value));
|
||||
batches.emplace_back(value);
|
||||
continue;
|
||||
}
|
||||
|
||||
throw std::runtime_error("unknown data types.");
|
||||
}
|
||||
|
||||
sample.GetBatch(batches);
|
||||
sample_batch_iter_index_ += 1;
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,10 +11,13 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "gsl/gsl"
|
||||
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -22,81 +25,61 @@ namespace training {
|
|||
namespace test {
|
||||
namespace training_api {
|
||||
|
||||
template <typename T>
|
||||
struct TypedSyntheticInput;
|
||||
using SyntheticDataVector = std::variant<std::vector<int32_t>, std::vector<int64_t>, std::vector<float>,
|
||||
std::vector<uint8_t>>;
|
||||
|
||||
struct SyntheticInput {
|
||||
explicit SyntheticInput(const std::vector<int64_t>& shape) : shape_(shape) {
|
||||
explicit SyntheticInput(gsl::span<const int64_t> shape) : shape_(shape.begin(), shape.end()) {
|
||||
for (auto d : shape) {
|
||||
num_of_elements_ *= d;
|
||||
}
|
||||
}
|
||||
|
||||
virtual ~SyntheticInput() {}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T>& GetData() {
|
||||
auto ptr = dynamic_cast<TypedSyntheticInput<T>*>(this);
|
||||
return ptr->Data();
|
||||
}
|
||||
|
||||
size_t NumOfElements() {
|
||||
size_t NumOfElements() const {
|
||||
return num_of_elements_;
|
||||
}
|
||||
|
||||
std::vector<int64_t> ShapeVector() const {
|
||||
gsl::span<const int64_t> ShapeVector() const {
|
||||
return shape_;
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<int64_t> shape_;
|
||||
size_t num_of_elements_{1};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TypedSyntheticInput : public SyntheticInput {
|
||||
explicit TypedSyntheticInput(const std::vector<int64_t>& shape)
|
||||
: SyntheticInput(shape) {
|
||||
data_.resize(num_of_elements_);
|
||||
}
|
||||
|
||||
std::vector<T>& Data() {
|
||||
SyntheticDataVector& GetData() {
|
||||
return data_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<T> data_;
|
||||
std::vector<int64_t> shape_;
|
||||
size_t num_of_elements_{1};
|
||||
SyntheticDataVector data_;
|
||||
};
|
||||
|
||||
struct SyntheticSampleBatch {
|
||||
SyntheticSampleBatch() {}
|
||||
SyntheticSampleBatch() = default;
|
||||
|
||||
void AddInt32Input(const std::vector<int64_t>& shape, int32_t low, int32_t high);
|
||||
void AddInt64Input(const std::vector<int64_t>& shape, int64_t low, int64_t high);
|
||||
void AddFloatInput(const std::vector<int64_t>& shape);
|
||||
void AddInt32Input(gsl::span<const int64_t> shape, int32_t low, int32_t high);
|
||||
void AddInt64Input(gsl::span<const int64_t> shape, int64_t low, int64_t high);
|
||||
void AddFloatInput(gsl::span<const int64_t> shape);
|
||||
void AddBoolInput(gsl::span<const int64_t> shape);
|
||||
|
||||
size_t NumOfInput() {
|
||||
return data_vector_.size();
|
||||
}
|
||||
|
||||
SyntheticInput* GetInputAtIndex(size_t index) {
|
||||
return data_vector_[index].get();
|
||||
}
|
||||
bool GetBatch(std::vector<OrtValue*>& batches);
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<SyntheticInput>> data_vector_;
|
||||
template <typename T>
|
||||
void AddIntInput(gsl::span<const int64_t> shape, T low, T high);
|
||||
|
||||
std::vector<SyntheticInput> data_vector_;
|
||||
};
|
||||
|
||||
struct SyntheticDataLoader {
|
||||
SyntheticDataLoader() {}
|
||||
SyntheticDataLoader() = default;
|
||||
|
||||
void AddSyntheticSampleBatch(std::unique_ptr<SyntheticSampleBatch> samples) {
|
||||
sample_batch_collections_.emplace_back(std::move(samples));
|
||||
void AddSyntheticSampleBatch(SyntheticSampleBatch&& samples) {
|
||||
sample_batch_collections_.emplace_back(samples);
|
||||
}
|
||||
|
||||
bool GetNextSampleBatch(std::vector<OrtValue*>& batches);
|
||||
|
||||
size_t NumOfSampleBatches() {
|
||||
size_t NumOfSampleBatches() const {
|
||||
return sample_batch_collections_.size();
|
||||
}
|
||||
|
||||
|
|
@ -109,7 +92,7 @@ struct SyntheticDataLoader {
|
|||
// did not explicitly copy the data in.
|
||||
// And also, the created OrtValue also won't clean the raw data pointer. The raw data should be removed when
|
||||
// the life time of this struct ends.
|
||||
std::vector<std::unique_ptr<SyntheticSampleBatch>> sample_batch_collections_;
|
||||
std::vector<SyntheticSampleBatch> sample_batch_collections_;
|
||||
size_t sample_batch_iter_index_{0};
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -213,11 +213,11 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad_CUDA) {
|
|||
const std::vector<int64_t> fc2_bias_shape{fc2_weight_dim_in};
|
||||
|
||||
onnxruntime::training::test::training_api::SyntheticDataLoader data_loader;
|
||||
auto sample = std::make_unique<onnxruntime::training::test::training_api::SyntheticSampleBatch>();
|
||||
sample->AddFloatInput(fc1_weight_shape);
|
||||
sample->AddFloatInput(fc1_bias_shape);
|
||||
sample->AddFloatInput(fc2_weight_shape);
|
||||
sample->AddFloatInput(fc2_bias_shape);
|
||||
auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch();
|
||||
sample.AddFloatInput(fc1_weight_shape);
|
||||
sample.AddFloatInput(fc1_bias_shape);
|
||||
sample.AddFloatInput(fc2_weight_shape);
|
||||
sample.AddFloatInput(fc2_bias_shape);
|
||||
data_loader.AddSyntheticSampleBatch(std::move(sample));
|
||||
|
||||
std::vector<OrtValue*> all_weights_values;
|
||||
|
|
@ -343,15 +343,15 @@ TEST(CheckpointApiTest, SaveCustomPropertyAsCheckpoint_ThenLoad_CPU) {
|
|||
|
||||
float f_data = 0.5f;
|
||||
std::string f_property_name("float_number");
|
||||
property_bag.AddProperty<float>(f_property_name, f_data);
|
||||
property_bag.AddProperty(f_property_name, f_data);
|
||||
|
||||
int64_t i_data = 400;
|
||||
std::string i_property_name("dataset_epoch_index");
|
||||
property_bag.AddProperty<int64_t>(i_property_name, i_data);
|
||||
property_bag.AddProperty(i_property_name, i_data);
|
||||
|
||||
std::string s_data("/data/path/train.bin");
|
||||
std::string s_property_name("train_data_path");
|
||||
property_bag.AddProperty<std::string>(s_property_name, s_data);
|
||||
property_bag.AddProperty(s_property_name, s_data);
|
||||
|
||||
// Remove the tempoprary directory if it already exists.
|
||||
auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir");
|
||||
|
|
|
|||
|
|
@ -156,9 +156,9 @@ void InitSyntheticDataLoader(
|
|||
std::vector<int64_t> input1_shape{params.train_batch_size, 784};
|
||||
std::vector<int64_t> target_shape{params.train_batch_size};
|
||||
for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) {
|
||||
auto sample = std::make_unique<onnxruntime::training::test::training_api::SyntheticSampleBatch>();
|
||||
sample->AddFloatInput(input1_shape);
|
||||
sample->AddInt32Input(target_shape, 0, 1);
|
||||
auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch();
|
||||
sample.AddFloatInput(input1_shape);
|
||||
sample.AddInt32Input(target_shape, 0, 1);
|
||||
data_loader.AddSyntheticSampleBatch(std::move(sample));
|
||||
}
|
||||
} else if (params.synthetic_input_type == "S") {
|
||||
|
|
@ -167,10 +167,10 @@ void InitSyntheticDataLoader(
|
|||
std::vector<int64_t> attention_mask_shape{params.train_batch_size, sequence_length};
|
||||
std::vector<int64_t> target_shape{params.train_batch_size};
|
||||
for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) {
|
||||
auto sample = std::make_unique<onnxruntime::training::test::training_api::SyntheticSampleBatch>();
|
||||
sample->AddInt64Input(input_ids_shape, 0, 250002 - 1);
|
||||
sample->AddInt64Input(attention_mask_shape, 0, 1);
|
||||
sample->AddInt32Input(target_shape, 0, 1);
|
||||
auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch();
|
||||
sample.AddInt64Input(input_ids_shape, 0, 250002 - 1);
|
||||
sample.AddInt64Input(attention_mask_shape, 0, 1);
|
||||
sample.AddInt32Input(target_shape, 0, 1);
|
||||
data_loader.AddSyntheticSampleBatch(std::move(sample));
|
||||
}
|
||||
} else if (params.synthetic_input_type == "U") {
|
||||
|
|
@ -180,11 +180,11 @@ void InitSyntheticDataLoader(
|
|||
std::vector<int64_t> target1_shape{params.train_batch_size};
|
||||
std::vector<int64_t> target2_shape{params.train_batch_size, 81};
|
||||
for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) {
|
||||
auto sample = std::make_unique<onnxruntime::training::test::training_api::SyntheticSampleBatch>();
|
||||
sample->AddInt64Input(input_ids_shape, 0, 250002 - 1);
|
||||
sample->AddInt64Input(attention_mask_shape, 0, 1);
|
||||
sample->AddInt32Input(target1_shape, 0, 1);
|
||||
sample->AddInt32Input(target2_shape, 0, 1);
|
||||
auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch();
|
||||
sample.AddInt64Input(input_ids_shape, 0, 250002 - 1);
|
||||
sample.AddInt64Input(attention_mask_shape, 0, 1);
|
||||
sample.AddInt32Input(target1_shape, 0, 1);
|
||||
sample.AddInt32Input(target2_shape, 0, 1);
|
||||
data_loader.AddSyntheticSampleBatch(std::move(sample));
|
||||
}
|
||||
} else if (params.synthetic_input_type == "R") {
|
||||
|
|
@ -193,10 +193,25 @@ void InitSyntheticDataLoader(
|
|||
std::vector<int64_t> attention_mask_shape{params.train_batch_size, sequence_length};
|
||||
std::vector<int64_t> labels_shape{params.train_batch_size, 81};
|
||||
for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) {
|
||||
auto sample = std::make_unique<onnxruntime::training::test::training_api::SyntheticSampleBatch>();
|
||||
sample->AddInt64Input(input_ids_shape, 0, 250002 - 1);
|
||||
sample->AddInt64Input(attention_mask_shape, 0, 1);
|
||||
sample->AddInt32Input(labels_shape, 0, 1);
|
||||
auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch();
|
||||
sample.AddInt64Input(input_ids_shape, 0, 250002 - 1);
|
||||
sample.AddInt64Input(attention_mask_shape, 0, 1);
|
||||
sample.AddInt32Input(labels_shape, 0, 1);
|
||||
data_loader.AddSyntheticSampleBatch(std::move(sample));
|
||||
}
|
||||
} else if (params.synthetic_input_type == "C") {
|
||||
int64_t section = 16;
|
||||
int64_t sequence_length = 128;
|
||||
std::vector<int64_t> input_ids_shape{params.train_batch_size, section, sequence_length};
|
||||
std::vector<int64_t> mask_clss_shape{params.train_batch_size, section};
|
||||
std::vector<int64_t> attention_mask_shape{params.train_batch_size, section, sequence_length};
|
||||
std::vector<int64_t> labels_shape{params.train_batch_size};
|
||||
for (int64_t i = 0; i < num_of_batches_per_epoch; ++i) {
|
||||
auto sample = onnxruntime::training::test::training_api::SyntheticSampleBatch();
|
||||
sample.AddInt64Input(input_ids_shape, 0, 250002 - 1);
|
||||
sample.AddBoolInput(mask_clss_shape);
|
||||
sample.AddInt64Input(attention_mask_shape, 0, 1);
|
||||
sample.AddInt32Input(labels_shape, 0, 1);
|
||||
data_loader.AddSyntheticSampleBatch(std::move(sample));
|
||||
}
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/inlined_containers.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/logging/sinks/clog_sink.h"
|
||||
#include "core/common/path.h"
|
||||
|
|
@ -46,14 +47,14 @@ Status CreateTensorProtosFromOrtValues(
|
|||
const DataTransferManager& data_transfer_manager,
|
||||
std::vector<ONNX_NAMESPACE::TensorProto>& saved_tensor_protos) {
|
||||
// Order the tensors by name.
|
||||
std::vector<std::string> ordered_tensor_names{};
|
||||
InlinedVector<std::string> ordered_tensor_names{};
|
||||
ordered_tensor_names.reserve(name_to_ort_value.size());
|
||||
std::transform(name_to_ort_value.begin(), name_to_ort_value.end(), std::back_inserter(ordered_tensor_names),
|
||||
[](const NameMLValMap::value_type& v) { return v.first; });
|
||||
std::sort(ordered_tensor_names.begin(), ordered_tensor_names.end());
|
||||
|
||||
// Copy the tensor data and create TensorProto storing the data.
|
||||
std::vector<char> tensor_data_buffer{};
|
||||
InlinedVector<char> tensor_data_buffer{};
|
||||
static const OrtMemoryInfo cpu_alloc_info{onnxruntime::CPU, OrtDeviceAllocator};
|
||||
|
||||
saved_tensor_protos.reserve(ordered_tensor_names.size());
|
||||
|
|
@ -184,7 +185,7 @@ Status OrtSaveInternal(
|
|||
// Make sure name unique across trainable and non-trainable lists.
|
||||
std::set<std::string> trainable_unique_names;
|
||||
std::set<std::string> non_trainable_unique_names;
|
||||
std::vector<std::string> inter_sec;
|
||||
InlinedVector<std::string> inter_sec;
|
||||
auto check_unique = [](const std::vector<ONNX_NAMESPACE::TensorProto>& tensor_protos,
|
||||
std::set<std::string>& unique_names) {
|
||||
for (auto& tensor_proto : tensor_protos) {
|
||||
|
|
@ -231,7 +232,7 @@ Status OrtSaveModuleStatesInternal(ModuleCheckpointState& module_state,
|
|||
ORT_ENFORCE(module_state.train_session_data_transfer_mgr,
|
||||
"module checkpoint state has null train_session_data_transfer_mgr.");
|
||||
|
||||
std::unordered_map<PathString, std::unordered_map<std::string, OrtValue>>
|
||||
InlinedHashMap<PathString, std::unordered_map<std::string, OrtValue>>
|
||||
parameter_ort_values;
|
||||
for (auto it = param_states.begin(); it != param_states.end(); ++it) {
|
||||
if (it->second->RequiresGrad()) {
|
||||
|
|
@ -277,7 +278,7 @@ Status OrtSaveOptimizerStatesInternal(OptimizerCheckpointState& optimizer_state,
|
|||
|
||||
// Re-organize optimizer_state_ort_values mapping
|
||||
// Firstly indexed by momentum names; Secondly indexed by parameter names.
|
||||
std::unordered_map<std::string, std::unordered_map<std::string, OrtValue>> optimizer_state_ort_values;
|
||||
InlinedHashMap<std::string, std::unordered_map<std::string, OrtValue>> optimizer_state_ort_values;
|
||||
for (const std::pair<std::string, ParameterOptimizerState>&
|
||||
param_named_optimizer_state : group_optimizer_state_ptr->param_named_optimizer_states) {
|
||||
const std::string& param_name = param_named_optimizer_state.first;
|
||||
|
|
@ -319,8 +320,8 @@ Status OrtSaveOptimizerStatesInternal(OptimizerCheckpointState& optimizer_state,
|
|||
|
||||
// Storing group-wise properties.
|
||||
PropertyBag properties;
|
||||
properties.AddProperty<float>(builtin_lr_property_name, group_optimizer_state_ptr->initial_lr);
|
||||
properties.AddProperty<int64_t>(builtin_step_property_name, group_optimizer_state_ptr->step);
|
||||
properties.AddProperty(builtin_lr_property_name, group_optimizer_state_ptr->initial_lr);
|
||||
properties.AddProperty(builtin_step_property_name, group_optimizer_state_ptr->step);
|
||||
std::vector<ONNX_NAMESPACE::TensorProto> group_wise_properties_tensor_protos;
|
||||
properties.ToTensorProtos(group_wise_properties_tensor_protos);
|
||||
|
||||
|
|
@ -363,7 +364,7 @@ Status OrtSaveInternal(
|
|||
Status OrtLoadModuleStatesInternal(
|
||||
const PathString& parameter_folder_path, ModuleCheckpointState& module_state) {
|
||||
// Find parameter files.
|
||||
std::vector<std::pair<PathString, bool>> param_filenames;
|
||||
InlinedVector<std::pair<PathString, bool>> param_filenames;
|
||||
FilterFilesFromDirectory(
|
||||
parameter_folder_path,
|
||||
[¶m_filenames](const PathChar* filename) -> bool {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "onnx/defs/tensor_proto_util.h"
|
||||
#include "core/common/inlined_containers.h"
|
||||
#include "core/platform/path_lib.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
|
@ -11,8 +12,12 @@ namespace onnxruntime {
|
|||
namespace training {
|
||||
namespace api {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
TypedCheckpointProperty<T>::TypedCheckpointProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
|
||||
void ParsePropertyFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto,
|
||||
std::string& name,
|
||||
PropertyDataType& value) {
|
||||
std::vector<int64_t> tensor_shape_vec = utils::GetTensorShapeFromTensorProto(tensor_proto);
|
||||
int64_t expected_num_elements = 1;
|
||||
for (auto& d : tensor_shape_vec) {
|
||||
|
|
@ -20,45 +25,13 @@ TypedCheckpointProperty<T>::TypedCheckpointProperty(const ONNX_NAMESPACE::Tensor
|
|||
}
|
||||
ORT_ENFORCE(expected_num_elements == 1, "Only scalar value support for checkpoint property.");
|
||||
Path model_path;
|
||||
std::vector<T> data_vector(1);
|
||||
InlinedVector<T> data_vector(1);
|
||||
T* p = data_vector.data();
|
||||
ORT_THROW_IF_ERROR(utils::UnpackTensor<T>(tensor_proto, model_path, p, expected_num_elements));
|
||||
prop_name_ = tensor_proto.name();
|
||||
prop_value_ = data_vector[0];
|
||||
name = tensor_proto.name();
|
||||
value = data_vector[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ONNX_NAMESPACE::TensorProto TypedCheckpointProperty<T>::ToTensorProto() {
|
||||
auto t_proto = ONNX_NAMESPACE::ToTensor<T>(prop_value_);
|
||||
t_proto.set_name(prop_name_);
|
||||
return t_proto;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::shared_ptr<CheckpointProperty> CreateCheckpointPropertyFromTensorProto(
|
||||
const ONNX_NAMESPACE::TensorProto& tensor_proto) {
|
||||
auto data_type = tensor_proto.data_type();
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto::FLOAT: {
|
||||
return std::static_pointer_cast<CheckpointProperty>(
|
||||
std::make_shared<TypedCheckpointProperty<float>>(tensor_proto));
|
||||
break;
|
||||
}
|
||||
case ONNX_NAMESPACE::TensorProto::STRING: {
|
||||
return std::static_pointer_cast<CheckpointProperty>(
|
||||
std::make_shared<TypedCheckpointProperty<std::string>>(tensor_proto));
|
||||
break;
|
||||
}
|
||||
case ONNX_NAMESPACE::TensorProto::INT64: {
|
||||
return std::static_pointer_cast<CheckpointProperty>(
|
||||
std::make_shared<TypedCheckpointProperty<int64_t>>(tensor_proto));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ORT_THROW("Unsupported input data type of ", data_type);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void PropertyBag::AddProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
|
||||
|
|
@ -69,7 +42,27 @@ void PropertyBag::AddProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
|
|||
ORT_THROW("Failed to add property from tensorproto: float, int64_t and std::string data types supported only.");
|
||||
}
|
||||
|
||||
named_properties_.insert({tensor_proto.name(), CreateCheckpointPropertyFromTensorProto(tensor_proto)});
|
||||
auto data_type = tensor_proto.data_type();
|
||||
std::string prop_name;
|
||||
PropertyDataType prop_value;
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto::FLOAT: {
|
||||
ParsePropertyFromTensorProto<float>(tensor_proto, prop_name, prop_value);
|
||||
break;
|
||||
}
|
||||
case ONNX_NAMESPACE::TensorProto::STRING: {
|
||||
ParsePropertyFromTensorProto<std::string>(tensor_proto, prop_name, prop_value);
|
||||
break;
|
||||
}
|
||||
case ONNX_NAMESPACE::TensorProto::INT64: {
|
||||
ParsePropertyFromTensorProto<int64_t>(tensor_proto, prop_name, prop_value);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ORT_THROW("Unsupported input data type of ", data_type);
|
||||
}
|
||||
|
||||
named_properties_.insert({prop_name, prop_value});
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
|
|
|
|||
|
|
@ -20,10 +20,10 @@
|
|||
* ii. optimizer state:
|
||||
* a instance of data class `OptimizerCheckpointState` managed along with Optimizer class,
|
||||
* iii. user defined training properties, for example 'epoch', 'best_score':
|
||||
* a instance of data class `PropertyBag` managed along with CheckpointProperty classes.
|
||||
* a instance of data class `PropertyBag`.
|
||||
*
|
||||
* In terms of class dependencies, Checkpoint implementations are dependent on (and on top of)
|
||||
* Parameter/Module/Optimizer/CheckpointProperty, NOT vice versa.
|
||||
* Parameter/Module/Optimizer, NOT vice versa.
|
||||
*
|
||||
* 2. A directory of files:
|
||||
* checkpoint/
|
||||
|
|
|
|||
|
|
@ -3,65 +3,18 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <variant>
|
||||
|
||||
#include "core/common/inlined_containers.h"
|
||||
#include "onnx/defs/tensor_proto_util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
namespace api {
|
||||
|
||||
template <typename T>
|
||||
struct TypedCheckpointProperty;
|
||||
|
||||
/**
|
||||
* @brief Base class for user defined checkpoint property.
|
||||
*/
|
||||
struct CheckpointProperty {
|
||||
public:
|
||||
CheckpointProperty() {}
|
||||
|
||||
CheckpointProperty(const std::string& prop_name)
|
||||
: prop_name_(prop_name) {
|
||||
}
|
||||
|
||||
virtual ~CheckpointProperty() {}
|
||||
virtual ONNX_NAMESPACE::TensorProto ToTensorProto() = 0;
|
||||
|
||||
std::string GetName() const {
|
||||
return prop_name_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T GetData() {
|
||||
auto ptr = dynamic_cast<TypedCheckpointProperty<T>*>(this);
|
||||
ORT_ENFORCE(ptr);
|
||||
return ptr->GetData();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::string prop_name_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief User defined checkpoint property.
|
||||
*/
|
||||
template <typename T>
|
||||
struct TypedCheckpointProperty : public CheckpointProperty {
|
||||
public:
|
||||
TypedCheckpointProperty(const std::string& prop_name, const T& prop_value)
|
||||
: CheckpointProperty(prop_name), prop_value_(prop_value) {
|
||||
}
|
||||
TypedCheckpointProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto);
|
||||
|
||||
ONNX_NAMESPACE::TensorProto ToTensorProto() override;
|
||||
|
||||
T GetData() const {
|
||||
return prop_value_;
|
||||
}
|
||||
|
||||
private:
|
||||
T prop_value_;
|
||||
};
|
||||
using PropertyDataType = std::variant<int64_t, float, std::string>;
|
||||
|
||||
/**
|
||||
* @brief Collection of user defined properties.
|
||||
|
|
@ -69,33 +22,41 @@ struct TypedCheckpointProperty : public CheckpointProperty {
|
|||
*/
|
||||
struct PropertyBag {
|
||||
public:
|
||||
PropertyBag() {}
|
||||
PropertyBag() = default;
|
||||
|
||||
template <typename T>
|
||||
void AddProperty(std::string name, T val) {
|
||||
static_assert(onnxruntime::training::api::PropertyBag::template IsSupportedDataType<T>(),
|
||||
"Failed to add property: float, int64_t and std::string data types supported only.");
|
||||
void AddProperty(const std::string& name, const PropertyDataType& val) {
|
||||
ORT_ENFORCE(named_properties_.find(name) == named_properties_.end(),
|
||||
"Duplicated property named ", name);
|
||||
|
||||
named_properties_.insert({name, std::make_shared<TypedCheckpointProperty<T>>(name, val)});
|
||||
named_properties_.insert({name, val});
|
||||
}
|
||||
|
||||
void AddProperty(const ONNX_NAMESPACE::TensorProto& tensor_proto);
|
||||
|
||||
template <typename T>
|
||||
T GetProperty(const std::string& name) const {
|
||||
static_assert(onnxruntime::training::api::PropertyBag::template IsSupportedDataType<T>(),
|
||||
"Failed to get property: float, int64_t and std::string data types supported only.");
|
||||
|
||||
auto it = named_properties_.find(name);
|
||||
ORT_ENFORCE(it != named_properties_.end(), "No property named ", name);
|
||||
return it->second->GetData<T>();
|
||||
|
||||
const T* tval = std::get_if<T>(&it->second);
|
||||
ORT_ENFORCE(tval, "Fail to get the property value using specified type.");
|
||||
return *tval;
|
||||
}
|
||||
|
||||
void ToTensorProtos(std::vector<ONNX_NAMESPACE::TensorProto>& properties_tensor_protos) const {
|
||||
for (auto it = named_properties_.begin(); it != named_properties_.end(); ++it) {
|
||||
properties_tensor_protos.emplace_back((it->second)->ToTensorProto());
|
||||
onnx::TensorProto t_proto;
|
||||
if (const float* fval = std::get_if<float>(&it->second); fval != nullptr) {
|
||||
t_proto = ONNX_NAMESPACE::ToTensor<float>(*fval);
|
||||
} else if (const int64_t* ival = std::get_if<int64_t>(&it->second); ival != nullptr) {
|
||||
t_proto = ONNX_NAMESPACE::ToTensor<int64_t>(*ival);
|
||||
} else if (const std::string* sval = std::get_if<std::string>(&it->second); sval != nullptr) {
|
||||
t_proto = ONNX_NAMESPACE::ToTensor<std::string>(*sval);
|
||||
} else {
|
||||
ORT_THROW("Should not go there, unexpected data_type for prop value.");
|
||||
}
|
||||
t_proto.set_name(it->first);
|
||||
properties_tensor_protos.emplace_back(t_proto);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -104,7 +65,7 @@ struct PropertyBag {
|
|||
}
|
||||
|
||||
private:
|
||||
const std::vector<int32_t> supported_data_types{
|
||||
const InlinedVector<int32_t> supported_data_types{
|
||||
ONNX_NAMESPACE::TensorProto::FLOAT,
|
||||
ONNX_NAMESPACE::TensorProto::INT64,
|
||||
ONNX_NAMESPACE::TensorProto::STRING};
|
||||
|
|
@ -113,13 +74,7 @@ struct PropertyBag {
|
|||
return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) != supported_data_types.end();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool IsSupportedDataType() {
|
||||
return (std::is_same<T, float>::value || std::is_same<T, int64_t>::value ||
|
||||
std::is_same<T, std::string>::value);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::shared_ptr<CheckpointProperty>> named_properties_;
|
||||
InlinedHashMap<std::string, PropertyDataType> named_properties_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
|
|
|
|||
Loading…
Reference in a new issue